mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2024-11-10 08:57:25 +01:00
use prepared statements
This commit is contained in:
parent
96f91f1a1c
commit
a528e42be6
@ -3,8 +3,10 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/auth"
|
||||
"github.com/ClusterCockpit/cc-backend/graph/model"
|
||||
@ -15,12 +17,43 @@ import (
|
||||
|
||||
type JobRepository struct {
|
||||
DB *sqlx.DB
|
||||
|
||||
stmtCache *sq.StmtCache
|
||||
}
|
||||
|
||||
func (r *JobRepository) Init() error {
|
||||
r.stmtCache = sq.NewStmtCache(r.DB)
|
||||
return nil
|
||||
}
|
||||
|
||||
var jobColumns []string = []string{
|
||||
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id",
|
||||
"job.num_nodes", "job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
|
||||
"job.duration", "job.resources", "job.meta_data",
|
||||
}
|
||||
|
||||
func scanJob(row interface{ Scan(...interface{}) error }) (*schema.Job, error) {
|
||||
job := &schema.Job{}
|
||||
if err := row.Scan(
|
||||
&job.ID, &job.JobID, &job.User, &job.Project, &job.Cluster, &job.StartTimeUnix, &job.Partition, &job.ArrayJobId,
|
||||
&job.NumNodes, &job.NumHWThreads, &job.NumAcc, &job.Exclusive, &job.MonitoringStatus, &job.SMT, &job.State,
|
||||
&job.Duration, &job.RawResources, &job.MetaData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
job.StartTime = time.Unix(job.StartTimeUnix, 0)
|
||||
if job.Duration == 0 && job.State == schema.JobStateRunning {
|
||||
job.Duration = int32(time.Since(job.StartTime).Seconds())
|
||||
}
|
||||
|
||||
job.RawResources = nil
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// Find executes a SQL query to find a specific batch job.
|
||||
// The job is queried using the batch job id, the cluster name,
|
||||
// and the start time of the job in UNIX epoch time seconds.
|
||||
@ -31,22 +64,17 @@ func (r *JobRepository) Find(
|
||||
cluster *string,
|
||||
startTime *int64) (*schema.Job, error) {
|
||||
|
||||
qb := sq.Select(schema.JobColumns...).From("job").
|
||||
q := sq.Select(jobColumns...).From("job").
|
||||
Where("job.job_id = ?", jobId)
|
||||
|
||||
if cluster != nil {
|
||||
qb = qb.Where("job.cluster = ?", *cluster)
|
||||
q = q.Where("job.cluster = ?", *cluster)
|
||||
}
|
||||
if startTime != nil {
|
||||
qb = qb.Where("job.start_time = ?", *startTime)
|
||||
q = q.Where("job.start_time = ?", *startTime)
|
||||
}
|
||||
|
||||
sqlQuery, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
|
||||
return scanJob(q.RunWith(r.stmtCache).QueryRow())
|
||||
}
|
||||
|
||||
// FindById executes a SQL query to find a specific batch job.
|
||||
@ -55,13 +83,9 @@ func (r *JobRepository) Find(
|
||||
// To check if no job was found test err == sql.ErrNoRows
|
||||
func (r *JobRepository) FindById(
|
||||
jobId int64) (*schema.Job, error) {
|
||||
sqlQuery, args, err := sq.Select(schema.JobColumns...).
|
||||
From("job").Where("job.id = ?", jobId).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
|
||||
q := sq.Select(jobColumns...).
|
||||
From("job").Where("job.id = ?", jobId)
|
||||
return scanJob(q.RunWith(r.stmtCache).QueryRow())
|
||||
}
|
||||
|
||||
// Start inserts a new job in the table, returning the unique job ID.
|
||||
@ -94,7 +118,7 @@ func (r *JobRepository) Stop(
|
||||
Set("monitoring_status", monitoringStatus).
|
||||
Where("job.id = ?", jobId)
|
||||
|
||||
_, err = stmt.RunWith(r.DB).Exec()
|
||||
_, err = stmt.RunWith(r.stmtCache).Exec()
|
||||
return
|
||||
}
|
||||
|
||||
@ -136,7 +160,7 @@ func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32
|
||||
Set("monitoring_status", monitoringStatus).
|
||||
Where("job.id = ?", job)
|
||||
|
||||
_, err = stmt.RunWith(r.DB).Exec()
|
||||
_, err = stmt.RunWith(r.stmtCache).Exec()
|
||||
return
|
||||
}
|
||||
|
||||
@ -167,7 +191,7 @@ func (r *JobRepository) Archive(
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := stmt.RunWith(r.DB).Exec(); err != nil {
|
||||
if _, err := stmt.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@ -186,7 +210,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
|
||||
qb = qb.Where("job.user = ?", user.Username)
|
||||
}
|
||||
|
||||
err := qb.RunWith(r.DB).QueryRow().Scan(&job)
|
||||
err := qb.RunWith(r.stmtCache).QueryRow().Scan(&job)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, "", err
|
||||
} else if err == nil {
|
||||
@ -197,7 +221,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
|
||||
if user == nil || user.HasRole(auth.RoleAdmin) {
|
||||
err := sq.Select("job.user").Distinct().From("job").
|
||||
Where("job.user = ?", searchterm).
|
||||
RunWith(r.DB).QueryRow().Scan(&username)
|
||||
RunWith(r.stmtCache).QueryRow().Scan(&username)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, "", err
|
||||
} else if err == nil {
|
||||
|
@ -21,7 +21,7 @@ func (r *JobRepository) QueryJobs(
|
||||
page *model.PageRequest,
|
||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
||||
|
||||
query := sq.Select(schema.JobColumns...).From("job")
|
||||
query := sq.Select(jobColumns...).From("job")
|
||||
query = SecurityCheck(ctx, query)
|
||||
|
||||
if order != nil {
|
||||
@ -50,14 +50,14 @@ func (r *JobRepository) QueryJobs(
|
||||
}
|
||||
|
||||
log.Debugf("SQL query: `%s`, args: %#v", sql, args)
|
||||
rows, err := r.DB.Queryx(sql, args...)
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jobs := make([]*schema.Job, 0, 50)
|
||||
for rows.Next() {
|
||||
job, err := schema.ScanJob(rows)
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -9,19 +9,19 @@ import (
|
||||
|
||||
// Add the tag with id `tagId` to the job with the database id `jobId`.
|
||||
func (r *JobRepository) AddTag(jobId int64, tagId int64) error {
|
||||
_, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId)
|
||||
_, err := r.stmtCache.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId)
|
||||
return err
|
||||
}
|
||||
|
||||
// Removes a tag from a job
|
||||
func (r *JobRepository) RemoveTag(job, tag int64) error {
|
||||
_, err := r.DB.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag)
|
||||
_, err := r.stmtCache.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTag creates a new tag with the specified type and name and returns its database id.
|
||||
func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
|
||||
res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
|
||||
res, err := r.stmtCache.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -52,13 +52,12 @@ func (r *JobRepository) CountTags(user *string) (tags []schema.Tag, counts map[s
|
||||
q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user)
|
||||
}
|
||||
|
||||
rows, err := q.RunWith(r.DB).Query()
|
||||
rows, err := q.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
counts = make(map[string]int)
|
||||
|
||||
for rows.Next() {
|
||||
var tagName string
|
||||
var count int
|
||||
@ -92,7 +91,7 @@ func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exis
|
||||
exists = true
|
||||
if err := sq.Select("id").From("tag").
|
||||
Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
|
||||
RunWith(r.DB).QueryRow().Scan(&tagId); err != nil {
|
||||
RunWith(r.stmtCache).QueryRow().Scan(&tagId); err != nil {
|
||||
exists = false
|
||||
}
|
||||
return
|
||||
@ -105,14 +104,18 @@ func (r *JobRepository) GetTags(job *int64) ([]*schema.Tag, error) {
|
||||
q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job)
|
||||
}
|
||||
|
||||
sql, args, err := q.ToSql()
|
||||
rows, err := q.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags := make([]*schema.Tag, 0)
|
||||
if err := r.DB.Select(&tags, sql, args...); err != nil {
|
||||
return nil, err
|
||||
for rows.Next() {
|
||||
tag := &schema.Tag{}
|
||||
if err := rows.Scan(&tag.ID, &tag.Type, &tag.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
|
||||
return tags, nil
|
||||
|
@ -1,7 +1,6 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -52,6 +51,7 @@ type Job struct {
|
||||
// This is why there is this struct, which contains all fields from the regular job struct, but "overwrites"
|
||||
// the StartTime field with one of type int64.
|
||||
type JobMeta struct {
|
||||
ID *int64 `json:"id,omitempty"` // never used in the job-archive, only available via REST-API
|
||||
BaseJob
|
||||
StartTime int64 `json:"startTime" db:"start_time"`
|
||||
Statistics map[string]JobStatistics `json:"statistics,omitempty"`
|
||||
@ -67,37 +67,6 @@ const (
|
||||
var JobDefaults BaseJob = BaseJob{
|
||||
Exclusive: 1,
|
||||
MonitoringStatus: MonitoringStatusRunningOrArchiving,
|
||||
MetaData: "",
|
||||
}
|
||||
|
||||
var JobColumns []string = []string{
|
||||
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id", "job.num_nodes",
|
||||
"job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
|
||||
"job.duration", "job.resources", "job.meta_data",
|
||||
}
|
||||
|
||||
type Scannable interface {
|
||||
StructScan(dest interface{}) error
|
||||
}
|
||||
|
||||
// Helper function for scanning jobs with the `jobTableCols` columns selected.
|
||||
func ScanJob(row Scannable) (*Job, error) {
|
||||
job := &Job{BaseJob: JobDefaults}
|
||||
if err := row.StructScan(job); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
job.StartTime = time.Unix(job.StartTimeUnix, 0)
|
||||
if job.Duration == 0 && job.State == JobStateRunning {
|
||||
job.Duration = int32(time.Since(job.StartTime).Seconds())
|
||||
}
|
||||
|
||||
job.RawResources = nil
|
||||
return job, nil
|
||||
}
|
||||
|
||||
type JobStatistics struct {
|
||||
|
Loading…
Reference in New Issue
Block a user