Refactor repository

Fix issues
Improve transaction API
Make hardcoded constants configurable
Make error messages consistent and always add context info
This commit is contained in:
2025-11-20 06:58:45 +01:00
parent e1c7583670
commit 8f4ef1e274
9 changed files with 188 additions and 115 deletions

View File

@@ -45,7 +45,7 @@ func GetJobRepository() *JobRepository {
driver: db.Driver,
stmtCache: sq.NewStmtCache(db.DB),
cache: lrucache.New(1024 * 1024),
cache: lrucache.New(repoConfig.CacheSize),
}
})
return jobRepoInstance
@@ -267,7 +267,31 @@ func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float6
func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
var cnt int
q := sq.Select("count(*)").From("job").Where("job.start_time < ?", startTime)
q.RunWith(r.DB).QueryRow().Scan(cnt)
if err := q.RunWith(r.DB).QueryRow().Scan(&cnt); err != nil {
cclog.Errorf("Error counting jobs before %d: %v", startTime, err)
return 0, err
}
// Invalidate cache for jobs being deleted (get job IDs first)
if cnt > 0 {
var jobIds []int64
rows, err := sq.Select("id").From("job").Where("job.start_time < ?", startTime).RunWith(r.DB).Query()
if err == nil {
defer rows.Close()
for rows.Next() {
var id int64
if err := rows.Scan(&id); err == nil {
jobIds = append(jobIds, id)
}
}
// Invalidate cache entries
for _, id := range jobIds {
r.cache.Del(fmt.Sprintf("metadata:%d", id))
r.cache.Del(fmt.Sprintf("energyFootprint:%d", id))
}
}
}
qd := sq.Delete("job").Where("job.start_time < ?", startTime)
_, err := qd.RunWith(r.DB).Exec()
@@ -281,6 +305,10 @@ func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
}
func (r *JobRepository) DeleteJobById(id int64) error {
// Invalidate cache entries before deletion
r.cache.Del(fmt.Sprintf("metadata:%d", id))
r.cache.Del(fmt.Sprintf("energyFootprint:%d", id))
qd := sq.Delete("job").Where("job.id = ?", id)
_, err := qd.RunWith(r.DB).Exec()
@@ -450,13 +478,14 @@ func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]in
// FIXME: Set duration to requested walltime?
func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error {
start := time.Now()
currentTime := time.Now().Unix()
res, err := sq.Update("job").
Set("monitoring_status", schema.MonitoringStatusArchivingFailed).
Set("duration", 0).
Set("job_state", schema.JobStateFailed).
Where("job.job_state = 'running'").
Where("job.walltime > 0").
Where(fmt.Sprintf("(%d - job.start_time) > (job.walltime + %d)", time.Now().Unix(), seconds)).
Where("(? - job.start_time) > (job.walltime + ?)", currentTime, seconds).
RunWith(r.DB).Exec()
if err != nil {
cclog.Warn("Error while stopping jobs exceeding walltime")
@@ -505,21 +534,21 @@ func (r *JobRepository) FindJobIdsByTag(tagId int64) ([]int64, error) {
// FIXME: Reconsider filtering short jobs with harcoded threshold
func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) {
query := sq.Select(jobColumns...).From("job").
Where(fmt.Sprintf("job.cluster = '%s'", cluster)).
Where("job.cluster = ?", cluster).
Where("job.job_state = 'running'").
Where("job.duration > 600")
Where("job.duration > ?", repoConfig.MinRunningJobDuration)
rows, err := query.RunWith(r.stmtCache).Query()
if err != nil {
cclog.Error("Error while running query")
return nil, err
}
defer rows.Close()
jobs := make([]*schema.Job, 0, 50)
for rows.Next() {
job, err := scanJob(rows)
if err != nil {
rows.Close()
cclog.Warn("Error while scanning rows")
return nil, err
}
@@ -552,12 +581,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
if startTimeBegin == 0 {
cclog.Infof("Find jobs before %d", startTimeEnd)
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
"job.start_time < %d", startTimeEnd))
query = sq.Select(jobColumns...).From("job").Where("job.start_time < ?", startTimeEnd)
} else {
cclog.Infof("Find jobs between %d and %d", startTimeBegin, startTimeEnd)
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
"job.start_time BETWEEN %d AND %d", startTimeBegin, startTimeEnd))
query = sq.Select(jobColumns...).From("job").Where("job.start_time BETWEEN ? AND ?", startTimeBegin, startTimeEnd)
}
rows, err := query.RunWith(r.stmtCache).Query()
@@ -565,12 +592,12 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
cclog.Error("Error while running query")
return nil, err
}
defer rows.Close()
jobs := make([]*schema.Job, 0, 50)
for rows.Next() {
job, err := scanJob(rows)
if err != nil {
rows.Close()
cclog.Warn("Error while scanning rows")
return nil, err
}
@@ -582,6 +609,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
}
func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) {
// Invalidate cache entries as monitoring status affects job state
r.cache.Del(fmt.Sprintf("metadata:%d", job))
r.cache.Del(fmt.Sprintf("energyFootprint:%d", job))
stmt := sq.Update("job").
Set("monitoring_status", monitoringStatus).
Where("job.id = ?", job)