use prepared statements

This commit is contained in:
Lou Knauer 2022-02-22 09:25:41 +01:00
parent 96f91f1a1c
commit a528e42be6
4 changed files with 61 additions and 65 deletions

View File

@ -3,8 +3,10 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"strconv" "strconv"
"time"
"github.com/ClusterCockpit/cc-backend/auth" "github.com/ClusterCockpit/cc-backend/auth"
"github.com/ClusterCockpit/cc-backend/graph/model" "github.com/ClusterCockpit/cc-backend/graph/model"
@ -15,12 +17,43 @@ import (
type JobRepository struct { type JobRepository struct {
DB *sqlx.DB DB *sqlx.DB
stmtCache *sq.StmtCache
} }
func (r *JobRepository) Init() error { func (r *JobRepository) Init() error {
r.stmtCache = sq.NewStmtCache(r.DB)
return nil return nil
} }
var jobColumns []string = []string{
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id",
"job.num_nodes", "job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
"job.duration", "job.resources", "job.meta_data",
}
func scanJob(row interface{ Scan(...interface{}) error }) (*schema.Job, error) {
job := &schema.Job{}
if err := row.Scan(
&job.ID, &job.JobID, &job.User, &job.Project, &job.Cluster, &job.StartTimeUnix, &job.Partition, &job.ArrayJobId,
&job.NumNodes, &job.NumHWThreads, &job.NumAcc, &job.Exclusive, &job.MonitoringStatus, &job.SMT, &job.State,
&job.Duration, &job.RawResources, &job.MetaData); err != nil {
return nil, err
}
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
return nil, err
}
job.StartTime = time.Unix(job.StartTimeUnix, 0)
if job.Duration == 0 && job.State == schema.JobStateRunning {
job.Duration = int32(time.Since(job.StartTime).Seconds())
}
job.RawResources = nil
return job, nil
}
// Find executes a SQL query to find a specific batch job. // Find executes a SQL query to find a specific batch job.
// The job is queried using the batch job id, the cluster name, // The job is queried using the batch job id, the cluster name,
// and the start time of the job in UNIX epoch time seconds. // and the start time of the job in UNIX epoch time seconds.
@ -31,22 +64,17 @@ func (r *JobRepository) Find(
cluster *string, cluster *string,
startTime *int64) (*schema.Job, error) { startTime *int64) (*schema.Job, error) {
qb := sq.Select(schema.JobColumns...).From("job"). q := sq.Select(jobColumns...).From("job").
Where("job.job_id = ?", jobId) Where("job.job_id = ?", jobId)
if cluster != nil { if cluster != nil {
qb = qb.Where("job.cluster = ?", *cluster) q = q.Where("job.cluster = ?", *cluster)
} }
if startTime != nil { if startTime != nil {
qb = qb.Where("job.start_time = ?", *startTime) q = q.Where("job.start_time = ?", *startTime)
} }
sqlQuery, args, err := qb.ToSql() return scanJob(q.RunWith(r.stmtCache).QueryRow())
if err != nil {
return nil, err
}
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
} }
// FindById executes a SQL query to find a specific batch job. // FindById executes a SQL query to find a specific batch job.
@ -55,13 +83,9 @@ func (r *JobRepository) Find(
// To check if no job was found test err == sql.ErrNoRows // To check if no job was found test err == sql.ErrNoRows
func (r *JobRepository) FindById( func (r *JobRepository) FindById(
jobId int64) (*schema.Job, error) { jobId int64) (*schema.Job, error) {
sqlQuery, args, err := sq.Select(schema.JobColumns...). q := sq.Select(jobColumns...).
From("job").Where("job.id = ?", jobId).ToSql() From("job").Where("job.id = ?", jobId)
if err != nil { return scanJob(q.RunWith(r.stmtCache).QueryRow())
return nil, err
}
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
} }
// Start inserts a new job in the table, returning the unique job ID. // Start inserts a new job in the table, returning the unique job ID.
@ -94,7 +118,7 @@ func (r *JobRepository) Stop(
Set("monitoring_status", monitoringStatus). Set("monitoring_status", monitoringStatus).
Where("job.id = ?", jobId) Where("job.id = ?", jobId)
_, err = stmt.RunWith(r.DB).Exec() _, err = stmt.RunWith(r.stmtCache).Exec()
return return
} }
@ -136,7 +160,7 @@ func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32
Set("monitoring_status", monitoringStatus). Set("monitoring_status", monitoringStatus).
Where("job.id = ?", job) Where("job.id = ?", job)
_, err = stmt.RunWith(r.DB).Exec() _, err = stmt.RunWith(r.stmtCache).Exec()
return return
} }
@ -167,7 +191,7 @@ func (r *JobRepository) Archive(
} }
} }
if _, err := stmt.RunWith(r.DB).Exec(); err != nil { if _, err := stmt.RunWith(r.stmtCache).Exec(); err != nil {
return err return err
} }
return nil return nil
@ -186,7 +210,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
qb = qb.Where("job.user = ?", user.Username) qb = qb.Where("job.user = ?", user.Username)
} }
err := qb.RunWith(r.DB).QueryRow().Scan(&job) err := qb.RunWith(r.stmtCache).QueryRow().Scan(&job)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return 0, "", err return 0, "", err
} else if err == nil { } else if err == nil {
@ -197,7 +221,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
if user == nil || user.HasRole(auth.RoleAdmin) { if user == nil || user.HasRole(auth.RoleAdmin) {
err := sq.Select("job.user").Distinct().From("job"). err := sq.Select("job.user").Distinct().From("job").
Where("job.user = ?", searchterm). Where("job.user = ?", searchterm).
RunWith(r.DB).QueryRow().Scan(&username) RunWith(r.stmtCache).QueryRow().Scan(&username)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return 0, "", err return 0, "", err
} else if err == nil { } else if err == nil {

View File

@ -21,7 +21,7 @@ func (r *JobRepository) QueryJobs(
page *model.PageRequest, page *model.PageRequest,
order *model.OrderByInput) ([]*schema.Job, error) { order *model.OrderByInput) ([]*schema.Job, error) {
query := sq.Select(schema.JobColumns...).From("job") query := sq.Select(jobColumns...).From("job")
query = SecurityCheck(ctx, query) query = SecurityCheck(ctx, query)
if order != nil { if order != nil {
@ -50,14 +50,14 @@ func (r *JobRepository) QueryJobs(
} }
log.Debugf("SQL query: `%s`, args: %#v", sql, args) log.Debugf("SQL query: `%s`, args: %#v", sql, args)
rows, err := r.DB.Queryx(sql, args...) rows, err := query.RunWith(r.stmtCache).Query()
if err != nil { if err != nil {
return nil, err return nil, err
} }
jobs := make([]*schema.Job, 0, 50) jobs := make([]*schema.Job, 0, 50)
for rows.Next() { for rows.Next() {
job, err := schema.ScanJob(rows) job, err := scanJob(rows)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,19 +9,19 @@ import (
// Add the tag with id `tagId` to the job with the database id `jobId`. // Add the tag with id `tagId` to the job with the database id `jobId`.
func (r *JobRepository) AddTag(jobId int64, tagId int64) error { func (r *JobRepository) AddTag(jobId int64, tagId int64) error {
_, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId) _, err := r.stmtCache.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId)
return err return err
} }
// Removes a tag from a job // Removes a tag from a job
func (r *JobRepository) RemoveTag(job, tag int64) error { func (r *JobRepository) RemoveTag(job, tag int64) error {
_, err := r.DB.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag) _, err := r.stmtCache.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag)
return err return err
} }
// CreateTag creates a new tag with the specified type and name and returns its database id. // CreateTag creates a new tag with the specified type and name and returns its database id.
func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) { func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName) res, err := r.stmtCache.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -52,13 +52,12 @@ func (r *JobRepository) CountTags(user *string) (tags []schema.Tag, counts map[s
q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user) q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user)
} }
rows, err := q.RunWith(r.DB).Query() rows, err := q.RunWith(r.stmtCache).Query()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
counts = make(map[string]int) counts = make(map[string]int)
for rows.Next() { for rows.Next() {
var tagName string var tagName string
var count int var count int
@ -92,7 +91,7 @@ func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exis
exists = true exists = true
if err := sq.Select("id").From("tag"). if err := sq.Select("id").From("tag").
Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName). Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
RunWith(r.DB).QueryRow().Scan(&tagId); err != nil { RunWith(r.stmtCache).QueryRow().Scan(&tagId); err != nil {
exists = false exists = false
} }
return return
@ -105,14 +104,18 @@ func (r *JobRepository) GetTags(job *int64) ([]*schema.Tag, error) {
q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job) q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job)
} }
sql, args, err := q.ToSql() rows, err := q.RunWith(r.stmtCache).Query()
if err != nil { if err != nil {
return nil, err return nil, err
} }
tags := make([]*schema.Tag, 0) tags := make([]*schema.Tag, 0)
if err := r.DB.Select(&tags, sql, args...); err != nil { for rows.Next() {
return nil, err tag := &schema.Tag{}
if err := rows.Scan(&tag.ID, &tag.Type, &tag.Name); err != nil {
return nil, err
}
tags = append(tags, tag)
} }
return tags, nil return tags, nil

View File

@ -1,7 +1,6 @@
package schema package schema
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -52,6 +51,7 @@ type Job struct {
// This is why there is this struct, which contains all fields from the regular job struct, but "overwrites" // This is why there is this struct, which contains all fields from the regular job struct, but "overwrites"
// the StartTime field with one of type int64. // the StartTime field with one of type int64.
type JobMeta struct { type JobMeta struct {
ID *int64 `json:"id,omitempty"` // never used in the job-archive, only available via REST-API
BaseJob BaseJob
StartTime int64 `json:"startTime" db:"start_time"` StartTime int64 `json:"startTime" db:"start_time"`
Statistics map[string]JobStatistics `json:"statistics,omitempty"` Statistics map[string]JobStatistics `json:"statistics,omitempty"`
@ -67,37 +67,6 @@ const (
var JobDefaults BaseJob = BaseJob{ var JobDefaults BaseJob = BaseJob{
Exclusive: 1, Exclusive: 1,
MonitoringStatus: MonitoringStatusRunningOrArchiving, MonitoringStatus: MonitoringStatusRunningOrArchiving,
MetaData: "",
}
var JobColumns []string = []string{
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id", "job.num_nodes",
"job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
"job.duration", "job.resources", "job.meta_data",
}
type Scannable interface {
StructScan(dest interface{}) error
}
// Helper function for scanning jobs with the `jobTableCols` columns selected.
func ScanJob(row Scannable) (*Job, error) {
job := &Job{BaseJob: JobDefaults}
if err := row.StructScan(job); err != nil {
return nil, err
}
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
return nil, err
}
job.StartTime = time.Unix(job.StartTimeUnix, 0)
if job.Duration == 0 && job.State == JobStateRunning {
job.Duration = int32(time.Since(job.StartTime).Seconds())
}
job.RawResources = nil
return job, nil
} }
type JobStatistics struct { type JobStatistics struct {