use prepared statements

This commit is contained in:
Lou Knauer 2022-02-22 09:25:41 +01:00
parent 96f91f1a1c
commit a528e42be6
4 changed files with 61 additions and 65 deletions

View File

@ -3,8 +3,10 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"time"
"github.com/ClusterCockpit/cc-backend/auth"
"github.com/ClusterCockpit/cc-backend/graph/model"
@ -15,12 +17,43 @@ import (
type JobRepository struct {
DB *sqlx.DB
stmtCache *sq.StmtCache
}
func (r *JobRepository) Init() error {
r.stmtCache = sq.NewStmtCache(r.DB)
return nil
}
var jobColumns []string = []string{
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id",
"job.num_nodes", "job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
"job.duration", "job.resources", "job.meta_data",
}
func scanJob(row interface{ Scan(...interface{}) error }) (*schema.Job, error) {
job := &schema.Job{}
if err := row.Scan(
&job.ID, &job.JobID, &job.User, &job.Project, &job.Cluster, &job.StartTimeUnix, &job.Partition, &job.ArrayJobId,
&job.NumNodes, &job.NumHWThreads, &job.NumAcc, &job.Exclusive, &job.MonitoringStatus, &job.SMT, &job.State,
&job.Duration, &job.RawResources, &job.MetaData); err != nil {
return nil, err
}
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
return nil, err
}
job.StartTime = time.Unix(job.StartTimeUnix, 0)
if job.Duration == 0 && job.State == schema.JobStateRunning {
job.Duration = int32(time.Since(job.StartTime).Seconds())
}
job.RawResources = nil
return job, nil
}
// Find executes a SQL query to find a specific batch job.
// The job is queried using the batch job id, the cluster name,
// and the start time of the job in UNIX epoch time seconds.
@ -31,22 +64,17 @@ func (r *JobRepository) Find(
cluster *string,
startTime *int64) (*schema.Job, error) {
qb := sq.Select(schema.JobColumns...).From("job").
q := sq.Select(jobColumns...).From("job").
Where("job.job_id = ?", jobId)
if cluster != nil {
qb = qb.Where("job.cluster = ?", *cluster)
q = q.Where("job.cluster = ?", *cluster)
}
if startTime != nil {
qb = qb.Where("job.start_time = ?", *startTime)
q = q.Where("job.start_time = ?", *startTime)
}
sqlQuery, args, err := qb.ToSql()
if err != nil {
return nil, err
}
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
return scanJob(q.RunWith(r.stmtCache).QueryRow())
}
// FindById executes a SQL query to find a specific batch job.
@ -55,13 +83,9 @@ func (r *JobRepository) Find(
// To check if no job was found test err == sql.ErrNoRows
func (r *JobRepository) FindById(
jobId int64) (*schema.Job, error) {
sqlQuery, args, err := sq.Select(schema.JobColumns...).
From("job").Where("job.id = ?", jobId).ToSql()
if err != nil {
return nil, err
}
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
q := sq.Select(jobColumns...).
From("job").Where("job.id = ?", jobId)
return scanJob(q.RunWith(r.stmtCache).QueryRow())
}
// Start inserts a new job in the table, returning the unique job ID.
@ -94,7 +118,7 @@ func (r *JobRepository) Stop(
Set("monitoring_status", monitoringStatus).
Where("job.id = ?", jobId)
_, err = stmt.RunWith(r.DB).Exec()
_, err = stmt.RunWith(r.stmtCache).Exec()
return
}
@ -136,7 +160,7 @@ func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32
Set("monitoring_status", monitoringStatus).
Where("job.id = ?", job)
_, err = stmt.RunWith(r.DB).Exec()
_, err = stmt.RunWith(r.stmtCache).Exec()
return
}
@ -167,7 +191,7 @@ func (r *JobRepository) Archive(
}
}
if _, err := stmt.RunWith(r.DB).Exec(); err != nil {
if _, err := stmt.RunWith(r.stmtCache).Exec(); err != nil {
return err
}
return nil
@ -186,7 +210,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
qb = qb.Where("job.user = ?", user.Username)
}
err := qb.RunWith(r.DB).QueryRow().Scan(&job)
err := qb.RunWith(r.stmtCache).QueryRow().Scan(&job)
if err != nil && err != sql.ErrNoRows {
return 0, "", err
} else if err == nil {
@ -197,7 +221,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
if user == nil || user.HasRole(auth.RoleAdmin) {
err := sq.Select("job.user").Distinct().From("job").
Where("job.user = ?", searchterm).
RunWith(r.DB).QueryRow().Scan(&username)
RunWith(r.stmtCache).QueryRow().Scan(&username)
if err != nil && err != sql.ErrNoRows {
return 0, "", err
} else if err == nil {

View File

@ -21,7 +21,7 @@ func (r *JobRepository) QueryJobs(
page *model.PageRequest,
order *model.OrderByInput) ([]*schema.Job, error) {
query := sq.Select(schema.JobColumns...).From("job")
query := sq.Select(jobColumns...).From("job")
query = SecurityCheck(ctx, query)
if order != nil {
@ -50,14 +50,14 @@ func (r *JobRepository) QueryJobs(
}
log.Debugf("SQL query: `%s`, args: %#v", sql, args)
rows, err := r.DB.Queryx(sql, args...)
rows, err := query.RunWith(r.stmtCache).Query()
if err != nil {
return nil, err
}
jobs := make([]*schema.Job, 0, 50)
for rows.Next() {
job, err := schema.ScanJob(rows)
job, err := scanJob(rows)
if err != nil {
return nil, err
}

View File

@ -9,19 +9,19 @@ import (
// Add the tag with id `tagId` to the job with the database id `jobId`.
func (r *JobRepository) AddTag(jobId int64, tagId int64) error {
_, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId)
_, err := r.stmtCache.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId)
return err
}
// Removes a tag from a job
func (r *JobRepository) RemoveTag(job, tag int64) error {
_, err := r.DB.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag)
_, err := r.stmtCache.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag)
return err
}
// CreateTag creates a new tag with the specified type and name and returns its database id.
func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
res, err := r.stmtCache.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
if err != nil {
return 0, err
}
@ -52,13 +52,12 @@ func (r *JobRepository) CountTags(user *string) (tags []schema.Tag, counts map[s
q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user)
}
rows, err := q.RunWith(r.DB).Query()
rows, err := q.RunWith(r.stmtCache).Query()
if err != nil {
return nil, nil, err
}
counts = make(map[string]int)
for rows.Next() {
var tagName string
var count int
@ -92,7 +91,7 @@ func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exis
exists = true
if err := sq.Select("id").From("tag").
Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
RunWith(r.DB).QueryRow().Scan(&tagId); err != nil {
RunWith(r.stmtCache).QueryRow().Scan(&tagId); err != nil {
exists = false
}
return
@ -105,14 +104,18 @@ func (r *JobRepository) GetTags(job *int64) ([]*schema.Tag, error) {
q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job)
}
sql, args, err := q.ToSql()
rows, err := q.RunWith(r.stmtCache).Query()
if err != nil {
return nil, err
}
tags := make([]*schema.Tag, 0)
if err := r.DB.Select(&tags, sql, args...); err != nil {
return nil, err
for rows.Next() {
tag := &schema.Tag{}
if err := rows.Scan(&tag.ID, &tag.Type, &tag.Name); err != nil {
return nil, err
}
tags = append(tags, tag)
}
return tags, nil

View File

@ -1,7 +1,6 @@
package schema
import (
"encoding/json"
"errors"
"fmt"
"io"
@ -52,6 +51,7 @@ type Job struct {
// This is why there is this struct, which contains all fields from the regular job struct, but "overwrites"
// the StartTime field with one of type int64.
type JobMeta struct {
ID *int64 `json:"id,omitempty"` // never used in the job-archive, only available via REST-API
BaseJob
StartTime int64 `json:"startTime" db:"start_time"`
Statistics map[string]JobStatistics `json:"statistics,omitempty"`
@ -67,37 +67,6 @@ const (
var JobDefaults BaseJob = BaseJob{
Exclusive: 1,
MonitoringStatus: MonitoringStatusRunningOrArchiving,
MetaData: "",
}
var JobColumns []string = []string{
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id", "job.num_nodes",
"job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
"job.duration", "job.resources", "job.meta_data",
}
type Scannable interface {
StructScan(dest interface{}) error
}
// Helper function for scanning jobs with the `jobTableCols` columns selected.
func ScanJob(row Scannable) (*Job, error) {
job := &Job{BaseJob: JobDefaults}
if err := row.StructScan(job); err != nil {
return nil, err
}
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
return nil, err
}
job.StartTime = time.Unix(job.StartTimeUnix, 0)
if job.Duration == 0 && job.State == JobStateRunning {
job.Duration = int32(time.Since(job.StartTime).Seconds())
}
job.RawResources = nil
return job, nil
}
type JobStatistics struct {