From 84f927b9182a97277edb807f029870940255d10b Mon Sep 17 00:00:00 2001 From: Brandon Dyck Date: Wed, 30 Oct 2024 15:22:55 -0600 Subject: [PATCH] Added a simple sexp matcher --- csexp/match/match.go | 190 ++++++++++++++++++++++++++++++++++++++ csexp/match/match_test.go | 95 +++++++++++++++++++ csexp/sexp.go | 3 +- 3 files changed, 286 insertions(+), 2 deletions(-) create mode 100644 csexp/match/match.go create mode 100644 csexp/match/match_test.go diff --git a/csexp/match/match.go b/csexp/match/match.go new file mode 100644 index 0000000..7991de8 --- /dev/null +++ b/csexp/match/match.go @@ -0,0 +1,190 @@ +package match + +import ( + "bytes" + "errors" + "fmt" + "io" + "strconv" + "strings" + + "git.codemonkeysoftware.net/b/gigaparsec" + pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes" + "git.codemonkeysoftware.net/b/peachy-go/csexp" +) + +type astNode interface { + isAstNode() +} + +type literalNode string + +func (l literalNode) isAstNode() {} + +type stringFormatNode struct{} + +func (s stringFormatNode) isAstNode() {} + +type listNode []astNode + +func (l listNode) isAstNode() {} + +var parseLength = gigaparsec.Map(pbytes.Regexp(`0|[1-9]\d*`), func(s string) uint64 { + n, err := strconv.ParseUint(s, 10, 64) + if err != nil { + panic(err) + } + return n +}) + +func acceptN(n uint64) gigaparsec.Parser[byte, []byte] { + expected := fmt.Sprintf("%d bytes", n) + return func(s gigaparsec.State[byte]) (gigaparsec.Result[byte, []byte], error) { + if n == 0 { + return gigaparsec.Succeed[byte, []byte](true, nil, s, gigaparsec.MessageOK(s.Pos())), nil + } + dst := make([]byte, n) + _, next, err := s.Read(dst) + if errors.Is(err, io.EOF) { + return gigaparsec.Fail[byte, []byte](false, gigaparsec.MessageEnd(s.Pos(), expected)), nil + } + if err != nil { + return gigaparsec.Result[byte, []byte]{}, err + } + return gigaparsec.Succeed(true, dst, next, gigaparsec.MessageOK(s.Pos())), nil + } +} + +var parseAtom = gigaparsec.Map(gigaparsec.Bind2( + parseLength, + gigaparsec.Pipe[byte, byte, uint64](gigaparsec.Match[byte](':')), + acceptN, +), func(s []byte) astNode { return literalNode(s) }) + +var parseFormat = gigaparsec.Map(pbytes.MatchString("%s"), + func(s string) astNode { return stringFormatNode{} }) + +func parseSexp() gigaparsec.Parser[byte, astNode] { + return gigaparsec.Choose(parseAtom, parseList, parseFormat) +} + +func parseRestOfList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, listNode], error) { + return gigaparsec.Choose( + gigaparsec.Map(gigaparsec.Match[byte](')'), func(byte) listNode { return nil }), + gigaparsec.Bind( + parseSexp(), + func(s astNode) gigaparsec.Parser[byte, listNode] { + return gigaparsec.Map( + parseRestOfList, + func(rest listNode) listNode { return append(listNode{s}, rest...) }, + ) + }, + ), + )(input) +} + +func parseList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, astNode], error) { + return gigaparsec.Map(gigaparsec.Seq2( + gigaparsec.Match[byte]('('), + parseRestOfList, + func(_ byte, rest listNode) listNode { return rest }, + ), func(sexps listNode) astNode { return sexps })(input) +} + +type MismatchError struct { + Got csexp.Sexp + Expected string +} + +func (e MismatchError) Error() string { + return fmt.Sprintf("sexp mismatch: expected %s, got %v", e.Expected, e.Got) +} + +type Matcher struct { + pattern astNode +} + +func Compile(pattern string) (*Matcher, error) { + ast, err := gigaparsec.Run(parseSexp(), strings.NewReader(pattern)) + if err != nil { + return nil, fmt.Errorf("invalid pattern: %w", err) + } + return &Matcher{pattern: ast}, nil +} + +func MustCompile(pattern string) *Matcher { + m, err := Compile(pattern) + if err != nil { + panic(err) + } + return m +} + +func (m Matcher) Match(sexp csexp.Sexp, dst ...any) error { + dst, err := match(sexp, m.pattern, dst) + if err != nil { + return err + } + if len(dst) > 0 { + return errors.New("too many operands") + } + return nil +} + +func match(sexp csexp.Sexp, ast astNode, dst []any) ([]any, error) { + mismatch := func(expected string) MismatchError { + return MismatchError{Expected: expected, Got: sexp} + } + if sexp == nil { + return dst, mismatch("non-nil input") + } + switch node := ast.(type) { + case literalNode: + if atom, ok := sexp.(csexp.Atom); ok { + if bytes.Equal(atom, []byte(node)) { + return dst, nil + } + } + return dst, MismatchError{Got: sexp, Expected: csexp.Atom(node).String()} + case stringFormatNode: + if len(dst) == 0 { + return dst, errors.New("too few operands") + } + atom, isAtom := sexp.(csexp.Atom) + if !isAtom { + return dst, mismatch("atom") + } + thisDst := dst[0] + dst = dst[1:] + switch ptr := thisDst.(type) { + case *string: + *ptr = string(atom) + case *[]byte: + *ptr = bytes.Clone([]byte(atom)) + default: + return dst, fmt.Errorf("can't scan type %T", thisDst) + } + return dst, nil + case listNode: + list, ok := sexp.(csexp.List) + if !ok { + return dst, mismatch("list") + } + var err error + var i int + for i = range node { + if i == len(list) { + return dst, mismatch(fmt.Sprintf("list of %d items", len(node))) + } + dst, err = match(list[i], node[i], dst) + if err != nil { + return dst, err + } + } + if i < len(list)-1 { + return dst, mismatch(fmt.Sprintf("list of %d items", len(node))) + } + return dst, nil + } + panic("unreachable") +} diff --git a/csexp/match/match_test.go b/csexp/match/match_test.go new file mode 100644 index 0000000..e5b6802 --- /dev/null +++ b/csexp/match/match_test.go @@ -0,0 +1,95 @@ +package match_test + +import ( + "testing" + + "git.codemonkeysoftware.net/b/peachy-go/csexp" + "git.codemonkeysoftware.net/b/peachy-go/csexp/match" + "github.com/shoenig/test" + "github.com/shoenig/test/must" +) + +func TestMatchLiteral(t *testing.T) { + t.Run("same", func(t *testing.T) { + atom := csexp.Atom("hello") + err := match.MustCompile(atom.String()).Match(atom) + must.NoError(t, err) + }) + t.Run("different", func(t *testing.T) { + atom := csexp.Atom("hello") + pattern := csexp.Atom("goodbye").String() + + err := match.MustCompile(pattern).Match(atom) + + var mismatch match.MismatchError + must.ErrorAs(t, err, &mismatch) + test.Equal(t, mismatch.Got, csexp.Sexp(atom)) + test.StrContains(t, mismatch.Expected, pattern) + }) +} + +func TestMatchList(t *testing.T) { + t.Run("both empty", func(t *testing.T) { + err := match.MustCompile("()").Match(csexp.List(nil)) + must.NoError(t, err) + }) + t.Run("same length", func(t *testing.T) { + list := csexp.List{csexp.Atom("a"), csexp.Atom("b"), csexp.Atom("c")} + err := match.MustCompile(list.String()).Match(list) + must.NoError(t, err) + }) + t.Run("sexp too long", func(t *testing.T) { + list := csexp.List{csexp.Atom("a"), csexp.Atom("b"), csexp.Atom("c")} + pattern := list[0:2].String() + + err := match.MustCompile(pattern).Match(list) + + var mismatch match.MismatchError + must.ErrorAs(t, err, &mismatch) + test.Equal(t, mismatch.Got, csexp.Sexp(list)) + test.StrContains(t, mismatch.Expected, "list of 2 items") + }) + t.Run("sexp too short", func(t *testing.T) { + list := csexp.List{csexp.Atom("a"), csexp.Atom("b"), csexp.Atom("c")} + shortList := list[0:2] + + err := match.MustCompile(list.String()).Match(shortList) + + var mismatch match.MismatchError + must.ErrorAs(t, err, &mismatch) + test.Equal(t, mismatch.Got, csexp.Sexp(shortList)) + test.StrContains(t, mismatch.Expected, "list of 3 items") + }) +} + +func TestExtract(t *testing.T) { + matcher := match.MustCompile("(3:abc%s(%s3:jkl)%s3:pqr%s)") + sexp, err := csexp.Parse([]byte("(3:abc3:def(3:ghi3:jkl)3:mno3:pqr3:stu)")) + must.NoError(t, err) + t.Run("correct number of operands", func(t *testing.T) { + var def, ghi, mno, stu string + err := matcher.Match(sexp, &def, &ghi, &mno, &stu) + must.NoError(t, err) + test.EqOp(t, "def", def) + test.EqOp(t, "ghi", ghi) + test.EqOp(t, "mno", mno) + test.EqOp(t, "stu", stu) + }) + t.Run("too many operands", func(t *testing.T) { + var def, ghi, mno, stu, extra string + err := matcher.Match(sexp, &def, &ghi, &mno, &stu, &extra) + must.ErrorContains(t, err, "too many operands") + test.EqOp(t, "def", def) + test.EqOp(t, "ghi", ghi) + test.EqOp(t, "mno", mno) + test.EqOp(t, "stu", stu) + }) + t.Run("too few operands", func(t *testing.T) { + var def, ghi, mno string + err := matcher.Match(sexp, &def, &ghi, &mno) + must.ErrorContains(t, err, "too few operands") + test.EqOp(t, "def", def) + test.EqOp(t, "ghi", ghi) + test.EqOp(t, "mno", mno) + }) +} diff --git a/csexp/sexp.go b/csexp/sexp.go index 69791cc..bb4a2ce 100644 --- a/csexp/sexp.go +++ b/csexp/sexp.go @@ -10,7 +10,6 @@ import ( "git.codemonkeysoftware.net/b/gigaparsec" pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes" - "git.codemonkeysoftware.net/b/gigaparsec/cursor" "git.codemonkeysoftware.net/b/peachy-go/shortcircuit" ) @@ -152,7 +151,7 @@ func Parse(data []byte) (Sexp, error) { gigaparsec.End[byte](), func(s Sexp, _ struct{}) Sexp { return s }, ) - result, err := gigaparsec.Run(parser, cursor.NewSlice(data)) + result, err := gigaparsec.Run(parser, bytes.NewReader(data)) if err != nil { return nil, fmt.Errorf("csexp.Parse: %w", err) }