From c6922d220a9ab4b5930da45703d4e16fdf24c412 Mon Sep 17 00:00:00 2001 From: Brandon Dyck Date: Thu, 31 Oct 2024 13:18:25 -0600 Subject: [PATCH] Create and get field types --- csexp/sexp.go | 11 +++++++- db.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++- db_test.go | 22 ++++++++++++++++ 3 files changed, 100 insertions(+), 2 deletions(-) diff --git a/csexp/sexp.go b/csexp/sexp.go index bb4a2ce..9b40ba3 100644 --- a/csexp/sexp.go +++ b/csexp/sexp.go @@ -7,6 +7,7 @@ import ( "io" "slices" "strconv" + "strings" "git.codemonkeysoftware.net/b/gigaparsec" pbytes "git.codemonkeysoftware.net/b/gigaparsec/bytes" @@ -146,12 +147,20 @@ func parseList(input gigaparsec.State[byte]) (gigaparsec.Result[byte, Sexp], err } func Parse(data []byte) (Sexp, error) { + return parse(bytes.NewReader(data)) +} + +func ParseString(data string) (Sexp, error) { + return parse(strings.NewReader(data)) +} + +func parse(r io.ReaderAt) (Sexp, error) { parser := gigaparsec.Seq2( parseSexp, gigaparsec.End[byte](), func(s Sexp, _ struct{}) Sexp { return s }, ) - result, err := gigaparsec.Run(parser, bytes.NewReader(data)) + result, err := gigaparsec.Run(parser, r) if err != nil { return nil, fmt.Errorf("csexp.Parse: %w", err) } diff --git a/db.go b/db.go index 433b554..ba00026 100644 --- a/db.go +++ b/db.go @@ -4,7 +4,10 @@ import ( "errors" "fmt" "os" + "strings" + "git.codemonkeysoftware.net/b/peachy-go/csexp" + "git.codemonkeysoftware.net/b/peachy-go/csexp/match" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) @@ -29,8 +32,29 @@ func (db *DB) Close() { db.conn.Close() } +var fieldTypeMatcher = match.MustCompile("(10:field type%s)") + +func parseFieldTypeName(ctx sqlite.Context, args []sqlite.Value) (sqlite.Value, error) { + rawName := args[0].Text() + sexp, err := csexp.ParseString(rawName) + if err != nil { + return sqlite.Value{}, nil + } + var name string + err = fieldTypeMatcher.Match(sexp, &name) + if err != nil { + return sqlite.Value{}, nil + } + return sqlite.TextValue(name), nil +} + func setupConn(conn *sqlite.Conn) error { - return nil + return conn.CreateFunction("parse_field_type_name", &sqlite.FunctionImpl{ + NArgs: 1, + Deterministic: true, + AllowIndirect: true, + Scalar: parseFieldTypeName, + }) } func Open(path string) (db *DB, err error) { @@ -96,3 +120,46 @@ func Create(path string) (db *DB, err error) { } return &DB{conn: conn}, nil } + +func quoteName(name string) string { + return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` +} + +const addFieldTypeQueryFmt = `CREATE TABLE %s ( + id INTEGER PRIMARY KEY +);` + +type FieldType struct { + Name string +} + +func (db *DB) AddFieldType(name string) error { + tableName := csexp.List{csexp.Atom("field type"), csexp.Atom(name)}.String() + quotedName := quoteName(tableName) + query := fmt.Sprintf(addFieldTypeQueryFmt, quotedName) + err := sqlitex.Execute(db.conn, query, nil) + if err != nil { + return fmt.Errorf("AddFieldType: %w", err) + } + return nil +} + +const getFieldTypesQuery = `SELECT parse_field_type_name(name) +FROM sqlite_schema +WHERE parse_field_type_name(name) IS NOT NULL +ORDER BY parse_field_type_name(name) ASC` + +func (db *DB) GetFieldTypes() ([]FieldType, error) { + var results []FieldType + err := sqlitex.Execute(db.conn, getFieldTypesQuery, &sqlitex.ExecOptions{ + ResultFunc: func(stmt *sqlite.Stmt) error { + name := stmt.ColumnText(0) + results = append(results, FieldType{Name: name}) + return nil + }, + }) + if err != nil { + return nil, fmt.Errorf("GetFieldTypes: %w", err) + } + return results, nil +} diff --git a/db_test.go b/db_test.go index 3394545..35611a9 100644 --- a/db_test.go +++ b/db_test.go @@ -6,6 +6,7 @@ import ( "testing" "git.codemonkeysoftware.net/b/peachy-go" + "github.com/shoenig/test" "github.com/shoenig/test/must" "zombiezen.com/go/sqlite" ) @@ -67,3 +68,24 @@ func TestOpen(t *testing.T) { must.ErrorIs(t, err, peachy.ErrInvalidDB, must.Sprint(sqlite.ErrCode(err))) }) } + +func TestFieldTypes(t *testing.T) { + db, err := peachy.Create(":memory:") + must.NoError(t, err) + defer db.Close() + + names := []string{"alpha", `"bravo`, " charl!e"} + for _, name := range names { + err := db.AddFieldType(name) + must.NoError(t, err) + } + + fieldTypes, err := db.GetFieldTypes() + must.NoError(t, err) + for _, name := range names { + test.SliceContainsFunc(t, fieldTypes, name, + func(ft peachy.FieldType, s string) bool { + return ft.Name == s + }) + } +}