diff --git a/api/rest.go b/api/rest.go index 125cc24..8f50c8c 100644 --- a/api/rest.go +++ b/api/rest.go @@ -17,7 +17,6 @@ import ( "time" "github.com/ClusterCockpit/cc-backend/auth" - "github.com/ClusterCockpit/cc-backend/config" "github.com/ClusterCockpit/cc-backend/graph" "github.com/ClusterCockpit/cc-backend/graph/model" "github.com/ClusterCockpit/cc-backend/log" @@ -255,14 +254,11 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) { return } - if config.GetClusterConfig(req.Cluster) == nil || config.GetPartition(req.Cluster, req.Partition) == nil { - handleError(fmt.Errorf("cluster or partition does not exist: %#v/%#v", req.Cluster, req.Partition), http.StatusBadRequest, rw) - return + if req.State == "" { + req.State = schema.JobStateRunning } - - // TODO: Do more such checks, be smarter with them. - if len(req.Resources) == 0 || len(req.User) == 0 || req.NumNodes == 0 { - handleError(errors.New("the fields 'resources', 'user' and 'numNodes' are required"), http.StatusBadRequest, rw) + if err := repository.SanityChecks(&req.BaseJob); err != nil { + handleError(err, http.StatusBadRequest, rw) return } @@ -278,10 +274,6 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) { return } - if req.State == "" { - req.State = schema.JobStateRunning - } - req.RawResources, err = json.Marshal(req.Resources) if err != nil { handleError(fmt.Errorf("basically impossible: %w", err), http.StatusBadRequest, rw) diff --git a/init-db.go b/init-db.go index 3c2595f..9319aae 100644 --- a/init-db.go +++ b/init-db.go @@ -218,6 +218,10 @@ func loadJob(tx *sqlx.Tx, stmt *sqlx.NamedStmt, tags map[string]int64, path stri return err } + if err := repository.SanityChecks(&job.BaseJob); err != nil { + return err + } + res, err := stmt.Exec(job) if err != nil { return err diff --git a/repository/import.go b/repository/import.go index ac0b293..5689278 100644 --- a/repository/import.go +++ b/repository/import.go @@ -4,12 +4,12 @@ import ( "bytes" "database/sql" "encoding/json" - "errors" "fmt" "os" "strings" "time" + "github.com/ClusterCockpit/cc-backend/config" "github.com/ClusterCockpit/cc-backend/log" "github.com/ClusterCockpit/cc-backend/metricdata" "github.com/ClusterCockpit/cc-backend/schema" @@ -30,7 +30,7 @@ func (r *JobRepository) HandleImportFlag(flag string) error { for _, pair := range strings.Split(flag, ",") { files := strings.Split(pair, ":") if len(files) != 2 { - return errors.New("invalid import flag format") + return fmt.Errorf("invalid import flag format") } raw, err := os.ReadFile(files[0]) @@ -94,6 +94,10 @@ func (r *JobRepository) ImportJob(jobMeta *schema.JobMeta, jobData *schema.JobDa return err } + if err := SanityChecks(&job.BaseJob); err != nil { + return err + } + res, err := r.DB.NamedExec(NamedJobInsert, job) if err != nil { return err @@ -114,6 +118,29 @@ func (r *JobRepository) ImportJob(jobMeta *schema.JobMeta, jobData *schema.JobDa return nil } +func SanityChecks(job *schema.BaseJob) error { + if c := config.GetClusterConfig(job.Cluster); c == nil { + return fmt.Errorf("no such cluster: %#v", job.Cluster) + } + if p := config.GetPartition(job.Cluster, job.Partition); p == nil { + return fmt.Errorf("no such partition: %#v (on cluster %#v)", job.Partition, job.Cluster) + } + if !job.State.Valid() { + return fmt.Errorf("not a valid job state: %#v", job.State) + } + if len(job.Resources) == 0 || len(job.User) == 0 { + return fmt.Errorf("'resources' and 'user' should not be empty") + } + if job.NumAcc < 0 || job.NumHWThreads < 0 || job.NumNodes < 1 { + return fmt.Errorf("'numNodes', 'numAcc' or 'numHWThreads' invalid") + } + if len(job.Resources) != int(job.NumNodes) { + return fmt.Errorf("len(resources) does not equal numNodes (%d vs %d)", len(job.Resources), job.NumNodes) + } + + return nil +} + func loadJobStat(job *schema.JobMeta, metric string) float64 { if stats, ok := job.Statistics[metric]; ok { return stats.Avg