mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2024-12-25 12:59:06 +01:00
Sanitize job before import
This commit is contained in:
parent
49fbfc23d4
commit
f330d24768
16
api/rest.go
16
api/rest.go
@ -17,7 +17,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/auth"
|
||||
"github.com/ClusterCockpit/cc-backend/config"
|
||||
"github.com/ClusterCockpit/cc-backend/graph"
|
||||
"github.com/ClusterCockpit/cc-backend/graph/model"
|
||||
"github.com/ClusterCockpit/cc-backend/log"
|
||||
@ -255,14 +254,11 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if config.GetClusterConfig(req.Cluster) == nil || config.GetPartition(req.Cluster, req.Partition) == nil {
|
||||
handleError(fmt.Errorf("cluster or partition does not exist: %#v/%#v", req.Cluster, req.Partition), http.StatusBadRequest, rw)
|
||||
return
|
||||
if req.State == "" {
|
||||
req.State = schema.JobStateRunning
|
||||
}
|
||||
|
||||
// TODO: Do more such checks, be smarter with them.
|
||||
if len(req.Resources) == 0 || len(req.User) == 0 || req.NumNodes == 0 {
|
||||
handleError(errors.New("the fields 'resources', 'user' and 'numNodes' are required"), http.StatusBadRequest, rw)
|
||||
if err := repository.SanityChecks(&req.BaseJob); err != nil {
|
||||
handleError(err, http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@ -278,10 +274,6 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.State == "" {
|
||||
req.State = schema.JobStateRunning
|
||||
}
|
||||
|
||||
req.RawResources, err = json.Marshal(req.Resources)
|
||||
if err != nil {
|
||||
handleError(fmt.Errorf("basically impossible: %w", err), http.StatusBadRequest, rw)
|
||||
|
@ -218,6 +218,10 @@ func loadJob(tx *sqlx.Tx, stmt *sqlx.NamedStmt, tags map[string]int64, path stri
|
||||
return err
|
||||
}
|
||||
|
||||
if err := repository.SanityChecks(&job.BaseJob); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := stmt.Exec(job)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -4,12 +4,12 @@ import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/config"
|
||||
"github.com/ClusterCockpit/cc-backend/log"
|
||||
"github.com/ClusterCockpit/cc-backend/metricdata"
|
||||
"github.com/ClusterCockpit/cc-backend/schema"
|
||||
@ -30,7 +30,7 @@ func (r *JobRepository) HandleImportFlag(flag string) error {
|
||||
for _, pair := range strings.Split(flag, ",") {
|
||||
files := strings.Split(pair, ":")
|
||||
if len(files) != 2 {
|
||||
return errors.New("invalid import flag format")
|
||||
return fmt.Errorf("invalid import flag format")
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(files[0])
|
||||
@ -94,6 +94,10 @@ func (r *JobRepository) ImportJob(jobMeta *schema.JobMeta, jobData *schema.JobDa
|
||||
return err
|
||||
}
|
||||
|
||||
if err := SanityChecks(&job.BaseJob); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := r.DB.NamedExec(NamedJobInsert, job)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -114,6 +118,29 @@ func (r *JobRepository) ImportJob(jobMeta *schema.JobMeta, jobData *schema.JobDa
|
||||
return nil
|
||||
}
|
||||
|
||||
func SanityChecks(job *schema.BaseJob) error {
|
||||
if c := config.GetClusterConfig(job.Cluster); c == nil {
|
||||
return fmt.Errorf("no such cluster: %#v", job.Cluster)
|
||||
}
|
||||
if p := config.GetPartition(job.Cluster, job.Partition); p == nil {
|
||||
return fmt.Errorf("no such partition: %#v (on cluster %#v)", job.Partition, job.Cluster)
|
||||
}
|
||||
if !job.State.Valid() {
|
||||
return fmt.Errorf("not a valid job state: %#v", job.State)
|
||||
}
|
||||
if len(job.Resources) == 0 || len(job.User) == 0 {
|
||||
return fmt.Errorf("'resources' and 'user' should not be empty")
|
||||
}
|
||||
if job.NumAcc < 0 || job.NumHWThreads < 0 || job.NumNodes < 1 {
|
||||
return fmt.Errorf("'numNodes', 'numAcc' or 'numHWThreads' invalid")
|
||||
}
|
||||
if len(job.Resources) != int(job.NumNodes) {
|
||||
return fmt.Errorf("len(resources) does not equal numNodes (%d vs %d)", len(job.Resources), job.NumNodes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadJobStat(job *schema.JobMeta, metric string) float64 {
|
||||
if stats, ok := job.Statistics[metric]; ok {
|
||||
return stats.Avg
|
||||
|
Loading…
Reference in New Issue
Block a user