go-http-skeleton/internal/repository/dbConn.go

232 lines
5.4 KiB
Go
Raw Normal View History

package repository
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitemigration"
"zombiezen.com/go/sqlite/sqlitex"
)
type Database struct {
filename string
migrations []string
writePool *sqlitex.Pool
readPool *sqlitex.Pool
}
type TxFn func(tx *sqlite.Conn) error
func NewDatabase(ctx context.Context, dbFilename string, migrations []string) (*Database, error) {
if dbFilename == "" {
return nil, fmt.Errorf("database filename is required")
}
db := &Database{
filename: dbFilename,
migrations: migrations,
}
if err := db.Reset(ctx, false); err != nil {
return nil, fmt.Errorf("failed to reset database: %w", err)
}
return db, nil
}
func (db *Database) WriteWithoutTx(ctx context.Context, fn TxFn) error {
conn, err := db.writePool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take write connection: %w", err)
}
if conn == nil {
return fmt.Errorf("could not get write connection from pool")
}
defer db.writePool.Put(conn)
if err := fn(conn); err != nil {
return fmt.Errorf("could not execute write transaction: %w", err)
}
return nil
}
func (db *Database) Reset(ctx context.Context, shouldClear bool) (err error) {
if err := db.Close(); err != nil {
return fmt.Errorf("could not close database: %w", err)
}
if shouldClear {
if err := os.RemoveAll(db.filename + "*"); err != nil {
return fmt.Errorf("could not remove database file: %w", err)
}
}
if err := os.MkdirAll(filepath.Dir(db.filename), 0755); err != nil {
return fmt.Errorf("could not create database directory: %w", err)
}
uri := fmt.Sprintf("file:%s?_journal_mode=WAL&_synchronous=NORMAL", db.filename)
db.writePool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
PoolSize: 1,
})
if err != nil {
return fmt.Errorf("could not open write pool: %w", err)
}
db.readPool, err = sqlitex.NewPool(uri, sqlitex.PoolOptions{
PoolSize: runtime.NumCPU(),
})
if err := db.WriteTX(ctx, func(tx *sqlite.Conn) error {
foreignKeysStmt := tx.Prep("PRAGMA foreign_keys = ON")
defer foreignKeysStmt.Finalize()
if _, err := foreignKeysStmt.Step(); err != nil {
return fmt.Errorf("failed to enable foreign keys: %w", err)
}
return nil
}); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
schema := sqlitemigration.Schema{Migrations: db.migrations}
conn, err := db.writePool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take write connection: %w", err)
}
defer db.writePool.Put(conn)
if err := sqlitemigration.Migrate(ctx, conn, schema); err != nil {
db.writePool.Put(conn)
return fmt.Errorf("failed to migrate database: %w", err)
}
return nil
}
func (db *Database) Close() error {
errs := []error{}
if db.writePool != nil {
errs = append(errs, db.writePool.Close())
}
if db.readPool != nil {
errs = append(errs, db.readPool.Close())
}
return errors.Join(errs...)
}
func (db *Database) WriteTX(ctx context.Context, fn TxFn) (err error) {
conn, err := db.writePool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take write connection: %w", err)
}
if conn == nil {
return fmt.Errorf("could not get write connection from pool")
}
defer db.writePool.Put(conn)
endFn, err := sqlitex.ImmediateTransaction(conn)
if err != nil {
return fmt.Errorf("could not start transaction: %w", err)
}
defer endFn(&err)
if err := fn(conn); err != nil {
return fmt.Errorf("could not execute write transaction: %w", err)
}
return nil
}
func (db *Database) ReadTX(ctx context.Context, fn TxFn) (err error) {
conn, err := db.readPool.Take(ctx)
if err != nil {
return fmt.Errorf("failed to take read connection: %w", err)
}
if conn == nil {
return fmt.Errorf("could not get read connection from pool")
}
defer db.readPool.Put(conn)
endFn := sqlitex.Transaction(conn)
defer endFn(&err)
if err := fn(conn); err != nil {
return fmt.Errorf("could not execute read transaction: %w", err)
}
return nil
}
const (
secondsInADay = 86400
UnixEpochJulianDay = 2440587.5
)
var JulianZeroTime = JulianDayToTime(0)
// TimeToJulianDay converts a time.Time into a Julian day.
func TimeToJulianDay(t time.Time) float64 {
return float64(t.UTC().Unix())/secondsInADay + UnixEpochJulianDay
}
// JulianDayToTime converts a Julian day into a time.Time.
func JulianDayToTime(d float64) time.Time {
return time.Unix(int64((d-UnixEpochJulianDay)*secondsInADay), 0).UTC()
}
func JulianNow() float64 {
return TimeToJulianDay(time.Now())
}
func TimestampJulian(ts *timestamppb.Timestamp) float64 {
return TimeToJulianDay(ts.AsTime())
}
func JulianDayToTimestamp(f float64) *timestamppb.Timestamp {
t := JulianDayToTime(f)
return timestamppb.New(t)
}
func StmtJulianToTimestamp(stmt *sqlite.Stmt, colName string) *timestamppb.Timestamp {
julianDays := stmt.GetFloat(colName)
return JulianDayToTimestamp(julianDays)
}
func StmtJulianToTime(stmt *sqlite.Stmt, colName string) time.Time {
julianDays := stmt.GetFloat(colName)
return JulianDayToTime(julianDays)
}
func DurationToMilliseconds(d time.Duration) int64 {
return int64(d / time.Millisecond)
}
func MillisecondsToDuration(ms int64) time.Duration {
return time.Duration(ms) * time.Millisecond
}
func StmtBytes(stmt *sqlite.Stmt, colName string) []byte {
bl := stmt.GetLen(colName)
if bl == 0 {
return nil
}
buf := make([]byte, bl)
if writtent := stmt.GetBytes(colName, buf); writtent != bl {
return nil
}
return buf
}