package peachy import ( "errors" "fmt" "log" "os" "strings" "git.codemonkeysoftware.net/b/peachy-go/csexp" "git.codemonkeysoftware.net/b/peachy-go/csexp/match" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) const AppID = '🍑' var ErrInvalidDB = errors.New("invalid database file") var ErrFileExists = errors.New("database file already exists") var ErrFileNotExist = errors.New("database file does not exist") type DBError struct{ error } func (dbe DBError) Error() string { return "database error: " + dbe.error.Error() } type DB struct { conn *sqlite.Conn } func (db *DB) Close() { db.conn.Close() } var tableNameMatcher = match.MustCompile("(%s%s)") func parseTableName(ctx sqlite.Context, args []sqlite.Value) (sqlite.Value, error) { rawName := args[0].Text() classOrName := args[1].Text() var class, name string var result *string switch classOrName { case "class": result = &class case "name": result = &name default: return sqlite.Value{}, errors.New(`parse_table_name: 2nd arg must be "class" or "name"`) } sexp, err := csexp.ParseString(rawName) if err != nil { return sqlite.Value{}, nil } err = tableNameMatcher.Match(sexp, &class, &name) if err != nil { log.Printf("parseTableName: unexpected sexp structure for %s", rawName) return sqlite.Value{}, nil } return sqlite.TextValue(*result), nil } func setupConn(conn *sqlite.Conn) error { return conn.CreateFunction("parse_table_name", &sqlite.FunctionImpl{ NArgs: 2, Deterministic: true, AllowIndirect: true, Scalar: parseTableName, }) } func Open(path string) (db *DB, err error) { var conn *sqlite.Conn defer func() { if err != nil { conn.Close() } }() conn, err = sqlite.OpenConn(path, sqlite.OpenReadWrite|sqlite.OpenWAL) switch sqlite.ErrCode(err) { case sqlite.ResultOK: case sqlite.ResultCantOpen: return nil, ErrFileNotExist case sqlite.ResultNotADB: return nil, ErrInvalidDB default: return nil, DBError{err} } var goodAppID bool sqlitex.ExecuteTransient(conn, "PRAGMA application_id", &sqlitex.ExecOptions{ ResultFunc: func(stmt *sqlite.Stmt) error { goodAppID = stmt.ColumnInt32(0) == AppID return nil }}) if !goodAppID { return nil, ErrInvalidDB } err = setupConn(conn) if err != nil { return nil, err } return &DB{conn: conn}, nil } func Create(path string) (db *DB, err error) { var conn *sqlite.Conn defer func() { if err != nil { conn.Close() } }() finfo, _ := os.Stat(path) if finfo != nil { return nil, ErrFileExists } conn, err = sqlite.OpenConn(path, sqlite.OpenCreate|sqlite.OpenReadWrite|sqlite.OpenWAL) if err != nil { return nil, fmt.Errorf("could not create database: %w", err) } query := fmt.Sprintf("PRAGMA application_id=%d", AppID) err = sqlitex.ExecuteTransient(conn, query, nil) if err != nil { return nil, DBError{err} } err = setupConn(conn) if err != nil { return nil, err } return &DB{conn: conn}, nil } func quoteName(name string) string { return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` } type CompositeKind int const ( Record CompositeKind = iota Variant ) type CompositeType struct { Name string Kind CompositeKind } const addConcreteCompositeTypeQuery = `CREATE TABLE %s ( id INTEGER PRIMARY KEY );` const addAbstractCompositeTypeQuery = `CREATE TABLE %s ( id INTEGER PRIMARY KEY, %s_value_id INTEGER NOT NULL REFERENCES %s(id) );` func (db *DB) AddCompositeType(name string, kind CompositeKind) error { var kindStr string switch kind { case Record: kindStr = "record" case Variant: kindStr = "variant" default: return errors.New("invalid kind") } abstractTableName := quoteName(csexp.List{ csexp.Atom("composite-value"), csexp.Atom(name), }.String()) concreteTableName := quoteName(csexp.List{ csexp.Atom(kindStr + "-value"), csexp.Atom(name), }.String()) err := sqlitex.Execute(db.conn, fmt.Sprintf(addConcreteCompositeTypeQuery, concreteTableName), nil) if err != nil { return fmt.Errorf("AddCompositeType: %w", err) } err = sqlitex.Execute(db.conn, fmt.Sprintf(addAbstractCompositeTypeQuery, abstractTableName, kindStr, concreteTableName), nil) if err != nil { return fmt.Errorf("AddCompositeType: %w", err) } return nil } const getCompositeTypesQuery = `SELECT parse_table_name(name, 'name'), parse_table_name(name, 'class') FROM sqlite_schema WHERE parse_table_name(name, 'class') IN ('record-value', 'variant-value') ORDER BY parse_table_name(name, 'name') ASC;` func (db *DB) GetCompositeTypes() ([]CompositeType, error) { var results []CompositeType err := sqlitex.Execute(db.conn, getCompositeTypesQuery, &sqlitex.ExecOptions{ ResultFunc: func(stmt *sqlite.Stmt) error { result := CompositeType{Name: stmt.ColumnText(0)} if stmt.ColumnText(1) == "record-value" { result.Kind = Record } else { result.Kind = Variant } results = append(results, result) return nil }, }) if err != nil { return nil, fmt.Errorf("GetCompositeTypes: %w", err) } return results, nil }