mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2025-12-15 19:56:16 +01:00
Merge branch 'ai-review' into dev
This commit is contained in:
68
internal/repository/config.go
Normal file
68
internal/repository/config.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package repository
|
||||
|
||||
import "time"
|
||||
|
||||
// RepositoryConfig holds configuration for repository operations.
|
||||
// All fields have sensible defaults, so this configuration is optional.
|
||||
type RepositoryConfig struct {
|
||||
// CacheSize is the LRU cache size in bytes for job metadata and energy footprints.
|
||||
// Default: 1MB (1024 * 1024 bytes)
|
||||
CacheSize int
|
||||
|
||||
// MaxOpenConnections is the maximum number of open database connections.
|
||||
// Default: 4
|
||||
MaxOpenConnections int
|
||||
|
||||
// MaxIdleConnections is the maximum number of idle database connections.
|
||||
// Default: 4
|
||||
MaxIdleConnections int
|
||||
|
||||
// ConnectionMaxLifetime is the maximum amount of time a connection may be reused.
|
||||
// Default: 1 hour
|
||||
ConnectionMaxLifetime time.Duration
|
||||
|
||||
// ConnectionMaxIdleTime is the maximum amount of time a connection may be idle.
|
||||
// Default: 1 hour
|
||||
ConnectionMaxIdleTime time.Duration
|
||||
|
||||
// MinRunningJobDuration is the minimum duration in seconds for a job to be
|
||||
// considered in "running jobs" queries. This filters out very short jobs.
|
||||
// Default: 600 seconds (10 minutes)
|
||||
MinRunningJobDuration int
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default repository configuration.
|
||||
// These values are optimized for typical deployments.
|
||||
func DefaultConfig() *RepositoryConfig {
|
||||
return &RepositoryConfig{
|
||||
CacheSize: 1 * 1024 * 1024, // 1MB
|
||||
MaxOpenConnections: 4,
|
||||
MaxIdleConnections: 4,
|
||||
ConnectionMaxLifetime: time.Hour,
|
||||
ConnectionMaxIdleTime: time.Hour,
|
||||
MinRunningJobDuration: 600, // 10 minutes
|
||||
}
|
||||
}
|
||||
|
||||
// repoConfig is the package-level configuration instance.
|
||||
// It is initialized with defaults and can be overridden via SetConfig.
|
||||
var repoConfig *RepositoryConfig = DefaultConfig()
|
||||
|
||||
// SetConfig sets the repository configuration.
|
||||
// This must be called before any repository initialization (Connect, GetJobRepository, etc.).
|
||||
// If not called, default values from DefaultConfig() are used.
|
||||
func SetConfig(cfg *RepositoryConfig) {
|
||||
if cfg != nil {
|
||||
repoConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig returns the current repository configuration.
|
||||
func GetConfig() *RepositoryConfig {
|
||||
return repoConfig
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
@@ -35,21 +36,15 @@ type DatabaseOptions struct {
|
||||
ConnectionMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func setupSqlite(db *sql.DB) (err error) {
|
||||
func setupSqlite(db *sql.DB) error {
|
||||
pragmas := []string{
|
||||
// "journal_mode = WAL",
|
||||
// "busy_timeout = 5000",
|
||||
// "synchronous = NORMAL",
|
||||
// "cache_size = 1000000000", // 1GB
|
||||
// "foreign_keys = true",
|
||||
"temp_store = memory",
|
||||
// "mmap_size = 3000000000",
|
||||
}
|
||||
|
||||
for _, pragma := range pragmas {
|
||||
_, err = db.Exec("PRAGMA " + pragma)
|
||||
_, err := db.Exec("PRAGMA " + pragma)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,24 +58,24 @@ func Connect(driver string, db string) {
|
||||
dbConnOnce.Do(func() {
|
||||
opts := DatabaseOptions{
|
||||
URL: db,
|
||||
MaxOpenConnections: 4,
|
||||
MaxIdleConnections: 4,
|
||||
ConnectionMaxLifetime: time.Hour,
|
||||
ConnectionMaxIdleTime: time.Hour,
|
||||
MaxOpenConnections: repoConfig.MaxOpenConnections,
|
||||
MaxIdleConnections: repoConfig.MaxIdleConnections,
|
||||
ConnectionMaxLifetime: repoConfig.ConnectionMaxLifetime,
|
||||
ConnectionMaxIdleTime: repoConfig.ConnectionMaxIdleTime,
|
||||
}
|
||||
|
||||
switch driver {
|
||||
case "sqlite3":
|
||||
// TODO: Have separate DB handles for Writes and Reads
|
||||
// Optimize SQLite connection: https://kerkour.com/sqlite-for-servers
|
||||
connectionUrlParams := make(url.Values)
|
||||
connectionUrlParams.Add("_txlock", "immediate")
|
||||
connectionUrlParams.Add("_journal_mode", "WAL")
|
||||
connectionUrlParams.Add("_busy_timeout", "5000")
|
||||
connectionUrlParams.Add("_synchronous", "NORMAL")
|
||||
connectionUrlParams.Add("_cache_size", "1000000000")
|
||||
connectionUrlParams.Add("_foreign_keys", "true")
|
||||
opts.URL = fmt.Sprintf("file:%s?%s", opts.URL, connectionUrlParams.Encode())
|
||||
connectionURLParams := make(url.Values)
|
||||
connectionURLParams.Add("_txlock", "immediate")
|
||||
connectionURLParams.Add("_journal_mode", "WAL")
|
||||
connectionURLParams.Add("_busy_timeout", "5000")
|
||||
connectionURLParams.Add("_synchronous", "NORMAL")
|
||||
connectionURLParams.Add("_cache_size", "1000000000")
|
||||
connectionURLParams.Add("_foreign_keys", "true")
|
||||
opts.URL = fmt.Sprintf("file:%s?%s", opts.URL, connectionURLParams.Encode())
|
||||
|
||||
if cclog.Loglevel() == "debug" {
|
||||
sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
|
||||
@@ -89,7 +84,10 @@ func Connect(driver string, db string) {
|
||||
dbHandle, err = sqlx.Open("sqlite3", opts.URL)
|
||||
}
|
||||
|
||||
setupSqlite(dbHandle.DB)
|
||||
err = setupSqlite(dbHandle.DB)
|
||||
if err != nil {
|
||||
cclog.Abortf("Failed sqlite db setup.\nError: %s\n", err.Error())
|
||||
}
|
||||
case "mysql":
|
||||
opts.URL += "?multiStatements=true"
|
||||
dbHandle, err = sqlx.Open("mysql", opts.URL)
|
||||
|
||||
@@ -2,6 +2,63 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package repository provides the data access layer for cc-backend using the repository pattern.
|
||||
//
|
||||
// The repository pattern abstracts database operations and provides a clean interface for
|
||||
// data access. Each major entity (Job, User, Node, Tag) has its own repository with CRUD
|
||||
// operations and specialized queries.
|
||||
//
|
||||
// # Database Connection
|
||||
//
|
||||
// Initialize the database connection before using any repository:
|
||||
//
|
||||
// repository.Connect("sqlite3", "./var/job.db")
|
||||
// // or for MySQL:
|
||||
// repository.Connect("mysql", "user:password@tcp(localhost:3306)/dbname")
|
||||
//
|
||||
// # Configuration
|
||||
//
|
||||
// Optional: Configure repository settings before initialization:
|
||||
//
|
||||
// repository.SetConfig(&repository.RepositoryConfig{
|
||||
// CacheSize: 2 * 1024 * 1024, // 2MB cache
|
||||
// MaxOpenConnections: 8, // Connection pool size
|
||||
// MinRunningJobDuration: 300, // Filter threshold
|
||||
// })
|
||||
//
|
||||
// If not configured, sensible defaults are used automatically.
|
||||
//
|
||||
// # Repositories
|
||||
//
|
||||
// - JobRepository: Job lifecycle management and querying
|
||||
// - UserRepository: User management and authentication
|
||||
// - NodeRepository: Cluster node state tracking
|
||||
// - Tags: Job tagging and categorization
|
||||
//
|
||||
// # Caching
|
||||
//
|
||||
// Repositories use LRU caching to improve performance. Cache keys are constructed
|
||||
// as "type:id" (e.g., "metadata:123"). Cache is automatically invalidated on
|
||||
// mutations to maintain consistency.
|
||||
//
|
||||
// # Transaction Support
|
||||
//
|
||||
// For batch operations, use transactions:
|
||||
//
|
||||
// t, err := jobRepo.TransactionInit()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer t.Rollback() // Rollback if not committed
|
||||
//
|
||||
// // Perform operations...
|
||||
// jobRepo.TransactionAdd(t, query, args...)
|
||||
//
|
||||
// // Commit when done
|
||||
// if err := t.Commit(); err != nil {
|
||||
// return err
|
||||
// }
|
||||
package repository
|
||||
|
||||
import (
|
||||
@@ -45,7 +102,7 @@ func GetJobRepository() *JobRepository {
|
||||
driver: db.Driver,
|
||||
|
||||
stmtCache: sq.NewStmtCache(db.DB),
|
||||
cache: lrucache.New(1024 * 1024),
|
||||
cache: lrucache.New(repoConfig.CacheSize),
|
||||
}
|
||||
})
|
||||
return jobRepoInstance
|
||||
@@ -267,7 +324,31 @@ func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float6
|
||||
func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
|
||||
var cnt int
|
||||
q := sq.Select("count(*)").From("job").Where("job.start_time < ?", startTime)
|
||||
q.RunWith(r.DB).QueryRow().Scan(cnt)
|
||||
if err := q.RunWith(r.DB).QueryRow().Scan(&cnt); err != nil {
|
||||
cclog.Errorf("Error counting jobs before %d: %v", startTime, err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Invalidate cache for jobs being deleted (get job IDs first)
|
||||
if cnt > 0 {
|
||||
var jobIds []int64
|
||||
rows, err := sq.Select("id").From("job").Where("job.start_time < ?", startTime).RunWith(r.DB).Query()
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err == nil {
|
||||
jobIds = append(jobIds, id)
|
||||
}
|
||||
}
|
||||
// Invalidate cache entries
|
||||
for _, id := range jobIds {
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", id))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
qd := sq.Delete("job").Where("job.start_time < ?", startTime)
|
||||
_, err := qd.RunWith(r.DB).Exec()
|
||||
|
||||
@@ -281,6 +362,10 @@ func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
|
||||
}
|
||||
|
||||
func (r *JobRepository) DeleteJobById(id int64) error {
|
||||
// Invalidate cache entries before deletion
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", id))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", id))
|
||||
|
||||
qd := sq.Delete("job").Where("job.id = ?", id)
|
||||
_, err := qd.RunWith(r.DB).Exec()
|
||||
|
||||
@@ -450,13 +535,14 @@ func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]in
|
||||
// FIXME: Set duration to requested walltime?
|
||||
func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error {
|
||||
start := time.Now()
|
||||
currentTime := time.Now().Unix()
|
||||
res, err := sq.Update("job").
|
||||
Set("monitoring_status", schema.MonitoringStatusArchivingFailed).
|
||||
Set("duration", 0).
|
||||
Set("job_state", schema.JobStateFailed).
|
||||
Where("job.job_state = 'running'").
|
||||
Where("job.walltime > 0").
|
||||
Where(fmt.Sprintf("(%d - job.start_time) > (job.walltime + %d)", time.Now().Unix(), seconds)).
|
||||
Where("(? - job.start_time) > (job.walltime + ?)", currentTime, seconds).
|
||||
RunWith(r.DB).Exec()
|
||||
if err != nil {
|
||||
cclog.Warn("Error while stopping jobs exceeding walltime")
|
||||
@@ -505,21 +591,21 @@ func (r *JobRepository) FindJobIdsByTag(tagId int64) ([]int64, error) {
|
||||
// FIXME: Reconsider filtering short jobs with harcoded threshold
|
||||
func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) {
|
||||
query := sq.Select(jobColumns...).From("job").
|
||||
Where(fmt.Sprintf("job.cluster = '%s'", cluster)).
|
||||
Where("job.cluster = ?", cluster).
|
||||
Where("job.job_state = 'running'").
|
||||
Where("job.duration > 600")
|
||||
Where("job.duration > ?", repoConfig.MinRunningJobDuration)
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 50)
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -552,12 +638,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
|
||||
|
||||
if startTimeBegin == 0 {
|
||||
cclog.Infof("Find jobs before %d", startTimeEnd)
|
||||
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
|
||||
"job.start_time < %d", startTimeEnd))
|
||||
query = sq.Select(jobColumns...).From("job").Where("job.start_time < ?", startTimeEnd)
|
||||
} else {
|
||||
cclog.Infof("Find jobs between %d and %d", startTimeBegin, startTimeEnd)
|
||||
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
|
||||
"job.start_time BETWEEN %d AND %d", startTimeBegin, startTimeEnd))
|
||||
query = sq.Select(jobColumns...).From("job").Where("job.start_time BETWEEN ? AND ?", startTimeBegin, startTimeEnd)
|
||||
}
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
@@ -565,12 +649,12 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 50)
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -582,6 +666,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
|
||||
}
|
||||
|
||||
func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) {
|
||||
// Invalidate cache entries as monitoring status affects job state
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", job))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", job))
|
||||
|
||||
stmt := sq.Update("job").
|
||||
Set("monitoring_status", monitoringStatus).
|
||||
Where("job.id = ?", job)
|
||||
|
||||
@@ -31,8 +31,9 @@ const NamedJobInsert string = `INSERT INTO job (
|
||||
|
||||
func (r *JobRepository) InsertJob(job *schema.Job) (int64, error) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
|
||||
res, err := r.DB.NamedExec(NamedJobCacheInsert, job)
|
||||
r.Mutex.Unlock()
|
||||
if err != nil {
|
||||
cclog.Warn("Error while NamedJobInsert")
|
||||
return 0, err
|
||||
@@ -57,12 +58,12 @@ func (r *JobRepository) SyncJobs() ([]*schema.Job, error) {
|
||||
cclog.Errorf("Error while running query %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 50)
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,6 +114,10 @@ func (r *JobRepository) Stop(
|
||||
state schema.JobState,
|
||||
monitoringStatus int32,
|
||||
) (err error) {
|
||||
// Invalidate cache entries as job state is changing
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", jobId))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", jobId))
|
||||
|
||||
stmt := sq.Update("job").
|
||||
Set("job_state", state).
|
||||
Set("duration", duration).
|
||||
@@ -129,11 +134,13 @@ func (r *JobRepository) StopCached(
|
||||
state schema.JobState,
|
||||
monitoringStatus int32,
|
||||
) (err error) {
|
||||
// Note: StopCached updates job_cache table, not the main job table
|
||||
// Cache invalidation happens when job is synced to main table
|
||||
stmt := sq.Update("job_cache").
|
||||
Set("job_state", state).
|
||||
Set("duration", duration).
|
||||
Set("monitoring_status", monitoringStatus).
|
||||
Where("job.id = ?", jobId)
|
||||
Where("job_cache.id = ?", jobId)
|
||||
|
||||
_, err = stmt.RunWith(r.stmtCache).Exec()
|
||||
return err
|
||||
|
||||
@@ -89,6 +89,7 @@ func (r *JobRepository) FindAll(
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 10)
|
||||
for rows.Next() {
|
||||
@@ -103,25 +104,31 @@ func (r *JobRepository) FindAll(
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// Get complete joblist only consisting of db ids.
|
||||
// GetJobList returns job IDs for non-running jobs.
|
||||
// This is useful to process large job counts and intended to be used
|
||||
// together with FindById to process jobs one by one
|
||||
func (r *JobRepository) GetJobList() ([]int64, error) {
|
||||
// together with FindById to process jobs one by one.
|
||||
// Use limit and offset for pagination. Use limit=0 to get all results (not recommended for large datasets).
|
||||
func (r *JobRepository) GetJobList(limit int, offset int) ([]int64, error) {
|
||||
query := sq.Select("id").From("job").
|
||||
Where("job.job_state != 'running'")
|
||||
|
||||
// Add pagination if limit is specified
|
||||
if limit > 0 {
|
||||
query = query.Limit(uint64(limit)).Offset(uint64(offset))
|
||||
}
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jl := make([]int64, 0, 1000)
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
err := rows.Scan(&id)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -256,6 +263,7 @@ func (r *JobRepository) FindConcurrentJobs(
|
||||
cclog.Errorf("Error while running query: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]*model.JobLink, 0, 10)
|
||||
queryString := fmt.Sprintf("cluster=%s", job.Cluster)
|
||||
@@ -283,6 +291,7 @@ func (r *JobRepository) FindConcurrentJobs(
|
||||
cclog.Errorf("Error while running query: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, jobId, startTime sql.NullInt64
|
||||
|
||||
@@ -43,7 +43,7 @@ func GetNodeRepository() *NodeRepository {
|
||||
driver: db.Driver,
|
||||
|
||||
stmtCache: sq.NewStmtCache(db.DB),
|
||||
cache: lrucache.New(1024 * 1024),
|
||||
cache: lrucache.New(repoConfig.CacheSize),
|
||||
}
|
||||
})
|
||||
return nodeRepoInstance
|
||||
@@ -77,43 +77,6 @@ func (r *NodeRepository) FetchMetadata(hostname string, cluster string) (map[str
|
||||
return MetaData, nil
|
||||
}
|
||||
|
||||
//
|
||||
// func (r *NodeRepository) UpdateMetadata(node *schema.Node, key, val string) (err error) {
|
||||
// cachekey := fmt.Sprintf("metadata:%d", node.ID)
|
||||
// r.cache.Del(cachekey)
|
||||
// if node.MetaData == nil {
|
||||
// if _, err = r.FetchMetadata(node); err != nil {
|
||||
// cclog.Warnf("Error while fetching metadata for node, DB ID '%v'", node.ID)
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if node.MetaData != nil {
|
||||
// cpy := make(map[string]string, len(node.MetaData)+1)
|
||||
// maps.Copy(cpy, node.MetaData)
|
||||
// cpy[key] = val
|
||||
// node.MetaData = cpy
|
||||
// } else {
|
||||
// node.MetaData = map[string]string{key: val}
|
||||
// }
|
||||
//
|
||||
// if node.RawMetaData, err = json.Marshal(node.MetaData); err != nil {
|
||||
// cclog.Warnf("Error while marshaling metadata for node, DB ID '%v'", node.ID)
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// if _, err = sq.Update("node").
|
||||
// Set("meta_data", node.RawMetaData).
|
||||
// Where("node.id = ?", node.ID).
|
||||
// RunWith(r.stmtCache).Exec(); err != nil {
|
||||
// cclog.Warnf("Error while updating metadata for node, DB ID '%v'", node.ID)
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// r.cache.Put(cachekey, node.MetaData, len(node.RawMetaData), 24*time.Hour)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (r *NodeRepository) GetNode(hostname string, cluster string, withMeta bool) (*schema.Node, error) {
|
||||
node := &schema.Node{}
|
||||
var timestamp int
|
||||
|
||||
@@ -115,7 +115,7 @@ func nodeTestSetup(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "version.txt"),
|
||||
fmt.Appendf(nil, "%d", 2), 0o666); err != nil {
|
||||
fmt.Appendf(nil, "%d", 3), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -114,16 +114,6 @@ func (r *JobRepository) buildStatsQuery(
|
||||
return query
|
||||
}
|
||||
|
||||
// func (r *JobRepository) getUserName(ctx context.Context, id string) string {
|
||||
// user := GetUserFromContext(ctx)
|
||||
// name, _ := r.FindColumnValue(user, id, "hpc_user", "name", "username", false)
|
||||
// if name != "" {
|
||||
// return name
|
||||
// } else {
|
||||
// return "-"
|
||||
// }
|
||||
// }
|
||||
|
||||
func (r *JobRepository) getCastType() string {
|
||||
var castType string
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -14,65 +15,32 @@ import (
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
)
|
||||
|
||||
// Add the tag with id `tagId` to the job with the database id `jobId`.
|
||||
// AddTag adds the tag with id `tagId` to the job with the database id `jobId`.
|
||||
// Requires user authentication for security checks.
|
||||
func (r *JobRepository) AddTag(user *schema.User, job int64, tag int64) ([]*schema.Tag, error) {
|
||||
j, err := r.FindByIdWithUser(user, job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while finding job by id")
|
||||
cclog.Warnf("Error finding job %d for user %s: %v", job, user.Username, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := sq.Insert("jobtag").Columns("job_id", "tag_id").Values(job, tag)
|
||||
|
||||
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
s, _, _ := q.ToSql()
|
||||
cclog.Errorf("Error adding tag with %s: %v", s, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags, err := r.GetTags(user, &job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
archiveTags, err := r.getArchiveTags(&job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tags, archive.UpdateTags(j, archiveTags)
|
||||
return r.addJobTag(job, tag, j, func() ([]*schema.Tag, error) {
|
||||
return r.GetTags(user, &job)
|
||||
})
|
||||
}
|
||||
|
||||
// AddTagDirect adds a tag without user security checks.
|
||||
// Use only for internal/admin operations.
|
||||
func (r *JobRepository) AddTagDirect(job int64, tag int64) ([]*schema.Tag, error) {
|
||||
j, err := r.FindByIdDirect(job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while finding job by id")
|
||||
cclog.Warnf("Error finding job %d: %v", job, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := sq.Insert("jobtag").Columns("job_id", "tag_id").Values(job, tag)
|
||||
|
||||
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
s, _, _ := q.ToSql()
|
||||
cclog.Errorf("Error adding tag with %s: %v", s, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags, err := r.GetTagsDirect(&job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
archiveTags, err := r.getArchiveTags(&job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tags, archive.UpdateTags(j, archiveTags)
|
||||
return r.addJobTag(job, tag, j, func() ([]*schema.Tag, error) {
|
||||
return r.GetTagsDirect(&job)
|
||||
})
|
||||
}
|
||||
|
||||
// Removes a tag from a job by tag id.
|
||||
@@ -260,15 +228,18 @@ func (r *JobRepository) CountTags(user *schema.User) (tags []schema.Tag, counts
|
||||
LeftJoin("jobtag jt ON t.id = jt.tag_id").
|
||||
GroupBy("t.tag_name")
|
||||
|
||||
// Handle Scope Filtering
|
||||
scopeList := "\"global\""
|
||||
// Build scope list for filtering
|
||||
var scopeBuilder strings.Builder
|
||||
scopeBuilder.WriteString(`"global"`)
|
||||
if user != nil {
|
||||
scopeList += ",\"" + user.Username + "\""
|
||||
scopeBuilder.WriteString(`,"`)
|
||||
scopeBuilder.WriteString(user.Username)
|
||||
scopeBuilder.WriteString(`"`)
|
||||
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) {
|
||||
scopeBuilder.WriteString(`,"admin"`)
|
||||
}
|
||||
}
|
||||
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) {
|
||||
scopeList += ",\"admin\""
|
||||
}
|
||||
q = q.Where("t.tag_scope IN (" + scopeList + ")")
|
||||
q = q.Where("t.tag_scope IN (" + scopeBuilder.String() + ")")
|
||||
|
||||
// Handle Job Ownership
|
||||
if user != nil && user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) { // ADMIN || SUPPORT: Count all jobs
|
||||
@@ -302,6 +273,41 @@ func (r *JobRepository) CountTags(user *schema.User) (tags []schema.Tag, counts
|
||||
return tags, counts, err
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTagNotFound = errors.New("the tag does not exist")
|
||||
ErrJobNotOwned = errors.New("user is not owner of job")
|
||||
ErrTagNoAccess = errors.New("user not permitted to use that tag")
|
||||
ErrTagPrivateScope = errors.New("tag is private to another user")
|
||||
ErrTagAdminScope = errors.New("tag requires admin privileges")
|
||||
ErrTagsIncompatScopes = errors.New("combining admin and non-admin scoped tags not allowed")
|
||||
)
|
||||
|
||||
// addJobTag is a helper function that inserts a job-tag association and updates the archive.
|
||||
// Returns the updated tag list for the job.
|
||||
func (r *JobRepository) addJobTag(jobId int64, tagId int64, job *schema.Job, getTags func() ([]*schema.Tag, error)) ([]*schema.Tag, error) {
|
||||
q := sq.Insert("jobtag").Columns("job_id", "tag_id").Values(jobId, tagId)
|
||||
|
||||
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
s, _, _ := q.ToSql()
|
||||
cclog.Errorf("Error adding tag with %s: %v", s, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags, err := getTags()
|
||||
if err != nil {
|
||||
cclog.Warnf("Error getting tags for job %d: %v", jobId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
archiveTags, err := r.getArchiveTags(&jobId)
|
||||
if err != nil {
|
||||
cclog.Warnf("Error getting archive tags for job %d: %v", jobId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tags, archive.UpdateTags(job, archiveTags)
|
||||
}
|
||||
|
||||
// AddTagOrCreate adds the tag with the specified type and name to the job with the database id `jobId`.
|
||||
// If such a tag does not yet exist, it is created.
|
||||
func (r *JobRepository) AddTagOrCreate(user *schema.User, jobId int64, tagType string, tagName string, tagScope string) (tagId int64, err error) {
|
||||
|
||||
@@ -5,84 +5,96 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// Transaction wraps a database transaction for job-related operations.
|
||||
type Transaction struct {
|
||||
tx *sqlx.Tx
|
||||
stmt *sqlx.NamedStmt
|
||||
tx *sqlx.Tx
|
||||
}
|
||||
|
||||
// TransactionInit begins a new transaction.
|
||||
func (r *JobRepository) TransactionInit() (*Transaction, error) {
|
||||
var err error
|
||||
t := new(Transaction)
|
||||
|
||||
t.tx, err = r.DB.Beginx()
|
||||
tx, err := r.DB.Beginx()
|
||||
if err != nil {
|
||||
cclog.Warn("Error while bundling transactions")
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("beginning transaction: %w", err)
|
||||
}
|
||||
return t, nil
|
||||
return &Transaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (r *JobRepository) TransactionCommit(t *Transaction) error {
|
||||
var err error
|
||||
if t.tx != nil {
|
||||
if err = t.tx.Commit(); err != nil {
|
||||
cclog.Warn("Error while committing transactions")
|
||||
return err
|
||||
}
|
||||
// Commit commits the transaction.
|
||||
// After calling Commit, the transaction should not be used again.
|
||||
func (t *Transaction) Commit() error {
|
||||
if t.tx == nil {
|
||||
return fmt.Errorf("transaction already committed or rolled back")
|
||||
}
|
||||
|
||||
t.tx, err = r.DB.Beginx()
|
||||
err := t.tx.Commit()
|
||||
t.tx = nil // Mark as completed
|
||||
if err != nil {
|
||||
cclog.Warn("Error while bundling transactions")
|
||||
return err
|
||||
return fmt.Errorf("committing transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback rolls back the transaction.
|
||||
// It's safe to call Rollback on an already committed or rolled back transaction.
|
||||
func (t *Transaction) Rollback() error {
|
||||
if t.tx == nil {
|
||||
return nil // Already committed/rolled back
|
||||
}
|
||||
err := t.tx.Rollback()
|
||||
t.tx = nil // Mark as completed
|
||||
if err != nil {
|
||||
return fmt.Errorf("rolling back transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TransactionEnd commits the transaction.
|
||||
// Deprecated: Use Commit() instead.
|
||||
func (r *JobRepository) TransactionEnd(t *Transaction) error {
|
||||
if err := t.tx.Commit(); err != nil {
|
||||
cclog.Warn("Error while committing SQL transactions")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return t.Commit()
|
||||
}
|
||||
|
||||
// TransactionAddNamed executes a named query within the transaction.
|
||||
func (r *JobRepository) TransactionAddNamed(
|
||||
t *Transaction,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (int64, error) {
|
||||
if t.tx == nil {
|
||||
return 0, fmt.Errorf("transaction is nil or already completed")
|
||||
}
|
||||
|
||||
res, err := t.tx.NamedExec(query, args)
|
||||
if err != nil {
|
||||
cclog.Errorf("Named Exec failed: %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("named exec: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
cclog.Errorf("repository initDB(): %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("getting last insert id: %w", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// TransactionAdd executes a query within the transaction.
|
||||
func (r *JobRepository) TransactionAdd(t *Transaction, query string, args ...interface{}) (int64, error) {
|
||||
if t.tx == nil {
|
||||
return 0, fmt.Errorf("transaction is nil or already completed")
|
||||
}
|
||||
|
||||
res, err := t.tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
cclog.Errorf("TransactionAdd(), Exec() Error: %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("exec: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
cclog.Errorf("TransactionAdd(), LastInsertId() Error: %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("getting last insert id: %w", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
|
||||
Reference in New Issue
Block a user