make database schema mysql compatible; use prepared statements

This commit is contained in:
Lou Knauer
2022-01-20 10:00:55 +01:00
parent a64944f3c3
commit 9034cb90aa
10 changed files with 163 additions and 86 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log"
"regexp"
"strings"
@@ -20,6 +21,31 @@ import (
type Resolver struct {
DB *sqlx.DB
findJobByIdStmt *sqlx.Stmt
findJobByIdWithUserStmt *sqlx.Stmt
}
func (r *Resolver) Init() {
findJobById, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).ToSql()
if err != nil {
log.Fatal(err)
}
r.findJobByIdStmt, err = r.DB.Preparex(findJobById)
if err != nil {
log.Fatal(err)
}
findJobByIdWithUser, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).Where("job.user = ?").ToSql()
if err != nil {
log.Fatal(err)
}
r.findJobByIdWithUserStmt, err = r.DB.Preparex(findJobByIdWithUser)
if err != nil {
log.Fatal(err)
}
}
// Helper function for the `jobs` GraphQL-Query. Is also used elsewhere when a list of jobs is needed.
@@ -82,17 +108,12 @@ func (r *Resolver) queryJobs(ctx context.Context, filters []*model.JobFilter, pa
}
func securityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder {
val := ctx.Value(auth.ContextUserKey)
if val == nil {
user := auth.GetUser(ctx)
if user == nil || user.IsAdmin {
return query
}
user := val.(*auth.User)
if user.IsAdmin {
return query
}
return query.Where("job.user_id = ?", user.Username)
return query.Where("job.user = ?", user.Username)
}
// Build a sq.SelectBuilder out of a schema.JobFilter.

View File

@@ -148,14 +148,14 @@ func (r *queryResolver) Tags(ctx context.Context) ([]*schema.Tag, error) {
}
func (r *queryResolver) Job(ctx context.Context, id string) (*schema.Job, error) {
query := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", id)
query = securityCheck(ctx, query)
sql, args, err := query.ToSql()
if err != nil {
return nil, err
// This query is very common (mostly called through other resolvers such as JobMetrics),
// so we use prepared statements here.
user := auth.GetUser(ctx)
if user == nil || user.IsAdmin {
return schema.ScanJob(r.findJobByIdStmt.QueryRowx(id))
}
return schema.ScanJob(r.DB.QueryRowx(sql, args...))
return schema.ScanJob(r.findJobByIdWithUserStmt.QueryRowx(id, user.Username))
}
func (r *queryResolver) JobMetrics(ctx context.Context, id string, metrics []string, scopes []schema.MetricScope) ([]*model.JobMetricWithName, error) {

View File

@@ -30,13 +30,13 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
// `socketsPerNode` and `coresPerSocket` can differ from cluster to cluster, so we need to explicitly loop over those.
for _, cluster := range config.Clusters {
for _, partition := range cluster.Partitions {
corehoursCol := fmt.Sprintf("SUM(job.duration * job.num_nodes * %d * %d) / 3600", partition.SocketsPerNode, partition.CoresPerSocket)
corehoursCol := fmt.Sprintf("ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600)", partition.SocketsPerNode, partition.CoresPerSocket)
var query sq.SelectBuilder
if groupBy == nil {
query = sq.Select(
"''",
"COUNT(job.id)",
"SUM(job.duration) / 3600",
"ROUND(SUM(job.duration) / 3600)",
corehoursCol,
).From("job")
} else {
@@ -44,7 +44,7 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
query = sq.Select(
col,
"COUNT(job.id)",
"SUM(job.duration) / 3600",
"ROUND(SUM(job.duration) / 3600)",
corehoursCol,
).From("job").GroupBy(col)
}