Refactor db code. Add migration support

This commit is contained in:
2025-05-25 20:27:55 +02:00
parent dacdd3b826
commit 706744d657
9 changed files with 300 additions and 268 deletions

View File

@@ -1,231 +1,41 @@
package repository
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitemigration"
"zombiezen.com/go/sqlite/sqlitex"
"database/sql"
"log/slog"
"sync"
)
type Database struct {
filename string
migrations []string
writePool *sqlitex.Pool
readPool *sqlitex.Pool
}
type TxFn func(tx *sqlite.Conn) error
func NewDatabase(ctx context.Context, dbFilename string, migrations []string) (*Database, error) {
if dbFilename == "" {
return nil, fmt.Errorf("database filename is required")
}
db := &Database{
filename: dbFilename,
migrations: migrations,
}
if err := db.Reset(ctx, false); err != nil {
return nil, fmt.Errorf("failed to reset database: %w", err)
}
return db, nil
}
func (db *Database) WriteWithoutTx(ctx context.Context, fn TxFn) error {
conn, err := db.writePool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take write connection: %w", err)
}
if conn == nil {
return fmt.Errorf("could not get write connection from pool")
}
defer db.writePool.Put(conn)
if err := fn(conn); err != nil {
return fmt.Errorf("could not execute write transaction: %w", err)
}
return nil
}
func (db *Database) Reset(ctx context.Context, shouldClear bool) (err error) {
if err := db.Close(); err != nil {
return fmt.Errorf("could not close database: %w", err)
}
if shouldClear {
if err := os.RemoveAll(db.filename + "*"); err != nil {
return fmt.Errorf("could not remove database file: %w", err)
}
}
if err := os.MkdirAll(filepath.Dir(db.filename), 0755); err != nil {
return fmt.Errorf("could not create database directory: %w", err)
}
uri := fmt.Sprintf("file:%s?_journal_mode=WAL&_synchronous=NORMAL", db.filename)
db.writePool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
PoolSize: 1,
})
if err != nil {
return fmt.Errorf("could not open write pool: %w", err)
}
db.readPool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
PoolSize: runtime.NumCPU(),
})
if err := db.WriteTX(ctx, func(tx *sqlite.Conn) error {
foreignKeysStmt := tx.Prep("PRAGMA foreign_keys = ON")
defer foreignKeysStmt.Finalize()
if _, err := foreignKeysStmt.Step(); err != nil {
return fmt.Errorf("failed to enable foreign keys: %w", err)
}
return nil
}); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
schema := sqlitemigration.Schema{Migrations: db.migrations}
conn, err := db.writePool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take write connection: %w", err)
}
defer db.writePool.Put(conn)
if err := sqlitemigration.Migrate(ctx, conn, schema); err != nil {
db.writePool.Put(conn)
return fmt.Errorf("failed to migrate database: %w", err)
}
return nil
}
func (db *Database) Close() error {
errs := []error{}
if db.writePool != nil {
errs = append(errs, db.writePool.Close())
}
if db.readPool != nil {
errs = append(errs, db.readPool.Close())
}
return errors.Join(errs...)
}
func (db *Database) WriteTX(ctx context.Context, fn TxFn) (err error) {
conn, err := db.writePool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take write connection: %w", err)
}
if conn == nil {
return fmt.Errorf("could not get write connection from pool")
}
defer db.writePool.Put(conn)
endFn, err := sqlitex.ImmediateTransaction(conn)
if err != nil {
return fmt.Errorf("could not start transaction: %w", err)
}
defer endFn(&err)
if err := fn(conn); err != nil {
return fmt.Errorf("could not execute write transaction: %w", err)
}
return nil
}
func (db *Database) ReadTX(ctx context.Context, fn TxFn) (err error) {
conn, err := db.readPool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take read connection: %w", err)
}
if conn == nil {
return fmt.Errorf("could not get read connection from pool")
}
defer db.readPool.Put(conn)
endFn := sqlitex.Transaction(conn)
defer endFn(&err)
if err := fn(conn); err != nil {
return fmt.Errorf("could not execute read transaction: %w", err)
}
return nil
}
const (
secondsInADay = 86400
UnixEpochJulianDay = 2440587.5
var (
dbConnOnce sync.Once
dbConn *sql.DB
repo *Queries
)
var JulianZeroTime = JulianDayToTime(0)
func Connect(dsnURI string) {
dbConnOnce.Do(func() {
var err error
dbConn, err = sql.Open("sqlite", dsnURI)
if err != nil {
slog.Error("Fatal error")
}
// TimeToJulianDay converts a time.Time into a Julian day.
func TimeToJulianDay(t time.Time) float64 {
return float64(t.UTC().Unix())/secondsInADay + UnixEpochJulianDay
repo = New(dbConn)
})
}
// JulianDayToTime converts a Julian day into a time.Time.
func JulianDayToTime(d float64) time.Time {
return time.Unix(int64((d-UnixEpochJulianDay)*secondsInADay), 0).UTC()
}
func JulianNow() float64 {
return TimeToJulianDay(time.Now())
}
func TimestampJulian(ts *timestamppb.Timestamp) float64 {
return TimeToJulianDay(ts.AsTime())
}
func JulianDayToTimestamp(f float64) *timestamppb.Timestamp {
t := JulianDayToTime(f)
return timestamppb.New(t)
}
func StmtJulianToTimestamp(stmt *sqlite.Stmt, colName string) *timestamppb.Timestamp {
julianDays := stmt.GetFloat(colName)
return JulianDayToTimestamp(julianDays)
}
func StmtJulianToTime(stmt *sqlite.Stmt, colName string) time.Time {
julianDays := stmt.GetFloat(colName)
return JulianDayToTime(julianDays)
}
func DurationToMilliseconds(d time.Duration) int64 {
return int64(d / time.Millisecond)
}
func MillisecondsToDuration(ms int64) time.Duration {
return time.Duration(ms) * time.Millisecond
}
func StmtBytes(stmt *sqlite.Stmt, colName string) []byte {
bl := stmt.GetLen(colName)
if bl == 0 {
return nil
func GetConnection() (*sql.DB, error) {
if dbConn == nil {
slog.Error("Database connection not initialized!")
}
buf := make([]byte, bl)
if writtent := stmt.GetBytes(colName, buf); writtent != bl {
return nil
return dbConn, nil
}
func GetRepository() (*Queries, error) {
if repo == nil {
slog.Error("Database connection not initialized!")
}
return buf
return repo, nil
}

View File

@@ -0,0 +1,137 @@
package repository
import (
"database/sql"
"embed"
"fmt"
"log/slog"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
)
const Version uint = 1
//go:embed migrations/*
var migrationFiles embed.FS
func checkDBVersion(db *sql.DB) error {
var m *migrate.Migrate
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
return err
}
d, err := iofs.New(migrationFiles, "migrations")
if err != nil {
return err
}
m, err = migrate.NewWithInstance("iofs", d, "sqlite3", driver)
if err != nil {
return err
}
v, dirty, err := m.Version()
if err != nil {
if err == migrate.ErrNilVersion {
slog.Warn("Legacy database without version or missing database file!")
} else {
return err
}
}
if v < Version {
return fmt.Errorf("unsupported database version %d, need %d.\nPlease backup your database file and run cc-backend -migrate-db", v, Version)
} else if v > Version {
return fmt.Errorf("unsupported database version %d, need %d.\nPlease refer to documentation how to downgrade db with external migrate tool", v, Version)
}
if dirty {
return fmt.Errorf("last migration to version %d has failed, please fix the db manually and force version with -force-db flag", Version)
}
return nil
}
func getMigrateInstance(dsnURI string) (m *migrate.Migrate, err error) {
d, err := iofs.New(migrationFiles, "migrations")
if err != nil {
slog.Error("failed to get instance", "Error", err)
}
m, err = migrate.NewWithSourceInstance("iofs", d, dsnURI)
if err != nil {
return m, err
}
return m, nil
}
func MigrateDB(db string) error {
m, err := getMigrateInstance(db)
if err != nil {
return err
}
v, dirty, err := m.Version()
if err != nil {
if err == migrate.ErrNilVersion {
slog.Warn("Legacy database without version or missing database file!")
} else {
return err
}
}
if v < Version {
slog.Info("unsupported database version %d, need %d.\nPlease backup your database file and run cc-backend -migrate-db", v, Version)
}
if dirty {
return fmt.Errorf("last migration to version %d has failed, please fix the db manually and force version with -force-db flag", Version)
}
if err := m.Up(); err != nil {
if err == migrate.ErrNoChange {
slog.Info("DB already up to date!")
} else {
return err
}
}
m.Close()
return nil
}
func RevertDB(db string) error {
m, err := getMigrateInstance(db)
if err != nil {
return err
}
if err := m.Migrate(Version - 1); err != nil {
if err == migrate.ErrNoChange {
slog.Info("DB already up to date!")
} else {
return err
}
}
m.Close()
return nil
}
func ForceDB(db string) error {
m, err := getMigrateInstance(db)
if err != nil {
return err
}
if err := m.Force(int(Version)); err != nil {
return err
}
m.Close()
return nil
}

View File

@@ -0,0 +1,5 @@
CREATE TABLE authors (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
bio TEXT
);

View File

@@ -0,0 +1,11 @@
CREATE TABLE news (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
bio TEXT
);
CREATE TABLE retailer (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
bio TEXT
);

View File

@@ -0,0 +1,25 @@
-- name: GetAuthor :one
SELECT * FROM authors
WHERE id = ? LIMIT 1;
-- name: ListAuthors :many
SELECT * FROM authors
ORDER BY name;
-- name: CreateAuthor :one
INSERT INTO authors (
name, bio
) VALUES (
?, ?
)
RETURNING *;
-- name: UpdateAuthor :exec
UPDATE authors
set name = ?,
bio = ?
WHERE id = ?;
-- name: DeleteAuthor :exec
DELETE FROM authors
WHERE id = ?;