peachy-go/csexp/match/match.go

191 lines
4.5 KiB
Go
Raw Normal View History

2024-10-30 21:22:55 +00:00
package match
import (
"bytes"
"errors"
"fmt"
"io"
"strconv"
"strings"
"git.codemonkeysoftware.net/b/gigaparsec"
pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes"
"git.codemonkeysoftware.net/b/peachy-go/csexp"
)
type astNode interface {
isAstNode()
}
type literalNode string
func (l literalNode) isAstNode() {}
type stringFormatNode struct{}
func (s stringFormatNode) isAstNode() {}
type listNode []astNode
func (l listNode) isAstNode() {}
var parseLength = gigaparsec.Map(pbytes.Regexp(`0|[1-9]\d*`), func(s string) uint64 {
n, err := strconv.ParseUint(s, 10, 64)
if err != nil {
panic(err)
}
return n
})
func acceptN(n uint64) gigaparsec.Parser[byte, []byte] {
expected := fmt.Sprintf("%d bytes", n)
return func(s gigaparsec.State[byte]) (gigaparsec.Result[byte, []byte], error) {
if n == 0 {
return gigaparsec.Succeed[byte, []byte](true, nil, s, gigaparsec.MessageOK(s.Pos())), nil
}
dst := make([]byte, n)
_, next, err := s.Read(dst)
if errors.Is(err, io.EOF) {
return gigaparsec.Fail[byte, []byte](false, gigaparsec.MessageEnd(s.Pos(), expected)), nil
}
if err != nil {
return gigaparsec.Result[byte, []byte]{}, err
}
return gigaparsec.Succeed(true, dst, next, gigaparsec.MessageOK(s.Pos())), nil
}
}
var parseAtom = gigaparsec.Map(gigaparsec.Bind2(
parseLength,
gigaparsec.Pipe[byte, byte, uint64](gigaparsec.Match[byte](':')),
acceptN,
), func(s []byte) astNode { return literalNode(s) })
var parseFormat = gigaparsec.Map(pbytes.MatchString("%s"),
func(s string) astNode { return stringFormatNode{} })
func parseSexp() gigaparsec.Parser[byte, astNode] {
return gigaparsec.Choose(parseAtom, parseList, parseFormat)
}
func parseRestOfList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, listNode], error) {
return gigaparsec.Choose(
gigaparsec.Map(gigaparsec.Match[byte](')'), func(byte) listNode { return nil }),
gigaparsec.Bind(
parseSexp(),
func(s astNode) gigaparsec.Parser[byte, listNode] {
return gigaparsec.Map(
parseRestOfList,
func(rest listNode) listNode { return append(listNode{s}, rest...) },
)
},
),
)(input)
}
func parseList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, astNode], error) {
return gigaparsec.Map(gigaparsec.Seq2(
gigaparsec.Match[byte]('('),
parseRestOfList,
func(_ byte, rest listNode) listNode { return rest },
), func(sexps listNode) astNode { return sexps })(input)
}
type MismatchError struct {
Got csexp.Sexp
Expected string
}
func (e MismatchError) Error() string {
return fmt.Sprintf("sexp mismatch: expected %s, got %v", e.Expected, e.Got)
}
type Matcher struct {
pattern astNode
}
func Compile(pattern string) (*Matcher, error) {
ast, err := gigaparsec.Run(parseSexp(), strings.NewReader(pattern))
if err != nil {
return nil, fmt.Errorf("invalid pattern: %w", err)
}
return &Matcher{pattern: ast}, nil
}
func MustCompile(pattern string) *Matcher {
m, err := Compile(pattern)
if err != nil {
panic(err)
}
return m
}
func (m Matcher) Match(sexp csexp.Sexp, dst ...any) error {
dst, err := match(sexp, m.pattern, dst)
if err != nil {
return err
}
if len(dst) > 0 {
return errors.New("too many operands")
}
return nil
}
func match(sexp csexp.Sexp, ast astNode, dst []any) ([]any, error) {
mismatch := func(expected string) MismatchError {
return MismatchError{Expected: expected, Got: sexp}
}
if sexp == nil {
return dst, mismatch("non-nil input")
}
switch node := ast.(type) {
case literalNode:
if atom, ok := sexp.(csexp.Atom); ok {
if bytes.Equal(atom, []byte(node)) {
return dst, nil
}
}
return dst, MismatchError{Got: sexp, Expected: csexp.Atom(node).String()}
case stringFormatNode:
if len(dst) == 0 {
return dst, errors.New("too few operands")
}
atom, isAtom := sexp.(csexp.Atom)
if !isAtom {
return dst, mismatch("atom")
}
thisDst := dst[0]
dst = dst[1:]
switch ptr := thisDst.(type) {
case *string:
*ptr = string(atom)
case *[]byte:
*ptr = bytes.Clone([]byte(atom))
default:
return dst, fmt.Errorf("can't scan type %T", thisDst)
}
return dst, nil
case listNode:
list, ok := sexp.(csexp.List)
if !ok {
return dst, mismatch("list")
}
var err error
var i int
for i = range node {
if i == len(list) {
return dst, mismatch(fmt.Sprintf("list of %d items", len(node)))
}
dst, err = match(list[i], node[i], dst)
if err != nil {
return dst, err
}
}
if i < len(list)-1 {
return dst, mismatch(fmt.Sprintf("list of %d items", len(node)))
}
return dst, nil
}
panic("unreachable")
}