From 8f4ef1e274165e610f86a2e0d3904e3a817fe139 Mon Sep 17 00:00:00 2001 From: Jan Eitzinger Date: Thu, 20 Nov 2025 06:58:45 +0100 Subject: [PATCH] Refactor repository Fix issues Improve transaction API Make hardcoded constants configurable Make error messages consistent and always add context info --- internal/repository/config.go | 67 +++++++++++++++++++++++ internal/repository/dbConnection.go | 14 ++--- internal/repository/job.go | 53 +++++++++++++++---- internal/repository/jobCreate.go | 13 +++-- internal/repository/jobFind.go | 17 ++++-- internal/repository/node.go | 45 ++-------------- internal/repository/stats.go | 10 ---- internal/repository/transaction.go | 82 +++++++++++++++++------------ internal/tagger/tagger.go | 2 +- 9 files changed, 188 insertions(+), 115 deletions(-) create mode 100644 internal/repository/config.go diff --git a/internal/repository/config.go b/internal/repository/config.go new file mode 100644 index 0000000..b54d0ae --- /dev/null +++ b/internal/repository/config.go @@ -0,0 +1,67 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package repository + +import "time" + +// RepositoryConfig holds configuration for repository operations. +// All fields have sensible defaults, so this configuration is optional. +type RepositoryConfig struct { + // CacheSize is the LRU cache size in bytes for job metadata and energy footprints. + // Default: 1MB (1024 * 1024 bytes) + CacheSize int + + // MaxOpenConnections is the maximum number of open database connections. + // Default: 4 + MaxOpenConnections int + + // MaxIdleConnections is the maximum number of idle database connections. + // Default: 4 + MaxIdleConnections int + + // ConnectionMaxLifetime is the maximum amount of time a connection may be reused. + // Default: 1 hour + ConnectionMaxLifetime time.Duration + + // ConnectionMaxIdleTime is the maximum amount of time a connection may be idle. + // Default: 1 hour + ConnectionMaxIdleTime time.Duration + + // MinRunningJobDuration is the minimum duration in seconds for a job to be + // considered in "running jobs" queries. This filters out very short jobs. + // Default: 600 seconds (10 minutes) + MinRunningJobDuration int +} + +// DefaultConfig returns the default repository configuration. +// These values are optimized for typical deployments. +func DefaultConfig() *RepositoryConfig { + return &RepositoryConfig{ + CacheSize: 1 * 1024 * 1024, // 1MB + MaxOpenConnections: 4, + MaxIdleConnections: 4, + ConnectionMaxLifetime: time.Hour, + ConnectionMaxIdleTime: time.Hour, + MinRunningJobDuration: 600, // 10 minutes + } +} + +// repoConfig is the package-level configuration instance. +// It is initialized with defaults and can be overridden via SetConfig. +var repoConfig *RepositoryConfig = DefaultConfig() + +// SetConfig sets the repository configuration. +// This must be called before any repository initialization (Connect, GetJobRepository, etc.). +// If not called, default values from DefaultConfig() are used. +func SetConfig(cfg *RepositoryConfig) { + if cfg != nil { + repoConfig = cfg + } +} + +// GetConfig returns the current repository configuration. +func GetConfig() *RepositoryConfig { + return repoConfig +} diff --git a/internal/repository/dbConnection.go b/internal/repository/dbConnection.go index 872edf1..79de284 100644 --- a/internal/repository/dbConnection.go +++ b/internal/repository/dbConnection.go @@ -37,13 +37,7 @@ type DatabaseOptions struct { func setupSqlite(db *sql.DB) (err error) { pragmas := []string{ - // "journal_mode = WAL", - // "busy_timeout = 5000", - // "synchronous = NORMAL", - // "cache_size = 1000000000", // 1GB - // "foreign_keys = true", "temp_store = memory", - // "mmap_size = 3000000000", } for _, pragma := range pragmas { @@ -63,10 +57,10 @@ func Connect(driver string, db string) { dbConnOnce.Do(func() { opts := DatabaseOptions{ URL: db, - MaxOpenConnections: 4, - MaxIdleConnections: 4, - ConnectionMaxLifetime: time.Hour, - ConnectionMaxIdleTime: time.Hour, + MaxOpenConnections: repoConfig.MaxOpenConnections, + MaxIdleConnections: repoConfig.MaxIdleConnections, + ConnectionMaxLifetime: repoConfig.ConnectionMaxLifetime, + ConnectionMaxIdleTime: repoConfig.ConnectionMaxIdleTime, } switch driver { diff --git a/internal/repository/job.go b/internal/repository/job.go index 41c727b..032c342 100644 --- a/internal/repository/job.go +++ b/internal/repository/job.go @@ -45,7 +45,7 @@ func GetJobRepository() *JobRepository { driver: db.Driver, stmtCache: sq.NewStmtCache(db.DB), - cache: lrucache.New(1024 * 1024), + cache: lrucache.New(repoConfig.CacheSize), } }) return jobRepoInstance @@ -267,7 +267,31 @@ func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float6 func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) { var cnt int q := sq.Select("count(*)").From("job").Where("job.start_time < ?", startTime) - q.RunWith(r.DB).QueryRow().Scan(cnt) + if err := q.RunWith(r.DB).QueryRow().Scan(&cnt); err != nil { + cclog.Errorf("Error counting jobs before %d: %v", startTime, err) + return 0, err + } + + // Invalidate cache for jobs being deleted (get job IDs first) + if cnt > 0 { + var jobIds []int64 + rows, err := sq.Select("id").From("job").Where("job.start_time < ?", startTime).RunWith(r.DB).Query() + if err == nil { + defer rows.Close() + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err == nil { + jobIds = append(jobIds, id) + } + } + // Invalidate cache entries + for _, id := range jobIds { + r.cache.Del(fmt.Sprintf("metadata:%d", id)) + r.cache.Del(fmt.Sprintf("energyFootprint:%d", id)) + } + } + } + qd := sq.Delete("job").Where("job.start_time < ?", startTime) _, err := qd.RunWith(r.DB).Exec() @@ -281,6 +305,10 @@ func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) { } func (r *JobRepository) DeleteJobById(id int64) error { + // Invalidate cache entries before deletion + r.cache.Del(fmt.Sprintf("metadata:%d", id)) + r.cache.Del(fmt.Sprintf("energyFootprint:%d", id)) + qd := sq.Delete("job").Where("job.id = ?", id) _, err := qd.RunWith(r.DB).Exec() @@ -450,13 +478,14 @@ func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]in // FIXME: Set duration to requested walltime? func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error { start := time.Now() + currentTime := time.Now().Unix() res, err := sq.Update("job"). Set("monitoring_status", schema.MonitoringStatusArchivingFailed). Set("duration", 0). Set("job_state", schema.JobStateFailed). Where("job.job_state = 'running'"). Where("job.walltime > 0"). - Where(fmt.Sprintf("(%d - job.start_time) > (job.walltime + %d)", time.Now().Unix(), seconds)). + Where("(? - job.start_time) > (job.walltime + ?)", currentTime, seconds). RunWith(r.DB).Exec() if err != nil { cclog.Warn("Error while stopping jobs exceeding walltime") @@ -505,21 +534,21 @@ func (r *JobRepository) FindJobIdsByTag(tagId int64) ([]int64, error) { // FIXME: Reconsider filtering short jobs with harcoded threshold func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) { query := sq.Select(jobColumns...).From("job"). - Where(fmt.Sprintf("job.cluster = '%s'", cluster)). + Where("job.cluster = ?", cluster). Where("job.job_state = 'running'"). - Where("job.duration > 600") + Where("job.duration > ?", repoConfig.MinRunningJobDuration) rows, err := query.RunWith(r.stmtCache).Query() if err != nil { cclog.Error("Error while running query") return nil, err } + defer rows.Close() jobs := make([]*schema.Job, 0, 50) for rows.Next() { job, err := scanJob(rows) if err != nil { - rows.Close() cclog.Warn("Error while scanning rows") return nil, err } @@ -552,12 +581,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64 if startTimeBegin == 0 { cclog.Infof("Find jobs before %d", startTimeEnd) - query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf( - "job.start_time < %d", startTimeEnd)) + query = sq.Select(jobColumns...).From("job").Where("job.start_time < ?", startTimeEnd) } else { cclog.Infof("Find jobs between %d and %d", startTimeBegin, startTimeEnd) - query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf( - "job.start_time BETWEEN %d AND %d", startTimeBegin, startTimeEnd)) + query = sq.Select(jobColumns...).From("job").Where("job.start_time BETWEEN ? AND ?", startTimeBegin, startTimeEnd) } rows, err := query.RunWith(r.stmtCache).Query() @@ -565,12 +592,12 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64 cclog.Error("Error while running query") return nil, err } + defer rows.Close() jobs := make([]*schema.Job, 0, 50) for rows.Next() { job, err := scanJob(rows) if err != nil { - rows.Close() cclog.Warn("Error while scanning rows") return nil, err } @@ -582,6 +609,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64 } func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) { + // Invalidate cache entries as monitoring status affects job state + r.cache.Del(fmt.Sprintf("metadata:%d", job)) + r.cache.Del(fmt.Sprintf("energyFootprint:%d", job)) + stmt := sq.Update("job"). Set("monitoring_status", monitoringStatus). Where("job.id = ?", job) diff --git a/internal/repository/jobCreate.go b/internal/repository/jobCreate.go index 4c38169..2fcc69e 100644 --- a/internal/repository/jobCreate.go +++ b/internal/repository/jobCreate.go @@ -31,8 +31,9 @@ const NamedJobInsert string = `INSERT INTO job ( func (r *JobRepository) InsertJob(job *schema.Job) (int64, error) { r.Mutex.Lock() + defer r.Mutex.Unlock() + res, err := r.DB.NamedExec(NamedJobCacheInsert, job) - r.Mutex.Unlock() if err != nil { cclog.Warn("Error while NamedJobInsert") return 0, err @@ -57,12 +58,12 @@ func (r *JobRepository) SyncJobs() ([]*schema.Job, error) { cclog.Errorf("Error while running query %v", err) return nil, err } + defer rows.Close() jobs := make([]*schema.Job, 0, 50) for rows.Next() { job, err := scanJob(rows) if err != nil { - rows.Close() cclog.Warn("Error while scanning rows") return nil, err } @@ -113,6 +114,10 @@ func (r *JobRepository) Stop( state schema.JobState, monitoringStatus int32, ) (err error) { + // Invalidate cache entries as job state is changing + r.cache.Del(fmt.Sprintf("metadata:%d", jobId)) + r.cache.Del(fmt.Sprintf("energyFootprint:%d", jobId)) + stmt := sq.Update("job"). Set("job_state", state). Set("duration", duration). @@ -129,11 +134,13 @@ func (r *JobRepository) StopCached( state schema.JobState, monitoringStatus int32, ) (err error) { + // Note: StopCached updates job_cache table, not the main job table + // Cache invalidation happens when job is synced to main table stmt := sq.Update("job_cache"). Set("job_state", state). Set("duration", duration). Set("monitoring_status", monitoringStatus). - Where("job.id = ?", jobId) + Where("job_cache.id = ?", jobId) _, err = stmt.RunWith(r.stmtCache).Exec() return err diff --git a/internal/repository/jobFind.go b/internal/repository/jobFind.go index 39519d5..11f66c4 100644 --- a/internal/repository/jobFind.go +++ b/internal/repository/jobFind.go @@ -89,6 +89,7 @@ func (r *JobRepository) FindAll( cclog.Error("Error while running query") return nil, err } + defer rows.Close() jobs := make([]*schema.Job, 0, 10) for rows.Next() { @@ -103,25 +104,31 @@ func (r *JobRepository) FindAll( return jobs, nil } -// Get complete joblist only consisting of db ids. +// GetJobList returns job IDs for non-running jobs. // This is useful to process large job counts and intended to be used -// together with FindById to process jobs one by one -func (r *JobRepository) GetJobList() ([]int64, error) { +// together with FindById to process jobs one by one. +// Use limit and offset for pagination. Use limit=0 to get all results (not recommended for large datasets). +func (r *JobRepository) GetJobList(limit int, offset int) ([]int64, error) { query := sq.Select("id").From("job"). Where("job.job_state != 'running'") + // Add pagination if limit is specified + if limit > 0 { + query = query.Limit(uint64(limit)).Offset(uint64(offset)) + } + rows, err := query.RunWith(r.stmtCache).Query() if err != nil { cclog.Error("Error while running query") return nil, err } + defer rows.Close() jl := make([]int64, 0, 1000) for rows.Next() { var id int64 err := rows.Scan(&id) if err != nil { - rows.Close() cclog.Warn("Error while scanning rows") return nil, err } @@ -256,6 +263,7 @@ func (r *JobRepository) FindConcurrentJobs( cclog.Errorf("Error while running query: %v", err) return nil, err } + defer rows.Close() items := make([]*model.JobLink, 0, 10) queryString := fmt.Sprintf("cluster=%s", job.Cluster) @@ -283,6 +291,7 @@ func (r *JobRepository) FindConcurrentJobs( cclog.Errorf("Error while running query: %v", err) return nil, err } + defer rows.Close() for rows.Next() { var id, jobId, startTime sql.NullInt64 diff --git a/internal/repository/node.go b/internal/repository/node.go index c3152f4..f9c056e 100644 --- a/internal/repository/node.go +++ b/internal/repository/node.go @@ -43,7 +43,7 @@ func GetNodeRepository() *NodeRepository { driver: db.Driver, stmtCache: sq.NewStmtCache(db.DB), - cache: lrucache.New(1024 * 1024), + cache: lrucache.New(repoConfig.CacheSize), } }) return nodeRepoInstance @@ -77,49 +77,12 @@ func (r *NodeRepository) FetchMetadata(hostname string, cluster string) (map[str return MetaData, nil } -// -// func (r *NodeRepository) UpdateMetadata(node *schema.Node, key, val string) (err error) { -// cachekey := fmt.Sprintf("metadata:%d", node.ID) -// r.cache.Del(cachekey) -// if node.MetaData == nil { -// if _, err = r.FetchMetadata(node); err != nil { -// cclog.Warnf("Error while fetching metadata for node, DB ID '%v'", node.ID) -// return err -// } -// } -// -// if node.MetaData != nil { -// cpy := make(map[string]string, len(node.MetaData)+1) -// maps.Copy(cpy, node.MetaData) -// cpy[key] = val -// node.MetaData = cpy -// } else { -// node.MetaData = map[string]string{key: val} -// } -// -// if node.RawMetaData, err = json.Marshal(node.MetaData); err != nil { -// cclog.Warnf("Error while marshaling metadata for node, DB ID '%v'", node.ID) -// return err -// } -// -// if _, err = sq.Update("node"). -// Set("meta_data", node.RawMetaData). -// Where("node.id = ?", node.ID). -// RunWith(r.stmtCache).Exec(); err != nil { -// cclog.Warnf("Error while updating metadata for node, DB ID '%v'", node.ID) -// return err -// } -// -// r.cache.Put(cachekey, node.MetaData, len(node.RawMetaData), 24*time.Hour) -// return nil -// } - func (r *NodeRepository) GetNode(hostname string, cluster string, withMeta bool) (*schema.Node, error) { node := &schema.Node{} if err := sq.Select("node.hostname", "node.cluster", "node.subcluster", "node_state.node_state", "node_state.health_state", "MAX(node_state.time_stamp) as time"). From("node_state"). - Join("node ON nodes_state.node_id = node.id"). + Join("node ON node_state.node_id = node.id"). Where("node.hostname = ?", hostname). Where("node.cluster = ?", cluster). GroupBy("node_state.node_id"). @@ -147,7 +110,7 @@ func (r *NodeRepository) GetNodeById(id int64, withMeta bool) (*schema.Node, err if err := sq.Select("node.hostname", "node.cluster", "node.subcluster", "node_state.node_state", "node_state.health_state", "MAX(node_state.time_stamp) as time"). From("node_state"). - Join("node ON nodes_state.node_id = node.id"). + Join("node ON node_state.node_id = node.id"). Where("node.id = ?", id). GroupBy("node_state.node_id"). RunWith(r.DB). @@ -278,7 +241,7 @@ func (r *NodeRepository) QueryNodes( sq.Select("node.hostname", "node.cluster", "node.subcluster", "node_state.node_state", "node_state.health_state", "MAX(node_state.time_stamp) as time"). From("node"). - Join("node_state ON nodes_state.node_id = node.id")) + Join("node_state ON node_state.node_id = node.id")) if qerr != nil { return nil, qerr } diff --git a/internal/repository/stats.go b/internal/repository/stats.go index f6f3aa9..ba0d09f 100644 --- a/internal/repository/stats.go +++ b/internal/repository/stats.go @@ -114,16 +114,6 @@ func (r *JobRepository) buildStatsQuery( return query } -// func (r *JobRepository) getUserName(ctx context.Context, id string) string { -// user := GetUserFromContext(ctx) -// name, _ := r.FindColumnValue(user, id, "hpc_user", "name", "username", false) -// if name != "" { -// return name -// } else { -// return "-" -// } -// } - func (r *JobRepository) getCastType() string { var castType string diff --git a/internal/repository/transaction.go b/internal/repository/transaction.go index 39941c1..9074428 100644 --- a/internal/repository/transaction.go +++ b/internal/repository/transaction.go @@ -5,84 +5,96 @@ package repository import ( - cclog "github.com/ClusterCockpit/cc-lib/ccLogger" + "fmt" + "github.com/jmoiron/sqlx" ) +// Transaction wraps a database transaction for job-related operations. type Transaction struct { - tx *sqlx.Tx - stmt *sqlx.NamedStmt + tx *sqlx.Tx } +// TransactionInit begins a new transaction. func (r *JobRepository) TransactionInit() (*Transaction, error) { - var err error - t := new(Transaction) - - t.tx, err = r.DB.Beginx() + tx, err := r.DB.Beginx() if err != nil { - cclog.Warn("Error while bundling transactions") - return nil, err + return nil, fmt.Errorf("beginning transaction: %w", err) } - return t, nil + return &Transaction{tx: tx}, nil } -func (r *JobRepository) TransactionCommit(t *Transaction) error { - var err error - if t.tx != nil { - if err = t.tx.Commit(); err != nil { - cclog.Warn("Error while committing transactions") - return err - } +// Commit commits the transaction. +// After calling Commit, the transaction should not be used again. +func (t *Transaction) Commit() error { + if t.tx == nil { + return fmt.Errorf("transaction already committed or rolled back") } - - t.tx, err = r.DB.Beginx() + err := t.tx.Commit() + t.tx = nil // Mark as completed if err != nil { - cclog.Warn("Error while bundling transactions") - return err + return fmt.Errorf("committing transaction: %w", err) } - return nil } +// Rollback rolls back the transaction. +// It's safe to call Rollback on an already committed or rolled back transaction. +func (t *Transaction) Rollback() error { + if t.tx == nil { + return nil // Already committed/rolled back + } + err := t.tx.Rollback() + t.tx = nil // Mark as completed + if err != nil { + return fmt.Errorf("rolling back transaction: %w", err) + } + return nil +} + +// TransactionEnd commits the transaction. +// Deprecated: Use Commit() instead. func (r *JobRepository) TransactionEnd(t *Transaction) error { - if err := t.tx.Commit(); err != nil { - cclog.Warn("Error while committing SQL transactions") - return err - } - return nil + return t.Commit() } +// TransactionAddNamed executes a named query within the transaction. func (r *JobRepository) TransactionAddNamed( t *Transaction, query string, args ...interface{}, ) (int64, error) { + if t.tx == nil { + return 0, fmt.Errorf("transaction is nil or already completed") + } + res, err := t.tx.NamedExec(query, args) if err != nil { - cclog.Errorf("Named Exec failed: %v", err) - return 0, err + return 0, fmt.Errorf("named exec: %w", err) } id, err := res.LastInsertId() if err != nil { - cclog.Errorf("repository initDB(): %v", err) - return 0, err + return 0, fmt.Errorf("getting last insert id: %w", err) } return id, nil } +// TransactionAdd executes a query within the transaction. func (r *JobRepository) TransactionAdd(t *Transaction, query string, args ...interface{}) (int64, error) { + if t.tx == nil { + return 0, fmt.Errorf("transaction is nil or already completed") + } + res, err := t.tx.Exec(query, args...) if err != nil { - cclog.Errorf("TransactionAdd(), Exec() Error: %v", err) - return 0, err + return 0, fmt.Errorf("exec: %w", err) } id, err := res.LastInsertId() if err != nil { - cclog.Errorf("TransactionAdd(), LastInsertId() Error: %v", err) - return 0, err + return 0, fmt.Errorf("getting last insert id: %w", err) } return id, nil diff --git a/internal/tagger/tagger.go b/internal/tagger/tagger.go index 7558914..028d9ef 100644 --- a/internal/tagger/tagger.go +++ b/internal/tagger/tagger.go @@ -91,7 +91,7 @@ func (jt *JobTagger) JobStopCallback(job *schema.Job) { func RunTaggers() error { newTagger() r := repository.GetJobRepository() - jl, err := r.GetJobList() + jl, err := r.GetJobList(0, 0) // 0 limit means get all jobs (no pagination) if err != nil { cclog.Errorf("Error while getting job list %s", err) return err