232 lines
5.4 KiB
Go
232 lines
5.4 KiB
Go
|
package repository
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"runtime"
|
||
|
"time"
|
||
|
|
||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||
|
"zombiezen.com/go/sqlite"
|
||
|
"zombiezen.com/go/sqlite/sqlitemigration"
|
||
|
"zombiezen.com/go/sqlite/sqlitex"
|
||
|
)
|
||
|
|
||
|
type Database struct {
|
||
|
filename string
|
||
|
migrations []string
|
||
|
writePool *sqlitex.Pool
|
||
|
readPool *sqlitex.Pool
|
||
|
}
|
||
|
|
||
|
type TxFn func(tx *sqlite.Conn) error
|
||
|
|
||
|
func NewDatabase(ctx context.Context, dbFilename string, migrations []string) (*Database, error) {
|
||
|
if dbFilename == "" {
|
||
|
return nil, fmt.Errorf("database filename is required")
|
||
|
}
|
||
|
|
||
|
db := &Database{
|
||
|
filename: dbFilename,
|
||
|
migrations: migrations,
|
||
|
}
|
||
|
|
||
|
if err := db.Reset(ctx, false); err != nil {
|
||
|
return nil, fmt.Errorf("failed to reset database: %w", err)
|
||
|
}
|
||
|
|
||
|
return db, nil
|
||
|
}
|
||
|
|
||
|
func (db *Database) WriteWithoutTx(ctx context.Context, fn TxFn) error {
|
||
|
conn, err := db.writePool.Take(ctx)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to take write connection: %w", err)
|
||
|
}
|
||
|
if conn == nil {
|
||
|
return fmt.Errorf("could not get write connection from pool")
|
||
|
}
|
||
|
defer db.writePool.Put(conn)
|
||
|
|
||
|
if err := fn(conn); err != nil {
|
||
|
return fmt.Errorf("could not execute write transaction: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (db *Database) Reset(ctx context.Context, shouldClear bool) (err error) {
|
||
|
if err := db.Close(); err != nil {
|
||
|
return fmt.Errorf("could not close database: %w", err)
|
||
|
}
|
||
|
|
||
|
if shouldClear {
|
||
|
if err := os.RemoveAll(db.filename + "*"); err != nil {
|
||
|
return fmt.Errorf("could not remove database file: %w", err)
|
||
|
}
|
||
|
}
|
||
|
if err := os.MkdirAll(filepath.Dir(db.filename), 0755); err != nil {
|
||
|
return fmt.Errorf("could not create database directory: %w", err)
|
||
|
}
|
||
|
|
||
|
uri := fmt.Sprintf("file:%s?_journal_mode=WAL&_synchronous=NORMAL", db.filename)
|
||
|
|
||
|
db.writePool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
|
||
|
PoolSize: 1,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not open write pool: %w", err)
|
||
|
}
|
||
|
|
||
|
db.readPool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
|
||
|
PoolSize: runtime.NumCPU(),
|
||
|
})
|
||
|
|
||
|
if err := db.WriteTX(ctx, func(tx *sqlite.Conn) error {
|
||
|
foreignKeysStmt := tx.Prep("PRAGMA foreign_keys = ON")
|
||
|
defer foreignKeysStmt.Finalize()
|
||
|
if _, err := foreignKeysStmt.Step(); err != nil {
|
||
|
return fmt.Errorf("failed to enable foreign keys: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}); err != nil {
|
||
|
return fmt.Errorf("failed to initialize database: %w", err)
|
||
|
}
|
||
|
|
||
|
schema := sqlitemigration.Schema{Migrations: db.migrations}
|
||
|
conn, err := db.writePool.Take(ctx)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to take write connection: %w", err)
|
||
|
}
|
||
|
defer db.writePool.Put(conn)
|
||
|
|
||
|
if err := sqlitemigration.Migrate(ctx, conn, schema); err != nil {
|
||
|
db.writePool.Put(conn)
|
||
|
return fmt.Errorf("failed to migrate database: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (db *Database) Close() error {
|
||
|
errs := []error{}
|
||
|
if db.writePool != nil {
|
||
|
errs = append(errs, db.writePool.Close())
|
||
|
}
|
||
|
|
||
|
if db.readPool != nil {
|
||
|
errs = append(errs, db.readPool.Close())
|
||
|
}
|
||
|
|
||
|
return errors.Join(errs...)
|
||
|
}
|
||
|
|
||
|
func (db *Database) WriteTX(ctx context.Context, fn TxFn) (err error) {
|
||
|
conn, err := db.writePool.Take(ctx)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to take write connection: %w", err)
|
||
|
}
|
||
|
if conn == nil {
|
||
|
return fmt.Errorf("could not get write connection from pool")
|
||
|
}
|
||
|
defer db.writePool.Put(conn)
|
||
|
|
||
|
endFn, err := sqlitex.ImmediateTransaction(conn)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not start transaction: %w", err)
|
||
|
}
|
||
|
defer endFn(&err)
|
||
|
|
||
|
if err := fn(conn); err != nil {
|
||
|
return fmt.Errorf("could not execute write transaction: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (db *Database) ReadTX(ctx context.Context, fn TxFn) (err error) {
|
||
|
conn, err := db.readPool.Take(ctx)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to take read connection: %w", err)
|
||
|
}
|
||
|
if conn == nil {
|
||
|
return fmt.Errorf("could not get read connection from pool")
|
||
|
}
|
||
|
defer db.readPool.Put(conn)
|
||
|
|
||
|
endFn := sqlitex.Transaction(conn)
|
||
|
defer endFn(&err)
|
||
|
|
||
|
if err := fn(conn); err != nil {
|
||
|
return fmt.Errorf("could not execute read transaction: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
secondsInADay = 86400
|
||
|
UnixEpochJulianDay = 2440587.5
|
||
|
)
|
||
|
|
||
|
var JulianZeroTime = JulianDayToTime(0)
|
||
|
|
||
|
// TimeToJulianDay converts a time.Time into a Julian day.
|
||
|
func TimeToJulianDay(t time.Time) float64 {
|
||
|
return float64(t.UTC().Unix())/secondsInADay + UnixEpochJulianDay
|
||
|
}
|
||
|
|
||
|
// JulianDayToTime converts a Julian day into a time.Time.
|
||
|
func JulianDayToTime(d float64) time.Time {
|
||
|
return time.Unix(int64((d-UnixEpochJulianDay)*secondsInADay), 0).UTC()
|
||
|
}
|
||
|
|
||
|
func JulianNow() float64 {
|
||
|
return TimeToJulianDay(time.Now())
|
||
|
}
|
||
|
|
||
|
func TimestampJulian(ts *timestamppb.Timestamp) float64 {
|
||
|
return TimeToJulianDay(ts.AsTime())
|
||
|
}
|
||
|
|
||
|
func JulianDayToTimestamp(f float64) *timestamppb.Timestamp {
|
||
|
t := JulianDayToTime(f)
|
||
|
return timestamppb.New(t)
|
||
|
}
|
||
|
|
||
|
func StmtJulianToTimestamp(stmt *sqlite.Stmt, colName string) *timestamppb.Timestamp {
|
||
|
julianDays := stmt.GetFloat(colName)
|
||
|
return JulianDayToTimestamp(julianDays)
|
||
|
}
|
||
|
|
||
|
func StmtJulianToTime(stmt *sqlite.Stmt, colName string) time.Time {
|
||
|
julianDays := stmt.GetFloat(colName)
|
||
|
return JulianDayToTime(julianDays)
|
||
|
}
|
||
|
|
||
|
func DurationToMilliseconds(d time.Duration) int64 {
|
||
|
return int64(d / time.Millisecond)
|
||
|
}
|
||
|
|
||
|
func MillisecondsToDuration(ms int64) time.Duration {
|
||
|
return time.Duration(ms) * time.Millisecond
|
||
|
}
|
||
|
|
||
|
func StmtBytes(stmt *sqlite.Stmt, colName string) []byte {
|
||
|
bl := stmt.GetLen(colName)
|
||
|
if bl == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
buf := make([]byte, bl)
|
||
|
if writtent := stmt.GetBytes(colName, buf); writtent != bl {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return buf
|
||
|
}
|