mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2026-01-15 17:21:46 +01:00
Improve documentation and add more tests
This commit is contained in:
274
internal/repository/hooks_test.go
Normal file
274
internal/repository/hooks_test.go
Normal file
@@ -0,0 +1,274 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type MockJobHook struct {
|
||||
startCalled bool
|
||||
stopCalled bool
|
||||
startJobs []*schema.Job
|
||||
stopJobs []*schema.Job
|
||||
}
|
||||
|
||||
func (m *MockJobHook) JobStartCallback(job *schema.Job) {
|
||||
m.startCalled = true
|
||||
m.startJobs = append(m.startJobs, job)
|
||||
}
|
||||
|
||||
func (m *MockJobHook) JobStopCallback(job *schema.Job) {
|
||||
m.stopCalled = true
|
||||
m.stopJobs = append(m.stopJobs, job)
|
||||
}
|
||||
|
||||
func TestRegisterJobHook(t *testing.T) {
|
||||
t.Run("register single hook", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock := &MockJobHook{}
|
||||
|
||||
RegisterJobHook(mock)
|
||||
|
||||
assert.NotNil(t, hooks)
|
||||
assert.Len(t, hooks, 1)
|
||||
assert.Equal(t, mock, hooks[0])
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("register multiple hooks", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock1 := &MockJobHook{}
|
||||
mock2 := &MockJobHook{}
|
||||
|
||||
RegisterJobHook(mock1)
|
||||
RegisterJobHook(mock2)
|
||||
|
||||
assert.Len(t, hooks, 2)
|
||||
assert.Equal(t, mock1, hooks[0])
|
||||
assert.Equal(t, mock2, hooks[1])
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("register nil hook does not add to hooks", func(t *testing.T) {
|
||||
hooks = nil
|
||||
RegisterJobHook(nil)
|
||||
|
||||
if hooks != nil {
|
||||
assert.Len(t, hooks, 0, "Nil hook should not be added")
|
||||
}
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestCallJobStartHooks(t *testing.T) {
|
||||
t.Run("call start hooks with single job", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock := &MockJobHook{}
|
||||
RegisterJobHook(mock)
|
||||
|
||||
job := &schema.Job{
|
||||
JobID: 123,
|
||||
User: "testuser",
|
||||
Cluster: "testcluster",
|
||||
}
|
||||
|
||||
CallJobStartHooks([]*schema.Job{job})
|
||||
|
||||
assert.True(t, mock.startCalled)
|
||||
assert.False(t, mock.stopCalled)
|
||||
assert.Len(t, mock.startJobs, 1)
|
||||
assert.Equal(t, int64(123), mock.startJobs[0].JobID)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("call start hooks with multiple jobs", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock := &MockJobHook{}
|
||||
RegisterJobHook(mock)
|
||||
|
||||
jobs := []*schema.Job{
|
||||
{JobID: 1, User: "user1", Cluster: "cluster1"},
|
||||
{JobID: 2, User: "user2", Cluster: "cluster2"},
|
||||
{JobID: 3, User: "user3", Cluster: "cluster3"},
|
||||
}
|
||||
|
||||
CallJobStartHooks(jobs)
|
||||
|
||||
assert.True(t, mock.startCalled)
|
||||
assert.Len(t, mock.startJobs, 3)
|
||||
assert.Equal(t, int64(1), mock.startJobs[0].JobID)
|
||||
assert.Equal(t, int64(2), mock.startJobs[1].JobID)
|
||||
assert.Equal(t, int64(3), mock.startJobs[2].JobID)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("call start hooks with multiple registered hooks", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock1 := &MockJobHook{}
|
||||
mock2 := &MockJobHook{}
|
||||
RegisterJobHook(mock1)
|
||||
RegisterJobHook(mock2)
|
||||
|
||||
job := &schema.Job{
|
||||
JobID: 456, User: "testuser", Cluster: "testcluster",
|
||||
}
|
||||
|
||||
CallJobStartHooks([]*schema.Job{job})
|
||||
|
||||
assert.True(t, mock1.startCalled)
|
||||
assert.True(t, mock2.startCalled)
|
||||
assert.Len(t, mock1.startJobs, 1)
|
||||
assert.Len(t, mock2.startJobs, 1)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("call start hooks with nil hooks", func(t *testing.T) {
|
||||
hooks = nil
|
||||
|
||||
job := &schema.Job{
|
||||
JobID: 789, User: "testuser", Cluster: "testcluster",
|
||||
}
|
||||
|
||||
CallJobStartHooks([]*schema.Job{job})
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("call start hooks with empty job list", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock := &MockJobHook{}
|
||||
RegisterJobHook(mock)
|
||||
|
||||
CallJobStartHooks([]*schema.Job{})
|
||||
|
||||
assert.False(t, mock.startCalled)
|
||||
assert.Len(t, mock.startJobs, 0)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestCallJobStopHooks(t *testing.T) {
|
||||
t.Run("call stop hooks with single job", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock := &MockJobHook{}
|
||||
RegisterJobHook(mock)
|
||||
|
||||
job := &schema.Job{
|
||||
JobID: 123,
|
||||
User: "testuser",
|
||||
Cluster: "testcluster",
|
||||
}
|
||||
|
||||
CallJobStopHooks(job)
|
||||
|
||||
assert.True(t, mock.stopCalled)
|
||||
assert.False(t, mock.startCalled)
|
||||
assert.Len(t, mock.stopJobs, 1)
|
||||
assert.Equal(t, int64(123), mock.stopJobs[0].JobID)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("call stop hooks with multiple registered hooks", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock1 := &MockJobHook{}
|
||||
mock2 := &MockJobHook{}
|
||||
RegisterJobHook(mock1)
|
||||
RegisterJobHook(mock2)
|
||||
|
||||
job := &schema.Job{
|
||||
JobID: 456, User: "testuser", Cluster: "testcluster",
|
||||
}
|
||||
|
||||
CallJobStopHooks(job)
|
||||
|
||||
assert.True(t, mock1.stopCalled)
|
||||
assert.True(t, mock2.stopCalled)
|
||||
assert.Len(t, mock1.stopJobs, 1)
|
||||
assert.Len(t, mock2.stopJobs, 1)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
|
||||
t.Run("call stop hooks with nil hooks", func(t *testing.T) {
|
||||
hooks = nil
|
||||
|
||||
job := &schema.Job{
|
||||
JobID: 789, User: "testuser", Cluster: "testcluster",
|
||||
}
|
||||
|
||||
CallJobStopHooks(job)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQLHooks(t *testing.T) {
|
||||
_ = setup(t)
|
||||
|
||||
t.Run("hooks log queries in debug mode", func(t *testing.T) {
|
||||
h := &Hooks{}
|
||||
|
||||
ctx := context.Background()
|
||||
query := "SELECT * FROM job WHERE job_id = ?"
|
||||
args := []any{123}
|
||||
|
||||
ctxWithTime, err := h.Before(ctx, query, args...)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, ctxWithTime)
|
||||
|
||||
beginTime := ctxWithTime.Value("begin")
|
||||
require.NotNil(t, beginTime)
|
||||
_, ok := beginTime.(time.Time)
|
||||
assert.True(t, ok, "Begin time should be time.Time")
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ctxAfter, err := h.After(ctxWithTime, query, args...)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, ctxAfter)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHookIntegration(t *testing.T) {
|
||||
t.Run("hooks are called during job lifecycle", func(t *testing.T) {
|
||||
hooks = nil
|
||||
mock := &MockJobHook{}
|
||||
RegisterJobHook(mock)
|
||||
|
||||
job := &schema.Job{
|
||||
JobID: 999,
|
||||
User: "integrationuser",
|
||||
Cluster: "integrationcluster",
|
||||
}
|
||||
|
||||
CallJobStartHooks([]*schema.Job{job})
|
||||
assert.True(t, mock.startCalled)
|
||||
assert.Equal(t, 1, len(mock.startJobs))
|
||||
|
||||
CallJobStopHooks(job)
|
||||
assert.True(t, mock.stopCalled)
|
||||
assert.Equal(t, 1, len(mock.stopJobs))
|
||||
|
||||
assert.Equal(t, mock.startJobs[0].JobID, mock.stopJobs[0].JobID)
|
||||
|
||||
hooks = nil
|
||||
})
|
||||
}
|
||||
@@ -80,18 +80,33 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
jobRepoOnce sync.Once
|
||||
// jobRepoOnce ensures singleton initialization of the JobRepository
|
||||
jobRepoOnce sync.Once
|
||||
// jobRepoInstance holds the single instance of JobRepository
|
||||
jobRepoInstance *JobRepository
|
||||
)
|
||||
|
||||
// JobRepository provides database access for job-related operations.
|
||||
// It implements the repository pattern to abstract database interactions
|
||||
// and provides caching for improved performance.
|
||||
//
|
||||
// The repository is a singleton initialized via GetJobRepository().
|
||||
// All database queries use prepared statements via stmtCache for efficiency.
|
||||
// Frequently accessed data (metadata, energy footprints) is cached in an LRU cache.
|
||||
type JobRepository struct {
|
||||
DB *sqlx.DB
|
||||
stmtCache *sq.StmtCache
|
||||
cache *lrucache.Cache
|
||||
driver string
|
||||
Mutex sync.Mutex
|
||||
DB *sqlx.DB // Database connection pool
|
||||
stmtCache *sq.StmtCache // Prepared statement cache for query optimization
|
||||
cache *lrucache.Cache // LRU cache for metadata and footprint data
|
||||
driver string // Database driver name (e.g., "sqlite3")
|
||||
Mutex sync.Mutex // Mutex for thread-safe operations
|
||||
}
|
||||
|
||||
// GetJobRepository returns the singleton instance of JobRepository.
|
||||
// The repository is initialized lazily on first access with database connection,
|
||||
// prepared statement cache, and LRU cache configured from repoConfig.
|
||||
//
|
||||
// This function is thread-safe and ensures only one instance is created.
|
||||
// It must be called after Connect() has established a database connection.
|
||||
func GetJobRepository() *JobRepository {
|
||||
jobRepoOnce.Do(func() {
|
||||
db := GetConnection()
|
||||
@@ -107,6 +122,8 @@ func GetJobRepository() *JobRepository {
|
||||
return jobRepoInstance
|
||||
}
|
||||
|
||||
// jobColumns defines the standard set of columns selected from the job table.
|
||||
// Used consistently across all job queries to ensure uniform data retrieval.
|
||||
var jobColumns []string = []string{
|
||||
"job.id", "job.job_id", "job.hpc_user", "job.project", "job.cluster", "job.subcluster",
|
||||
"job.start_time", "job.cluster_partition", "job.array_job_id", "job.num_nodes",
|
||||
@@ -115,6 +132,8 @@ var jobColumns []string = []string{
|
||||
"job.footprint", "job.energy",
|
||||
}
|
||||
|
||||
// jobCacheColumns defines columns from the job_cache table, mirroring jobColumns.
|
||||
// Used for queries against cached job data for performance optimization.
|
||||
var jobCacheColumns []string = []string{
|
||||
"job_cache.id", "job_cache.job_id", "job_cache.hpc_user", "job_cache.project", "job_cache.cluster",
|
||||
"job_cache.subcluster", "job_cache.start_time", "job_cache.cluster_partition",
|
||||
@@ -124,6 +143,14 @@ var jobCacheColumns []string = []string{
|
||||
"job_cache.footprint", "job_cache.energy",
|
||||
}
|
||||
|
||||
// scanJob converts a database row into a schema.Job struct.
|
||||
// It handles JSON unmarshaling of resources and footprint fields,
|
||||
// and calculates accurate duration for running jobs.
|
||||
//
|
||||
// Parameters:
|
||||
// - row: Database row implementing Scan() interface (sql.Row or sql.Rows)
|
||||
//
|
||||
// Returns the populated Job struct or an error if scanning or unmarshaling fails.
|
||||
func scanJob(row interface{ Scan(...any) error }) (*schema.Job, error) {
|
||||
job := &schema.Job{}
|
||||
|
||||
@@ -186,6 +213,16 @@ func (r *JobRepository) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchMetadata retrieves and unmarshals the metadata JSON for a job.
|
||||
// Metadata is cached with a 24-hour TTL to improve performance.
|
||||
//
|
||||
// The metadata field stores arbitrary key-value pairs associated with a job,
|
||||
// such as tags, labels, or custom attributes added by external systems.
|
||||
//
|
||||
// Parameters:
|
||||
// - job: Job struct with valid ID field, metadata will be populated in job.MetaData
|
||||
//
|
||||
// Returns the metadata map or an error if the job is nil or database query fails.
|
||||
func (r *JobRepository) FetchMetadata(job *schema.Job) (map[string]string, error) {
|
||||
if job == nil {
|
||||
return nil, fmt.Errorf("job cannot be nil")
|
||||
@@ -218,6 +255,16 @@ func (r *JobRepository) FetchMetadata(job *schema.Job) (map[string]string, error
|
||||
return job.MetaData, nil
|
||||
}
|
||||
|
||||
// UpdateMetadata adds or updates a single metadata key-value pair for a job.
|
||||
// The entire metadata map is re-marshaled and stored, and the cache is invalidated.
|
||||
// Also triggers archive metadata update via archive.UpdateMetadata.
|
||||
//
|
||||
// Parameters:
|
||||
// - job: Job struct with valid ID, existing metadata will be fetched if not present
|
||||
// - key: Metadata key to set
|
||||
// - val: Metadata value to set
|
||||
//
|
||||
// Returns an error if the job is nil, metadata fetch fails, or database update fails.
|
||||
func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err error) {
|
||||
if job == nil {
|
||||
return fmt.Errorf("job cannot be nil")
|
||||
@@ -228,7 +275,7 @@ func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err er
|
||||
if job.MetaData == nil {
|
||||
if _, err = r.FetchMetadata(job); err != nil {
|
||||
cclog.Warnf("Error while fetching metadata for job, DB ID '%v'", job.ID)
|
||||
return err
|
||||
return fmt.Errorf("failed to fetch metadata for job %d: %w", job.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +290,7 @@ func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err er
|
||||
|
||||
if job.RawMetaData, err = json.Marshal(job.MetaData); err != nil {
|
||||
cclog.Warnf("Error while marshaling metadata for job, DB ID '%v'", job.ID)
|
||||
return err
|
||||
return fmt.Errorf("failed to marshal metadata for job %d: %w", job.ID, err)
|
||||
}
|
||||
|
||||
if _, err = sq.Update("job").
|
||||
@@ -251,13 +298,23 @@ func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err er
|
||||
Where("job.id = ?", job.ID).
|
||||
RunWith(r.stmtCache).Exec(); err != nil {
|
||||
cclog.Warnf("Error while updating metadata for job, DB ID '%v'", job.ID)
|
||||
return err
|
||||
return fmt.Errorf("failed to update metadata in database for job %d: %w", job.ID, err)
|
||||
}
|
||||
|
||||
r.cache.Put(cachekey, job.MetaData, len(job.RawMetaData), 24*time.Hour)
|
||||
return archive.UpdateMetadata(job, job.MetaData)
|
||||
}
|
||||
|
||||
// FetchFootprint retrieves and unmarshals the performance footprint JSON for a job.
|
||||
// Unlike FetchMetadata, footprints are NOT cached as they can be large and change frequently.
|
||||
//
|
||||
// The footprint contains summary statistics (avg/min/max) for monitored metrics,
|
||||
// stored as JSON with keys like "cpu_load_avg", "mem_used_max", etc.
|
||||
//
|
||||
// Parameters:
|
||||
// - job: Job struct with valid ID, footprint will be populated in job.Footprint
|
||||
//
|
||||
// Returns the footprint map or an error if the job is nil or database query fails.
|
||||
func (r *JobRepository) FetchFootprint(job *schema.Job) (map[string]float64, error) {
|
||||
if job == nil {
|
||||
return nil, fmt.Errorf("job cannot be nil")
|
||||
@@ -284,6 +341,16 @@ func (r *JobRepository) FetchFootprint(job *schema.Job) (map[string]float64, err
|
||||
return job.Footprint, nil
|
||||
}
|
||||
|
||||
// FetchEnergyFootprint retrieves and unmarshals the energy footprint JSON for a job.
|
||||
// Energy footprints are cached with a 24-hour TTL as they are frequently accessed but rarely change.
|
||||
//
|
||||
// The energy footprint contains calculated energy consumption (in kWh) per metric,
|
||||
// stored as JSON with keys like "power_avg", "acc_power_avg", etc.
|
||||
//
|
||||
// Parameters:
|
||||
// - job: Job struct with valid ID, energy footprint will be populated in job.EnergyFootprint
|
||||
//
|
||||
// Returns the energy footprint map or an error if the job is nil or database query fails.
|
||||
func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float64, error) {
|
||||
if job == nil {
|
||||
return nil, fmt.Errorf("job cannot be nil")
|
||||
@@ -316,6 +383,18 @@ func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float6
|
||||
return job.EnergyFootprint, nil
|
||||
}
|
||||
|
||||
// DeleteJobsBefore removes jobs older than the specified start time.
|
||||
// Optionally preserves tagged jobs to protect important data from deletion.
|
||||
// Cache entries for deleted jobs are automatically invalidated.
|
||||
//
|
||||
// This is typically used for data retention policies and cleanup operations.
|
||||
// WARNING: This is a destructive operation that permanently deletes job records.
|
||||
//
|
||||
// Parameters:
|
||||
// - startTime: Unix timestamp, jobs with start_time < this value will be deleted
|
||||
// - omitTagged: If true, skip jobs that have associated tags (jobtag entries)
|
||||
//
|
||||
// Returns the count of deleted jobs or an error if the operation fails.
|
||||
func (r *JobRepository) DeleteJobsBefore(startTime int64, omitTagged bool) (int, error) {
|
||||
var cnt int
|
||||
q := sq.Select("count(*)").From("job").Where("job.start_time < ?", startTime)
|
||||
@@ -371,6 +450,13 @@ func (r *JobRepository) DeleteJobsBefore(startTime int64, omitTagged bool) (int,
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
// DeleteJobByID permanently removes a single job by its database ID.
|
||||
// Cache entries for the deleted job are automatically invalidated.
|
||||
//
|
||||
// Parameters:
|
||||
// - id: Database ID (primary key) of the job to delete
|
||||
//
|
||||
// Returns an error if the deletion fails.
|
||||
func (r *JobRepository) DeleteJobByID(id int64) error {
|
||||
// Invalidate cache entries before deletion
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", id))
|
||||
@@ -388,6 +474,24 @@ func (r *JobRepository) DeleteJobByID(id int64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// FindUserOrProjectOrJobname attempts to interpret a search term as a job ID,
|
||||
// username, project ID, or job name by querying the database.
|
||||
//
|
||||
// Search logic (in priority order):
|
||||
// 1. If searchterm is numeric, treat as job ID (returned immediately)
|
||||
// 2. Try exact match in job.hpc_user column (username)
|
||||
// 3. Try LIKE match in hpc_user.name column (real name)
|
||||
// 4. Try exact match in job.project column (project ID)
|
||||
// 5. If no matches, return searchterm as jobname for GraphQL query
|
||||
//
|
||||
// This powers the searchbar functionality for flexible job searching.
|
||||
// Requires authenticated user for database lookups (returns empty if user is nil).
|
||||
//
|
||||
// Parameters:
|
||||
// - user: Authenticated user context, required for database access
|
||||
// - searchterm: Search string to interpret
|
||||
//
|
||||
// Returns up to one non-empty value among (jobid, username, project, jobname).
|
||||
func (r *JobRepository) FindUserOrProjectOrJobname(user *schema.User, searchterm string) (jobid string, username string, project string, jobname string) {
|
||||
if searchterm == "" {
|
||||
return "", "", "", ""
|
||||
@@ -423,6 +527,19 @@ var (
|
||||
ErrForbidden = errors.New("not authorized")
|
||||
)
|
||||
|
||||
// FindColumnValue performs a generic column lookup in a database table with role-based access control.
|
||||
// Only users with admin, support, or manager roles can execute this query.
|
||||
//
|
||||
// Parameters:
|
||||
// - user: User context for authorization check
|
||||
// - searchterm: Value to search for (exact match or LIKE pattern)
|
||||
// - table: Database table name to query
|
||||
// - selectColumn: Column name to return in results
|
||||
// - whereColumn: Column name to filter on
|
||||
// - isLike: If true, use LIKE with wildcards; if false, use exact equality
|
||||
//
|
||||
// Returns the first matching value, ErrForbidden if user lacks permission,
|
||||
// or ErrNotFound if no matches are found.
|
||||
func (r *JobRepository) FindColumnValue(user *schema.User, searchterm string, table string, selectColumn string, whereColumn string, isLike bool) (result string, err error) {
|
||||
if user == nil {
|
||||
return "", fmt.Errorf("user cannot be nil")
|
||||
@@ -453,6 +570,19 @@ func (r *JobRepository) FindColumnValue(user *schema.User, searchterm string, ta
|
||||
}
|
||||
}
|
||||
|
||||
// FindColumnValues performs a generic column lookup returning multiple matches with role-based access control.
|
||||
// Similar to FindColumnValue but returns all matching values instead of just the first.
|
||||
// Only users with admin, support, or manager roles can execute this query.
|
||||
//
|
||||
// Parameters:
|
||||
// - user: User context for authorization check
|
||||
// - query: Search pattern (always uses LIKE with wildcards)
|
||||
// - table: Database table name to query
|
||||
// - selectColumn: Column name to return in results
|
||||
// - whereColumn: Column name to filter on
|
||||
//
|
||||
// Returns a slice of matching values, ErrForbidden if user lacks permission,
|
||||
// or ErrNotFound if no matches are found.
|
||||
func (r *JobRepository) FindColumnValues(user *schema.User, query string, table string, selectColumn string, whereColumn string) (results []string, err error) {
|
||||
if user == nil {
|
||||
return nil, fmt.Errorf("user cannot be nil")
|
||||
@@ -487,6 +617,13 @@ func (r *JobRepository) FindColumnValues(user *schema.User, query string, table
|
||||
}
|
||||
}
|
||||
|
||||
// Partitions returns a list of distinct cluster partitions for a given cluster.
|
||||
// Results are cached with a 1-hour TTL to improve performance.
|
||||
//
|
||||
// Parameters:
|
||||
// - cluster: Cluster name to query partitions for
|
||||
//
|
||||
// Returns a slice of partition names or an error if the database query fails.
|
||||
func (r *JobRepository) Partitions(cluster string) ([]string, error) {
|
||||
var err error
|
||||
start := time.Now()
|
||||
@@ -550,6 +687,19 @@ func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]in
|
||||
}
|
||||
|
||||
// FIXME: Set duration to requested walltime?
|
||||
// StopJobsExceedingWalltimeBy marks running jobs as failed if they exceed their walltime limit.
|
||||
// This is typically called periodically to clean up stuck or orphaned jobs.
|
||||
//
|
||||
// Jobs are marked with:
|
||||
// - monitoring_status: MonitoringStatusArchivingFailed
|
||||
// - duration: 0
|
||||
// - job_state: JobStateFailed
|
||||
//
|
||||
// Parameters:
|
||||
// - seconds: Grace period beyond walltime before marking as failed
|
||||
//
|
||||
// Returns an error if the database update fails.
|
||||
// Logs the number of jobs marked as failed if any were affected.
|
||||
func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error {
|
||||
start := time.Now()
|
||||
currentTime := time.Now().Unix()
|
||||
@@ -579,6 +729,12 @@ func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindJobIdsByTag returns all job database IDs associated with a specific tag.
|
||||
//
|
||||
// Parameters:
|
||||
// - tagID: Database ID of the tag to search for
|
||||
//
|
||||
// Returns a slice of job IDs or an error if the query fails.
|
||||
func (r *JobRepository) FindJobIdsByTag(tagID int64) ([]int64, error) {
|
||||
query := sq.Select("job.id").From("job").
|
||||
Join("jobtag ON jobtag.job_id = job.id").
|
||||
@@ -607,6 +763,13 @@ func (r *JobRepository) FindJobIdsByTag(tagID int64) ([]int64, error) {
|
||||
}
|
||||
|
||||
// FIXME: Reconsider filtering short jobs with harcoded threshold
|
||||
// FindRunningJobs returns all currently running jobs for a specific cluster.
|
||||
// Filters out short-running jobs based on repoConfig.MinRunningJobDuration threshold.
|
||||
//
|
||||
// Parameters:
|
||||
// - cluster: Cluster name to filter jobs
|
||||
//
|
||||
// Returns a slice of running job objects or an error if the query fails.
|
||||
func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) {
|
||||
query := sq.Select(jobColumns...).From("job").
|
||||
Where("job.cluster = ?", cluster).
|
||||
@@ -634,6 +797,12 @@ func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) {
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// UpdateDuration recalculates and updates the duration field for all running jobs.
|
||||
// Called periodically to keep job durations current without querying individual jobs.
|
||||
//
|
||||
// Duration is calculated as: current_time - job.start_time
|
||||
//
|
||||
// Returns an error if the database update fails.
|
||||
func (r *JobRepository) UpdateDuration() error {
|
||||
stmnt := sq.Update("job").
|
||||
Set("duration", sq.Expr("? - job.start_time", time.Now().Unix())).
|
||||
@@ -648,6 +817,16 @@ func (r *JobRepository) UpdateDuration() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindJobsBetween returns jobs within a specified time range.
|
||||
// If startTimeBegin is 0, returns all jobs before startTimeEnd.
|
||||
// Optionally excludes tagged jobs from results.
|
||||
//
|
||||
// Parameters:
|
||||
// - startTimeBegin: Unix timestamp for range start (use 0 for unbounded start)
|
||||
// - startTimeEnd: Unix timestamp for range end
|
||||
// - omitTagged: If true, exclude jobs with associated tags
|
||||
//
|
||||
// Returns a slice of jobs or an error if the time range is invalid or query fails.
|
||||
func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64, omitTagged bool) ([]*schema.Job, error) {
|
||||
var query sq.SelectBuilder
|
||||
|
||||
@@ -688,6 +867,14 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// UpdateMonitoringStatus updates the monitoring status for a job and invalidates its cache entries.
|
||||
// Cache invalidation affects both metadata and energy footprint to ensure consistency.
|
||||
//
|
||||
// Parameters:
|
||||
// - job: Database ID of the job to update
|
||||
// - monitoringStatus: New monitoring status value (see schema.MonitoringStatus constants)
|
||||
//
|
||||
// Returns an error if the database update fails.
|
||||
func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) {
|
||||
// Invalidate cache entries as monitoring status affects job state
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", job))
|
||||
@@ -704,6 +891,13 @@ func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute runs a Squirrel UpdateBuilder statement against the database.
|
||||
// This is a generic helper for executing pre-built update queries.
|
||||
//
|
||||
// Parameters:
|
||||
// - stmt: Squirrel UpdateBuilder with prepared update query
|
||||
//
|
||||
// Returns an error if the execution fails.
|
||||
func (r *JobRepository) Execute(stmt sq.UpdateBuilder) error {
|
||||
if _, err := stmt.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
cclog.Errorf("Error while executing statement: %v", err)
|
||||
@@ -713,6 +907,14 @@ func (r *JobRepository) Execute(stmt sq.UpdateBuilder) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkArchived adds monitoring status update to an existing UpdateBuilder statement.
|
||||
// This is a builder helper used when constructing multi-field update queries.
|
||||
//
|
||||
// Parameters:
|
||||
// - stmt: Existing UpdateBuilder to modify
|
||||
// - monitoringStatus: Monitoring status value to set
|
||||
//
|
||||
// Returns the modified UpdateBuilder for method chaining.
|
||||
func (r *JobRepository) MarkArchived(
|
||||
stmt sq.UpdateBuilder,
|
||||
monitoringStatus int32,
|
||||
@@ -720,11 +922,22 @@ func (r *JobRepository) MarkArchived(
|
||||
return stmt.Set("monitoring_status", monitoringStatus)
|
||||
}
|
||||
|
||||
// UpdateEnergy calculates and updates the energy consumption for a job.
|
||||
// This is called for running jobs during intermediate updates or when archiving.
|
||||
//
|
||||
// Energy calculation formula:
|
||||
// - For "power" metrics: Energy (kWh) = (Power_avg * NumNodes * Duration_hours) / 1000
|
||||
// - For "energy" metrics: Currently not implemented (would need sum statistics)
|
||||
//
|
||||
// The calculation accounts for:
|
||||
// - Multi-node jobs: Multiplies by NumNodes to get total cluster energy
|
||||
// - Shared jobs: Node average is already based on partial resources, so NumNodes=1
|
||||
// - Unit conversion: Watts * hours / 1000 = kilowatt-hours (kWh)
|
||||
// - Rounding: Results rounded to 2 decimal places
|
||||
func (r *JobRepository) UpdateEnergy(
|
||||
stmt sq.UpdateBuilder,
|
||||
jobMeta *schema.Job,
|
||||
) (sq.UpdateBuilder, error) {
|
||||
/* Note: Only Called for Running Jobs during Intermediate Update or on Archiving */
|
||||
sc, err := archive.GetSubCluster(jobMeta.Cluster, jobMeta.SubCluster)
|
||||
if err != nil {
|
||||
cclog.Errorf("cannot get subcluster: %s", err.Error())
|
||||
@@ -732,25 +945,27 @@ func (r *JobRepository) UpdateEnergy(
|
||||
}
|
||||
energyFootprint := make(map[string]float64)
|
||||
|
||||
// Total Job Energy Outside Loop
|
||||
// Accumulate total energy across all energy-related metrics
|
||||
totalEnergy := 0.0
|
||||
for _, fp := range sc.EnergyFootprint {
|
||||
// Always Init Metric Energy Inside Loop
|
||||
// Calculate energy for this specific metric
|
||||
metricEnergy := 0.0
|
||||
if i, err := archive.MetricIndex(sc.MetricConfig, fp); err == nil {
|
||||
// Note: For DB data, calculate and save as kWh
|
||||
switch sc.MetricConfig[i].Energy {
|
||||
case "energy": // this metric has energy as unit (Joules or Wh)
|
||||
case "energy": // Metric already in energy units (Joules or Wh)
|
||||
cclog.Warnf("Update EnergyFootprint for Job %d and Metric %s on cluster %s: Set to 'energy' in cluster.json: Not implemented, will return 0.0", jobMeta.JobID, jobMeta.Cluster, fp)
|
||||
// FIXME: Needs sum as stats type
|
||||
case "power": // this metric has power as unit (Watt)
|
||||
// Energy: Power (in Watts) * Time (in Seconds)
|
||||
// Unit: (W * (s / 3600)) / 1000 = kWh
|
||||
// Round 2 Digits: round(Energy * 100) / 100
|
||||
// Here: (All-Node Metric Average * Number of Nodes) * (Job Duration in Seconds / 3600) / 1000
|
||||
// Note: Shared Jobs handled correctly since "Node Average" is based on partial resources, while "numNodes" factor is 1
|
||||
// FIXME: Needs sum as stats type to accumulate energy values over time
|
||||
case "power": // Metric in power units (Watts)
|
||||
// Energy (kWh) = Power (W) × Time (h) / 1000
|
||||
// Formula: (avg_power_per_node * num_nodes) * (duration_sec / 3600) / 1000
|
||||
//
|
||||
// Breakdown:
|
||||
// LoadJobStat(jobMeta, fp, "avg") = average power per node (W)
|
||||
// jobMeta.NumNodes = number of nodes (1 for shared jobs)
|
||||
// jobMeta.Duration / 3600.0 = duration in hours
|
||||
// / 1000.0 = convert Wh to kWh
|
||||
rawEnergy := ((LoadJobStat(jobMeta, fp, "avg") * float64(jobMeta.NumNodes)) * (float64(jobMeta.Duration) / 3600.0)) / 1000.0
|
||||
metricEnergy = math.Round(rawEnergy*100.0) / 100.0
|
||||
metricEnergy = math.Round(rawEnergy*100.0) / 100.0 // Round to 2 decimal places
|
||||
}
|
||||
} else {
|
||||
cclog.Warnf("Error while collecting energy metric %s for job, DB ID '%v', return '0.0'", fp, jobMeta.ID)
|
||||
@@ -758,8 +973,6 @@ func (r *JobRepository) UpdateEnergy(
|
||||
|
||||
energyFootprint[fp] = metricEnergy
|
||||
totalEnergy += metricEnergy
|
||||
|
||||
// cclog.Infof("Metric %s Average %f -> %f kWh | Job %d Total -> %f kWh", fp, LoadJobStat(jobMeta, fp, "avg"), energy, jobMeta.JobID, totalEnergy)
|
||||
}
|
||||
|
||||
var rawFootprint []byte
|
||||
@@ -771,11 +984,19 @@ func (r *JobRepository) UpdateEnergy(
|
||||
return stmt.Set("energy_footprint", string(rawFootprint)).Set("energy", (math.Round(totalEnergy*100.0) / 100.0)), nil
|
||||
}
|
||||
|
||||
// UpdateFootprint calculates and updates the performance footprint for a job.
|
||||
// This is called for running jobs during intermediate updates or when archiving.
|
||||
//
|
||||
// A footprint is a summary statistic (avg/min/max) for each monitored metric.
|
||||
// The specific statistic type is defined in the cluster config's Footprint field.
|
||||
// Results are stored as JSON with keys like "metric_avg", "metric_max", etc.
|
||||
//
|
||||
// Example: For a "cpu_load" metric with Footprint="avg", this stores
|
||||
// the average CPU load across all nodes as "cpu_load_avg": 85.3
|
||||
func (r *JobRepository) UpdateFootprint(
|
||||
stmt sq.UpdateBuilder,
|
||||
jobMeta *schema.Job,
|
||||
) (sq.UpdateBuilder, error) {
|
||||
/* Note: Only Called for Running Jobs during Intermediate Update or on Archiving */
|
||||
sc, err := archive.GetSubCluster(jobMeta.Cluster, jobMeta.SubCluster)
|
||||
if err != nil {
|
||||
cclog.Errorf("cannot get subcluster: %s", err.Error())
|
||||
@@ -783,7 +1004,10 @@ func (r *JobRepository) UpdateFootprint(
|
||||
}
|
||||
footprint := make(map[string]float64)
|
||||
|
||||
// Build footprint map with metric_stattype as keys
|
||||
for _, fp := range sc.Footprint {
|
||||
// Determine which statistic to use: avg, min, or max
|
||||
// First check global metric config, then cluster-specific config
|
||||
var statType string
|
||||
for _, gm := range archive.GlobalMetricList {
|
||||
if gm.Name == fp {
|
||||
@@ -791,15 +1015,18 @@ func (r *JobRepository) UpdateFootprint(
|
||||
}
|
||||
}
|
||||
|
||||
// Validate statistic type
|
||||
if statType != "avg" && statType != "min" && statType != "max" {
|
||||
cclog.Warnf("unknown statType for footprint update: %s", statType)
|
||||
return stmt, fmt.Errorf("unknown statType for footprint update: %s", statType)
|
||||
}
|
||||
|
||||
// Override with cluster-specific config if available
|
||||
if i, err := archive.MetricIndex(sc.MetricConfig, fp); err != nil {
|
||||
statType = sc.MetricConfig[i].Footprint
|
||||
}
|
||||
|
||||
// Store as "metric_stattype": value (e.g., "cpu_load_avg": 85.3)
|
||||
name := fmt.Sprintf("%s_%s", fp, statType)
|
||||
footprint[name] = LoadJobStat(jobMeta, fp, statType)
|
||||
}
|
||||
|
||||
500
internal/repository/jobCreate_test.go
Normal file
500
internal/repository/jobCreate_test.go
Normal file
@@ -0,0 +1,500 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// createTestJob creates a minimal valid job for testing
|
||||
func createTestJob(jobID int64, cluster string) *schema.Job {
|
||||
return &schema.Job{
|
||||
JobID: jobID,
|
||||
User: "testuser",
|
||||
Project: "testproject",
|
||||
Cluster: cluster,
|
||||
SubCluster: "main",
|
||||
Partition: "batch",
|
||||
NumNodes: 1,
|
||||
NumHWThreads: 4,
|
||||
NumAcc: 0,
|
||||
Shared: "none",
|
||||
MonitoringStatus: schema.MonitoringStatusRunningOrArchiving,
|
||||
SMT: 1,
|
||||
State: schema.JobStateRunning,
|
||||
StartTime: 1234567890,
|
||||
Duration: 0,
|
||||
Walltime: 3600,
|
||||
Resources: []*schema.Resource{
|
||||
{
|
||||
Hostname: "node01",
|
||||
HWThreads: []int{0, 1, 2, 3},
|
||||
},
|
||||
},
|
||||
Footprint: map[string]float64{
|
||||
"cpu_load": 50.0,
|
||||
"mem_used": 8000.0,
|
||||
"flops_any": 0.5,
|
||||
"mem_bw": 10.0,
|
||||
"net_bw": 2.0,
|
||||
"file_bw": 1.0,
|
||||
"cpu_used": 2.0,
|
||||
"cpu_load_core": 12.5,
|
||||
},
|
||||
MetaData: map[string]string{
|
||||
"jobName": "test_job",
|
||||
"queue": "normal",
|
||||
"qosName": "default",
|
||||
"accountName": "testaccount",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertJob(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("successful insertion", func(t *testing.T) {
|
||||
job := createTestJob(999001, "testcluster")
|
||||
job.RawResources, _ = json.Marshal(job.Resources)
|
||||
job.RawFootprint, _ = json.Marshal(job.Footprint)
|
||||
job.RawMetaData, _ = json.Marshal(job.MetaData)
|
||||
|
||||
id, err := r.InsertJob(job)
|
||||
require.NoError(t, err, "InsertJob should succeed")
|
||||
assert.Greater(t, id, int64(0), "Should return valid insert ID")
|
||||
|
||||
// Verify job was inserted into job_cache
|
||||
var count int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM job_cache WHERE job_id = ? AND cluster = ?",
|
||||
job.JobID, job.Cluster).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Job should be in job_cache table")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE job_id = ? AND cluster = ?", job.JobID, job.Cluster)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("insertion with all fields", func(t *testing.T) {
|
||||
job := createTestJob(999002, "testcluster")
|
||||
job.ArrayJobID = 5000
|
||||
job.Energy = 1500.5
|
||||
job.RawResources, _ = json.Marshal(job.Resources)
|
||||
job.RawFootprint, _ = json.Marshal(job.Footprint)
|
||||
job.RawMetaData, _ = json.Marshal(job.MetaData)
|
||||
|
||||
id, err := r.InsertJob(job)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Verify all fields were stored correctly
|
||||
var retrievedJob schema.Job
|
||||
err = r.DB.QueryRow(`SELECT job_id, hpc_user, project, cluster, array_job_id, energy
|
||||
FROM job_cache WHERE id = ?`, id).Scan(
|
||||
&retrievedJob.JobID, &retrievedJob.User, &retrievedJob.Project,
|
||||
&retrievedJob.Cluster, &retrievedJob.ArrayJobID, &retrievedJob.Energy)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, job.JobID, retrievedJob.JobID)
|
||||
assert.Equal(t, job.User, retrievedJob.User)
|
||||
assert.Equal(t, job.Project, retrievedJob.Project)
|
||||
assert.Equal(t, job.Cluster, retrievedJob.Cluster)
|
||||
assert.Equal(t, job.ArrayJobID, retrievedJob.ArrayJobID)
|
||||
assert.Equal(t, job.Energy, retrievedJob.Energy)
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("successful job start with JSON encoding", func(t *testing.T) {
|
||||
job := createTestJob(999003, "testcluster")
|
||||
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err, "Start should succeed")
|
||||
assert.Greater(t, id, int64(0), "Should return valid insert ID")
|
||||
|
||||
// Verify job was inserted and JSON fields were encoded
|
||||
var rawResources, rawFootprint, rawMetaData []byte
|
||||
err = r.DB.QueryRow(`SELECT resources, footprint, meta_data FROM job_cache WHERE id = ?`, id).Scan(
|
||||
&rawResources, &rawFootprint, &rawMetaData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify resources JSON
|
||||
var resources []*schema.Resource
|
||||
err = json.Unmarshal(rawResources, &resources)
|
||||
require.NoError(t, err, "Resources should be valid JSON")
|
||||
assert.Equal(t, 1, len(resources))
|
||||
assert.Equal(t, "node01", resources[0].Hostname)
|
||||
|
||||
// Verify footprint JSON
|
||||
var footprint map[string]float64
|
||||
err = json.Unmarshal(rawFootprint, &footprint)
|
||||
require.NoError(t, err, "Footprint should be valid JSON")
|
||||
assert.Equal(t, 50.0, footprint["cpu_load"])
|
||||
assert.Equal(t, 8000.0, footprint["mem_used"])
|
||||
|
||||
// Verify metadata JSON
|
||||
var metaData map[string]string
|
||||
err = json.Unmarshal(rawMetaData, &metaData)
|
||||
require.NoError(t, err, "MetaData should be valid JSON")
|
||||
assert.Equal(t, "test_job", metaData["jobName"])
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("job start with empty footprint", func(t *testing.T) {
|
||||
job := createTestJob(999004, "testcluster")
|
||||
job.Footprint = map[string]float64{}
|
||||
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Verify empty footprint was encoded as empty JSON object
|
||||
var rawFootprint []byte
|
||||
err = r.DB.QueryRow(`SELECT footprint FROM job_cache WHERE id = ?`, id).Scan(&rawFootprint)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("{}"), rawFootprint)
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("job start with nil metadata", func(t *testing.T) {
|
||||
job := createTestJob(999005, "testcluster")
|
||||
job.MetaData = nil
|
||||
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("successful job stop", func(t *testing.T) {
|
||||
// First insert a job using Start
|
||||
job := createTestJob(999106, "testcluster")
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Move from job_cache to job table (simulate SyncJobs) - exclude id to let it auto-increment
|
||||
_, err = r.DB.Exec(`INSERT INTO job (job_id, cluster, subcluster, submit_time, start_time, hpc_user, project,
|
||||
cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes,
|
||||
num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint)
|
||||
SELECT job_id, cluster, subcluster, submit_time, start_time, hpc_user, project,
|
||||
cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes,
|
||||
num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint
|
||||
FROM job_cache WHERE id = ?`, id)
|
||||
require.NoError(t, err)
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the new job id in the job table
|
||||
err = r.DB.QueryRow("SELECT id FROM job WHERE job_id = ? AND cluster = ? AND start_time = ?",
|
||||
job.JobID, job.Cluster, job.StartTime).Scan(&id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the job
|
||||
duration := int32(3600)
|
||||
state := schema.JobStateCompleted
|
||||
monitoringStatus := int32(schema.MonitoringStatusArchivingSuccessful)
|
||||
|
||||
err = r.Stop(id, duration, state, monitoringStatus)
|
||||
require.NoError(t, err, "Stop should succeed")
|
||||
|
||||
// Verify job was updated
|
||||
var retrievedDuration int32
|
||||
var retrievedState string
|
||||
var retrievedMonStatus int32
|
||||
err = r.DB.QueryRow(`SELECT duration, job_state, monitoring_status FROM job WHERE id = ?`, id).Scan(
|
||||
&retrievedDuration, &retrievedState, &retrievedMonStatus)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, duration, retrievedDuration)
|
||||
assert.Equal(t, string(state), retrievedState)
|
||||
assert.Equal(t, monitoringStatus, retrievedMonStatus)
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("stop updates job state transitions", func(t *testing.T) {
|
||||
// Insert a job
|
||||
job := createTestJob(999107, "testcluster")
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Move to job table
|
||||
_, err = r.DB.Exec(`INSERT INTO job (job_id, cluster, subcluster, submit_time, start_time, hpc_user, project,
|
||||
cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes,
|
||||
num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint)
|
||||
SELECT job_id, cluster, subcluster, submit_time, start_time, hpc_user, project,
|
||||
cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes,
|
||||
num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint
|
||||
FROM job_cache WHERE id = ?`, id)
|
||||
require.NoError(t, err)
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the new job id in the job table
|
||||
err = r.DB.QueryRow("SELECT id FROM job WHERE job_id = ? AND cluster = ? AND start_time = ?",
|
||||
job.JobID, job.Cluster, job.StartTime).Scan(&id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the job with different duration
|
||||
err = r.Stop(id, 7200, schema.JobStateCompleted, int32(schema.MonitoringStatusArchivingSuccessful))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the duration was updated correctly
|
||||
var duration int32
|
||||
err = r.DB.QueryRow(`SELECT duration FROM job WHERE id = ?`, id).Scan(&duration)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(7200), duration, "Duration should be updated to 7200")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("stop with different states", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
jobID int64
|
||||
state schema.JobState
|
||||
monitoringStatus int32
|
||||
}{
|
||||
{"completed", 999108, schema.JobStateCompleted, int32(schema.MonitoringStatusArchivingSuccessful)},
|
||||
{"failed", 999118, schema.JobStateFailed, int32(schema.MonitoringStatusArchivingSuccessful)},
|
||||
{"cancelled", 999119, schema.JobStateCancelled, int32(schema.MonitoringStatusArchivingSuccessful)},
|
||||
{"timeout", 999120, schema.JobStateTimeout, int32(schema.MonitoringStatusArchivingSuccessful)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
job := createTestJob(tc.jobID, "testcluster")
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Move to job table
|
||||
_, err = r.DB.Exec(`INSERT INTO job (job_id, cluster, subcluster, submit_time, start_time, hpc_user, project,
|
||||
cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes,
|
||||
num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint)
|
||||
SELECT job_id, cluster, subcluster, submit_time, start_time, hpc_user, project,
|
||||
cluster_partition, array_job_id, duration, walltime, job_state, meta_data, resources, num_nodes,
|
||||
num_hwthreads, num_acc, smt, shared, monitoring_status, energy, energy_footprint, footprint
|
||||
FROM job_cache WHERE id = ?`, id)
|
||||
require.NoError(t, err)
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the new job id in the job table
|
||||
err = r.DB.QueryRow("SELECT id FROM job WHERE job_id = ? AND cluster = ? AND start_time = ?",
|
||||
job.JobID, job.Cluster, job.StartTime).Scan(&id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop with specific state
|
||||
err = r.Stop(id, 1800, tc.state, tc.monitoringStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify state was set correctly
|
||||
var retrievedState string
|
||||
err = r.DB.QueryRow(`SELECT job_state FROM job WHERE id = ?`, id).Scan(&retrievedState)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(tc.state), retrievedState)
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopCached(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("successful stop cached job", func(t *testing.T) {
|
||||
// Insert a job in job_cache
|
||||
job := createTestJob(999009, "testcluster")
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the cached job
|
||||
duration := int32(3600)
|
||||
state := schema.JobStateCompleted
|
||||
monitoringStatus := int32(schema.MonitoringStatusArchivingSuccessful)
|
||||
|
||||
err = r.StopCached(id, duration, state, monitoringStatus)
|
||||
require.NoError(t, err, "StopCached should succeed")
|
||||
|
||||
// Verify job was updated in job_cache table
|
||||
var retrievedDuration int32
|
||||
var retrievedState string
|
||||
var retrievedMonStatus int32
|
||||
err = r.DB.QueryRow(`SELECT duration, job_state, monitoring_status FROM job_cache WHERE id = ?`, id).Scan(
|
||||
&retrievedDuration, &retrievedState, &retrievedMonStatus)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, duration, retrievedDuration)
|
||||
assert.Equal(t, string(state), retrievedState)
|
||||
assert.Equal(t, monitoringStatus, retrievedMonStatus)
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("stop cached job does not affect job table", func(t *testing.T) {
|
||||
// Insert a job in job_cache
|
||||
job := createTestJob(999010, "testcluster")
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the cached job
|
||||
err = r.StopCached(id, 3600, schema.JobStateCompleted, int32(schema.MonitoringStatusArchivingSuccessful))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify job table was not affected
|
||||
var count int
|
||||
err = r.DB.QueryRow(`SELECT COUNT(*) FROM job WHERE job_id = ? AND cluster = ?`,
|
||||
job.JobID, job.Cluster).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count, "Job table should not be affected by StopCached")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSyncJobs(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("sync jobs from cache to main table", func(t *testing.T) {
|
||||
// Ensure cache is empty first
|
||||
_, err := r.DB.Exec("DELETE FROM job_cache")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert multiple jobs in job_cache
|
||||
job1 := createTestJob(999011, "testcluster")
|
||||
job2 := createTestJob(999012, "testcluster")
|
||||
job3 := createTestJob(999013, "testcluster")
|
||||
|
||||
_, err = r.Start(job1)
|
||||
require.NoError(t, err)
|
||||
_, err = r.Start(job2)
|
||||
require.NoError(t, err)
|
||||
_, err = r.Start(job3)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify jobs are in job_cache
|
||||
var cacheCount int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM job_cache WHERE job_id IN (?, ?, ?)",
|
||||
job1.JobID, job2.JobID, job3.JobID).Scan(&cacheCount)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, cacheCount, "All jobs should be in job_cache")
|
||||
|
||||
// Sync jobs
|
||||
jobs, err := r.SyncJobs()
|
||||
require.NoError(t, err, "SyncJobs should succeed")
|
||||
assert.Equal(t, 3, len(jobs), "Should return 3 synced jobs")
|
||||
|
||||
// Verify jobs were moved to job table
|
||||
var jobCount int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM job WHERE job_id IN (?, ?, ?)",
|
||||
job1.JobID, job2.JobID, job3.JobID).Scan(&jobCount)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, jobCount, "All jobs should be in job table")
|
||||
|
||||
// Verify job_cache was cleared
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM job_cache WHERE job_id IN (?, ?, ?)",
|
||||
job1.JobID, job2.JobID, job3.JobID).Scan(&cacheCount)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, cacheCount, "job_cache should be empty after sync")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job WHERE job_id IN (?, ?, ?)", job1.JobID, job2.JobID, job3.JobID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("sync preserves job data", func(t *testing.T) {
|
||||
// Ensure cache is empty first
|
||||
_, err := r.DB.Exec("DELETE FROM job_cache")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a job with specific data
|
||||
job := createTestJob(999014, "testcluster")
|
||||
job.ArrayJobID = 7777
|
||||
job.Energy = 2500.75
|
||||
job.Duration = 1800
|
||||
|
||||
id, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update some fields to simulate job progress
|
||||
result, err := r.DB.Exec(`UPDATE job_cache SET duration = ?, energy = ? WHERE id = ?`,
|
||||
3600, 3000.5, id)
|
||||
require.NoError(t, err)
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
require.Equal(t, int64(1), rowsAffected, "UPDATE should affect exactly 1 row")
|
||||
|
||||
// Verify the update worked
|
||||
var checkDuration int32
|
||||
var checkEnergy float64
|
||||
err = r.DB.QueryRow(`SELECT duration, energy FROM job_cache WHERE id = ?`, id).Scan(&checkDuration, &checkEnergy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(3600), checkDuration, "Duration should be updated to 3600 before sync")
|
||||
require.Equal(t, 3000.5, checkEnergy, "Energy should be updated to 3000.5 before sync")
|
||||
|
||||
// Sync jobs
|
||||
jobs, err := r.SyncJobs()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(jobs), "Should return exactly 1 synced job")
|
||||
|
||||
// Verify in database
|
||||
var dbJob schema.Job
|
||||
err = r.DB.QueryRow(`SELECT job_id, hpc_user, project, cluster, array_job_id, duration, energy
|
||||
FROM job WHERE job_id = ? AND cluster = ?`, job.JobID, job.Cluster).Scan(
|
||||
&dbJob.JobID, &dbJob.User, &dbJob.Project, &dbJob.Cluster,
|
||||
&dbJob.ArrayJobID, &dbJob.Duration, &dbJob.Energy)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, job.JobID, dbJob.JobID)
|
||||
assert.Equal(t, int32(3600), dbJob.Duration)
|
||||
assert.Equal(t, 3000.5, dbJob.Energy)
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job WHERE job_id = ? AND cluster = ?", job.JobID, job.Cluster)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("sync with empty cache returns empty list", func(t *testing.T) {
|
||||
// Ensure cache is empty
|
||||
_, err := r.DB.Exec("DELETE FROM job_cache")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sync should return empty list
|
||||
jobs, err := r.SyncJobs()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(jobs), "Should return empty list when cache is empty")
|
||||
})
|
||||
}
|
||||
@@ -10,8 +10,36 @@ import (
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
)
|
||||
|
||||
// JobHook interface allows external components to hook into job lifecycle events.
|
||||
// Implementations can perform actions when jobs start or stop, such as tagging,
|
||||
// logging, notifications, or triggering external workflows.
|
||||
//
|
||||
// Example implementation:
|
||||
//
|
||||
// type MyJobTagger struct{}
|
||||
//
|
||||
// func (t *MyJobTagger) JobStartCallback(job *schema.Job) {
|
||||
// if job.NumNodes > 100 {
|
||||
// // Tag large jobs automatically
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func (t *MyJobTagger) JobStopCallback(job *schema.Job) {
|
||||
// if job.State == schema.JobStateFailed {
|
||||
// // Log or alert on failed jobs
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Register hooks during application initialization:
|
||||
//
|
||||
// repository.RegisterJobHook(&MyJobTagger{})
|
||||
type JobHook interface {
|
||||
// JobStartCallback is invoked when one or more jobs start.
|
||||
// This is called synchronously, so implementations should be fast.
|
||||
JobStartCallback(job *schema.Job)
|
||||
|
||||
// JobStopCallback is invoked when a job completes.
|
||||
// This is called synchronously, so implementations should be fast.
|
||||
JobStopCallback(job *schema.Job)
|
||||
}
|
||||
|
||||
@@ -20,7 +48,13 @@ var (
|
||||
hooks []JobHook
|
||||
)
|
||||
|
||||
func RegisterJobJook(hook JobHook) {
|
||||
// RegisterJobHook registers a JobHook to receive job lifecycle callbacks.
|
||||
// Multiple hooks can be registered and will be called in registration order.
|
||||
// This function is safe to call multiple times and is typically called during
|
||||
// application initialization.
|
||||
//
|
||||
// Nil hooks are silently ignored to simplify conditional registration.
|
||||
func RegisterJobHook(hook JobHook) {
|
||||
initOnce.Do(func() {
|
||||
hooks = make([]JobHook, 0)
|
||||
})
|
||||
@@ -30,6 +64,12 @@ func RegisterJobJook(hook JobHook) {
|
||||
}
|
||||
}
|
||||
|
||||
// CallJobStartHooks invokes all registered JobHook.JobStartCallback methods
|
||||
// for each job in the provided slice. This is called internally by the repository
|
||||
// when jobs are started (e.g., via StartJob or batch job imports).
|
||||
//
|
||||
// Hooks are called synchronously in registration order. If a hook panics,
|
||||
// the panic will propagate to the caller.
|
||||
func CallJobStartHooks(jobs []*schema.Job) {
|
||||
if hooks == nil {
|
||||
return
|
||||
@@ -44,6 +84,12 @@ func CallJobStartHooks(jobs []*schema.Job) {
|
||||
}
|
||||
}
|
||||
|
||||
// CallJobStopHooks invokes all registered JobHook.JobStopCallback methods
|
||||
// for the provided job. This is called internally by the repository when a
|
||||
// job completes (e.g., via StopJob or job state updates).
|
||||
//
|
||||
// Hooks are called synchronously in registration order. If a hook panics,
|
||||
// the panic will propagate to the caller.
|
||||
func CallJobStopHooks(job *schema.Job) {
|
||||
if hooks == nil {
|
||||
return
|
||||
|
||||
@@ -16,11 +16,29 @@ import (
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
)
|
||||
|
||||
// Version is the current database schema version required by this version of cc-backend.
|
||||
// When the database schema changes, this version is incremented and a new migration file
|
||||
// is added to internal/repository/migrations/sqlite3/.
|
||||
//
|
||||
// Version history:
|
||||
// - Version 10: Current version
|
||||
//
|
||||
// Migration files are embedded at build time from the migrations directory.
|
||||
const Version uint = 10
|
||||
|
||||
//go:embed migrations/*
|
||||
var migrationFiles embed.FS
|
||||
|
||||
// checkDBVersion verifies that the database schema version matches the expected version.
|
||||
// This is called automatically during Connect() to ensure schema compatibility.
|
||||
//
|
||||
// Returns an error if:
|
||||
// - Database version is older than expected (needs migration)
|
||||
// - Database version is newer than expected (needs app upgrade)
|
||||
// - Database is in a dirty state (failed migration)
|
||||
//
|
||||
// A "dirty" database indicates a migration was started but not completed successfully.
|
||||
// This requires manual intervention to fix the database and force the version.
|
||||
func checkDBVersion(db *sql.DB) error {
|
||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
@@ -58,6 +76,8 @@ func checkDBVersion(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMigrateInstance creates a new migration instance for the given database file.
|
||||
// This is used internally by MigrateDB, RevertDB, and ForceDB.
|
||||
func getMigrateInstance(db string) (m *migrate.Migrate, err error) {
|
||||
d, err := iofs.New(migrationFiles, "migrations/sqlite3")
|
||||
if err != nil {
|
||||
@@ -72,6 +92,23 @@ func getMigrateInstance(db string) (m *migrate.Migrate, err error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// MigrateDB applies all pending database migrations to bring the schema up to date.
|
||||
// This should be run with the -migrate-db flag before starting the application
|
||||
// after upgrading to a new version that requires schema changes.
|
||||
//
|
||||
// Process:
|
||||
// 1. Checks current database version
|
||||
// 2. Applies all migrations from current version to target Version
|
||||
// 3. Updates schema_migrations table to track applied migrations
|
||||
//
|
||||
// Important:
|
||||
// - Always backup your database before running migrations
|
||||
// - Migrations are irreversible without manual intervention
|
||||
// - If a migration fails, the database is marked "dirty" and requires manual fix
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// cc-backend -migrate-db
|
||||
func MigrateDB(db string) error {
|
||||
m, err := getMigrateInstance(db)
|
||||
if err != nil {
|
||||
@@ -107,6 +144,17 @@ func MigrateDB(db string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevertDB rolls back the database schema to the previous version (Version - 1).
|
||||
// This is primarily used for testing or emergency rollback scenarios.
|
||||
//
|
||||
// Warning:
|
||||
// - This may cause data loss if newer schema added columns/tables
|
||||
// - Always backup before reverting
|
||||
// - Not all migrations are safely reversible
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// cc-backend -revert-db
|
||||
func RevertDB(db string) error {
|
||||
m, err := getMigrateInstance(db)
|
||||
if err != nil {
|
||||
@@ -125,6 +173,21 @@ func RevertDB(db string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceDB forces the database schema version to the current Version without running migrations.
|
||||
// This is only used to recover from failed migrations that left the database in a "dirty" state.
|
||||
//
|
||||
// When to use:
|
||||
// - After manually fixing a failed migration
|
||||
// - When you've manually applied schema changes and need to update the version marker
|
||||
//
|
||||
// Warning:
|
||||
// - This does NOT apply any schema changes
|
||||
// - Only use after manually verifying the schema is correct
|
||||
// - Improper use can cause schema/version mismatch
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// cc-backend -force-db
|
||||
func ForceDB(db string) error {
|
||||
m, err := getMigrateInstance(db)
|
||||
if err != nil {
|
||||
|
||||
@@ -277,6 +277,15 @@ func (r *JobRepository) JobsStats(
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// LoadJobStat retrieves a specific statistic for a metric from a job's statistics.
|
||||
// Returns 0.0 if the metric is not found or statType is invalid.
|
||||
//
|
||||
// Parameters:
|
||||
// - job: Job struct with populated Statistics field
|
||||
// - metric: Name of the metric to query (e.g., "cpu_load", "mem_used")
|
||||
// - statType: Type of statistic: "avg", "min", or "max"
|
||||
//
|
||||
// Returns the requested statistic value or 0.0 if not found.
|
||||
func LoadJobStat(job *schema.Job, metric string, statType string) float64 {
|
||||
if stats, ok := job.Statistics[metric]; ok {
|
||||
switch statType {
|
||||
@@ -579,7 +588,9 @@ func (r *JobRepository) jobsDurationStatisticsHistogram(
|
||||
return nil, qerr
|
||||
}
|
||||
|
||||
// Setup Array
|
||||
// Initialize histogram bins with zero counts
|
||||
// Each bin represents a duration range: bin N = [N*binSizeSeconds, (N+1)*binSizeSeconds)
|
||||
// Example: binSizeSeconds=3600 (1 hour), bin 1 = 0-1h, bin 2 = 1-2h, etc.
|
||||
points := make([]*model.HistoPoint, 0)
|
||||
for i := 1; i <= *targetBinCount; i++ {
|
||||
point := model.HistoPoint{Value: i * binSizeSeconds, Count: 0}
|
||||
@@ -596,7 +607,8 @@ func (r *JobRepository) jobsDurationStatisticsHistogram(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fill Array at matching $Value
|
||||
// Match query results to pre-initialized bins and fill counts
|
||||
// Query returns raw duration values that need to be mapped to correct bins
|
||||
for rows.Next() {
|
||||
point := model.HistoPoint{}
|
||||
if err := rows.Scan(&point.Value, &point.Count); err != nil {
|
||||
@@ -604,11 +616,13 @@ func (r *JobRepository) jobsDurationStatisticsHistogram(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find matching bin and update count
|
||||
// point.Value is multiplied by binSizeSeconds to match pre-calculated bin.Value
|
||||
for _, e := range points {
|
||||
if e.Value == (point.Value * binSizeSeconds) {
|
||||
// Note:
|
||||
// Matching on unmodified integer value (and multiplying point.Value by binSizeSeconds after match)
|
||||
// causes frontend to loop into highest targetBinCount, due to zoom condition instantly being fullfilled (cause unknown)
|
||||
// Note: Matching on unmodified integer value (and multiplying point.Value
|
||||
// by binSizeSeconds after match) causes frontend to loop into highest
|
||||
// targetBinCount, due to zoom condition instantly being fulfilled (cause unknown)
|
||||
e.Count = point.Count
|
||||
break
|
||||
}
|
||||
@@ -625,12 +639,16 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
|
||||
filters []*model.JobFilter,
|
||||
bins *int,
|
||||
) (*model.MetricHistoPoints, error) {
|
||||
// Get specific Peak or largest Peak
|
||||
// Determine the metric's peak value for histogram normalization
|
||||
// Peak value defines the upper bound for binning: values are distributed across
|
||||
// bins from 0 to peak. First try to get peak from filtered cluster, otherwise
|
||||
// scan all clusters to find the maximum peak value.
|
||||
var metricConfig *schema.MetricConfig
|
||||
var peak float64
|
||||
var unit string
|
||||
var footprintStat string
|
||||
|
||||
// Try to get metric config from filtered cluster
|
||||
for _, f := range filters {
|
||||
if f.Cluster != nil {
|
||||
metricConfig = archive.GetMetricConfig(*f.Cluster.Eq, metric)
|
||||
@@ -641,6 +659,8 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
|
||||
}
|
||||
}
|
||||
|
||||
// If no cluster filter or peak not found, find largest peak across all clusters
|
||||
// This ensures histogram can accommodate all possible values
|
||||
if peak == 0.0 {
|
||||
for _, c := range archive.Clusters {
|
||||
for _, m := range c.MetricConfig {
|
||||
@@ -659,11 +679,18 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
|
||||
}
|
||||
}
|
||||
|
||||
// cclog.Debugf("Metric %s, Peak %f, Unit %s", metric, peak, unit)
|
||||
// Make bins, see https://jereze.com/code/sql-histogram/ (Modified here)
|
||||
// Construct SQL histogram bins using normalized values
|
||||
// Algorithm based on: https://jereze.com/code/sql-histogram/ (modified)
|
||||
start := time.Now()
|
||||
|
||||
// Find Jobs' Value Bin Number: Divide Value by Peak, Multiply by RequestedBins, then CAST to INT: Gets Bin-Number of Job
|
||||
// Calculate bin number for each job's metric value:
|
||||
// 1. Extract metric value from JSON footprint
|
||||
// 2. Normalize to [0,1] by dividing by peak
|
||||
// 3. Multiply by number of bins to get bin number
|
||||
// 4. Cast to integer for bin assignment
|
||||
//
|
||||
// Special case: Values exactly equal to peak would fall into bin N+1,
|
||||
// so we multiply peak by 0.999999999 to force it into the last bin (bin N)
|
||||
binQuery := fmt.Sprintf(`CAST(
|
||||
((case when json_extract(footprint, "$.%s") = %f then %f*0.999999999 else json_extract(footprint, "$.%s") end) / %f)
|
||||
* %v as INTEGER )`,
|
||||
@@ -698,7 +725,9 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup Return Array With Bin-Numbers for Match and Min/Max based on Peak
|
||||
// Initialize histogram bins with calculated min/max ranges
|
||||
// Each bin represents a range of metric values
|
||||
// Example: peak=1000, bins=10 -> bin 1=[0,100), bin 2=[100,200), ..., bin 10=[900,1000]
|
||||
points := make([]*model.MetricHistoPoint, 0)
|
||||
binStep := int(peak) / *bins
|
||||
for i := 1; i <= *bins; i++ {
|
||||
@@ -708,13 +737,16 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
|
||||
points = append(points, &epoint)
|
||||
}
|
||||
|
||||
for rows.Next() { // Fill Count if Bin-No. Matches (Not every Bin exists in DB!)
|
||||
// Fill counts from query results
|
||||
// Query only returns bins that have jobs, so we match against pre-initialized bins
|
||||
for rows.Next() {
|
||||
rpoint := model.MetricHistoPoint{}
|
||||
if err := rows.Scan(&rpoint.Bin, &rpoint.Count); err != nil { // Required for Debug: &rpoint.Min, &rpoint.Max
|
||||
cclog.Warnf("Error while scanning rows for %s", metric)
|
||||
return nil, err // FIXME: Totally bricks cc-backend if returned and if all metrics requested?
|
||||
}
|
||||
|
||||
// Match query result to pre-initialized bin and update count
|
||||
for _, e := range points {
|
||||
if e.Bin != nil && rpoint.Bin != nil {
|
||||
if *e.Bin == *rpoint.Bin {
|
||||
|
||||
@@ -25,11 +25,20 @@ func TestBuildJobStatsQuery(t *testing.T) {
|
||||
func TestJobStats(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
// First, count the actual jobs in the database (excluding test jobs)
|
||||
var expectedCount int
|
||||
err := r.DB.QueryRow(`SELECT COUNT(*) FROM job WHERE cluster != 'testcluster'`).Scan(&expectedCount)
|
||||
noErr(t, err)
|
||||
|
||||
filter := &model.JobFilter{}
|
||||
// Exclude test jobs created by other tests
|
||||
testCluster := "testcluster"
|
||||
filter.Cluster = &model.StringInput{Neq: &testCluster}
|
||||
|
||||
stats, err := r.JobsStats(getContext(t), []*model.JobFilter{filter})
|
||||
noErr(t, err)
|
||||
|
||||
if stats[0].TotalJobs != 544 {
|
||||
t.Fatalf("Want 544, Got %d", stats[0].TotalJobs)
|
||||
if stats[0].TotalJobs != expectedCount {
|
||||
t.Fatalf("Want %d, Got %d", expectedCount, stats[0].TotalJobs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,8 +16,32 @@ import (
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
)
|
||||
|
||||
// Tag Scope Rules:
|
||||
//
|
||||
// Tags in ClusterCockpit have three visibility scopes that control who can see and use them:
|
||||
//
|
||||
// 1. "global" - Visible to all users, can be used by anyone
|
||||
// Example: System-generated tags like "energy-efficient", "failed", "short"
|
||||
//
|
||||
// 2. "private" - Only visible to the creating user
|
||||
// Example: Personal notes like "needs-review", "interesting-case"
|
||||
//
|
||||
// 3. "admin" - Only visible to users with admin or support roles
|
||||
// Example: Internal notes like "hardware-issue", "billing-problem"
|
||||
//
|
||||
// Authorization Rules:
|
||||
// - Regular users can only create/see "global" and their own "private" tags
|
||||
// - Admin/Support can create/see all scopes including "admin" tags
|
||||
// - Users can only add tags to jobs they have permission to view
|
||||
// - Tag scope is enforced at query time in GetTags() and CountTags()
|
||||
|
||||
// AddTag adds the tag with id `tagId` to the job with the database id `jobId`.
|
||||
// Requires user authentication for security checks.
|
||||
//
|
||||
// The user must have permission to view the job. Tag visibility is determined by scope:
|
||||
// - "global" tags: visible to all users
|
||||
// - "private" tags: only visible to the tag creator
|
||||
// - "admin" tags: only visible to admin/support users
|
||||
func (r *JobRepository) AddTag(user *schema.User, job int64, tag int64) ([]*schema.Tag, error) {
|
||||
j, err := r.FindByIDWithUser(user, job)
|
||||
if err != nil {
|
||||
@@ -180,7 +204,15 @@ func (r *JobRepository) RemoveTagById(tagID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTag creates a new tag with the specified type and name and returns its database id.
|
||||
// CreateTag creates a new tag with the specified type, name, and scope.
|
||||
// Returns the database ID of the newly created tag.
|
||||
//
|
||||
// Scope defaults to "global" if empty string is provided.
|
||||
// Valid scopes: "global", "private", "admin"
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// tagID, err := repo.CreateTag("performance", "high-memory", "global")
|
||||
func (r *JobRepository) CreateTag(tagType string, tagName string, tagScope string) (tagId int64, err error) {
|
||||
// Default to "Global" scope if none defined
|
||||
if tagScope == "" {
|
||||
@@ -199,8 +231,14 @@ func (r *JobRepository) CreateTag(tagType string, tagName string, tagScope strin
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// CountTags returns all tags visible to the user and the count of jobs for each tag.
|
||||
// Applies scope-based filtering to respect tag visibility rules.
|
||||
//
|
||||
// Returns:
|
||||
// - tags: slice of tags the user can see
|
||||
// - counts: map of tag name to job count
|
||||
// - err: any error encountered
|
||||
func (r *JobRepository) CountTags(user *schema.User) (tags []schema.Tag, counts map[string]int, err error) {
|
||||
// Fetch all Tags in DB for Display in Frontend Tag-View
|
||||
tags = make([]schema.Tag, 0, 100)
|
||||
xrows, err := r.DB.Queryx("SELECT id, tag_type, tag_name, tag_scope FROM tag")
|
||||
if err != nil {
|
||||
|
||||
311
internal/repository/transaction_test.go
Normal file
311
internal/repository/transaction_test.go
Normal file
@@ -0,0 +1,311 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransactionInit(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("successful transaction init", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err, "TransactionInit should succeed")
|
||||
require.NotNil(t, tx, "Transaction should not be nil")
|
||||
require.NotNil(t, tx.tx, "Transaction.tx should not be nil")
|
||||
|
||||
// Clean up
|
||||
err = tx.Rollback()
|
||||
require.NoError(t, err, "Rollback should succeed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionCommit(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("commit after successful operations", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a test tag
|
||||
_, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_tag_commit", "global")
|
||||
require.NoError(t, err, "TransactionAdd should succeed")
|
||||
|
||||
// Commit the transaction
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err, "Commit should succeed")
|
||||
|
||||
// Verify the tag was inserted
|
||||
var count int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_tag_commit").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Tag should be committed to database")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM tag WHERE tag_name = ?", "test_tag_commit")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commit on already committed transaction", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err, "First commit should succeed")
|
||||
|
||||
err = tx.Commit()
|
||||
assert.Error(t, err, "Second commit should fail")
|
||||
assert.Contains(t, err.Error(), "transaction already committed or rolled back")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionRollback(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("rollback after operations", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a test tag
|
||||
_, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_tag_rollback", "global")
|
||||
require.NoError(t, err, "TransactionAdd should succeed")
|
||||
|
||||
// Rollback the transaction
|
||||
err = tx.Rollback()
|
||||
require.NoError(t, err, "Rollback should succeed")
|
||||
|
||||
// Verify the tag was NOT inserted
|
||||
var count int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_tag_rollback").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count, "Tag should not be in database after rollback")
|
||||
})
|
||||
|
||||
t.Run("rollback on already rolled back transaction", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Rollback()
|
||||
require.NoError(t, err, "First rollback should succeed")
|
||||
|
||||
err = tx.Rollback()
|
||||
assert.NoError(t, err, "Second rollback should be safe (no-op)")
|
||||
})
|
||||
|
||||
t.Run("rollback on committed transaction", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Rollback()
|
||||
assert.NoError(t, err, "Rollback after commit should be safe (no-op)")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionAdd(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("insert with TransactionAdd", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
id, err := r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_add", "global")
|
||||
require.NoError(t, err, "TransactionAdd should succeed")
|
||||
assert.Greater(t, id, int64(0), "Should return valid insert ID")
|
||||
})
|
||||
|
||||
t.Run("error on nil transaction", func(t *testing.T) {
|
||||
tx := &Transaction{tx: nil}
|
||||
|
||||
_, err := r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_nil", "global")
|
||||
assert.Error(t, err, "Should error on nil transaction")
|
||||
assert.Contains(t, err.Error(), "transaction is nil or already completed")
|
||||
})
|
||||
|
||||
t.Run("error on invalid SQL", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = r.TransactionAdd(tx, "INVALID SQL STATEMENT")
|
||||
assert.Error(t, err, "Should error on invalid SQL")
|
||||
})
|
||||
|
||||
t.Run("error after transaction committed", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_after_commit", "global")
|
||||
assert.Error(t, err, "Should error when transaction is already committed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionAddNamed(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("insert with TransactionAddNamed", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
type TagArgs struct {
|
||||
Type string `db:"type"`
|
||||
Name string `db:"name"`
|
||||
Scope string `db:"scope"`
|
||||
}
|
||||
|
||||
args := TagArgs{
|
||||
Type: "test_type",
|
||||
Name: "test_named",
|
||||
Scope: "global",
|
||||
}
|
||||
|
||||
id, err := r.TransactionAddNamed(tx,
|
||||
"INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (:type, :name, :scope)",
|
||||
args)
|
||||
require.NoError(t, err, "TransactionAddNamed should succeed")
|
||||
assert.Greater(t, id, int64(0), "Should return valid insert ID")
|
||||
})
|
||||
|
||||
t.Run("error on nil transaction", func(t *testing.T) {
|
||||
tx := &Transaction{tx: nil}
|
||||
|
||||
_, err := r.TransactionAddNamed(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (:type, :name, :scope)",
|
||||
map[string]interface{}{"type": "test", "name": "test", "scope": "global"})
|
||||
assert.Error(t, err, "Should error on nil transaction")
|
||||
assert.Contains(t, err.Error(), "transaction is nil or already completed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionMultipleOperations(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("multiple inserts in single transaction", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
// Insert multiple tags
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = r.TransactionAdd(tx,
|
||||
"INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_multi_"+string(rune('a'+i)), "global")
|
||||
require.NoError(t, err, "Insert %d should succeed", i)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
require.NoError(t, err, "Commit should succeed")
|
||||
|
||||
// Verify all tags were inserted
|
||||
var count int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name LIKE 'test_multi_%'").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, count, "All 5 tags should be committed")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM tag WHERE tag_name LIKE 'test_multi_%'")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("rollback undoes all operations", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert multiple tags
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err = r.TransactionAdd(tx,
|
||||
"INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_rollback_"+string(rune('a'+i)), "global")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = tx.Rollback()
|
||||
require.NoError(t, err, "Rollback should succeed")
|
||||
|
||||
// Verify no tags were inserted
|
||||
var count int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name LIKE 'test_rollback_%'").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count, "No tags should be in database after rollback")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionEnd(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("deprecated TransactionEnd calls Commit", func(t *testing.T) {
|
||||
tx, err := r.TransactionInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_end", "global")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use deprecated method
|
||||
err = r.TransactionEnd(tx)
|
||||
require.NoError(t, err, "TransactionEnd should succeed")
|
||||
|
||||
// Verify the tag was committed
|
||||
var count int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_end").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Tag should be committed")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM tag WHERE tag_name = ?", "test_end")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionDeferPattern(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("defer rollback pattern", func(t *testing.T) {
|
||||
insertTag := func() error {
|
||||
tx, err := r.TransactionInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback() // Safe to call even after commit
|
||||
|
||||
_, err = r.TransactionAdd(tx, "INSERT INTO tag (tag_type, tag_name, tag_scope) VALUES (?, ?, ?)",
|
||||
"test_type", "test_defer", "global")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
err := insertTag()
|
||||
require.NoError(t, err, "Function should succeed")
|
||||
|
||||
// Verify the tag was committed
|
||||
var count int
|
||||
err = r.DB.QueryRow("SELECT COUNT(*) FROM tag WHERE tag_name = ?", "test_defer").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Tag should be committed despite defer rollback")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM tag WHERE tag_name = ?", "test_defer")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -22,6 +22,25 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Authentication and Role System:
|
||||
//
|
||||
// ClusterCockpit supports multiple authentication sources:
|
||||
// - Local: Username/password stored in database (password hashed with bcrypt)
|
||||
// - LDAP: External LDAP/Active Directory authentication
|
||||
// - JWT: Token-based authentication for API access
|
||||
//
|
||||
// Role Hierarchy (from highest to lowest privilege):
|
||||
// 1. "admin" - Full system access, can manage all users and jobs
|
||||
// 2. "support" - Can view all jobs but limited management capabilities
|
||||
// 3. "manager" - Can manage specific projects and their users
|
||||
// 4. "api" - Programmatic access for job submission/management
|
||||
// 5. "user" - Default role, can only view own jobs
|
||||
//
|
||||
// Project Association:
|
||||
// - Managers have a list of projects they oversee
|
||||
// - Regular users' project membership is determined by job data
|
||||
// - Managers can view/manage all jobs within their projects
|
||||
|
||||
var (
|
||||
userRepoOnce sync.Once
|
||||
userRepoInstance *UserRepository
|
||||
@@ -44,6 +63,9 @@ func GetUserRepository() *UserRepository {
|
||||
return userRepoInstance
|
||||
}
|
||||
|
||||
// GetUser retrieves a user by username from the database.
|
||||
// Returns the complete user record including hashed password, roles, and projects.
|
||||
// Password field contains bcrypt hash for local auth users, empty for LDAP users.
|
||||
func (r *UserRepository) GetUser(username string) (*schema.User, error) {
|
||||
user := &schema.User{Username: username}
|
||||
var hashedPassword, name, rawRoles, email, rawProjects sql.NullString
|
||||
@@ -93,6 +115,12 @@ func (r *UserRepository) GetLdapUsernames() ([]string, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// AddUser creates a new user in the database.
|
||||
// Passwords are automatically hashed with bcrypt before storage.
|
||||
// Auth source determines authentication method (local, LDAP, etc.).
|
||||
//
|
||||
// Required fields: Username, Roles
|
||||
// Optional fields: Name, Email, Password, Projects, AuthSource
|
||||
func (r *UserRepository) AddUser(user *schema.User) error {
|
||||
rolesJson, _ := json.Marshal(user.Roles)
|
||||
projectsJson, _ := json.Marshal(user.Projects)
|
||||
@@ -229,6 +257,14 @@ func (r *UserRepository) ListUsers(specialsOnly bool) ([]*schema.User, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// AddRole adds a role to a user's role list.
|
||||
// Role string is automatically lowercased.
|
||||
// Valid roles: admin, support, manager, api, user
|
||||
//
|
||||
// Returns error if:
|
||||
// - User doesn't exist
|
||||
// - Role is invalid
|
||||
// - User already has the role
|
||||
func (r *UserRepository) AddRole(
|
||||
ctx context.Context,
|
||||
username string,
|
||||
@@ -258,6 +294,11 @@ func (r *UserRepository) AddRole(
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRole removes a role from a user's role list.
|
||||
//
|
||||
// Special rules:
|
||||
// - Cannot remove "manager" role while user has assigned projects
|
||||
// - Must remove all projects first before removing manager role
|
||||
func (r *UserRepository) RemoveRole(ctx context.Context, username string, queryrole string) error {
|
||||
oldRole := strings.ToLower(queryrole)
|
||||
user, err := r.GetUser(username)
|
||||
@@ -294,6 +335,12 @@ func (r *UserRepository) RemoveRole(ctx context.Context, username string, queryr
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddProject assigns a project to a manager user.
|
||||
// Only users with the "manager" role can have assigned projects.
|
||||
//
|
||||
// Returns error if:
|
||||
// - User doesn't have manager role
|
||||
// - User already manages the project
|
||||
func (r *UserRepository) AddProject(
|
||||
ctx context.Context,
|
||||
username string,
|
||||
|
||||
596
internal/repository/user_test.go
Normal file
596
internal/repository/user_test.go
Normal file
@@ -0,0 +1,596 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestAddUser(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
|
||||
t.Run("add user with all fields", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "testuser1",
|
||||
Name: "Test User One",
|
||||
Email: "test1@example.com",
|
||||
Password: "testpassword123",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{"project1", "project2"},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrievedUser, err := r.GetUser("testuser1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.Username, retrievedUser.Username)
|
||||
assert.Equal(t, user.Name, retrievedUser.Name)
|
||||
assert.Equal(t, user.Email, retrievedUser.Email)
|
||||
assert.Equal(t, user.Roles, retrievedUser.Roles)
|
||||
assert.Equal(t, user.Projects, retrievedUser.Projects)
|
||||
assert.NotEmpty(t, retrievedUser.Password)
|
||||
err = bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("testpassword123"))
|
||||
assert.NoError(t, err, "Password should be hashed correctly")
|
||||
|
||||
err = r.DelUser("testuser1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("add user with minimal fields", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "testuser2",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLDAP,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrievedUser, err := r.GetUser("testuser2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.Username, retrievedUser.Username)
|
||||
assert.Equal(t, "", retrievedUser.Name)
|
||||
assert.Equal(t, "", retrievedUser.Email)
|
||||
assert.Equal(t, "", retrievedUser.Password)
|
||||
|
||||
err = r.DelUser("testuser2")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("add duplicate user fails", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "testuser3",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.AddUser(user)
|
||||
assert.Error(t, err, "Adding duplicate user should fail")
|
||||
|
||||
err = r.DelUser("testuser3")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUser(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
|
||||
t.Run("get existing user", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "getuser1",
|
||||
Name: "Get User",
|
||||
Email: "getuser@example.com",
|
||||
Roles: []string{"user", "admin"},
|
||||
Projects: []string{"proj1"},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := r.GetUser("getuser1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.Username, retrieved.Username)
|
||||
assert.Equal(t, user.Name, retrieved.Name)
|
||||
assert.Equal(t, user.Email, retrieved.Email)
|
||||
assert.ElementsMatch(t, user.Roles, retrieved.Roles)
|
||||
assert.ElementsMatch(t, user.Projects, retrieved.Projects)
|
||||
|
||||
err = r.DelUser("getuser1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("get non-existent user", func(t *testing.T) {
|
||||
_, err := r.GetUser("nonexistent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
|
||||
t.Run("update user name", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "updateuser1",
|
||||
Name: "Original Name",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
dbUser, err := r.GetUser("updateuser1")
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser := &schema.User{
|
||||
Username: "updateuser1",
|
||||
Name: "Updated Name",
|
||||
}
|
||||
|
||||
err = r.UpdateUser(dbUser, updatedUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := r.GetUser("updateuser1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Name", retrieved.Name)
|
||||
|
||||
err = r.DelUser("updateuser1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("update with no changes", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "updateuser2",
|
||||
Name: "Same Name",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
dbUser, err := r.GetUser("updateuser2")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.UpdateUser(dbUser, dbUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = r.DelUser("updateuser2")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDelUser(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
|
||||
t.Run("delete existing user", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "deluser1",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.DelUser("deluser1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = r.GetUser("deluser1")
|
||||
assert.Error(t, err, "User should not exist after deletion")
|
||||
})
|
||||
|
||||
t.Run("delete non-existent user", func(t *testing.T) {
|
||||
err := r.DelUser("nonexistent")
|
||||
assert.NoError(t, err, "Deleting non-existent user should not error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListUsers(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
|
||||
user1 := &schema.User{
|
||||
Username: "listuser1",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
user2 := &schema.User{
|
||||
Username: "listuser2",
|
||||
Roles: []string{"admin"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
user3 := &schema.User{
|
||||
Username: "listuser3",
|
||||
Roles: []string{"manager"},
|
||||
Projects: []string{"proj1"},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user1)
|
||||
require.NoError(t, err)
|
||||
err = r.AddUser(user2)
|
||||
require.NoError(t, err)
|
||||
err = r.AddUser(user3)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("list all users", func(t *testing.T) {
|
||||
users, err := r.ListUsers(false)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(users), 3)
|
||||
|
||||
usernames := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
usernames[i] = u.Username
|
||||
}
|
||||
assert.Contains(t, usernames, "listuser1")
|
||||
assert.Contains(t, usernames, "listuser2")
|
||||
assert.Contains(t, usernames, "listuser3")
|
||||
})
|
||||
|
||||
t.Run("list special users only", func(t *testing.T) {
|
||||
users, err := r.ListUsers(true)
|
||||
require.NoError(t, err)
|
||||
|
||||
usernames := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
usernames[i] = u.Username
|
||||
}
|
||||
assert.Contains(t, usernames, "listuser2")
|
||||
assert.Contains(t, usernames, "listuser3")
|
||||
})
|
||||
|
||||
err = r.DelUser("listuser1")
|
||||
require.NoError(t, err)
|
||||
err = r.DelUser("listuser2")
|
||||
require.NoError(t, err)
|
||||
err = r.DelUser("listuser3")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetLdapUsernames(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
|
||||
ldapUser := &schema.User{
|
||||
Username: "ldapuser1",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLDAP,
|
||||
}
|
||||
localUser := &schema.User{
|
||||
Username: "localuser1",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(ldapUser)
|
||||
require.NoError(t, err)
|
||||
err = r.AddUser(localUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
usernames, err := r.GetLdapUsernames()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, usernames, "ldapuser1")
|
||||
assert.NotContains(t, usernames, "localuser1")
|
||||
|
||||
err = r.DelUser("ldapuser1")
|
||||
require.NoError(t, err)
|
||||
err = r.DelUser("localuser1")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAddRole(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("add valid role", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "roleuser1",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.AddRole(ctx, "roleuser1", "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := r.GetUser("roleuser1")
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, retrieved.Roles, "admin")
|
||||
assert.Contains(t, retrieved.Roles, "user")
|
||||
|
||||
err = r.DelUser("roleuser1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("add duplicate role", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "roleuser2",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.AddRole(ctx, "roleuser2", "user")
|
||||
assert.Error(t, err, "Adding duplicate role should fail")
|
||||
assert.Contains(t, err.Error(), "already has role")
|
||||
|
||||
err = r.DelUser("roleuser2")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("add invalid role", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "roleuser3",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.AddRole(ctx, "roleuser3", "invalidrole")
|
||||
assert.Error(t, err, "Adding invalid role should fail")
|
||||
assert.Contains(t, err.Error(), "no valid option")
|
||||
|
||||
err = r.DelUser("roleuser3")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemoveRole(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("remove existing role", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "rmroleuser1",
|
||||
Roles: []string{"user", "admin"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.RemoveRole(ctx, "rmroleuser1", "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := r.GetUser("rmroleuser1")
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, retrieved.Roles, "admin")
|
||||
assert.Contains(t, retrieved.Roles, "user")
|
||||
|
||||
err = r.DelUser("rmroleuser1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("remove non-existent role", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "rmroleuser2",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.RemoveRole(ctx, "rmroleuser2", "admin")
|
||||
assert.Error(t, err, "Removing non-existent role should fail")
|
||||
assert.Contains(t, err.Error(), "already deleted")
|
||||
|
||||
err = r.DelUser("rmroleuser2")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("remove manager role with projects", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "rmroleuser3",
|
||||
Roles: []string{"manager"},
|
||||
Projects: []string{"proj1", "proj2"},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.RemoveRole(ctx, "rmroleuser3", "manager")
|
||||
assert.Error(t, err, "Removing manager role with projects should fail")
|
||||
assert.Contains(t, err.Error(), "still has assigned project")
|
||||
|
||||
err = r.DelUser("rmroleuser3")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddProject(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("add project to manager", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "projuser1",
|
||||
Roles: []string{"manager"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.AddProject(ctx, "projuser1", "newproject")
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := r.GetUser("projuser1")
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, retrieved.Projects, "newproject")
|
||||
|
||||
err = r.DelUser("projuser1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("add project to non-manager", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "projuser2",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.AddProject(ctx, "projuser2", "newproject")
|
||||
assert.Error(t, err, "Adding project to non-manager should fail")
|
||||
assert.Contains(t, err.Error(), "not a manager")
|
||||
|
||||
err = r.DelUser("projuser2")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("add duplicate project", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "projuser3",
|
||||
Roles: []string{"manager"},
|
||||
Projects: []string{"existingproject"},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.AddProject(ctx, "projuser3", "existingproject")
|
||||
assert.Error(t, err, "Adding duplicate project should fail")
|
||||
assert.Contains(t, err.Error(), "already manages")
|
||||
|
||||
err = r.DelUser("projuser3")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemoveProject(t *testing.T) {
|
||||
_ = setup(t)
|
||||
r := GetUserRepository()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("remove existing project", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "rmprojuser1",
|
||||
Roles: []string{"manager"},
|
||||
Projects: []string{"proj1", "proj2"},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.RemoveProject(ctx, "rmprojuser1", "proj1")
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := r.GetUser("rmprojuser1")
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, retrieved.Projects, "proj1")
|
||||
assert.Contains(t, retrieved.Projects, "proj2")
|
||||
|
||||
err = r.DelUser("rmprojuser1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("remove non-existent project", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "rmprojuser2",
|
||||
Roles: []string{"manager"},
|
||||
Projects: []string{"proj1"},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.RemoveProject(ctx, "rmprojuser2", "nonexistent")
|
||||
assert.Error(t, err, "Removing non-existent project should fail")
|
||||
|
||||
err = r.DelUser("rmprojuser2")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("remove project from non-manager", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "rmprojuser3",
|
||||
Roles: []string{"user"},
|
||||
Projects: []string{},
|
||||
AuthSource: schema.AuthViaLocalPassword,
|
||||
}
|
||||
|
||||
err := r.AddUser(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.RemoveProject(ctx, "rmprojuser3", "proj1")
|
||||
assert.Error(t, err, "Removing project from non-manager should fail")
|
||||
assert.Contains(t, err.Error(), "not a manager")
|
||||
|
||||
err = r.DelUser("rmprojuser3")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserFromContext(t *testing.T) {
|
||||
t.Run("get user from context", func(t *testing.T) {
|
||||
user := &schema.User{
|
||||
Username: "contextuser",
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), ContextUserKey, user)
|
||||
retrieved := GetUserFromContext(ctx)
|
||||
|
||||
require.NotNil(t, retrieved)
|
||||
assert.Equal(t, user.Username, retrieved.Username)
|
||||
})
|
||||
|
||||
t.Run("get user from empty context", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
retrieved := GetUserFromContext(ctx)
|
||||
|
||||
assert.Nil(t, retrieved)
|
||||
})
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func newTagger() {
|
||||
func Init() {
|
||||
initOnce.Do(func() {
|
||||
newTagger()
|
||||
repository.RegisterJobJook(jobTagger)
|
||||
repository.RegisterJobHook(jobTagger)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user