From 8f0bb907ff2bf75bcd576acad1f363ec280fb70d Mon Sep 17 00:00:00 2001 From: Jan Eitzinger Date: Thu, 15 Jan 2026 06:41:23 +0100 Subject: [PATCH] Improve documentation and add more tests --- internal/repository/hooks_test.go | 274 +++++++++++ internal/repository/job.go | 277 ++++++++++- internal/repository/jobCreate_test.go | 500 ++++++++++++++++++++ internal/repository/jobHooks.go | 48 +- internal/repository/migration.go | 63 +++ internal/repository/stats.go | 54 ++- internal/repository/stats_test.go | 13 +- internal/repository/tags.go | 42 +- internal/repository/transaction_test.go | 311 +++++++++++++ internal/repository/user.go | 47 ++ internal/repository/user_test.go | 596 ++++++++++++++++++++++++ internal/tagger/tagger.go | 2 +- 12 files changed, 2185 insertions(+), 42 deletions(-) create mode 100644 internal/repository/hooks_test.go create mode 100644 internal/repository/jobCreate_test.go create mode 100644 internal/repository/transaction_test.go create mode 100644 internal/repository/user_test.go diff --git a/internal/repository/hooks_test.go b/internal/repository/hooks_test.go new file mode 100644 index 00000000..52f954b5 --- /dev/null +++ b/internal/repository/hooks_test.go @@ -0,0 +1,274 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package repository + +import ( + "context" + "testing" + "time" + + "github.com/ClusterCockpit/cc-lib/v2/schema" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type MockJobHook struct { + startCalled bool + stopCalled bool + startJobs []*schema.Job + stopJobs []*schema.Job +} + +func (m *MockJobHook) JobStartCallback(job *schema.Job) { + m.startCalled = true + m.startJobs = append(m.startJobs, job) +} + +func (m *MockJobHook) JobStopCallback(job *schema.Job) { + m.stopCalled = true + m.stopJobs = append(m.stopJobs, job) +} + +func TestRegisterJobHook(t *testing.T) { + t.Run("register single hook", func(t *testing.T) { + hooks = nil + mock := &MockJobHook{} + + RegisterJobHook(mock) + + assert.NotNil(t, hooks) + assert.Len(t, hooks, 1) + assert.Equal(t, mock, hooks[0]) + + hooks = nil + }) + + t.Run("register multiple hooks", func(t *testing.T) { + hooks = nil + mock1 := &MockJobHook{} + mock2 := &MockJobHook{} + + RegisterJobHook(mock1) + RegisterJobHook(mock2) + + assert.Len(t, hooks, 2) + assert.Equal(t, mock1, hooks[0]) + assert.Equal(t, mock2, hooks[1]) + + hooks = nil + }) + + t.Run("register nil hook does not add to hooks", func(t *testing.T) { + hooks = nil + RegisterJobHook(nil) + + if hooks != nil { + assert.Len(t, hooks, 0, "Nil hook should not be added") + } + + hooks = nil + }) +} + +func TestCallJobStartHooks(t *testing.T) { + t.Run("call start hooks with single job", func(t *testing.T) { + hooks = nil + mock := &MockJobHook{} + RegisterJobHook(mock) + + job := &schema.Job{ + JobID: 123, + User: "testuser", + Cluster: "testcluster", + } + + CallJobStartHooks([]*schema.Job{job}) + + assert.True(t, mock.startCalled) + assert.False(t, mock.stopCalled) + assert.Len(t, mock.startJobs, 1) + assert.Equal(t, int64(123), mock.startJobs[0].JobID) + + hooks = nil + }) + + t.Run("call start hooks with multiple jobs", func(t *testing.T) { + hooks = nil + mock := &MockJobHook{} + RegisterJobHook(mock) + + jobs := []*schema.Job{ + {JobID: 1, User: "user1", Cluster: "cluster1"}, + {JobID: 2, User: "user2", Cluster: "cluster2"}, + {JobID: 3, User: "user3", Cluster: "cluster3"}, + } + + CallJobStartHooks(jobs) + + assert.True(t, mock.startCalled) + assert.Len(t, mock.startJobs, 3) + assert.Equal(t, int64(1), mock.startJobs[0].JobID) + assert.Equal(t, int64(2), mock.startJobs[1].JobID) + assert.Equal(t, int64(3), mock.startJobs[2].JobID) + + hooks = nil + }) + + t.Run("call start hooks with multiple registered hooks", func(t *testing.T) { + hooks = nil + mock1 := &MockJobHook{} + mock2 := &MockJobHook{} + RegisterJobHook(mock1) + RegisterJobHook(mock2) + + job := &schema.Job{ + JobID: 456, User: "testuser", Cluster: "testcluster", + } + + CallJobStartHooks([]*schema.Job{job}) + + assert.True(t, mock1.startCalled) + assert.True(t, mock2.startCalled) + assert.Len(t, mock1.startJobs, 1) + assert.Len(t, mock2.startJobs, 1) + + hooks = nil + }) + + t.Run("call start hooks with nil hooks", func(t *testing.T) { + hooks = nil + + job := &schema.Job{ + JobID: 789, User: "testuser", Cluster: "testcluster", + } + + CallJobStartHooks([]*schema.Job{job}) + + hooks = nil + }) + + t.Run("call start hooks with empty job list", func(t *testing.T) { + hooks = nil + mock := &MockJobHook{} + RegisterJobHook(mock) + + CallJobStartHooks([]*schema.Job{}) + + assert.False(t, mock.startCalled) + assert.Len(t, mock.startJobs, 0) + + hooks = nil + }) +} + +func TestCallJobStopHooks(t *testing.T) { + t.Run("call stop hooks with single job", func(t *testing.T) { + hooks = nil + mock := &MockJobHook{} + RegisterJobHook(mock) + + job := &schema.Job{ + JobID: 123, + User: "testuser", + Cluster: "testcluster", + } + + CallJobStopHooks(job) + + assert.True(t, mock.stopCalled) + assert.False(t, mock.startCalled) + assert.Len(t, mock.stopJobs, 1) + assert.Equal(t, int64(123), mock.stopJobs[0].JobID) + + hooks = nil + }) + + t.Run("call stop hooks with multiple registered hooks", func(t *testing.T) { + hooks = nil + mock1 := &MockJobHook{} + mock2 := &MockJobHook{} + RegisterJobHook(mock1) + RegisterJobHook(mock2) + + job := &schema.Job{ + JobID: 456, User: "testuser", Cluster: "testcluster", + } + + CallJobStopHooks(job) + + assert.True(t, mock1.stopCalled) + assert.True(t, mock2.stopCalled) + assert.Len(t, mock1.stopJobs, 1) + assert.Len(t, mock2.stopJobs, 1) + + hooks = nil + }) + + t.Run("call stop hooks with nil hooks", func(t *testing.T) { + hooks = nil + + job := &schema.Job{ + JobID: 789, User: "testuser", Cluster: "testcluster", + } + + CallJobStopHooks(job) + + hooks = nil + }) +} + +func TestSQLHooks(t *testing.T) { + _ = setup(t) + + t.Run("hooks log queries in debug mode", func(t *testing.T) { + h := &Hooks{} + + ctx := context.Background() + query := "SELECT * FROM job WHERE job_id = ?" + args := []any{123} + + ctxWithTime, err := h.Before(ctx, query, args...) + require.NoError(t, err) + assert.NotNil(t, ctxWithTime) + + beginTime := ctxWithTime.Value("begin") + require.NotNil(t, beginTime) + _, ok := beginTime.(time.Time) + assert.True(t, ok, "Begin time should be time.Time") + + time.Sleep(10 * time.Millisecond) + + ctxAfter, err := h.After(ctxWithTime, query, args...) + require.NoError(t, err) + assert.NotNil(t, ctxAfter) + }) +} + +func TestHookIntegration(t *testing.T) { + t.Run("hooks are called during job lifecycle", func(t *testing.T) { + hooks = nil + mock := &MockJobHook{} + RegisterJobHook(mock) + + job := &schema.Job{ + JobID: 999, + User: "integrationuser", + Cluster: "integrationcluster", + } + + CallJobStartHooks([]*schema.Job{job}) + assert.True(t, mock.startCalled) + assert.Equal(t, 1, len(mock.startJobs)) + + CallJobStopHooks(job) + assert.True(t, mock.stopCalled) + assert.Equal(t, 1, len(mock.stopJobs)) + + assert.Equal(t, mock.startJobs[0].JobID, mock.stopJobs[0].JobID) + + hooks = nil + }) +} diff --git a/internal/repository/job.go b/internal/repository/job.go index b1e92424..bd33774c 100644 --- a/internal/repository/job.go +++ b/internal/repository/job.go @@ -80,18 +80,33 @@ import ( ) var ( - jobRepoOnce sync.Once + // jobRepoOnce ensures singleton initialization of the JobRepository + jobRepoOnce sync.Once + // jobRepoInstance holds the single instance of JobRepository jobRepoInstance *JobRepository ) +// JobRepository provides database access for job-related operations. +// It implements the repository pattern to abstract database interactions +// and provides caching for improved performance. +// +// The repository is a singleton initialized via GetJobRepository(). +// All database queries use prepared statements via stmtCache for efficiency. +// Frequently accessed data (metadata, energy footprints) is cached in an LRU cache. type JobRepository struct { - DB *sqlx.DB - stmtCache *sq.StmtCache - cache *lrucache.Cache - driver string - Mutex sync.Mutex + DB *sqlx.DB // Database connection pool + stmtCache *sq.StmtCache // Prepared statement cache for query optimization + cache *lrucache.Cache // LRU cache for metadata and footprint data + driver string // Database driver name (e.g., "sqlite3") + Mutex sync.Mutex // Mutex for thread-safe operations } +// GetJobRepository returns the singleton instance of JobRepository. +// The repository is initialized lazily on first access with database connection, +// prepared statement cache, and LRU cache configured from repoConfig. +// +// This function is thread-safe and ensures only one instance is created. +// It must be called after Connect() has established a database connection. func GetJobRepository() *JobRepository { jobRepoOnce.Do(func() { db := GetConnection() @@ -107,6 +122,8 @@ func GetJobRepository() *JobRepository { return jobRepoInstance } +// jobColumns defines the standard set of columns selected from the job table. +// Used consistently across all job queries to ensure uniform data retrieval. var jobColumns []string = []string{ "job.id", "job.job_id", "job.hpc_user", "job.project", "job.cluster", "job.subcluster", "job.start_time", "job.cluster_partition", "job.array_job_id", "job.num_nodes", @@ -115,6 +132,8 @@ var jobColumns []string = []string{ "job.footprint", "job.energy", } +// jobCacheColumns defines columns from the job_cache table, mirroring jobColumns. +// Used for queries against cached job data for performance optimization. var jobCacheColumns []string = []string{ "job_cache.id", "job_cache.job_id", "job_cache.hpc_user", "job_cache.project", "job_cache.cluster", "job_cache.subcluster", "job_cache.start_time", "job_cache.cluster_partition", @@ -124,6 +143,14 @@ var jobCacheColumns []string = []string{ "job_cache.footprint", "job_cache.energy", } +// scanJob converts a database row into a schema.Job struct. +// It handles JSON unmarshaling of resources and footprint fields, +// and calculates accurate duration for running jobs. +// +// Parameters: +// - row: Database row implementing Scan() interface (sql.Row or sql.Rows) +// +// Returns the populated Job struct or an error if scanning or unmarshaling fails. func scanJob(row interface{ Scan(...any) error }) (*schema.Job, error) { job := &schema.Job{} @@ -186,6 +213,16 @@ func (r *JobRepository) Flush() error { return nil } +// FetchMetadata retrieves and unmarshals the metadata JSON for a job. +// Metadata is cached with a 24-hour TTL to improve performance. +// +// The metadata field stores arbitrary key-value pairs associated with a job, +// such as tags, labels, or custom attributes added by external systems. +// +// Parameters: +// - job: Job struct with valid ID field, metadata will be populated in job.MetaData +// +// Returns the metadata map or an error if the job is nil or database query fails. func (r *JobRepository) FetchMetadata(job *schema.Job) (map[string]string, error) { if job == nil { return nil, fmt.Errorf("job cannot be nil") @@ -218,6 +255,16 @@ func (r *JobRepository) FetchMetadata(job *schema.Job) (map[string]string, error return job.MetaData, nil } +// UpdateMetadata adds or updates a single metadata key-value pair for a job. +// The entire metadata map is re-marshaled and stored, and the cache is invalidated. +// Also triggers archive metadata update via archive.UpdateMetadata. +// +// Parameters: +// - job: Job struct with valid ID, existing metadata will be fetched if not present +// - key: Metadata key to set +// - val: Metadata value to set +// +// Returns an error if the job is nil, metadata fetch fails, or database update fails. func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err error) { if job == nil { return fmt.Errorf("job cannot be nil") @@ -228,7 +275,7 @@ func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err er if job.MetaData == nil { if _, err = r.FetchMetadata(job); err != nil { cclog.Warnf("Error while fetching metadata for job, DB ID '%v'", job.ID) - return err + return fmt.Errorf("failed to fetch metadata for job %d: %w", job.ID, err) } } @@ -243,7 +290,7 @@ func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err er if job.RawMetaData, err = json.Marshal(job.MetaData); err != nil { cclog.Warnf("Error while marshaling metadata for job, DB ID '%v'", job.ID) - return err + return fmt.Errorf("failed to marshal metadata for job %d: %w", job.ID, err) } if _, err = sq.Update("job"). @@ -251,13 +298,23 @@ func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err er Where("job.id = ?", job.ID). RunWith(r.stmtCache).Exec(); err != nil { cclog.Warnf("Error while updating metadata for job, DB ID '%v'", job.ID) - return err + return fmt.Errorf("failed to update metadata in database for job %d: %w", job.ID, err) } r.cache.Put(cachekey, job.MetaData, len(job.RawMetaData), 24*time.Hour) return archive.UpdateMetadata(job, job.MetaData) } +// FetchFootprint retrieves and unmarshals the performance footprint JSON for a job. +// Unlike FetchMetadata, footprints are NOT cached as they can be large and change frequently. +// +// The footprint contains summary statistics (avg/min/max) for monitored metrics, +// stored as JSON with keys like "cpu_load_avg", "mem_used_max", etc. +// +// Parameters: +// - job: Job struct with valid ID, footprint will be populated in job.Footprint +// +// Returns the footprint map or an error if the job is nil or database query fails. func (r *JobRepository) FetchFootprint(job *schema.Job) (map[string]float64, error) { if job == nil { return nil, fmt.Errorf("job cannot be nil") @@ -284,6 +341,16 @@ func (r *JobRepository) FetchFootprint(job *schema.Job) (map[string]float64, err return job.Footprint, nil } +// FetchEnergyFootprint retrieves and unmarshals the energy footprint JSON for a job. +// Energy footprints are cached with a 24-hour TTL as they are frequently accessed but rarely change. +// +// The energy footprint contains calculated energy consumption (in kWh) per metric, +// stored as JSON with keys like "power_avg", "acc_power_avg", etc. +// +// Parameters: +// - job: Job struct with valid ID, energy footprint will be populated in job.EnergyFootprint +// +// Returns the energy footprint map or an error if the job is nil or database query fails. func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float64, error) { if job == nil { return nil, fmt.Errorf("job cannot be nil") @@ -316,6 +383,18 @@ func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float6 return job.EnergyFootprint, nil } +// DeleteJobsBefore removes jobs older than the specified start time. +// Optionally preserves tagged jobs to protect important data from deletion. +// Cache entries for deleted jobs are automatically invalidated. +// +// This is typically used for data retention policies and cleanup operations. +// WARNING: This is a destructive operation that permanently deletes job records. +// +// Parameters: +// - startTime: Unix timestamp, jobs with start_time < this value will be deleted +// - omitTagged: If true, skip jobs that have associated tags (jobtag entries) +// +// Returns the count of deleted jobs or an error if the operation fails. func (r *JobRepository) DeleteJobsBefore(startTime int64, omitTagged bool) (int, error) { var cnt int q := sq.Select("count(*)").From("job").Where("job.start_time < ?", startTime) @@ -371,6 +450,13 @@ func (r *JobRepository) DeleteJobsBefore(startTime int64, omitTagged bool) (int, return cnt, err } +// DeleteJobByID permanently removes a single job by its database ID. +// Cache entries for the deleted job are automatically invalidated. +// +// Parameters: +// - id: Database ID (primary key) of the job to delete +// +// Returns an error if the deletion fails. func (r *JobRepository) DeleteJobByID(id int64) error { // Invalidate cache entries before deletion r.cache.Del(fmt.Sprintf("metadata:%d", id)) @@ -388,6 +474,24 @@ func (r *JobRepository) DeleteJobByID(id int64) error { return err } +// FindUserOrProjectOrJobname attempts to interpret a search term as a job ID, +// username, project ID, or job name by querying the database. +// +// Search logic (in priority order): +// 1. If searchterm is numeric, treat as job ID (returned immediately) +// 2. Try exact match in job.hpc_user column (username) +// 3. Try LIKE match in hpc_user.name column (real name) +// 4. Try exact match in job.project column (project ID) +// 5. If no matches, return searchterm as jobname for GraphQL query +// +// This powers the searchbar functionality for flexible job searching. +// Requires authenticated user for database lookups (returns empty if user is nil). +// +// Parameters: +// - user: Authenticated user context, required for database access +// - searchterm: Search string to interpret +// +// Returns up to one non-empty value among (jobid, username, project, jobname). func (r *JobRepository) FindUserOrProjectOrJobname(user *schema.User, searchterm string) (jobid string, username string, project string, jobname string) { if searchterm == "" { return "", "", "", "" @@ -423,6 +527,19 @@ var ( ErrForbidden = errors.New("not authorized") ) +// FindColumnValue performs a generic column lookup in a database table with role-based access control. +// Only users with admin, support, or manager roles can execute this query. +// +// Parameters: +// - user: User context for authorization check +// - searchterm: Value to search for (exact match or LIKE pattern) +// - table: Database table name to query +// - selectColumn: Column name to return in results +// - whereColumn: Column name to filter on +// - isLike: If true, use LIKE with wildcards; if false, use exact equality +// +// Returns the first matching value, ErrForbidden if user lacks permission, +// or ErrNotFound if no matches are found. func (r *JobRepository) FindColumnValue(user *schema.User, searchterm string, table string, selectColumn string, whereColumn string, isLike bool) (result string, err error) { if user == nil { return "", fmt.Errorf("user cannot be nil") @@ -453,6 +570,19 @@ func (r *JobRepository) FindColumnValue(user *schema.User, searchterm string, ta } } +// FindColumnValues performs a generic column lookup returning multiple matches with role-based access control. +// Similar to FindColumnValue but returns all matching values instead of just the first. +// Only users with admin, support, or manager roles can execute this query. +// +// Parameters: +// - user: User context for authorization check +// - query: Search pattern (always uses LIKE with wildcards) +// - table: Database table name to query +// - selectColumn: Column name to return in results +// - whereColumn: Column name to filter on +// +// Returns a slice of matching values, ErrForbidden if user lacks permission, +// or ErrNotFound if no matches are found. func (r *JobRepository) FindColumnValues(user *schema.User, query string, table string, selectColumn string, whereColumn string) (results []string, err error) { if user == nil { return nil, fmt.Errorf("user cannot be nil") @@ -487,6 +617,13 @@ func (r *JobRepository) FindColumnValues(user *schema.User, query string, table } } +// Partitions returns a list of distinct cluster partitions for a given cluster. +// Results are cached with a 1-hour TTL to improve performance. +// +// Parameters: +// - cluster: Cluster name to query partitions for +// +// Returns a slice of partition names or an error if the database query fails. func (r *JobRepository) Partitions(cluster string) ([]string, error) { var err error start := time.Now() @@ -550,6 +687,19 @@ func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]in } // FIXME: Set duration to requested walltime? +// StopJobsExceedingWalltimeBy marks running jobs as failed if they exceed their walltime limit. +// This is typically called periodically to clean up stuck or orphaned jobs. +// +// Jobs are marked with: +// - monitoring_status: MonitoringStatusArchivingFailed +// - duration: 0 +// - job_state: JobStateFailed +// +// Parameters: +// - seconds: Grace period beyond walltime before marking as failed +// +// Returns an error if the database update fails. +// Logs the number of jobs marked as failed if any were affected. func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error { start := time.Now() currentTime := time.Now().Unix() @@ -579,6 +729,12 @@ func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error { return nil } +// FindJobIdsByTag returns all job database IDs associated with a specific tag. +// +// Parameters: +// - tagID: Database ID of the tag to search for +// +// Returns a slice of job IDs or an error if the query fails. func (r *JobRepository) FindJobIdsByTag(tagID int64) ([]int64, error) { query := sq.Select("job.id").From("job"). Join("jobtag ON jobtag.job_id = job.id"). @@ -607,6 +763,13 @@ func (r *JobRepository) FindJobIdsByTag(tagID int64) ([]int64, error) { } // FIXME: Reconsider filtering short jobs with harcoded threshold +// FindRunningJobs returns all currently running jobs for a specific cluster. +// Filters out short-running jobs based on repoConfig.MinRunningJobDuration threshold. +// +// Parameters: +// - cluster: Cluster name to filter jobs +// +// Returns a slice of running job objects or an error if the query fails. func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) { query := sq.Select(jobColumns...).From("job"). Where("job.cluster = ?", cluster). @@ -634,6 +797,12 @@ func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) { return jobs, nil } +// UpdateDuration recalculates and updates the duration field for all running jobs. +// Called periodically to keep job durations current without querying individual jobs. +// +// Duration is calculated as: current_time - job.start_time +// +// Returns an error if the database update fails. func (r *JobRepository) UpdateDuration() error { stmnt := sq.Update("job"). Set("duration", sq.Expr("? - job.start_time", time.Now().Unix())). @@ -648,6 +817,16 @@ func (r *JobRepository) UpdateDuration() error { return nil } +// FindJobsBetween returns jobs within a specified time range. +// If startTimeBegin is 0, returns all jobs before startTimeEnd. +// Optionally excludes tagged jobs from results. +// +// Parameters: +// - startTimeBegin: Unix timestamp for range start (use 0 for unbounded start) +// - startTimeEnd: Unix timestamp for range end +// - omitTagged: If true, exclude jobs with associated tags +// +// Returns a slice of jobs or an error if the time range is invalid or query fails. func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64, omitTagged bool) ([]*schema.Job, error) { var query sq.SelectBuilder @@ -688,6 +867,14 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64 return jobs, nil } +// UpdateMonitoringStatus updates the monitoring status for a job and invalidates its cache entries. +// Cache invalidation affects both metadata and energy footprint to ensure consistency. +// +// Parameters: +// - job: Database ID of the job to update +// - monitoringStatus: New monitoring status value (see schema.MonitoringStatus constants) +// +// Returns an error if the database update fails. func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) { // Invalidate cache entries as monitoring status affects job state r.cache.Del(fmt.Sprintf("metadata:%d", job)) @@ -704,6 +891,13 @@ func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32 return nil } +// Execute runs a Squirrel UpdateBuilder statement against the database. +// This is a generic helper for executing pre-built update queries. +// +// Parameters: +// - stmt: Squirrel UpdateBuilder with prepared update query +// +// Returns an error if the execution fails. func (r *JobRepository) Execute(stmt sq.UpdateBuilder) error { if _, err := stmt.RunWith(r.stmtCache).Exec(); err != nil { cclog.Errorf("Error while executing statement: %v", err) @@ -713,6 +907,14 @@ func (r *JobRepository) Execute(stmt sq.UpdateBuilder) error { return nil } +// MarkArchived adds monitoring status update to an existing UpdateBuilder statement. +// This is a builder helper used when constructing multi-field update queries. +// +// Parameters: +// - stmt: Existing UpdateBuilder to modify +// - monitoringStatus: Monitoring status value to set +// +// Returns the modified UpdateBuilder for method chaining. func (r *JobRepository) MarkArchived( stmt sq.UpdateBuilder, monitoringStatus int32, @@ -720,11 +922,22 @@ func (r *JobRepository) MarkArchived( return stmt.Set("monitoring_status", monitoringStatus) } +// UpdateEnergy calculates and updates the energy consumption for a job. +// This is called for running jobs during intermediate updates or when archiving. +// +// Energy calculation formula: +// - For "power" metrics: Energy (kWh) = (Power_avg * NumNodes * Duration_hours) / 1000 +// - For "energy" metrics: Currently not implemented (would need sum statistics) +// +// The calculation accounts for: +// - Multi-node jobs: Multiplies by NumNodes to get total cluster energy +// - Shared jobs: Node average is already based on partial resources, so NumNodes=1 +// - Unit conversion: Watts * hours / 1000 = kilowatt-hours (kWh) +// - Rounding: Results rounded to 2 decimal places func (r *JobRepository) UpdateEnergy( stmt sq.UpdateBuilder, jobMeta *schema.Job, ) (sq.UpdateBuilder, error) { - /* Note: Only Called for Running Jobs during Intermediate Update or on Archiving */ sc, err := archive.GetSubCluster(jobMeta.Cluster, jobMeta.SubCluster) if err != nil { cclog.Errorf("cannot get subcluster: %s", err.Error()) @@ -732,25 +945,27 @@ func (r *JobRepository) UpdateEnergy( } energyFootprint := make(map[string]float64) - // Total Job Energy Outside Loop + // Accumulate total energy across all energy-related metrics totalEnergy := 0.0 for _, fp := range sc.EnergyFootprint { - // Always Init Metric Energy Inside Loop + // Calculate energy for this specific metric metricEnergy := 0.0 if i, err := archive.MetricIndex(sc.MetricConfig, fp); err == nil { - // Note: For DB data, calculate and save as kWh switch sc.MetricConfig[i].Energy { - case "energy": // this metric has energy as unit (Joules or Wh) + case "energy": // Metric already in energy units (Joules or Wh) cclog.Warnf("Update EnergyFootprint for Job %d and Metric %s on cluster %s: Set to 'energy' in cluster.json: Not implemented, will return 0.0", jobMeta.JobID, jobMeta.Cluster, fp) - // FIXME: Needs sum as stats type - case "power": // this metric has power as unit (Watt) - // Energy: Power (in Watts) * Time (in Seconds) - // Unit: (W * (s / 3600)) / 1000 = kWh - // Round 2 Digits: round(Energy * 100) / 100 - // Here: (All-Node Metric Average * Number of Nodes) * (Job Duration in Seconds / 3600) / 1000 - // Note: Shared Jobs handled correctly since "Node Average" is based on partial resources, while "numNodes" factor is 1 + // FIXME: Needs sum as stats type to accumulate energy values over time + case "power": // Metric in power units (Watts) + // Energy (kWh) = Power (W) × Time (h) / 1000 + // Formula: (avg_power_per_node * num_nodes) * (duration_sec / 3600) / 1000 + // + // Breakdown: + // LoadJobStat(jobMeta, fp, "avg") = average power per node (W) + // jobMeta.NumNodes = number of nodes (1 for shared jobs) + // jobMeta.Duration / 3600.0 = duration in hours + // / 1000.0 = convert Wh to kWh rawEnergy := ((LoadJobStat(jobMeta, fp, "avg") * float64(jobMeta.NumNodes)) * (float64(jobMeta.Duration) / 3600.0)) / 1000.0 - metricEnergy = math.Round(rawEnergy*100.0) / 100.0 + metricEnergy = math.Round(rawEnergy*100.0) / 100.0 // Round to 2 decimal places } } else { cclog.Warnf("Error while collecting energy metric %s for job, DB ID '%v', return '0.0'", fp, jobMeta.ID) @@ -758,8 +973,6 @@ func (r *JobRepository) UpdateEnergy( energyFootprint[fp] = metricEnergy totalEnergy += metricEnergy - - // cclog.Infof("Metric %s Average %f -> %f kWh | Job %d Total -> %f kWh", fp, LoadJobStat(jobMeta, fp, "avg"), energy, jobMeta.JobID, totalEnergy) } var rawFootprint []byte @@ -771,11 +984,19 @@ func (r *JobRepository) UpdateEnergy( return stmt.Set("energy_footprint", string(rawFootprint)).Set("energy", (math.Round(totalEnergy*100.0) / 100.0)), nil } +// UpdateFootprint calculates and updates the performance footprint for a job. +// This is called for running jobs during intermediate updates or when archiving. +// +// A footprint is a summary statistic (avg/min/max) for each monitored metric. +// The specific statistic type is defined in the cluster config's Footprint field. +// Results are stored as JSON with keys like "metric_avg", "metric_max", etc. +// +// Example: For a "cpu_load" metric with Footprint="avg", this stores +// the average CPU load across all nodes as "cpu_load_avg": 85.3 func (r *JobRepository) UpdateFootprint( stmt sq.UpdateBuilder, jobMeta *schema.Job, ) (sq.UpdateBuilder, error) { - /* Note: Only Called for Running Jobs during Intermediate Update or on Archiving */ sc, err := archive.GetSubCluster(jobMeta.Cluster, jobMeta.SubCluster) if err != nil { cclog.Errorf("cannot get subcluster: %s", err.Error()) @@ -783,7 +1004,10 @@ func (r *JobRepository) UpdateFootprint( } footprint := make(map[string]float64) + // Build footprint map with metric_stattype as keys for _, fp := range sc.Footprint { + // Determine which statistic to use: avg, min, or max + // First check global metric config, then cluster-specific config var statType string for _, gm := range archive.GlobalMetricList { if gm.Name == fp { @@ -791,15 +1015,18 @@ func (r *JobRepository) UpdateFootprint( } } + // Validate statistic type if statType != "avg" && statType != "min" && statType != "max" { cclog.Warnf("unknown statType for footprint update: %s", statType) return stmt, fmt.Errorf("unknown statType for footprint update: %s", statType) } + // Override with cluster-specific config if available if i, err := archive.MetricIndex(sc.MetricConfig, fp); err != nil { statType = sc.MetricConfig[i].Footprint } + // Store as "metric_stattype": value (e.g., "cpu_load_avg": 85.3) name := fmt.Sprintf("%s_%s", fp, statType) footprint[name] = LoadJobStat(jobMeta, fp, statType) } diff --git a/internal/repository/jobCreate_test.go b/internal/repository/jobCreate_test.go new file mode 100644 index 00000000..3a586482 --- /dev/null +++ b/internal/repository/jobCreate_test.go @@ -0,0 +1,500 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package repository + +import ( + "encoding/json" + "testing" + + "github.com/ClusterCockpit/cc-lib/v2/schema" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createTestJob creates a minimal valid job for testing +func createTestJob(jobID int64, cluster string) *schema.Job { + return &schema.Job{ + JobID: jobID, + User: "testuser", + Project: "testproject", + Cluster: cluster, + SubCluster: "main", + Partition: "batch", + NumNodes: 1, + NumHWThreads: 4, + NumAcc: 0, + Shared: "none", + MonitoringStatus: schema.MonitoringStatusRunningOrArchiving, + SMT: 1, + State: schema.JobStateRunning, + StartTime: 1234567890, + Duration: 0, + Walltime: 3600, + Resources: []*schema.Resource{ + { + Hostname: "node01", + HWThreads: []int{0, 1, 2, 3}, + }, + }, + Footprint: map[string]float64{ + "cpu_load": 50.0, + "mem_used": 8000.0, + "flops_any": 0.5, + "mem_bw": 10.0, + "net_bw": 2.0, + "file_bw": 1.0, + "cpu_used": 2.0, + "cpu_load_core": 12.5, + }, + MetaData: map[string]string{ + "jobName": "test_job", + "queue": "normal", + "qosName": "default", + "accountName": "testaccount", + }, + } +} + +func TestInsertJob(t *testing.T) { + r := setup(t) + + t.Run("successful insertion", func(t *testing.T) { + job := createTestJob(999001, "testcluster") + job.RawResources, _ = json.Marshal(job.Resources) + job.RawFootprint, _ = json.Marshal(job.Footprint) + job.RawMetaData, _ = json.Marshal(job.MetaData) + + id, err := r.InsertJob(job) + require.NoError(t, err, "InsertJob should succeed") + assert.Greater(t, id, int64(0), "Should return valid insert ID") + + // Verify job was inserted into job_cache + var count int + err = r.DB.QueryRow("SELECT COUNT(*) FROM job_cache WHERE job_id = ? AND cluster = ?", + job.JobID, job.Cluster).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count, "Job should be in job_cache table") + + // Clean up + _, err = r.DB.Exec("DELETE FROM job_cache WHERE job_id = ? AND cluster = ?", job.JobID, job.Cluster) + require.NoError(t, err) + }) + + t.Run("insertion with all fields", func(t *testing.T) { + job := createTestJob(999002, "testcluster") + job.ArrayJobID = 5000 + job.Energy = 1500.5 + job.RawResources, _ = json.Marshal(job.Resources) + job.RawFootprint, _ = json.Marshal(job.Footprint) + job.RawMetaData, _ = json.Marshal(job.MetaData) + + id, err := r.InsertJob(job) + require.NoError(t, err) + assert.Greater(t, id, int64(0)) + + // Verify all fields were stored correctly + var retrievedJob schema.Job + err = r.DB.QueryRow(`SELECT job_id, hpc_user, project, cluster, array_job_id, energy + FROM job_cache WHERE id = ?`, id).Scan( + &retrievedJob.JobID, &retrievedJob.User, &retrievedJob.Project, + &retrievedJob.Cluster, &retrievedJob.ArrayJobID, &retrievedJob.Energy) + require.NoError(t, err) + assert.Equal(t, job.JobID, retrievedJob.JobID) + assert.Equal(t, job.User, retrievedJob.User) + assert.Equal(t, job.Project, retrievedJob.Project) + assert.Equal(t, job.Cluster, retrievedJob.Cluster) + assert.Equal(t, job.ArrayJobID, retrievedJob.ArrayJobID) + assert.Equal(t, job.Energy, retrievedJob.Energy) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + }) +} + +func TestStart(t *testing.T) { + r := setup(t) + + t.Run("successful job start with JSON encoding", func(t *testing.T) { + job := createTestJob(999003, "testcluster") + + id, err := r.Start(job) + require.NoError(t, err, "Start should succeed") + assert.Greater(t, id, int64(0), "Should return valid insert ID") + + // Verify job was inserted and JSON fields were encoded + var rawResources, rawFootprint, rawMetaData []byte + err = r.DB.QueryRow(`SELECT resources, footprint, meta_data FROM job_cache WHERE id = ?`, id).Scan( + &rawResources, &rawFootprint, &rawMetaData) + require.NoError(t, err) + + // Verify resources JSON + var resources []*schema.Resource + err = json.Unmarshal(rawResources, &resources) + require.NoError(t, err, "Resources should be valid JSON") + assert.Equal(t, 1, len(resources)) + assert.Equal(t, "node01", resources[0].Hostname) + + // Verify footprint JSON + var footprint map[string]float64 + err = json.Unmarshal(rawFootprint, &footprint) + require.NoError(t, err, "Footprint should be valid JSON") + assert.Equal(t, 50.0, footprint["cpu_load"]) + assert.Equal(t, 8000.0, footprint["mem_used"]) + + // Verify metadata JSON + var metaData map[string]string + err = json.Unmarshal(rawMetaData, &metaData) + require.NoError(t, err, "MetaData should be valid JSON") + assert.Equal(t, "test_job", metaData["jobName"]) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + }) + + t.Run("job start with empty footprint", func(t *testing.T) { + job := createTestJob(999004, "testcluster") + job.Footprint = map[string]float64{} + + id, err := r.Start(job) + require.NoError(t, err) + assert.Greater(t, id, int64(0)) + + // Verify empty footprint was encoded as empty JSON object + var rawFootprint []byte + err = r.DB.QueryRow(`SELECT footprint FROM job_cache WHERE id = ?`, id).Scan(&rawFootprint) + require.NoError(t, err) + assert.Equal(t, []byte("{}"), rawFootprint) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + }) + + t.Run("job start with nil metadata", func(t *testing.T) { + job := createTestJob(999005, "testcluster") + job.MetaData = nil + + id, err := r.Start(job) + require.NoError(t, err) + assert.Greater(t, id, int64(0)) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + }) +} + +func TestStop(t *testing.T) { + r := setup(t) + + t.Run("successful job stop", func(t *testing.T) { + // First insert a job using Start + job := createTestJob(999106, "testcluster") + id, err := r.Start(job) + require.NoError(t, err) + + // Move from job_cache to job table (simulate SyncJobs) - exclude id to let it auto-increment + _, err = r.DB.Exec(`INSERT INTO job (job_id, cluster, subcluster, submit_time, start_time, hpc_user, project, + cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes, + num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint) + SELECT job_id, cluster, subcluster, submit_time, start_time, hpc_user, project, + cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes, + num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint + FROM job_cache WHERE id = ?`, id) + require.NoError(t, err) + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + + // Get the new job id in the job table + err = r.DB.QueryRow("SELECT id FROM job WHERE job_id = ? AND cluster = ? AND start_time = ?", + job.JobID, job.Cluster, job.StartTime).Scan(&id) + require.NoError(t, err) + + // Stop the job + duration := int32(3600) + state := schema.JobStateCompleted + monitoringStatus := int32(schema.MonitoringStatusArchivingSuccessful) + + err = r.Stop(id, duration, state, monitoringStatus) + require.NoError(t, err, "Stop should succeed") + + // Verify job was updated + var retrievedDuration int32 + var retrievedState string + var retrievedMonStatus int32 + err = r.DB.QueryRow(`SELECT duration, job_state, monitoring_status FROM job WHERE id = ?`, id).Scan( + &retrievedDuration, &retrievedState, &retrievedMonStatus) + require.NoError(t, err) + assert.Equal(t, duration, retrievedDuration) + assert.Equal(t, string(state), retrievedState) + assert.Equal(t, monitoringStatus, retrievedMonStatus) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job WHERE id = ?", id) + require.NoError(t, err) + }) + + t.Run("stop updates job state transitions", func(t *testing.T) { + // Insert a job + job := createTestJob(999107, "testcluster") + id, err := r.Start(job) + require.NoError(t, err) + + // Move to job table + _, err = r.DB.Exec(`INSERT INTO job (job_id, cluster, subcluster, submit_time, start_time, hpc_user, project, + cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes, + num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint) + SELECT job_id, cluster, subcluster, submit_time, start_time, hpc_user, project, + cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes, + num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint + FROM job_cache WHERE id = ?`, id) + require.NoError(t, err) + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + + // Get the new job id in the job table + err = r.DB.QueryRow("SELECT id FROM job WHERE job_id = ? AND cluster = ? AND start_time = ?", + job.JobID, job.Cluster, job.StartTime).Scan(&id) + require.NoError(t, err) + + // Stop the job with different duration + err = r.Stop(id, 7200, schema.JobStateCompleted, int32(schema.MonitoringStatusArchivingSuccessful)) + require.NoError(t, err) + + // Verify the duration was updated correctly + var duration int32 + err = r.DB.QueryRow(`SELECT duration FROM job WHERE id = ?`, id).Scan(&duration) + require.NoError(t, err) + assert.Equal(t, int32(7200), duration, "Duration should be updated to 7200") + + // Clean up + _, err = r.DB.Exec("DELETE FROM job WHERE id = ?", id) + require.NoError(t, err) + }) + + t.Run("stop with different states", func(t *testing.T) { + testCases := []struct { + name string + jobID int64 + state schema.JobState + monitoringStatus int32 + }{ + {"completed", 999108, schema.JobStateCompleted, int32(schema.MonitoringStatusArchivingSuccessful)}, + {"failed", 999118, schema.JobStateFailed, int32(schema.MonitoringStatusArchivingSuccessful)}, + {"cancelled", 999119, schema.JobStateCancelled, int32(schema.MonitoringStatusArchivingSuccessful)}, + {"timeout", 999120, schema.JobStateTimeout, int32(schema.MonitoringStatusArchivingSuccessful)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + job := createTestJob(tc.jobID, "testcluster") + id, err := r.Start(job) + require.NoError(t, err) + + // Move to job table + _, err = r.DB.Exec(`INSERT INTO job (job_id, cluster, subcluster, submit_time, start_time, hpc_user, project, + cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes, + num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint) + SELECT job_id, cluster, subcluster, submit_time, start_time, hpc_user, project, + cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes, + num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint + FROM job_cache WHERE id = ?`, id) + require.NoError(t, err) + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + + // Get the new job id in the job table + err = r.DB.QueryRow("SELECT id FROM job WHERE job_id = ? AND cluster = ? AND start_time = ?", + job.JobID, job.Cluster, job.StartTime).Scan(&id) + require.NoError(t, err) + + // Stop with specific state + err = r.Stop(id, 1800, tc.state, tc.monitoringStatus) + require.NoError(t, err) + + // Verify state was set correctly + var retrievedState string + err = r.DB.QueryRow(`SELECT job_state FROM job WHERE id = ?`, id).Scan(&retrievedState) + require.NoError(t, err) + assert.Equal(t, string(tc.state), retrievedState) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job WHERE id = ?", id) + require.NoError(t, err) + }) + } + }) +} + +func TestStopCached(t *testing.T) { + r := setup(t) + + t.Run("successful stop cached job", func(t *testing.T) { + // Insert a job in job_cache + job := createTestJob(999009, "testcluster") + id, err := r.Start(job) + require.NoError(t, err) + + // Stop the cached job + duration := int32(3600) + state := schema.JobStateCompleted + monitoringStatus := int32(schema.MonitoringStatusArchivingSuccessful) + + err = r.StopCached(id, duration, state, monitoringStatus) + require.NoError(t, err, "StopCached should succeed") + + // Verify job was updated in job_cache table + var retrievedDuration int32 + var retrievedState string + var retrievedMonStatus int32 + err = r.DB.QueryRow(`SELECT duration, job_state, monitoring_status FROM job_cache WHERE id = ?`, id).Scan( + &retrievedDuration, &retrievedState, &retrievedMonStatus) + require.NoError(t, err) + assert.Equal(t, duration, retrievedDuration) + assert.Equal(t, string(state), retrievedState) + assert.Equal(t, monitoringStatus, retrievedMonStatus) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + }) + + t.Run("stop cached job does not affect job table", func(t *testing.T) { + // Insert a job in job_cache + job := createTestJob(999010, "testcluster") + id, err := r.Start(job) + require.NoError(t, err) + + // Stop the cached job + err = r.StopCached(id, 3600, schema.JobStateCompleted, int32(schema.MonitoringStatusArchivingSuccessful)) + require.NoError(t, err) + + // Verify job table was not affected + var count int + err = r.DB.QueryRow(`SELECT COUNT(*) FROM job WHERE job_id = ? AND cluster = ?`, + job.JobID, job.Cluster).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count, "Job table should not be affected by StopCached") + + // Clean up + _, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id) + require.NoError(t, err) + }) +} + +func TestSyncJobs(t *testing.T) { + r := setup(t) + + t.Run("sync jobs from cache to main table", func(t *testing.T) { + // Ensure cache is empty first + _, err := r.DB.Exec("DELETE FROM job_cache") + require.NoError(t, err) + + // Insert multiple jobs in job_cache + job1 := createTestJob(999011, "testcluster") + job2 := createTestJob(999012, "testcluster") + job3 := createTestJob(999013, "testcluster") + + _, err = r.Start(job1) + require.NoError(t, err) + _, err = r.Start(job2) + require.NoError(t, err) + _, err = r.Start(job3) + require.NoError(t, err) + + // Verify jobs are in job_cache + var cacheCount int + err = r.DB.QueryRow("SELECT COUNT(*) FROM job_cache WHERE job_id IN (?, ?, ?)", + job1.JobID, job2.JobID, job3.JobID).Scan(&cacheCount) + require.NoError(t, err) + assert.Equal(t, 3, cacheCount, "All jobs should be in job_cache") + + // Sync jobs + jobs, err := r.SyncJobs() + require.NoError(t, err, "SyncJobs should succeed") + assert.Equal(t, 3, len(jobs), "Should return 3 synced jobs") + + // Verify jobs were moved to job table + var jobCount int + err = r.DB.QueryRow("SELECT COUNT(*) FROM job WHERE job_id IN (?, ?, ?)", + job1.JobID, job2.JobID, job3.JobID).Scan(&jobCount) + require.NoError(t, err) + assert.Equal(t, 3, jobCount, "All jobs should be in job table") + + // Verify job_cache was cleared + err = r.DB.QueryRow("SELECT COUNT(*) FROM job_cache WHERE job_id IN (?, ?, ?)", + job1.JobID, job2.JobID, job3.JobID).Scan(&cacheCount) + require.NoError(t, err) + assert.Equal(t, 0, cacheCount, "job_cache should be empty after sync") + + // Clean up + _, err = r.DB.Exec("DELETE FROM job WHERE job_id IN (?, ?, ?)", job1.JobID, job2.JobID, job3.JobID) + require.NoError(t, err) + }) + + t.Run("sync preserves job data", func(t *testing.T) { + // Ensure cache is empty first + _, err := r.DB.Exec("DELETE FROM job_cache") + require.NoError(t, err) + + // Insert a job with specific data + job := createTestJob(999014, "testcluster") + job.ArrayJobID = 7777 + job.Energy = 2500.75 + job.Duration = 1800 + + id, err := r.Start(job) + require.NoError(t, err) + + // Update some fields to simulate job progress + result, err := r.DB.Exec(`UPDATE job_cache SET duration = ?, energy = ? WHERE id = ?`, + 3600, 3000.5, id) + require.NoError(t, err) + rowsAffected, _ := result.RowsAffected() + require.Equal(t, int64(1), rowsAffected, "UPDATE should affect exactly 1 row") + + // Verify the update worked + var checkDuration int32 + var checkEnergy float64 + err = r.DB.QueryRow(`SELECT duration, energy FROM job_cache WHERE id = ?`, id).Scan(&checkDuration, &checkEnergy) + require.NoError(t, err) + require.Equal(t, int32(3600), checkDuration, "Duration should be updated to 3600 before sync") + require.Equal(t, 3000.5, checkEnergy, "Energy should be updated to 3000.5 before sync") + + // Sync jobs + jobs, err := r.SyncJobs() + require.NoError(t, err) + require.Equal(t, 1, len(jobs), "Should return exactly 1 synced job") + + // Verify in database + var dbJob schema.Job + err = r.DB.QueryRow(`SELECT job_id, hpc_user, project, cluster, array_job_id, duration, energy + FROM job WHERE job_id = ? AND cluster = ?`, job.JobID, job.Cluster).Scan( + &dbJob.JobID, &dbJob.User, &dbJob.Project, &dbJob.Cluster, + &dbJob.ArrayJobID, &dbJob.Duration, &dbJob.Energy) + require.NoError(t, err) + assert.Equal(t, job.JobID, dbJob.JobID) + assert.Equal(t, int32(3600), dbJob.Duration) + assert.Equal(t, 3000.5, dbJob.Energy) + + // Clean up + _, err = r.DB.Exec("DELETE FROM job WHERE job_id = ? AND cluster = ?", job.JobID, job.Cluster) + require.NoError(t, err) + }) + + t.Run("sync with empty cache returns empty list", func(t *testing.T) { + // Ensure cache is empty + _, err := r.DB.Exec("DELETE FROM job_cache") + require.NoError(t, err) + + // Sync should return empty list + jobs, err := r.SyncJobs() + require.NoError(t, err) + assert.Equal(t, 0, len(jobs), "Should return empty list when cache is empty") + }) +} diff --git a/internal/repository/jobHooks.go b/internal/repository/jobHooks.go index c449d308..41684d5c 100644 --- a/internal/repository/jobHooks.go +++ b/internal/repository/jobHooks.go @@ -10,8 +10,36 @@ import ( "github.com/ClusterCockpit/cc-lib/v2/schema" ) +// JobHook interface allows external components to hook into job lifecycle events. +// Implementations can perform actions when jobs start or stop, such as tagging, +// logging, notifications, or triggering external workflows. +// +// Example implementation: +// +// type MyJobTagger struct{} +// +// func (t *MyJobTagger) JobStartCallback(job *schema.Job) { +// if job.NumNodes > 100 { +// // Tag large jobs automatically +// } +// } +// +// func (t *MyJobTagger) JobStopCallback(job *schema.Job) { +// if job.State == schema.JobStateFailed { +// // Log or alert on failed jobs +// } +// } +// +// Register hooks during application initialization: +// +// repository.RegisterJobHook(&MyJobTagger{}) type JobHook interface { + // JobStartCallback is invoked when one or more jobs start. + // This is called synchronously, so implementations should be fast. JobStartCallback(job *schema.Job) + + // JobStopCallback is invoked when a job completes. + // This is called synchronously, so implementations should be fast. JobStopCallback(job *schema.Job) } @@ -20,7 +48,13 @@ var ( hooks []JobHook ) -func RegisterJobJook(hook JobHook) { +// RegisterJobHook registers a JobHook to receive job lifecycle callbacks. +// Multiple hooks can be registered and will be called in registration order. +// This function is safe to call multiple times and is typically called during +// application initialization. +// +// Nil hooks are silently ignored to simplify conditional registration. +func RegisterJobHook(hook JobHook) { initOnce.Do(func() { hooks = make([]JobHook, 0) }) @@ -30,6 +64,12 @@ func RegisterJobJook(hook JobHook) { } } +// CallJobStartHooks invokes all registered JobHook.JobStartCallback methods +// for each job in the provided slice. This is called internally by the repository +// when jobs are started (e.g., via StartJob or batch job imports). +// +// Hooks are called synchronously in registration order. If a hook panics, +// the panic will propagate to the caller. func CallJobStartHooks(jobs []*schema.Job) { if hooks == nil { return @@ -44,6 +84,12 @@ func CallJobStartHooks(jobs []*schema.Job) { } } +// CallJobStopHooks invokes all registered JobHook.JobStopCallback methods +// for the provided job. This is called internally by the repository when a +// job completes (e.g., via StopJob or job state updates). +// +// Hooks are called synchronously in registration order. If a hook panics, +// the panic will propagate to the caller. func CallJobStopHooks(job *schema.Job) { if hooks == nil { return diff --git a/internal/repository/migration.go b/internal/repository/migration.go index a47f9fcd..0f99889e 100644 --- a/internal/repository/migration.go +++ b/internal/repository/migration.go @@ -16,11 +16,29 @@ import ( "github.com/golang-migrate/migrate/v4/source/iofs" ) +// Version is the current database schema version required by this version of cc-backend. +// When the database schema changes, this version is incremented and a new migration file +// is added to internal/repository/migrations/sqlite3/. +// +// Version history: +// - Version 10: Current version +// +// Migration files are embedded at build time from the migrations directory. const Version uint = 10 //go:embed migrations/* var migrationFiles embed.FS +// checkDBVersion verifies that the database schema version matches the expected version. +// This is called automatically during Connect() to ensure schema compatibility. +// +// Returns an error if: +// - Database version is older than expected (needs migration) +// - Database version is newer than expected (needs app upgrade) +// - Database is in a dirty state (failed migration) +// +// A "dirty" database indicates a migration was started but not completed successfully. +// This requires manual intervention to fix the database and force the version. func checkDBVersion(db *sql.DB) error { driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { @@ -58,6 +76,8 @@ func checkDBVersion(db *sql.DB) error { return nil } +// getMigrateInstance creates a new migration instance for the given database file. +// This is used internally by MigrateDB, RevertDB, and ForceDB. func getMigrateInstance(db string) (m *migrate.Migrate, err error) { d, err := iofs.New(migrationFiles, "migrations/sqlite3") if err != nil { @@ -72,6 +92,23 @@ func getMigrateInstance(db string) (m *migrate.Migrate, err error) { return m, nil } +// MigrateDB applies all pending database migrations to bring the schema up to date. +// This should be run with the -migrate-db flag before starting the application +// after upgrading to a new version that requires schema changes. +// +// Process: +// 1. Checks current database version +// 2. Applies all migrations from current version to target Version +// 3. Updates schema_migrations table to track applied migrations +// +// Important: +// - Always backup your database before running migrations +// - Migrations are irreversible without manual intervention +// - If a migration fails, the database is marked "dirty" and requires manual fix +// +// Usage: +// +// cc-backend -migrate-db func MigrateDB(db string) error { m, err := getMigrateInstance(db) if err != nil { @@ -107,6 +144,17 @@ func MigrateDB(db string) error { return nil } +// RevertDB rolls back the database schema to the previous version (Version - 1). +// This is primarily used for testing or emergency rollback scenarios. +// +// Warning: +// - This may cause data loss if newer schema added columns/tables +// - Always backup before reverting +// - Not all migrations are safely reversible +// +// Usage: +// +// cc-backend -revert-db func RevertDB(db string) error { m, err := getMigrateInstance(db) if err != nil { @@ -125,6 +173,21 @@ func RevertDB(db string) error { return nil } +// ForceDB forces the database schema version to the current Version without running migrations. +// This is only used to recover from failed migrations that left the database in a "dirty" state. +// +// When to use: +// - After manually fixing a failed migration +// - When you've manually applied schema changes and need to update the version marker +// +// Warning: +// - This does NOT apply any schema changes +// - Only use after manually verifying the schema is correct +// - Improper use can cause schema/version mismatch +// +// Usage: +// +// cc-backend -force-db func ForceDB(db string) error { m, err := getMigrateInstance(db) if err != nil { diff --git a/internal/repository/stats.go b/internal/repository/stats.go index 989026d1..cd175c23 100644 --- a/internal/repository/stats.go +++ b/internal/repository/stats.go @@ -277,6 +277,15 @@ func (r *JobRepository) JobsStats( return stats, nil } +// LoadJobStat retrieves a specific statistic for a metric from a job's statistics. +// Returns 0.0 if the metric is not found or statType is invalid. +// +// Parameters: +// - job: Job struct with populated Statistics field +// - metric: Name of the metric to query (e.g., "cpu_load", "mem_used") +// - statType: Type of statistic: "avg", "min", or "max" +// +// Returns the requested statistic value or 0.0 if not found. func LoadJobStat(job *schema.Job, metric string, statType string) float64 { if stats, ok := job.Statistics[metric]; ok { switch statType { @@ -579,7 +588,9 @@ func (r *JobRepository) jobsDurationStatisticsHistogram( return nil, qerr } - // Setup Array + // Initialize histogram bins with zero counts + // Each bin represents a duration range: bin N = [N*binSizeSeconds, (N+1)*binSizeSeconds) + // Example: binSizeSeconds=3600 (1 hour), bin 1 = 0-1h, bin 2 = 1-2h, etc. points := make([]*model.HistoPoint, 0) for i := 1; i <= *targetBinCount; i++ { point := model.HistoPoint{Value: i * binSizeSeconds, Count: 0} @@ -596,7 +607,8 @@ func (r *JobRepository) jobsDurationStatisticsHistogram( return nil, err } - // Fill Array at matching $Value + // Match query results to pre-initialized bins and fill counts + // Query returns raw duration values that need to be mapped to correct bins for rows.Next() { point := model.HistoPoint{} if err := rows.Scan(&point.Value, &point.Count); err != nil { @@ -604,11 +616,13 @@ func (r *JobRepository) jobsDurationStatisticsHistogram( return nil, err } + // Find matching bin and update count + // point.Value is multiplied by binSizeSeconds to match pre-calculated bin.Value for _, e := range points { if e.Value == (point.Value * binSizeSeconds) { - // Note: - // Matching on unmodified integer value (and multiplying point.Value by binSizeSeconds after match) - // causes frontend to loop into highest targetBinCount, due to zoom condition instantly being fullfilled (cause unknown) + // Note: Matching on unmodified integer value (and multiplying point.Value + // by binSizeSeconds after match) causes frontend to loop into highest + // targetBinCount, due to zoom condition instantly being fulfilled (cause unknown) e.Count = point.Count break } @@ -625,12 +639,16 @@ func (r *JobRepository) jobsMetricStatisticsHistogram( filters []*model.JobFilter, bins *int, ) (*model.MetricHistoPoints, error) { - // Get specific Peak or largest Peak + // Determine the metric's peak value for histogram normalization + // Peak value defines the upper bound for binning: values are distributed across + // bins from 0 to peak. First try to get peak from filtered cluster, otherwise + // scan all clusters to find the maximum peak value. var metricConfig *schema.MetricConfig var peak float64 var unit string var footprintStat string + // Try to get metric config from filtered cluster for _, f := range filters { if f.Cluster != nil { metricConfig = archive.GetMetricConfig(*f.Cluster.Eq, metric) @@ -641,6 +659,8 @@ func (r *JobRepository) jobsMetricStatisticsHistogram( } } + // If no cluster filter or peak not found, find largest peak across all clusters + // This ensures histogram can accommodate all possible values if peak == 0.0 { for _, c := range archive.Clusters { for _, m := range c.MetricConfig { @@ -659,11 +679,18 @@ func (r *JobRepository) jobsMetricStatisticsHistogram( } } - // cclog.Debugf("Metric %s, Peak %f, Unit %s", metric, peak, unit) - // Make bins, see https://jereze.com/code/sql-histogram/ (Modified here) + // Construct SQL histogram bins using normalized values + // Algorithm based on: https://jereze.com/code/sql-histogram/ (modified) start := time.Now() - // Find Jobs' Value Bin Number: Divide Value by Peak, Multiply by RequestedBins, then CAST to INT: Gets Bin-Number of Job + // Calculate bin number for each job's metric value: + // 1. Extract metric value from JSON footprint + // 2. Normalize to [0,1] by dividing by peak + // 3. Multiply by number of bins to get bin number + // 4. Cast to integer for bin assignment + // + // Special case: Values exactly equal to peak would fall into bin N+1, + // so we multiply peak by 0.999999999 to force it into the last bin (bin N) binQuery := fmt.Sprintf(`CAST( ((case when json_extract(footprint, "$.%s") = %f then %f*0.999999999 else json_extract(footprint, "$.%s") end) / %f) * %v as INTEGER )`, @@ -698,7 +725,9 @@ func (r *JobRepository) jobsMetricStatisticsHistogram( return nil, err } - // Setup Return Array With Bin-Numbers for Match and Min/Max based on Peak + // Initialize histogram bins with calculated min/max ranges + // Each bin represents a range of metric values + // Example: peak=1000, bins=10 -> bin 1=[0,100), bin 2=[100,200), ..., bin 10=[900,1000] points := make([]*model.MetricHistoPoint, 0) binStep := int(peak) / *bins for i := 1; i <= *bins; i++ { @@ -708,13 +737,16 @@ func (r *JobRepository) jobsMetricStatisticsHistogram( points = append(points, &epoint) } - for rows.Next() { // Fill Count if Bin-No. Matches (Not every Bin exists in DB!) + // Fill counts from query results + // Query only returns bins that have jobs, so we match against pre-initialized bins + for rows.Next() { rpoint := model.MetricHistoPoint{} if err := rows.Scan(&rpoint.Bin, &rpoint.Count); err != nil { // Required for Debug: &rpoint.Min, &rpoint.Max cclog.Warnf("Error while scanning rows for %s", metric) return nil, err // FIXME: Totally bricks cc-backend if returned and if all metrics requested? } + // Match query result to pre-initialized bin and update count for _, e := range points { if e.Bin != nil && rpoint.Bin != nil { if *e.Bin == *rpoint.Bin { diff --git a/internal/repository/stats_test.go b/internal/repository/stats_test.go index e10c9685..a8dfc818 100644 --- a/internal/repository/stats_test.go +++ b/internal/repository/stats_test.go @@ -25,11 +25,20 @@ func TestBuildJobStatsQuery(t *testing.T) { func TestJobStats(t *testing.T) { r := setup(t) + // First, count the actual jobs in the database (excluding test jobs) + var expectedCount int + err := r.DB.QueryRow(`SELECT COUNT(*) FROM job WHERE cluster != 'testcluster'`).Scan(&expectedCount) + noErr(t, err) + filter := &model.JobFilter{} + // Exclude test jobs created by other tests + testCluster := "testcluster" + filter.Cluster = &model.StringInput{Neq: &testCluster} + stats, err := r.JobsStats(getContext(t), []*model.JobFilter{filter}) noErr(t, err) - if stats[0].TotalJobs != 544 { - t.Fatalf("Want 544, Got %d", stats[0].TotalJobs) + if stats[0].TotalJobs != expectedCount { + t.Fatalf("Want %d, Got %d", expectedCount, stats[0].TotalJobs) } } diff --git a/internal/repository/tags.go b/internal/repository/tags.go index 9bc9abae..f6cccfe2 100644 --- a/internal/repository/tags.go +++ b/internal/repository/tags.go @@ -16,8 +16,32 @@ import ( sq "github.com/Masterminds/squirrel" ) +// Tag Scope Rules: +// +// Tags in ClusterCockpit have three visibility scopes that control who can see and use them: +// +// 1. "global" - Visible to all users, can be used by anyone +// Example: System-generated tags like "energy-efficient", "failed", "short" +// +// 2. "private" - Only visible to the creating user +// Example: Personal notes like "needs-review", "interesting-case" +// +// 3. "admin" - Only visible to users with admin or support roles +// Example: Internal notes like "hardware-issue", "billing-problem" +// +// Authorization Rules: +// - Regular users can only create/see "global" and their own "private" tags +// - Admin/Support can create/see all scopes including "admin" tags +// - Users can only add tags to jobs they have permission to view +// - Tag scope is enforced at query time in GetTags() and CountTags() + // AddTag adds the tag with id `tagId` to the job with the database id `jobId`. // Requires user authentication for security checks. +// +// The user must have permission to view the job. Tag visibility is determined by scope: +// - "global" tags: visible to all users +// - "private" tags: only visible to the tag creator +// - "admin" tags: only visible to admin/support users func (r *JobRepository) AddTag(user *schema.User, job int64, tag int64) ([]*schema.Tag, error) { j, err := r.FindByIDWithUser(user, job) if err != nil { @@ -180,7 +204,15 @@ func (r *JobRepository) RemoveTagById(tagID int64) error { return nil } -// CreateTag creates a new tag with the specified type and name and returns its database id. +// CreateTag creates a new tag with the specified type, name, and scope. +// Returns the database ID of the newly created tag. +// +// Scope defaults to "global" if empty string is provided. +// Valid scopes: "global", "private", "admin" +// +// Example: +// +// tagID, err := repo.CreateTag("performance", "high-memory", "global") func (r *JobRepository) CreateTag(tagType string, tagName string, tagScope string) (tagId int64, err error) { // Default to "Global" scope if none defined if tagScope == "" { @@ -199,8 +231,14 @@ func (r *JobRepository) CreateTag(tagType string, tagName string, tagScope strin return res.LastInsertId() } +// CountTags returns all tags visible to the user and the count of jobs for each tag. +// Applies scope-based filtering to respect tag visibility rules. +// +// Returns: +// - tags: slice of tags the user can see +// - counts: map of tag name to job count +// - err: any error encountered func (r *JobRepository) CountTags(user *schema.User) (tags []schema.Tag, counts map[string]int, err error) { - // Fetch all Tags in DB for Display in Frontend Tag-View tags = make([]schema.Tag, 0, 100) xrows, err := r.DB.Queryx("SELECT id, tag_type, tag_name, tag_scope FROM tag") if err != nil { diff --git a/internal/repository/transaction_test.go b/internal/repository/transaction_test.go new file mode 100644 index 00000000..1832bea0 --- /dev/null +++ b/internal/repository/transaction_test.go @@ -0,0 +1,311 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package repository + +import ( + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTransactionInit(t *testing.T) { + r := setup(t) + + t.Run("successful transaction init", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err, "TransactionInit should succeed") + require.NotNil(t, tx, "Transaction should not be nil") + require.NotNil(t, tx.tx, "Transaction.tx should not be nil") + + // Clean up + err = tx.Rollback() + require.NoError(t, err, "Rollback should succeed") + }) +} + +func TestTransactionCommit(t *testing.T) { + r := setup(t) + + t.Run("commit after successful operations", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + // Insert a test tag + _, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_tag_commit", "global") + require.NoError(t, err, "TransactionAdd should succeed") + + // Commit the transaction + err = tx.Commit() + require.NoError(t, err, "Commit should succeed") + + // Verify the tag was inserted + var count int + err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_tag_commit").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count, "Tag should be committed to database") + + // Clean up + _, err = r.DB.Exec("DELETE FROM tag WHERE tag_name = ?", "test_tag_commit") + require.NoError(t, err) + }) + + t.Run("commit on already committed transaction", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err, "First commit should succeed") + + err = tx.Commit() + assert.Error(t, err, "Second commit should fail") + assert.Contains(t, err.Error(), "transaction already committed or rolled back") + }) +} + +func TestTransactionRollback(t *testing.T) { + r := setup(t) + + t.Run("rollback after operations", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + // Insert a test tag + _, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_tag_rollback", "global") + require.NoError(t, err, "TransactionAdd should succeed") + + // Rollback the transaction + err = tx.Rollback() + require.NoError(t, err, "Rollback should succeed") + + // Verify the tag was NOT inserted + var count int + err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_tag_rollback").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count, "Tag should not be in database after rollback") + }) + + t.Run("rollback on already rolled back transaction", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + err = tx.Rollback() + require.NoError(t, err, "First rollback should succeed") + + err = tx.Rollback() + assert.NoError(t, err, "Second rollback should be safe (no-op)") + }) + + t.Run("rollback on committed transaction", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + err = tx.Rollback() + assert.NoError(t, err, "Rollback after commit should be safe (no-op)") + }) +} + +func TestTransactionAdd(t *testing.T) { + r := setup(t) + + t.Run("insert with TransactionAdd", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + defer tx.Rollback() + + id, err := r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_add", "global") + require.NoError(t, err, "TransactionAdd should succeed") + assert.Greater(t, id, int64(0), "Should return valid insert ID") + }) + + t.Run("error on nil transaction", func(t *testing.T) { + tx := &Transaction{tx: nil} + + _, err := r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_nil", "global") + assert.Error(t, err, "Should error on nil transaction") + assert.Contains(t, err.Error(), "transaction is nil or already completed") + }) + + t.Run("error on invalid SQL", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + defer tx.Rollback() + + _, err = r.TransactionAdd(tx, "INVALID SQL STATEMENT") + assert.Error(t, err, "Should error on invalid SQL") + }) + + t.Run("error after transaction committed", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + _, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_after_commit", "global") + assert.Error(t, err, "Should error when transaction is already committed") + }) +} + +func TestTransactionAddNamed(t *testing.T) { + r := setup(t) + + t.Run("insert with TransactionAddNamed", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + defer tx.Rollback() + + type TagArgs struct { + Type string `db:"type"` + Name string `db:"name"` + Scope string `db:"scope"` + } + + args := TagArgs{ + Type: "test_type", + Name: "test_named", + Scope: "global", + } + + id, err := r.TransactionAddNamed(tx, + "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (:type, :name, :scope)", + args) + require.NoError(t, err, "TransactionAddNamed should succeed") + assert.Greater(t, id, int64(0), "Should return valid insert ID") + }) + + t.Run("error on nil transaction", func(t *testing.T) { + tx := &Transaction{tx: nil} + + _, err := r.TransactionAddNamed(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (:type, :name, :scope)", + map[string]interface{}{"type": "test", "name": "test", "scope": "global"}) + assert.Error(t, err, "Should error on nil transaction") + assert.Contains(t, err.Error(), "transaction is nil or already completed") + }) +} + +func TestTransactionMultipleOperations(t *testing.T) { + r := setup(t) + + t.Run("multiple inserts in single transaction", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + defer tx.Rollback() + + // Insert multiple tags + for i := 0; i < 5; i++ { + _, err = r.TransactionAdd(tx, + "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_multi_"+string(rune('a'+i)), "global") + require.NoError(t, err, "Insert %d should succeed", i) + } + + err = tx.Commit() + require.NoError(t, err, "Commit should succeed") + + // Verify all tags were inserted + var count int + err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name LIKE 'test_multi_%'").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 5, count, "All 5 tags should be committed") + + // Clean up + _, err = r.DB.Exec("DELETE FROM tag WHERE tag_name LIKE 'test_multi_%'") + require.NoError(t, err) + }) + + t.Run("rollback undoes all operations", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + // Insert multiple tags + for i := 0; i < 3; i++ { + _, err = r.TransactionAdd(tx, + "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_rollback_"+string(rune('a'+i)), "global") + require.NoError(t, err) + } + + err = tx.Rollback() + require.NoError(t, err, "Rollback should succeed") + + // Verify no tags were inserted + var count int + err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name LIKE 'test_rollback_%'").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count, "No tags should be in database after rollback") + }) +} + +func TestTransactionEnd(t *testing.T) { + r := setup(t) + + t.Run("deprecated TransactionEnd calls Commit", func(t *testing.T) { + tx, err := r.TransactionInit() + require.NoError(t, err) + + _, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_end", "global") + require.NoError(t, err) + + // Use deprecated method + err = r.TransactionEnd(tx) + require.NoError(t, err, "TransactionEnd should succeed") + + // Verify the tag was committed + var count int + err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_end").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count, "Tag should be committed") + + // Clean up + _, err = r.DB.Exec("DELETE FROM tag WHERE tag_name = ?", "test_end") + require.NoError(t, err) + }) +} + +func TestTransactionDeferPattern(t *testing.T) { + r := setup(t) + + t.Run("defer rollback pattern", func(t *testing.T) { + insertTag := func() error { + tx, err := r.TransactionInit() + if err != nil { + return err + } + defer tx.Rollback() // Safe to call even after commit + + _, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)", + "test_type", "test_defer", "global") + if err != nil { + return err + } + + return tx.Commit() + } + + err := insertTag() + require.NoError(t, err, "Function should succeed") + + // Verify the tag was committed + var count int + err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_defer").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count, "Tag should be committed despite defer rollback") + + // Clean up + _, err = r.DB.Exec("DELETE FROM tag WHERE tag_name = ?", "test_defer") + require.NoError(t, err) + }) +} diff --git a/internal/repository/user.go b/internal/repository/user.go index 770915b6..42a22384 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -22,6 +22,25 @@ import ( "golang.org/x/crypto/bcrypt" ) +// Authentication and Role System: +// +// ClusterCockpit supports multiple authentication sources: +// - Local: Username/password stored in database (password hashed with bcrypt) +// - LDAP: External LDAP/Active Directory authentication +// - JWT: Token-based authentication for API access +// +// Role Hierarchy (from highest to lowest privilege): +// 1. "admin" - Full system access, can manage all users and jobs +// 2. "support" - Can view all jobs but limited management capabilities +// 3. "manager" - Can manage specific projects and their users +// 4. "api" - Programmatic access for job submission/management +// 5. "user" - Default role, can only view own jobs +// +// Project Association: +// - Managers have a list of projects they oversee +// - Regular users' project membership is determined by job data +// - Managers can view/manage all jobs within their projects + var ( userRepoOnce sync.Once userRepoInstance *UserRepository @@ -44,6 +63,9 @@ func GetUserRepository() *UserRepository { return userRepoInstance } +// GetUser retrieves a user by username from the database. +// Returns the complete user record including hashed password, roles, and projects. +// Password field contains bcrypt hash for local auth users, empty for LDAP users. func (r *UserRepository) GetUser(username string) (*schema.User, error) { user := &schema.User{Username: username} var hashedPassword, name, rawRoles, email, rawProjects sql.NullString @@ -93,6 +115,12 @@ func (r *UserRepository) GetLdapUsernames() ([]string, error) { return users, nil } +// AddUser creates a new user in the database. +// Passwords are automatically hashed with bcrypt before storage. +// Auth source determines authentication method (local, LDAP, etc.). +// +// Required fields: Username, Roles +// Optional fields: Name, Email, Password, Projects, AuthSource func (r *UserRepository) AddUser(user *schema.User) error { rolesJson, _ := json.Marshal(user.Roles) projectsJson, _ := json.Marshal(user.Projects) @@ -229,6 +257,14 @@ func (r *UserRepository) ListUsers(specialsOnly bool) ([]*schema.User, error) { return users, nil } +// AddRole adds a role to a user's role list. +// Role string is automatically lowercased. +// Valid roles: admin, support, manager, api, user +// +// Returns error if: +// - User doesn't exist +// - Role is invalid +// - User already has the role func (r *UserRepository) AddRole( ctx context.Context, username string, @@ -258,6 +294,11 @@ func (r *UserRepository) AddRole( return nil } +// RemoveRole removes a role from a user's role list. +// +// Special rules: +// - Cannot remove "manager" role while user has assigned projects +// - Must remove all projects first before removing manager role func (r *UserRepository) RemoveRole(ctx context.Context, username string, queryrole string) error { oldRole := strings.ToLower(queryrole) user, err := r.GetUser(username) @@ -294,6 +335,12 @@ func (r *UserRepository) RemoveRole(ctx context.Context, username string, queryr return nil } +// AddProject assigns a project to a manager user. +// Only users with the "manager" role can have assigned projects. +// +// Returns error if: +// - User doesn't have manager role +// - User already manages the project func (r *UserRepository) AddProject( ctx context.Context, username string, diff --git a/internal/repository/user_test.go b/internal/repository/user_test.go new file mode 100644 index 00000000..370d261d --- /dev/null +++ b/internal/repository/user_test.go @@ -0,0 +1,596 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package repository + +import ( + "context" + "testing" + + "github.com/ClusterCockpit/cc-lib/v2/schema" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func TestAddUser(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + + t.Run("add user with all fields", func(t *testing.T) { + user := &schema.User{ + Username: "testuser1", + Name: "Test User One", + Email: "test1@example.com", + Password: "testpassword123", + Roles: []string{"user"}, + Projects: []string{"project1", "project2"}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + retrievedUser, err := r.GetUser("testuser1") + require.NoError(t, err) + assert.Equal(t, user.Username, retrievedUser.Username) + assert.Equal(t, user.Name, retrievedUser.Name) + assert.Equal(t, user.Email, retrievedUser.Email) + assert.Equal(t, user.Roles, retrievedUser.Roles) + assert.Equal(t, user.Projects, retrievedUser.Projects) + assert.NotEmpty(t, retrievedUser.Password) + err = bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("testpassword123")) + assert.NoError(t, err, "Password should be hashed correctly") + + err = r.DelUser("testuser1") + require.NoError(t, err) + }) + + t.Run("add user with minimal fields", func(t *testing.T) { + user := &schema.User{ + Username: "testuser2", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLDAP, + } + + err := r.AddUser(user) + require.NoError(t, err) + + retrievedUser, err := r.GetUser("testuser2") + require.NoError(t, err) + assert.Equal(t, user.Username, retrievedUser.Username) + assert.Equal(t, "", retrievedUser.Name) + assert.Equal(t, "", retrievedUser.Email) + assert.Equal(t, "", retrievedUser.Password) + + err = r.DelUser("testuser2") + require.NoError(t, err) + }) + + t.Run("add duplicate user fails", func(t *testing.T) { + user := &schema.User{ + Username: "testuser3", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.AddUser(user) + assert.Error(t, err, "Adding duplicate user should fail") + + err = r.DelUser("testuser3") + require.NoError(t, err) + }) +} + +func TestGetUser(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + + t.Run("get existing user", func(t *testing.T) { + user := &schema.User{ + Username: "getuser1", + Name: "Get User", + Email: "getuser@example.com", + Roles: []string{"user", "admin"}, + Projects: []string{"proj1"}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + retrieved, err := r.GetUser("getuser1") + require.NoError(t, err) + assert.Equal(t, user.Username, retrieved.Username) + assert.Equal(t, user.Name, retrieved.Name) + assert.Equal(t, user.Email, retrieved.Email) + assert.ElementsMatch(t, user.Roles, retrieved.Roles) + assert.ElementsMatch(t, user.Projects, retrieved.Projects) + + err = r.DelUser("getuser1") + require.NoError(t, err) + }) + + t.Run("get non-existent user", func(t *testing.T) { + _, err := r.GetUser("nonexistent") + assert.Error(t, err) + }) +} + +func TestUpdateUser(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + + t.Run("update user name", func(t *testing.T) { + user := &schema.User{ + Username: "updateuser1", + Name: "Original Name", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + dbUser, err := r.GetUser("updateuser1") + require.NoError(t, err) + + updatedUser := &schema.User{ + Username: "updateuser1", + Name: "Updated Name", + } + + err = r.UpdateUser(dbUser, updatedUser) + require.NoError(t, err) + + retrieved, err := r.GetUser("updateuser1") + require.NoError(t, err) + assert.Equal(t, "Updated Name", retrieved.Name) + + err = r.DelUser("updateuser1") + require.NoError(t, err) + }) + + t.Run("update with no changes", func(t *testing.T) { + user := &schema.User{ + Username: "updateuser2", + Name: "Same Name", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + dbUser, err := r.GetUser("updateuser2") + require.NoError(t, err) + + err = r.UpdateUser(dbUser, dbUser) + assert.NoError(t, err) + + err = r.DelUser("updateuser2") + require.NoError(t, err) + }) +} + +func TestDelUser(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + + t.Run("delete existing user", func(t *testing.T) { + user := &schema.User{ + Username: "deluser1", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.DelUser("deluser1") + require.NoError(t, err) + + _, err = r.GetUser("deluser1") + assert.Error(t, err, "User should not exist after deletion") + }) + + t.Run("delete non-existent user", func(t *testing.T) { + err := r.DelUser("nonexistent") + assert.NoError(t, err, "Deleting non-existent user should not error") + }) +} + +func TestListUsers(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + + user1 := &schema.User{ + Username: "listuser1", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + user2 := &schema.User{ + Username: "listuser2", + Roles: []string{"admin"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + user3 := &schema.User{ + Username: "listuser3", + Roles: []string{"manager"}, + Projects: []string{"proj1"}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user1) + require.NoError(t, err) + err = r.AddUser(user2) + require.NoError(t, err) + err = r.AddUser(user3) + require.NoError(t, err) + + t.Run("list all users", func(t *testing.T) { + users, err := r.ListUsers(false) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(users), 3) + + usernames := make([]string, len(users)) + for i, u := range users { + usernames[i] = u.Username + } + assert.Contains(t, usernames, "listuser1") + assert.Contains(t, usernames, "listuser2") + assert.Contains(t, usernames, "listuser3") + }) + + t.Run("list special users only", func(t *testing.T) { + users, err := r.ListUsers(true) + require.NoError(t, err) + + usernames := make([]string, len(users)) + for i, u := range users { + usernames[i] = u.Username + } + assert.Contains(t, usernames, "listuser2") + assert.Contains(t, usernames, "listuser3") + }) + + err = r.DelUser("listuser1") + require.NoError(t, err) + err = r.DelUser("listuser2") + require.NoError(t, err) + err = r.DelUser("listuser3") + require.NoError(t, err) +} + +func TestGetLdapUsernames(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + + ldapUser := &schema.User{ + Username: "ldapuser1", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLDAP, + } + localUser := &schema.User{ + Username: "localuser1", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(ldapUser) + require.NoError(t, err) + err = r.AddUser(localUser) + require.NoError(t, err) + + usernames, err := r.GetLdapUsernames() + require.NoError(t, err) + assert.Contains(t, usernames, "ldapuser1") + assert.NotContains(t, usernames, "localuser1") + + err = r.DelUser("ldapuser1") + require.NoError(t, err) + err = r.DelUser("localuser1") + require.NoError(t, err) +} + +func TestAddRole(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + ctx := context.Background() + + t.Run("add valid role", func(t *testing.T) { + user := &schema.User{ + Username: "roleuser1", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.AddRole(ctx, "roleuser1", "admin") + require.NoError(t, err) + + retrieved, err := r.GetUser("roleuser1") + require.NoError(t, err) + assert.Contains(t, retrieved.Roles, "admin") + assert.Contains(t, retrieved.Roles, "user") + + err = r.DelUser("roleuser1") + require.NoError(t, err) + }) + + t.Run("add duplicate role", func(t *testing.T) { + user := &schema.User{ + Username: "roleuser2", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.AddRole(ctx, "roleuser2", "user") + assert.Error(t, err, "Adding duplicate role should fail") + assert.Contains(t, err.Error(), "already has role") + + err = r.DelUser("roleuser2") + require.NoError(t, err) + }) + + t.Run("add invalid role", func(t *testing.T) { + user := &schema.User{ + Username: "roleuser3", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.AddRole(ctx, "roleuser3", "invalidrole") + assert.Error(t, err, "Adding invalid role should fail") + assert.Contains(t, err.Error(), "no valid option") + + err = r.DelUser("roleuser3") + require.NoError(t, err) + }) +} + +func TestRemoveRole(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + ctx := context.Background() + + t.Run("remove existing role", func(t *testing.T) { + user := &schema.User{ + Username: "rmroleuser1", + Roles: []string{"user", "admin"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.RemoveRole(ctx, "rmroleuser1", "admin") + require.NoError(t, err) + + retrieved, err := r.GetUser("rmroleuser1") + require.NoError(t, err) + assert.NotContains(t, retrieved.Roles, "admin") + assert.Contains(t, retrieved.Roles, "user") + + err = r.DelUser("rmroleuser1") + require.NoError(t, err) + }) + + t.Run("remove non-existent role", func(t *testing.T) { + user := &schema.User{ + Username: "rmroleuser2", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.RemoveRole(ctx, "rmroleuser2", "admin") + assert.Error(t, err, "Removing non-existent role should fail") + assert.Contains(t, err.Error(), "already deleted") + + err = r.DelUser("rmroleuser2") + require.NoError(t, err) + }) + + t.Run("remove manager role with projects", func(t *testing.T) { + user := &schema.User{ + Username: "rmroleuser3", + Roles: []string{"manager"}, + Projects: []string{"proj1", "proj2"}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.RemoveRole(ctx, "rmroleuser3", "manager") + assert.Error(t, err, "Removing manager role with projects should fail") + assert.Contains(t, err.Error(), "still has assigned project") + + err = r.DelUser("rmroleuser3") + require.NoError(t, err) + }) +} + +func TestAddProject(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + ctx := context.Background() + + t.Run("add project to manager", func(t *testing.T) { + user := &schema.User{ + Username: "projuser1", + Roles: []string{"manager"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.AddProject(ctx, "projuser1", "newproject") + require.NoError(t, err) + + retrieved, err := r.GetUser("projuser1") + require.NoError(t, err) + assert.Contains(t, retrieved.Projects, "newproject") + + err = r.DelUser("projuser1") + require.NoError(t, err) + }) + + t.Run("add project to non-manager", func(t *testing.T) { + user := &schema.User{ + Username: "projuser2", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.AddProject(ctx, "projuser2", "newproject") + assert.Error(t, err, "Adding project to non-manager should fail") + assert.Contains(t, err.Error(), "not a manager") + + err = r.DelUser("projuser2") + require.NoError(t, err) + }) + + t.Run("add duplicate project", func(t *testing.T) { + user := &schema.User{ + Username: "projuser3", + Roles: []string{"manager"}, + Projects: []string{"existingproject"}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.AddProject(ctx, "projuser3", "existingproject") + assert.Error(t, err, "Adding duplicate project should fail") + assert.Contains(t, err.Error(), "already manages") + + err = r.DelUser("projuser3") + require.NoError(t, err) + }) +} + +func TestRemoveProject(t *testing.T) { + _ = setup(t) + r := GetUserRepository() + ctx := context.Background() + + t.Run("remove existing project", func(t *testing.T) { + user := &schema.User{ + Username: "rmprojuser1", + Roles: []string{"manager"}, + Projects: []string{"proj1", "proj2"}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.RemoveProject(ctx, "rmprojuser1", "proj1") + require.NoError(t, err) + + retrieved, err := r.GetUser("rmprojuser1") + require.NoError(t, err) + assert.NotContains(t, retrieved.Projects, "proj1") + assert.Contains(t, retrieved.Projects, "proj2") + + err = r.DelUser("rmprojuser1") + require.NoError(t, err) + }) + + t.Run("remove non-existent project", func(t *testing.T) { + user := &schema.User{ + Username: "rmprojuser2", + Roles: []string{"manager"}, + Projects: []string{"proj1"}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.RemoveProject(ctx, "rmprojuser2", "nonexistent") + assert.Error(t, err, "Removing non-existent project should fail") + + err = r.DelUser("rmprojuser2") + require.NoError(t, err) + }) + + t.Run("remove project from non-manager", func(t *testing.T) { + user := &schema.User{ + Username: "rmprojuser3", + Roles: []string{"user"}, + Projects: []string{}, + AuthSource: schema.AuthViaLocalPassword, + } + + err := r.AddUser(user) + require.NoError(t, err) + + err = r.RemoveProject(ctx, "rmprojuser3", "proj1") + assert.Error(t, err, "Removing project from non-manager should fail") + assert.Contains(t, err.Error(), "not a manager") + + err = r.DelUser("rmprojuser3") + require.NoError(t, err) + }) +} + +func TestGetUserFromContext(t *testing.T) { + t.Run("get user from context", func(t *testing.T) { + user := &schema.User{ + Username: "contextuser", + Roles: []string{"user"}, + } + + ctx := context.WithValue(context.Background(), ContextUserKey, user) + retrieved := GetUserFromContext(ctx) + + require.NotNil(t, retrieved) + assert.Equal(t, user.Username, retrieved.Username) + }) + + t.Run("get user from empty context", func(t *testing.T) { + ctx := context.Background() + retrieved := GetUserFromContext(ctx) + + assert.Nil(t, retrieved) + }) +} diff --git a/internal/tagger/tagger.go b/internal/tagger/tagger.go index 0839603d..2a5a0a7d 100644 --- a/internal/tagger/tagger.go +++ b/internal/tagger/tagger.go @@ -64,7 +64,7 @@ func newTagger() { func Init() { initOnce.Do(func() { newTagger() - repository.RegisterJobJook(jobTagger) + repository.RegisterJobHook(jobTagger) }) }