From a528e42be654b6288cd587347e63b2d68e7b7dfd Mon Sep 17 00:00:00 2001 From: Lou Knauer Date: Tue, 22 Feb 2022 09:25:41 +0100 Subject: [PATCH] use prepared statements --- repository/job.go | 66 ++++++++++++++++++++++++++++++--------------- repository/query.go | 6 ++--- repository/tags.go | 21 ++++++++------- schema/job.go | 33 +---------------------- 4 files changed, 61 insertions(+), 65 deletions(-) diff --git a/repository/job.go b/repository/job.go index 16a4291..98710a2 100644 --- a/repository/job.go +++ b/repository/job.go @@ -3,8 +3,10 @@ package repository import ( "context" "database/sql" + "encoding/json" "errors" "strconv" + "time" "github.com/ClusterCockpit/cc-backend/auth" "github.com/ClusterCockpit/cc-backend/graph/model" @@ -15,12 +17,43 @@ import ( type JobRepository struct { DB *sqlx.DB + + stmtCache *sq.StmtCache } func (r *JobRepository) Init() error { + r.stmtCache = sq.NewStmtCache(r.DB) return nil } +var jobColumns []string = []string{ + "job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id", + "job.num_nodes", "job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state", + "job.duration", "job.resources", "job.meta_data", +} + +func scanJob(row interface{ Scan(...interface{}) error }) (*schema.Job, error) { + job := &schema.Job{} + if err := row.Scan( + &job.ID, &job.JobID, &job.User, &job.Project, &job.Cluster, &job.StartTimeUnix, &job.Partition, &job.ArrayJobId, + &job.NumNodes, &job.NumHWThreads, &job.NumAcc, &job.Exclusive, &job.MonitoringStatus, &job.SMT, &job.State, + &job.Duration, &job.RawResources, &job.MetaData); err != nil { + return nil, err + } + + if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil { + return nil, err + } + + job.StartTime = time.Unix(job.StartTimeUnix, 0) + if job.Duration == 0 && job.State == schema.JobStateRunning { + job.Duration = int32(time.Since(job.StartTime).Seconds()) + } + + job.RawResources = nil + return job, nil +} + // Find executes a SQL query to find a specific batch job. // The job is queried using the batch job id, the cluster name, // and the start time of the job in UNIX epoch time seconds. @@ -31,22 +64,17 @@ func (r *JobRepository) Find( cluster *string, startTime *int64) (*schema.Job, error) { - qb := sq.Select(schema.JobColumns...).From("job"). + q := sq.Select(jobColumns...).From("job"). Where("job.job_id = ?", jobId) if cluster != nil { - qb = qb.Where("job.cluster = ?", *cluster) + q = q.Where("job.cluster = ?", *cluster) } if startTime != nil { - qb = qb.Where("job.start_time = ?", *startTime) + q = q.Where("job.start_time = ?", *startTime) } - sqlQuery, args, err := qb.ToSql() - if err != nil { - return nil, err - } - - return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...)) + return scanJob(q.RunWith(r.stmtCache).QueryRow()) } // FindById executes a SQL query to find a specific batch job. @@ -55,13 +83,9 @@ func (r *JobRepository) Find( // To check if no job was found test err == sql.ErrNoRows func (r *JobRepository) FindById( jobId int64) (*schema.Job, error) { - sqlQuery, args, err := sq.Select(schema.JobColumns...). - From("job").Where("job.id = ?", jobId).ToSql() - if err != nil { - return nil, err - } - - return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...)) + q := sq.Select(jobColumns...). + From("job").Where("job.id = ?", jobId) + return scanJob(q.RunWith(r.stmtCache).QueryRow()) } // Start inserts a new job in the table, returning the unique job ID. @@ -94,7 +118,7 @@ func (r *JobRepository) Stop( Set("monitoring_status", monitoringStatus). Where("job.id = ?", jobId) - _, err = stmt.RunWith(r.DB).Exec() + _, err = stmt.RunWith(r.stmtCache).Exec() return } @@ -136,7 +160,7 @@ func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32 Set("monitoring_status", monitoringStatus). Where("job.id = ?", job) - _, err = stmt.RunWith(r.DB).Exec() + _, err = stmt.RunWith(r.stmtCache).Exec() return } @@ -167,7 +191,7 @@ func (r *JobRepository) Archive( } } - if _, err := stmt.RunWith(r.DB).Exec(); err != nil { + if _, err := stmt.RunWith(r.stmtCache).Exec(); err != nil { return err } return nil @@ -186,7 +210,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j qb = qb.Where("job.user = ?", user.Username) } - err := qb.RunWith(r.DB).QueryRow().Scan(&job) + err := qb.RunWith(r.stmtCache).QueryRow().Scan(&job) if err != nil && err != sql.ErrNoRows { return 0, "", err } else if err == nil { @@ -197,7 +221,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j if user == nil || user.HasRole(auth.RoleAdmin) { err := sq.Select("job.user").Distinct().From("job"). Where("job.user = ?", searchterm). - RunWith(r.DB).QueryRow().Scan(&username) + RunWith(r.stmtCache).QueryRow().Scan(&username) if err != nil && err != sql.ErrNoRows { return 0, "", err } else if err == nil { diff --git a/repository/query.go b/repository/query.go index 6213eaa..dc9a5d8 100644 --- a/repository/query.go +++ b/repository/query.go @@ -21,7 +21,7 @@ func (r *JobRepository) QueryJobs( page *model.PageRequest, order *model.OrderByInput) ([]*schema.Job, error) { - query := sq.Select(schema.JobColumns...).From("job") + query := sq.Select(jobColumns...).From("job") query = SecurityCheck(ctx, query) if order != nil { @@ -50,14 +50,14 @@ func (r *JobRepository) QueryJobs( } log.Debugf("SQL query: `%s`, args: %#v", sql, args) - rows, err := r.DB.Queryx(sql, args...) + rows, err := query.RunWith(r.stmtCache).Query() if err != nil { return nil, err } jobs := make([]*schema.Job, 0, 50) for rows.Next() { - job, err := schema.ScanJob(rows) + job, err := scanJob(rows) if err != nil { return nil, err } diff --git a/repository/tags.go b/repository/tags.go index 7a26814..79b1575 100644 --- a/repository/tags.go +++ b/repository/tags.go @@ -9,19 +9,19 @@ import ( // Add the tag with id `tagId` to the job with the database id `jobId`. func (r *JobRepository) AddTag(jobId int64, tagId int64) error { - _, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId) + _, err := r.stmtCache.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId) return err } // Removes a tag from a job func (r *JobRepository) RemoveTag(job, tag int64) error { - _, err := r.DB.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag) + _, err := r.stmtCache.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag) return err } // CreateTag creates a new tag with the specified type and name and returns its database id. func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) { - res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName) + res, err := r.stmtCache.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName) if err != nil { return 0, err } @@ -52,13 +52,12 @@ func (r *JobRepository) CountTags(user *string) (tags []schema.Tag, counts map[s q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user) } - rows, err := q.RunWith(r.DB).Query() + rows, err := q.RunWith(r.stmtCache).Query() if err != nil { return nil, nil, err } counts = make(map[string]int) - for rows.Next() { var tagName string var count int @@ -92,7 +91,7 @@ func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exis exists = true if err := sq.Select("id").From("tag"). Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName). - RunWith(r.DB).QueryRow().Scan(&tagId); err != nil { + RunWith(r.stmtCache).QueryRow().Scan(&tagId); err != nil { exists = false } return @@ -105,14 +104,18 @@ func (r *JobRepository) GetTags(job *int64) ([]*schema.Tag, error) { q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job) } - sql, args, err := q.ToSql() + rows, err := q.RunWith(r.stmtCache).Query() if err != nil { return nil, err } tags := make([]*schema.Tag, 0) - if err := r.DB.Select(&tags, sql, args...); err != nil { - return nil, err + for rows.Next() { + tag := &schema.Tag{} + if err := rows.Scan(&tag.ID, &tag.Type, &tag.Name); err != nil { + return nil, err + } + tags = append(tags, tag) } return tags, nil diff --git a/schema/job.go b/schema/job.go index 8fe8dc4..6b2610b 100644 --- a/schema/job.go +++ b/schema/job.go @@ -1,7 +1,6 @@ package schema import ( - "encoding/json" "errors" "fmt" "io" @@ -52,6 +51,7 @@ type Job struct { // This is why there is this struct, which contains all fields from the regular job struct, but "overwrites" // the StartTime field with one of type int64. type JobMeta struct { + ID *int64 `json:"id,omitempty"` // never used in the job-archive, only available via REST-API BaseJob StartTime int64 `json:"startTime" db:"start_time"` Statistics map[string]JobStatistics `json:"statistics,omitempty"` @@ -67,37 +67,6 @@ const ( var JobDefaults BaseJob = BaseJob{ Exclusive: 1, MonitoringStatus: MonitoringStatusRunningOrArchiving, - MetaData: "", -} - -var JobColumns []string = []string{ - "job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id", "job.num_nodes", - "job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state", - "job.duration", "job.resources", "job.meta_data", -} - -type Scannable interface { - StructScan(dest interface{}) error -} - -// Helper function for scanning jobs with the `jobTableCols` columns selected. -func ScanJob(row Scannable) (*Job, error) { - job := &Job{BaseJob: JobDefaults} - if err := row.StructScan(job); err != nil { - return nil, err - } - - if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil { - return nil, err - } - - job.StartTime = time.Unix(job.StartTimeUnix, 0) - if job.Duration == 0 && job.State == JobStateRunning { - job.Duration = int32(time.Since(job.StartTime).Seconds()) - } - - job.RawResources = nil - return job, nil } type JobStatistics struct {