cc-backend/graph/resolver.go

226 lines
6.4 KiB
Go
Raw Normal View History

package graph
import (
2021-12-08 10:12:19 +01:00
"context"
"errors"
"fmt"
"log"
2021-05-03 10:23:47 +02:00
"regexp"
"strings"
2022-01-27 09:40:59 +01:00
"github.com/ClusterCockpit/cc-backend/auth"
"github.com/ClusterCockpit/cc-backend/graph/model"
"github.com/ClusterCockpit/cc-backend/schema"
sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
)
// This file will not be regenerated automatically.
//
// It serves as dependency injection for your app, add any dependencies you require here.
2021-05-03 10:23:47 +02:00
type Resolver struct {
2021-11-26 10:35:07 +01:00
DB *sqlx.DB
findJobByIdStmt *sqlx.Stmt
findJobByIdWithUserStmt *sqlx.Stmt
}
func (r *Resolver) Init() {
findJobById, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).ToSql()
if err != nil {
log.Fatal(err)
}
r.findJobByIdStmt, err = r.DB.Preparex(findJobById)
if err != nil {
log.Fatal(err)
}
findJobByIdWithUser, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).Where("job.user = ?").ToSql()
if err != nil {
log.Fatal(err)
}
r.findJobByIdWithUserStmt, err = r.DB.Preparex(findJobByIdWithUser)
if err != nil {
log.Fatal(err)
}
}
// Helper function for the `jobs` GraphQL-Query. Is also used elsewhere when a list of jobs is needed.
2021-12-17 15:49:22 +01:00
func (r *Resolver) queryJobs(ctx context.Context, filters []*model.JobFilter, page *model.PageRequest, order *model.OrderByInput) ([]*schema.Job, int, error) {
query := sq.Select(schema.JobColumns...).From("job")
2021-12-08 10:12:19 +01:00
query = securityCheck(ctx, query)
2021-04-21 10:12:19 +02:00
if order != nil {
field := toSnakeCase(order.Field)
if order.Order == model.SortDirectionEnumAsc {
query = query.OrderBy(fmt.Sprintf("job.%s ASC", field))
} else if order.Order == model.SortDirectionEnumDesc {
query = query.OrderBy(fmt.Sprintf("job.%s DESC", field))
} else {
return nil, 0, errors.New("invalid sorting order")
}
}
if page != nil {
limit := uint64(page.ItemsPerPage)
query = query.Offset((uint64(page.Page) - 1) * limit).Limit(limit)
} else {
query = query.Limit(50)
}
for _, f := range filters {
query = buildWhereClause(f, query)
}
2021-12-17 15:49:22 +01:00
sql, args, err := query.ToSql()
if err != nil {
return nil, 0, err
}
rows, err := r.DB.Queryx(sql, args...)
2021-05-21 09:30:15 +02:00
if err != nil {
return nil, 0, err
2021-05-21 09:30:15 +02:00
}
2021-12-17 15:49:22 +01:00
jobs := make([]*schema.Job, 0, 50)
2021-05-21 09:30:15 +02:00
for rows.Next() {
2021-12-17 15:49:22 +01:00
job, err := schema.ScanJob(rows)
2021-05-21 09:30:15 +02:00
if err != nil {
return nil, 0, err
2021-05-21 09:30:15 +02:00
}
jobs = append(jobs, job)
2021-05-21 09:30:15 +02:00
}
2021-12-17 15:49:22 +01:00
// count all jobs:
query = sq.Select("count(*)").From("job")
for _, f := range filters {
query = buildWhereClause(f, query)
}
var count int
2021-12-17 15:49:22 +01:00
if err := query.RunWith(r.DB).Scan(&count); err != nil {
return nil, 0, err
}
return jobs, count, nil
}
2021-12-08 10:12:19 +01:00
func securityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder {
user := auth.GetUser(ctx)
if user == nil || user.HasRole(auth.RoleAdmin) {
2021-12-08 10:12:19 +01:00
return query
}
return query.Where("job.user = ?", user.Username)
2021-12-08 10:12:19 +01:00
}
2021-12-17 15:49:22 +01:00
// Build a sq.SelectBuilder out of a schema.JobFilter.
func buildWhereClause(filter *model.JobFilter, query sq.SelectBuilder) sq.SelectBuilder {
if filter.Tags != nil {
2022-01-07 09:44:34 +01:00
query = query.Join("jobtag ON jobtag.job_id = job.id").Where(sq.Eq{"jobtag.tag_id": filter.Tags})
}
if filter.JobID != nil {
query = buildStringCondition("job.job_id", filter.JobID, query)
}
2022-01-27 10:40:48 +01:00
if filter.ArrayJobID != nil {
query = query.Where("job.array_job_id = ?", *filter.ArrayJobID)
}
if filter.User != nil {
query = buildStringCondition("job.user", filter.User, query)
}
if filter.Project != nil {
query = buildStringCondition("job.project", filter.Project, query)
2021-04-22 15:00:54 +02:00
}
if filter.Cluster != nil {
query = buildStringCondition("job.cluster", filter.Cluster, query)
2021-04-22 15:00:54 +02:00
}
2022-01-27 10:40:48 +01:00
if filter.Partition != nil {
query = buildStringCondition("job.partition", filter.Partition, query)
}
if filter.StartTime != nil {
query = buildTimeCondition("job.start_time", filter.StartTime, query)
2021-04-07 09:19:21 +02:00
}
if filter.Duration != nil {
query = buildIntCondition("job.duration", filter.Duration, query)
2021-04-07 09:19:21 +02:00
}
2021-12-17 15:49:22 +01:00
if filter.State != nil {
2022-01-07 09:44:34 +01:00
states := make([]string, len(filter.State))
for i, val := range filter.State {
states[i] = string(val)
}
query = query.Where(sq.Eq{"job.job_state": states})
2021-04-07 09:19:21 +02:00
}
if filter.NumNodes != nil {
query = buildIntCondition("job.num_nodes", filter.NumNodes, query)
2021-04-14 18:53:18 +02:00
}
if filter.NumAccelerators != nil {
query = buildIntCondition("job.num_acc", filter.NumAccelerators, query)
}
if filter.NumHWThreads != nil {
query = buildIntCondition("job.num_hwthreads", filter.NumHWThreads, query)
}
if filter.FlopsAnyAvg != nil {
query = buildFloatCondition("job.flops_any_avg", filter.FlopsAnyAvg, query)
2021-04-14 18:53:18 +02:00
}
if filter.MemBwAvg != nil {
query = buildFloatCondition("job.mem_bw_avg", filter.MemBwAvg, query)
}
if filter.LoadAvg != nil {
query = buildFloatCondition("job.load_avg", filter.LoadAvg, query)
}
if filter.MemUsedMax != nil {
query = buildFloatCondition("job.mem_used_max", filter.MemUsedMax, query)
}
return query
}
func buildIntCondition(field string, cond *model.IntRange, query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To)
2021-04-14 18:53:18 +02:00
}
func buildTimeCondition(field string, cond *model.TimeRange, query sq.SelectBuilder) sq.SelectBuilder {
if cond.From != nil && cond.To != nil {
return query.Where(field+" BETWEEN ? AND ?", cond.From.Unix(), cond.To.Unix())
} else if cond.From != nil {
return query.Where("? <= "+field, cond.From.Unix())
} else if cond.To != nil {
return query.Where(field+" <= ?", cond.To.Unix())
} else {
return query
}
}
func buildFloatCondition(field string, cond *model.FloatRange, query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To)
2021-05-21 09:30:15 +02:00
}
func buildStringCondition(field string, cond *model.StringInput, query sq.SelectBuilder) sq.SelectBuilder {
if cond.Eq != nil {
return query.Where(field+" = ?", *cond.Eq)
2021-05-21 09:30:15 +02:00
}
if cond.StartsWith != nil {
2022-01-07 09:44:34 +01:00
return query.Where(field+" LIKE ?", fmt.Sprint(*cond.StartsWith, "%"))
2021-05-21 09:30:15 +02:00
}
if cond.EndsWith != nil {
2022-01-07 09:44:34 +01:00
return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.EndsWith))
2021-05-21 09:30:15 +02:00
}
if cond.Contains != nil {
2022-01-07 09:44:34 +01:00
return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.Contains, "%"))
2021-05-21 09:30:15 +02:00
}
return query
2021-05-21 09:30:15 +02:00
}
2022-01-07 09:44:34 +01:00
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
func toSnakeCase(str string) string {
2022-01-07 09:44:34 +01:00
str = strings.ReplaceAll(str, "'", "")
str = strings.ReplaceAll(str, "\\", "")
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
return strings.ToLower(snake)
2021-05-21 09:30:15 +02:00
}