Added a simple sexp matcher

This commit is contained in:
Brandon Dyck 2024-10-30 15:22:55 -06:00
parent 478db3ad3c
commit 84f927b918
3 changed files with 286 additions and 2 deletions

190
csexp/match/match.go Normal file
View File

@ -0,0 +1,190 @@
package match
import (
"bytes"
"errors"
"fmt"
"io"
"strconv"
"strings"
"git.codemonkeysoftware.net/b/gigaparsec"
pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes"
"git.codemonkeysoftware.net/b/peachy-go/csexp"
)
type astNode interface {
isAstNode()
}
type literalNode string
func (l literalNode) isAstNode() {}
type stringFormatNode struct{}
func (s stringFormatNode) isAstNode() {}
type listNode []astNode
func (l listNode) isAstNode() {}
var parseLength = gigaparsec.Map(pbytes.Regexp(`0|[1-9]\d*`), func(s string) uint64 {
n, err := strconv.ParseUint(s, 10, 64)
if err != nil {
panic(err)
}
return n
})
func acceptN(n uint64) gigaparsec.Parser[byte, []byte] {
expected := fmt.Sprintf("%d bytes", n)
return func(s gigaparsec.State[byte]) (gigaparsec.Result[byte, []byte], error) {
if n == 0 {
return gigaparsec.Succeed[byte, []byte](true, nil, s, gigaparsec.MessageOK(s.Pos())), nil
}
dst := make([]byte, n)
_, next, err := s.Read(dst)
if errors.Is(err, io.EOF) {
return gigaparsec.Fail[byte, []byte](false, gigaparsec.MessageEnd(s.Pos(), expected)), nil
}
if err != nil {
return gigaparsec.Result[byte, []byte]{}, err
}
return gigaparsec.Succeed(true, dst, next, gigaparsec.MessageOK(s.Pos())), nil
}
}
var parseAtom = gigaparsec.Map(gigaparsec.Bind2(
parseLength,
gigaparsec.Pipe[byte, byte, uint64](gigaparsec.Match[byte](':')),
acceptN,
), func(s []byte) astNode { return literalNode(s) })
var parseFormat = gigaparsec.Map(pbytes.MatchString("%s"),
func(s string) astNode { return stringFormatNode{} })
func parseSexp() gigaparsec.Parser[byte, astNode] {
return gigaparsec.Choose(parseAtom, parseList, parseFormat)
}
func parseRestOfList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, listNode], error) {
return gigaparsec.Choose(
gigaparsec.Map(gigaparsec.Match[byte](')'), func(byte) listNode { return nil }),
gigaparsec.Bind(
parseSexp(),
func(s astNode) gigaparsec.Parser[byte, listNode] {
return gigaparsec.Map(
parseRestOfList,
func(rest listNode) listNode { return append(listNode{s}, rest...) },
)
},
),
)(input)
}
func parseList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, astNode], error) {
return gigaparsec.Map(gigaparsec.Seq2(
gigaparsec.Match[byte]('('),
parseRestOfList,
func(_ byte, rest listNode) listNode { return rest },
), func(sexps listNode) astNode { return sexps })(input)
}
type MismatchError struct {
Got csexp.Sexp
Expected string
}
func (e MismatchError) Error() string {
return fmt.Sprintf("sexp mismatch: expected %s, got %v", e.Expected, e.Got)
}
type Matcher struct {
pattern astNode
}
func Compile(pattern string) (*Matcher, error) {
ast, err := gigaparsec.Run(parseSexp(), strings.NewReader(pattern))
if err != nil {
return nil, fmt.Errorf("invalid pattern: %w", err)
}
return &Matcher{pattern: ast}, nil
}
func MustCompile(pattern string) *Matcher {
m, err := Compile(pattern)
if err != nil {
panic(err)
}
return m
}
func (m Matcher) Match(sexp csexp.Sexp, dst ...any) error {
dst, err := match(sexp, m.pattern, dst)
if err != nil {
return err
}
if len(dst) > 0 {
return errors.New("too many operands")
}
return nil
}
func match(sexp csexp.Sexp, ast astNode, dst []any) ([]any, error) {
mismatch := func(expected string) MismatchError {
return MismatchError{Expected: expected, Got: sexp}
}
if sexp == nil {
return dst, mismatch("non-nil input")
}
switch node := ast.(type) {
case literalNode:
if atom, ok := sexp.(csexp.Atom); ok {
if bytes.Equal(atom, []byte(node)) {
return dst, nil
}
}
return dst, MismatchError{Got: sexp, Expected: csexp.Atom(node).String()}
case stringFormatNode:
if len(dst) == 0 {
return dst, errors.New("too few operands")
}
atom, isAtom := sexp.(csexp.Atom)
if !isAtom {
return dst, mismatch("atom")
}
thisDst := dst[0]
dst = dst[1:]
switch ptr := thisDst.(type) {
case *string:
*ptr = string(atom)
case *[]byte:
*ptr = bytes.Clone([]byte(atom))
default:
return dst, fmt.Errorf("can't scan type %T", thisDst)
}
return dst, nil
case listNode:
list, ok := sexp.(csexp.List)
if !ok {
return dst, mismatch("list")
}
var err error
var i int
for i = range node {
if i == len(list) {
return dst, mismatch(fmt.Sprintf("list of %d items", len(node)))
}
dst, err = match(list[i], node[i], dst)
if err != nil {
return dst, err
}
}
if i < len(list)-1 {
return dst, mismatch(fmt.Sprintf("list of %d items", len(node)))
}
return dst, nil
}
panic("unreachable")
}

95
csexp/match/match_test.go Normal file
View File

@ -0,0 +1,95 @@
package match_test
import (
"testing"
"git.codemonkeysoftware.net/b/peachy-go/csexp"
"git.codemonkeysoftware.net/b/peachy-go/csexp/match"
"github.com/shoenig/test"
"github.com/shoenig/test/must"
)
func TestMatchLiteral(t *testing.T) {
t.Run("same", func(t *testing.T) {
atom := csexp.Atom("hello")
err := match.MustCompile(atom.String()).Match(atom)
must.NoError(t, err)
})
t.Run("different", func(t *testing.T) {
atom := csexp.Atom("hello")
pattern := csexp.Atom("goodbye").String()
err := match.MustCompile(pattern).Match(atom)
var mismatch match.MismatchError
must.ErrorAs(t, err, &mismatch)
test.Equal(t, mismatch.Got, csexp.Sexp(atom))
test.StrContains(t, mismatch.Expected, pattern)
})
}
func TestMatchList(t *testing.T) {
t.Run("both empty", func(t *testing.T) {
err := match.MustCompile("()").Match(csexp.List(nil))
must.NoError(t, err)
})
t.Run("same length", func(t *testing.T) {
list := csexp.List{csexp.Atom("a"), csexp.Atom("b"), csexp.Atom("c")}
err := match.MustCompile(list.String()).Match(list)
must.NoError(t, err)
})
t.Run("sexp too long", func(t *testing.T) {
list := csexp.List{csexp.Atom("a"), csexp.Atom("b"), csexp.Atom("c")}
pattern := list[0:2].String()
err := match.MustCompile(pattern).Match(list)
var mismatch match.MismatchError
must.ErrorAs(t, err, &mismatch)
test.Equal(t, mismatch.Got, csexp.Sexp(list))
test.StrContains(t, mismatch.Expected, "list of 2 items")
})
t.Run("sexp too short", func(t *testing.T) {
list := csexp.List{csexp.Atom("a"), csexp.Atom("b"), csexp.Atom("c")}
shortList := list[0:2]
err := match.MustCompile(list.String()).Match(shortList)
var mismatch match.MismatchError
must.ErrorAs(t, err, &mismatch)
test.Equal(t, mismatch.Got, csexp.Sexp(shortList))
test.StrContains(t, mismatch.Expected, "list of 3 items")
})
}
func TestExtract(t *testing.T) {
matcher := match.MustCompile("(3:abc%s(%s3:jkl)%s3:pqr%s)")
sexp, err := csexp.Parse([]byte("(3:abc3:def(3:ghi3:jkl)3:mno3:pqr3:stu)"))
must.NoError(t, err)
t.Run("correct number of operands", func(t *testing.T) {
var def, ghi, mno, stu string
err := matcher.Match(sexp, &def, &ghi, &mno, &stu)
must.NoError(t, err)
test.EqOp(t, "def", def)
test.EqOp(t, "ghi", ghi)
test.EqOp(t, "mno", mno)
test.EqOp(t, "stu", stu)
})
t.Run("too many operands", func(t *testing.T) {
var def, ghi, mno, stu, extra string
err := matcher.Match(sexp, &def, &ghi, &mno, &stu, &extra)
must.ErrorContains(t, err, "too many operands")
test.EqOp(t, "def", def)
test.EqOp(t, "ghi", ghi)
test.EqOp(t, "mno", mno)
test.EqOp(t, "stu", stu)
})
t.Run("too few operands", func(t *testing.T) {
var def, ghi, mno string
err := matcher.Match(sexp, &def, &ghi, &mno)
must.ErrorContains(t, err, "too few operands")
test.EqOp(t, "def", def)
test.EqOp(t, "ghi", ghi)
test.EqOp(t, "mno", mno)
})
}

View File

@ -10,7 +10,6 @@ import (
"git.codemonkeysoftware.net/b/gigaparsec"
pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes"
"git.codemonkeysoftware.net/b/gigaparsec/cursor"
"git.codemonkeysoftware.net/b/peachy-go/shortcircuit"
)
@ -152,7 +151,7 @@ func Parse(data []byte) (Sexp, error) {
gigaparsec.End[byte](),
func(s Sexp, _ struct{}) Sexp { return s },
)
result, err := gigaparsec.Run(parser, cursor.NewSlice(data))
result, err := gigaparsec.Run(parser, bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("csexp.Parse: %w", err)
}