mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2024-12-26 13:29:05 +01:00
use prepared statements
This commit is contained in:
parent
96f91f1a1c
commit
a528e42be6
@ -3,8 +3,10 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ClusterCockpit/cc-backend/auth"
|
"github.com/ClusterCockpit/cc-backend/auth"
|
||||||
"github.com/ClusterCockpit/cc-backend/graph/model"
|
"github.com/ClusterCockpit/cc-backend/graph/model"
|
||||||
@ -15,12 +17,43 @@ import (
|
|||||||
|
|
||||||
type JobRepository struct {
|
type JobRepository struct {
|
||||||
DB *sqlx.DB
|
DB *sqlx.DB
|
||||||
|
|
||||||
|
stmtCache *sq.StmtCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *JobRepository) Init() error {
|
func (r *JobRepository) Init() error {
|
||||||
|
r.stmtCache = sq.NewStmtCache(r.DB)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var jobColumns []string = []string{
|
||||||
|
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id",
|
||||||
|
"job.num_nodes", "job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
|
||||||
|
"job.duration", "job.resources", "job.meta_data",
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanJob(row interface{ Scan(...interface{}) error }) (*schema.Job, error) {
|
||||||
|
job := &schema.Job{}
|
||||||
|
if err := row.Scan(
|
||||||
|
&job.ID, &job.JobID, &job.User, &job.Project, &job.Cluster, &job.StartTimeUnix, &job.Partition, &job.ArrayJobId,
|
||||||
|
&job.NumNodes, &job.NumHWThreads, &job.NumAcc, &job.Exclusive, &job.MonitoringStatus, &job.SMT, &job.State,
|
||||||
|
&job.Duration, &job.RawResources, &job.MetaData); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
job.StartTime = time.Unix(job.StartTimeUnix, 0)
|
||||||
|
if job.Duration == 0 && job.State == schema.JobStateRunning {
|
||||||
|
job.Duration = int32(time.Since(job.StartTime).Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
job.RawResources = nil
|
||||||
|
return job, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Find executes a SQL query to find a specific batch job.
|
// Find executes a SQL query to find a specific batch job.
|
||||||
// The job is queried using the batch job id, the cluster name,
|
// The job is queried using the batch job id, the cluster name,
|
||||||
// and the start time of the job in UNIX epoch time seconds.
|
// and the start time of the job in UNIX epoch time seconds.
|
||||||
@ -31,22 +64,17 @@ func (r *JobRepository) Find(
|
|||||||
cluster *string,
|
cluster *string,
|
||||||
startTime *int64) (*schema.Job, error) {
|
startTime *int64) (*schema.Job, error) {
|
||||||
|
|
||||||
qb := sq.Select(schema.JobColumns...).From("job").
|
q := sq.Select(jobColumns...).From("job").
|
||||||
Where("job.job_id = ?", jobId)
|
Where("job.job_id = ?", jobId)
|
||||||
|
|
||||||
if cluster != nil {
|
if cluster != nil {
|
||||||
qb = qb.Where("job.cluster = ?", *cluster)
|
q = q.Where("job.cluster = ?", *cluster)
|
||||||
}
|
}
|
||||||
if startTime != nil {
|
if startTime != nil {
|
||||||
qb = qb.Where("job.start_time = ?", *startTime)
|
q = q.Where("job.start_time = ?", *startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlQuery, args, err := qb.ToSql()
|
return scanJob(q.RunWith(r.stmtCache).QueryRow())
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindById executes a SQL query to find a specific batch job.
|
// FindById executes a SQL query to find a specific batch job.
|
||||||
@ -55,13 +83,9 @@ func (r *JobRepository) Find(
|
|||||||
// To check if no job was found test err == sql.ErrNoRows
|
// To check if no job was found test err == sql.ErrNoRows
|
||||||
func (r *JobRepository) FindById(
|
func (r *JobRepository) FindById(
|
||||||
jobId int64) (*schema.Job, error) {
|
jobId int64) (*schema.Job, error) {
|
||||||
sqlQuery, args, err := sq.Select(schema.JobColumns...).
|
q := sq.Select(jobColumns...).
|
||||||
From("job").Where("job.id = ?", jobId).ToSql()
|
From("job").Where("job.id = ?", jobId)
|
||||||
if err != nil {
|
return scanJob(q.RunWith(r.stmtCache).QueryRow())
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start inserts a new job in the table, returning the unique job ID.
|
// Start inserts a new job in the table, returning the unique job ID.
|
||||||
@ -94,7 +118,7 @@ func (r *JobRepository) Stop(
|
|||||||
Set("monitoring_status", monitoringStatus).
|
Set("monitoring_status", monitoringStatus).
|
||||||
Where("job.id = ?", jobId)
|
Where("job.id = ?", jobId)
|
||||||
|
|
||||||
_, err = stmt.RunWith(r.DB).Exec()
|
_, err = stmt.RunWith(r.stmtCache).Exec()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,7 +160,7 @@ func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32
|
|||||||
Set("monitoring_status", monitoringStatus).
|
Set("monitoring_status", monitoringStatus).
|
||||||
Where("job.id = ?", job)
|
Where("job.id = ?", job)
|
||||||
|
|
||||||
_, err = stmt.RunWith(r.DB).Exec()
|
_, err = stmt.RunWith(r.stmtCache).Exec()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,7 +191,7 @@ func (r *JobRepository) Archive(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := stmt.RunWith(r.DB).Exec(); err != nil {
|
if _, err := stmt.RunWith(r.stmtCache).Exec(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -186,7 +210,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
|
|||||||
qb = qb.Where("job.user = ?", user.Username)
|
qb = qb.Where("job.user = ?", user.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := qb.RunWith(r.DB).QueryRow().Scan(&job)
|
err := qb.RunWith(r.stmtCache).QueryRow().Scan(&job)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return 0, "", err
|
return 0, "", err
|
||||||
} else if err == nil {
|
} else if err == nil {
|
||||||
@ -197,7 +221,7 @@ func (r *JobRepository) FindJobOrUser(ctx context.Context, searchterm string) (j
|
|||||||
if user == nil || user.HasRole(auth.RoleAdmin) {
|
if user == nil || user.HasRole(auth.RoleAdmin) {
|
||||||
err := sq.Select("job.user").Distinct().From("job").
|
err := sq.Select("job.user").Distinct().From("job").
|
||||||
Where("job.user = ?", searchterm).
|
Where("job.user = ?", searchterm).
|
||||||
RunWith(r.DB).QueryRow().Scan(&username)
|
RunWith(r.stmtCache).QueryRow().Scan(&username)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return 0, "", err
|
return 0, "", err
|
||||||
} else if err == nil {
|
} else if err == nil {
|
||||||
|
@ -21,7 +21,7 @@ func (r *JobRepository) QueryJobs(
|
|||||||
page *model.PageRequest,
|
page *model.PageRequest,
|
||||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
order *model.OrderByInput) ([]*schema.Job, error) {
|
||||||
|
|
||||||
query := sq.Select(schema.JobColumns...).From("job")
|
query := sq.Select(jobColumns...).From("job")
|
||||||
query = SecurityCheck(ctx, query)
|
query = SecurityCheck(ctx, query)
|
||||||
|
|
||||||
if order != nil {
|
if order != nil {
|
||||||
@ -50,14 +50,14 @@ func (r *JobRepository) QueryJobs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("SQL query: `%s`, args: %#v", sql, args)
|
log.Debugf("SQL query: `%s`, args: %#v", sql, args)
|
||||||
rows, err := r.DB.Queryx(sql, args...)
|
rows, err := query.RunWith(r.stmtCache).Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs := make([]*schema.Job, 0, 50)
|
jobs := make([]*schema.Job, 0, 50)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
job, err := schema.ScanJob(rows)
|
job, err := scanJob(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -9,19 +9,19 @@ import (
|
|||||||
|
|
||||||
// Add the tag with id `tagId` to the job with the database id `jobId`.
|
// Add the tag with id `tagId` to the job with the database id `jobId`.
|
||||||
func (r *JobRepository) AddTag(jobId int64, tagId int64) error {
|
func (r *JobRepository) AddTag(jobId int64, tagId int64) error {
|
||||||
_, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId)
|
_, err := r.stmtCache.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)`, jobId, tagId)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Removes a tag from a job
|
// Removes a tag from a job
|
||||||
func (r *JobRepository) RemoveTag(job, tag int64) error {
|
func (r *JobRepository) RemoveTag(job, tag int64) error {
|
||||||
_, err := r.DB.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag)
|
_, err := r.stmtCache.Exec("DELETE FROM jobtag WHERE jobtag.job_id = $1 AND jobtag.tag_id = $2", job, tag)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTag creates a new tag with the specified type and name and returns its database id.
|
// CreateTag creates a new tag with the specified type and name and returns its database id.
|
||||||
func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
|
func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
|
||||||
res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
|
res, err := r.stmtCache.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -52,13 +52,12 @@ func (r *JobRepository) CountTags(user *string) (tags []schema.Tag, counts map[s
|
|||||||
q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user)
|
q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := q.RunWith(r.DB).Query()
|
rows, err := q.RunWith(r.stmtCache).Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
counts = make(map[string]int)
|
counts = make(map[string]int)
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var tagName string
|
var tagName string
|
||||||
var count int
|
var count int
|
||||||
@ -92,7 +91,7 @@ func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exis
|
|||||||
exists = true
|
exists = true
|
||||||
if err := sq.Select("id").From("tag").
|
if err := sq.Select("id").From("tag").
|
||||||
Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
|
Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
|
||||||
RunWith(r.DB).QueryRow().Scan(&tagId); err != nil {
|
RunWith(r.stmtCache).QueryRow().Scan(&tagId); err != nil {
|
||||||
exists = false
|
exists = false
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -105,14 +104,18 @@ func (r *JobRepository) GetTags(job *int64) ([]*schema.Tag, error) {
|
|||||||
q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job)
|
q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job)
|
||||||
}
|
}
|
||||||
|
|
||||||
sql, args, err := q.ToSql()
|
rows, err := q.RunWith(r.stmtCache).Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tags := make([]*schema.Tag, 0)
|
tags := make([]*schema.Tag, 0)
|
||||||
if err := r.DB.Select(&tags, sql, args...); err != nil {
|
for rows.Next() {
|
||||||
return nil, err
|
tag := &schema.Tag{}
|
||||||
|
if err := rows.Scan(&tag.ID, &tag.Type, &tag.Name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tags = append(tags, tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tags, nil
|
return tags, nil
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -52,6 +51,7 @@ type Job struct {
|
|||||||
// This is why there is this struct, which contains all fields from the regular job struct, but "overwrites"
|
// This is why there is this struct, which contains all fields from the regular job struct, but "overwrites"
|
||||||
// the StartTime field with one of type int64.
|
// the StartTime field with one of type int64.
|
||||||
type JobMeta struct {
|
type JobMeta struct {
|
||||||
|
ID *int64 `json:"id,omitempty"` // never used in the job-archive, only available via REST-API
|
||||||
BaseJob
|
BaseJob
|
||||||
StartTime int64 `json:"startTime" db:"start_time"`
|
StartTime int64 `json:"startTime" db:"start_time"`
|
||||||
Statistics map[string]JobStatistics `json:"statistics,omitempty"`
|
Statistics map[string]JobStatistics `json:"statistics,omitempty"`
|
||||||
@ -67,37 +67,6 @@ const (
|
|||||||
var JobDefaults BaseJob = BaseJob{
|
var JobDefaults BaseJob = BaseJob{
|
||||||
Exclusive: 1,
|
Exclusive: 1,
|
||||||
MonitoringStatus: MonitoringStatusRunningOrArchiving,
|
MonitoringStatus: MonitoringStatusRunningOrArchiving,
|
||||||
MetaData: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
var JobColumns []string = []string{
|
|
||||||
"job.id", "job.job_id", "job.user", "job.project", "job.cluster", "job.start_time", "job.partition", "job.array_job_id", "job.num_nodes",
|
|
||||||
"job.num_hwthreads", "job.num_acc", "job.exclusive", "job.monitoring_status", "job.smt", "job.job_state",
|
|
||||||
"job.duration", "job.resources", "job.meta_data",
|
|
||||||
}
|
|
||||||
|
|
||||||
type Scannable interface {
|
|
||||||
StructScan(dest interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function for scanning jobs with the `jobTableCols` columns selected.
|
|
||||||
func ScanJob(row Scannable) (*Job, error) {
|
|
||||||
job := &Job{BaseJob: JobDefaults}
|
|
||||||
if err := row.StructScan(job); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(job.RawResources, &job.Resources); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
job.StartTime = time.Unix(job.StartTimeUnix, 0)
|
|
||||||
if job.Duration == 0 && job.State == JobStateRunning {
|
|
||||||
job.Duration = int32(time.Since(job.StartTime).Seconds())
|
|
||||||
}
|
|
||||||
|
|
||||||
job.RawResources = nil
|
|
||||||
return job, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type JobStatistics struct {
|
type JobStatistics struct {
|
||||||
|
Loading…
Reference in New Issue
Block a user