191 lines
4.5 KiB
Go
191 lines
4.5 KiB
Go
package match
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"git.codemonkeysoftware.net/b/gigaparsec"
|
|
pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes"
|
|
"git.codemonkeysoftware.net/b/peachy-go/csexp"
|
|
)
|
|
|
|
type astNode interface {
|
|
isAstNode()
|
|
}
|
|
|
|
type literalNode string
|
|
|
|
func (l literalNode) isAstNode() {}
|
|
|
|
type stringFormatNode struct{}
|
|
|
|
func (s stringFormatNode) isAstNode() {}
|
|
|
|
type listNode []astNode
|
|
|
|
func (l listNode) isAstNode() {}
|
|
|
|
var parseLength = gigaparsec.Map(pbytes.Regexp(`0|[1-9]\d*`), func(s string) uint64 {
|
|
n, err := strconv.ParseUint(s, 10, 64)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return n
|
|
})
|
|
|
|
func acceptN(n uint64) gigaparsec.Parser[byte, []byte] {
|
|
expected := fmt.Sprintf("%d bytes", n)
|
|
return func(s gigaparsec.State[byte]) (gigaparsec.Result[byte, []byte], error) {
|
|
if n == 0 {
|
|
return gigaparsec.Succeed[byte, []byte](true, nil, s, gigaparsec.MessageOK(s.Pos())), nil
|
|
}
|
|
dst := make([]byte, n)
|
|
_, next, err := s.Read(dst)
|
|
if errors.Is(err, io.EOF) {
|
|
return gigaparsec.Fail[byte, []byte](false, gigaparsec.MessageEnd(s.Pos(), expected)), nil
|
|
}
|
|
if err != nil {
|
|
return gigaparsec.Result[byte, []byte]{}, err
|
|
}
|
|
return gigaparsec.Succeed(true, dst, next, gigaparsec.MessageOK(s.Pos())), nil
|
|
}
|
|
}
|
|
|
|
var parseAtom = gigaparsec.Map(gigaparsec.Bind2(
|
|
parseLength,
|
|
gigaparsec.Pipe[byte, byte, uint64](gigaparsec.Match[byte](':')),
|
|
acceptN,
|
|
), func(s []byte) astNode { return literalNode(s) })
|
|
|
|
var parseFormat = gigaparsec.Map(pbytes.MatchString("%s"),
|
|
func(s string) astNode { return stringFormatNode{} })
|
|
|
|
func parseSexp() gigaparsec.Parser[byte, astNode] {
|
|
return gigaparsec.Choose(parseAtom, parseList, parseFormat)
|
|
}
|
|
|
|
func parseRestOfList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, listNode], error) {
|
|
return gigaparsec.Choose(
|
|
gigaparsec.Map(gigaparsec.Match[byte](')'), func(byte) listNode { return nil }),
|
|
gigaparsec.Bind(
|
|
parseSexp(),
|
|
func(s astNode) gigaparsec.Parser[byte, listNode] {
|
|
return gigaparsec.Map(
|
|
parseRestOfList,
|
|
func(rest listNode) listNode { return append(listNode{s}, rest...) },
|
|
)
|
|
},
|
|
),
|
|
)(input)
|
|
}
|
|
|
|
func parseList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, astNode], error) {
|
|
return gigaparsec.Map(gigaparsec.Seq2(
|
|
gigaparsec.Match[byte]('('),
|
|
parseRestOfList,
|
|
func(_ byte, rest listNode) listNode { return rest },
|
|
), func(sexps listNode) astNode { return sexps })(input)
|
|
}
|
|
|
|
type MismatchError struct {
|
|
Got csexp.Sexp
|
|
Expected string
|
|
}
|
|
|
|
func (e MismatchError) Error() string {
|
|
return fmt.Sprintf("sexp mismatch: expected %s, got %v", e.Expected, e.Got)
|
|
}
|
|
|
|
type Matcher struct {
|
|
pattern astNode
|
|
}
|
|
|
|
func Compile(pattern string) (*Matcher, error) {
|
|
ast, err := gigaparsec.Run(parseSexp(), strings.NewReader(pattern))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid pattern: %w", err)
|
|
}
|
|
return &Matcher{pattern: ast}, nil
|
|
}
|
|
|
|
func MustCompile(pattern string) *Matcher {
|
|
m, err := Compile(pattern)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (m Matcher) Match(sexp csexp.Sexp, dst ...any) error {
|
|
dst, err := match(sexp, m.pattern, dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(dst) > 0 {
|
|
return errors.New("too many operands")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func match(sexp csexp.Sexp, ast astNode, dst []any) ([]any, error) {
|
|
mismatch := func(expected string) MismatchError {
|
|
return MismatchError{Expected: expected, Got: sexp}
|
|
}
|
|
if sexp == nil {
|
|
return dst, mismatch("non-nil input")
|
|
}
|
|
switch node := ast.(type) {
|
|
case literalNode:
|
|
if atom, ok := sexp.(csexp.Atom); ok {
|
|
if bytes.Equal(atom, []byte(node)) {
|
|
return dst, nil
|
|
}
|
|
}
|
|
return dst, MismatchError{Got: sexp, Expected: csexp.Atom(node).String()}
|
|
case stringFormatNode:
|
|
if len(dst) == 0 {
|
|
return dst, errors.New("too few operands")
|
|
}
|
|
atom, isAtom := sexp.(csexp.Atom)
|
|
if !isAtom {
|
|
return dst, mismatch("atom")
|
|
}
|
|
thisDst := dst[0]
|
|
dst = dst[1:]
|
|
switch ptr := thisDst.(type) {
|
|
case *string:
|
|
*ptr = string(atom)
|
|
case *[]byte:
|
|
*ptr = bytes.Clone([]byte(atom))
|
|
default:
|
|
return dst, fmt.Errorf("can't scan type %T", thisDst)
|
|
}
|
|
return dst, nil
|
|
case listNode:
|
|
list, ok := sexp.(csexp.List)
|
|
if !ok {
|
|
return dst, mismatch("list")
|
|
}
|
|
var err error
|
|
var i int
|
|
for i = range node {
|
|
if i == len(list) {
|
|
return dst, mismatch(fmt.Sprintf("list of %d items", len(node)))
|
|
}
|
|
dst, err = match(list[i], node[i], dst)
|
|
if err != nil {
|
|
return dst, err
|
|
}
|
|
}
|
|
if i < len(list)-1 {
|
|
return dst, mismatch(fmt.Sprintf("list of %d items", len(node)))
|
|
}
|
|
return dst, nil
|
|
}
|
|
panic("unreachable")
|
|
}
|