package match import ( "bytes" "errors" "fmt" "io" "strconv" "strings" "git.codemonkeysoftware.net/b/gigaparsec" pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes" "git.codemonkeysoftware.net/b/peachy-go/csexp" ) type astNode interface { isAstNode() } type literalNode string func (l literalNode) isAstNode() {} type stringFormatNode struct{} func (s stringFormatNode) isAstNode() {} type listNode []astNode func (l listNode) isAstNode() {} var parseLength = gigaparsec.Map(pbytes.Regexp(`0|[1-9]\d*`), func(s string) uint64 { n, err := strconv.ParseUint(s, 10, 64) if err != nil { panic(err) } return n }) func acceptN(n uint64) gigaparsec.Parser[byte, []byte] { expected := fmt.Sprintf("%d bytes", n) return func(s gigaparsec.State[byte]) (gigaparsec.Result[byte, []byte], error) { if n == 0 { return gigaparsec.Succeed[byte, []byte](true, nil, s, gigaparsec.MessageOK(s.Pos())), nil } dst := make([]byte, n) _, next, err := s.Read(dst) if errors.Is(err, io.EOF) { return gigaparsec.Fail[byte, []byte](false, gigaparsec.MessageEnd(s.Pos(), expected)), nil } if err != nil { return gigaparsec.Result[byte, []byte]{}, err } return gigaparsec.Succeed(true, dst, next, gigaparsec.MessageOK(s.Pos())), nil } } var parseAtom = gigaparsec.Map(gigaparsec.Bind2( parseLength, gigaparsec.Pipe[byte, byte, uint64](gigaparsec.Match[byte](':')), acceptN, ), func(s []byte) astNode { return literalNode(s) }) var parseFormat = gigaparsec.Map(pbytes.MatchString("%s"), func(s string) astNode { return stringFormatNode{} }) func parseSexp() gigaparsec.Parser[byte, astNode] { return gigaparsec.Choose(parseAtom, parseList, parseFormat) } func parseRestOfList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, listNode], error) { return gigaparsec.Choose( gigaparsec.Map(gigaparsec.Match[byte](')'), func(byte) listNode { return nil }), gigaparsec.Bind( parseSexp(), func(s astNode) gigaparsec.Parser[byte, listNode] { return gigaparsec.Map( parseRestOfList, func(rest listNode) listNode { return append(listNode{s}, rest...) }, ) }, ), )(input) } func parseList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, astNode], error) { return gigaparsec.Map(gigaparsec.Seq2( gigaparsec.Match[byte]('('), parseRestOfList, func(_ byte, rest listNode) listNode { return rest }, ), func(sexps listNode) astNode { return sexps })(input) } type MismatchError struct { Got csexp.Sexp Expected string } func (e MismatchError) Error() string { return fmt.Sprintf("sexp mismatch: expected %s, got %v", e.Expected, e.Got) } type Matcher struct { pattern astNode } func Compile(pattern string) (*Matcher, error) { ast, err := gigaparsec.Run(parseSexp(), strings.NewReader(pattern)) if err != nil { return nil, fmt.Errorf("invalid pattern: %w", err) } return &Matcher{pattern: ast}, nil } func MustCompile(pattern string) *Matcher { m, err := Compile(pattern) if err != nil { panic(err) } return m } func (m Matcher) Match(sexp csexp.Sexp, dst ...any) error { dst, err := match(sexp, m.pattern, dst) if err != nil { return err } if len(dst) > 0 { return errors.New("too many operands") } return nil } func match(sexp csexp.Sexp, ast astNode, dst []any) ([]any, error) { mismatch := func(expected string) MismatchError { return MismatchError{Expected: expected, Got: sexp} } if sexp == nil { return dst, mismatch("non-nil input") } switch node := ast.(type) { case literalNode: if atom, ok := sexp.(csexp.Atom); ok { if bytes.Equal(atom, []byte(node)) { return dst, nil } } return dst, MismatchError{Got: sexp, Expected: csexp.Atom(node).String()} case stringFormatNode: if len(dst) == 0 { return dst, errors.New("too few operands") } atom, isAtom := sexp.(csexp.Atom) if !isAtom { return dst, mismatch("atom") } thisDst := dst[0] dst = dst[1:] switch ptr := thisDst.(type) { case *string: *ptr = string(atom) case *[]byte: *ptr = bytes.Clone([]byte(atom)) default: return dst, fmt.Errorf("can't scan type %T", thisDst) } return dst, nil case listNode: list, ok := sexp.(csexp.List) if !ok { return dst, mismatch("list") } var err error var i int for i = range node { if i == len(list) { return dst, mismatch(fmt.Sprintf("list of %d items", len(node))) } dst, err = match(list[i], node[i], dst) if err != nil { return dst, err } } if i < len(list)-1 { return dst, mismatch(fmt.Sprintf("list of %d items", len(node))) } return dst, nil } panic("unreachable") }