package peachy import ( "errors" "fmt" "os" "strings" "git.codemonkeysoftware.net/b/peachy-go/csexp" "git.codemonkeysoftware.net/b/peachy-go/csexp/match" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) const AppID = '🍑' var ErrInvalidDB = errors.New("invalid database file") var ErrFileExists = errors.New("database file already exists") var ErrFileNotExist = errors.New("database file does not exist") type DBError struct{ error } func (dbe DBError) Error() string { return "database error: " + dbe.error.Error() } type DB struct { conn *sqlite.Conn } func (db *DB) Close() { db.conn.Close() } var fieldTypeMatcher = match.MustCompile("(10:field type%s)") func parseFieldTypeName(ctx sqlite.Context, args []sqlite.Value) (sqlite.Value, error) { rawName := args[0].Text() sexp, err := csexp.ParseString(rawName) if err != nil { return sqlite.Value{}, nil } var name string err = fieldTypeMatcher.Match(sexp, &name) if err != nil { return sqlite.Value{}, nil } return sqlite.TextValue(name), nil } func setupConn(conn *sqlite.Conn) error { return conn.CreateFunction("parse_field_type_name", &sqlite.FunctionImpl{ NArgs: 1, Deterministic: true, AllowIndirect: true, Scalar: parseFieldTypeName, }) } func Open(path string) (db *DB, err error) { var conn *sqlite.Conn defer func() { if err != nil { conn.Close() } }() conn, err = sqlite.OpenConn(path, sqlite.OpenReadWrite|sqlite.OpenWAL) switch sqlite.ErrCode(err) { case sqlite.ResultOK: case sqlite.ResultCantOpen: return nil, ErrFileNotExist case sqlite.ResultNotADB: return nil, ErrInvalidDB default: return nil, DBError{err} } var goodAppID bool sqlitex.ExecuteTransient(conn, "PRAGMA application_id", &sqlitex.ExecOptions{ ResultFunc: func(stmt *sqlite.Stmt) error { goodAppID = stmt.ColumnInt32(0) == AppID return nil }}) if !goodAppID { return nil, ErrInvalidDB } err = setupConn(conn) if err != nil { return nil, err } return &DB{conn: conn}, nil } func Create(path string) (db *DB, err error) { var conn *sqlite.Conn defer func() { if err != nil { conn.Close() } }() finfo, _ := os.Stat(path) if finfo != nil { return nil, ErrFileExists } conn, err = sqlite.OpenConn(path, sqlite.OpenCreate|sqlite.OpenReadWrite|sqlite.OpenWAL) if err != nil { return nil, fmt.Errorf("could not create database: %w", err) } query := fmt.Sprintf("PRAGMA application_id=%d", AppID) err = sqlitex.ExecuteTransient(conn, query, nil) if err != nil { return nil, DBError{err} } err = setupConn(conn) if err != nil { return nil, err } return &DB{conn: conn}, nil } func quoteName(name string) string { return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` } const addFieldTypeQueryFmt = `CREATE TABLE %s ( id INTEGER PRIMARY KEY );` type FieldType struct { Name string } func (db *DB) AddFieldType(name string) error { tableName := csexp.List{csexp.Atom("field type"), csexp.Atom(name)}.String() quotedName := quoteName(tableName) query := fmt.Sprintf(addFieldTypeQueryFmt, quotedName) err := sqlitex.Execute(db.conn, query, nil) if err != nil { return fmt.Errorf("AddFieldType: %w", err) } return nil } const getFieldTypesQuery = `SELECT parse_field_type_name(name) FROM sqlite_schema WHERE parse_field_type_name(name) IS NOT NULL ORDER BY parse_field_type_name(name) ASC` func (db *DB) GetFieldTypes() ([]FieldType, error) { var results []FieldType err := sqlitex.Execute(db.conn, getFieldTypesQuery, &sqlitex.ExecOptions{ ResultFunc: func(stmt *sqlite.Stmt) error { name := stmt.ColumnText(0) results = append(results, FieldType{Name: name}) return nil }, }) if err != nil { return nil, fmt.Errorf("GetFieldTypes: %w", err) } return results, nil }