diff --git a/api_test.go b/api_test.go
index e80cd21..e726ed8 100644
--- a/api_test.go
+++ b/api_test.go
@@ -109,12 +109,9 @@ func setup(t *testing.T) *api.RestApi {
t.Fatal(err)
}
- resolver := &graph.Resolver{DB: db}
- if err := resolver.Init(); err != nil {
- t.Fatal(err)
- }
+ resolver := &graph.Resolver{DB: db, Repo: &repository.JobRepository{DB: db}}
return &api.RestApi{
- JobRepository: &repository.JobRepository{DB: db},
+ JobRepository: resolver.Repo,
Resolver: resolver,
}
}
diff --git a/graph/resolver.go b/graph/resolver.go
index a023b5b..ce08e33 100644
--- a/graph/resolver.go
+++ b/graph/resolver.go
@@ -1,17 +1,7 @@
package graph
import (
- "context"
- "errors"
- "fmt"
- "regexp"
- "strings"
-
- "github.com/ClusterCockpit/cc-backend/auth"
- "github.com/ClusterCockpit/cc-backend/graph/model"
- "github.com/ClusterCockpit/cc-backend/log"
- "github.com/ClusterCockpit/cc-backend/schema"
- sq "github.com/Masterminds/squirrel"
+ "github.com/ClusterCockpit/cc-backend/repository"
"github.com/jmoiron/sqlx"
)
@@ -20,215 +10,6 @@ import (
// It serves as dependency injection for your app, add any dependencies you require here.
type Resolver struct {
- DB *sqlx.DB
-
- findJobByIdStmt *sqlx.Stmt
- findJobByIdWithUserStmt *sqlx.Stmt
-}
-
-func (r *Resolver) Init() error {
- findJobById, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).ToSql()
- if err != nil {
- return err
- }
-
- r.findJobByIdStmt, err = r.DB.Preparex(findJobById)
- if err != nil {
- return err
- }
-
- findJobByIdWithUser, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).Where("job.user = ?").ToSql()
- if err != nil {
- return err
- }
-
- r.findJobByIdWithUserStmt, err = r.DB.Preparex(findJobByIdWithUser)
- if err != nil {
- return err
- }
-
- return nil
-}
-
-// Helper function for the `jobs` GraphQL-Query. Is also used elsewhere when a list of jobs is needed.
-func (r *Resolver) queryJobs(
- ctx context.Context,
- filters []*model.JobFilter,
- page *model.PageRequest,
- order *model.OrderByInput) ([]*schema.Job, int, error) {
-
- query := sq.Select(schema.JobColumns...).From("job")
- query = securityCheck(ctx, query)
-
- if order != nil {
- field := toSnakeCase(order.Field)
- if order.Order == model.SortDirectionEnumAsc {
- query = query.OrderBy(fmt.Sprintf("job.%s ASC", field))
- } else if order.Order == model.SortDirectionEnumDesc {
- query = query.OrderBy(fmt.Sprintf("job.%s DESC", field))
- } else {
- return nil, 0, errors.New("invalid sorting order")
- }
- }
-
- if page != nil {
- limit := uint64(page.ItemsPerPage)
- query = query.Offset((uint64(page.Page) - 1) * limit).Limit(limit)
- } else {
- query = query.Limit(50)
- }
-
- for _, f := range filters {
- query = buildWhereClause(f, query)
- }
-
- sql, args, err := query.ToSql()
- if err != nil {
- return nil, 0, err
- }
-
- log.Debugf("SQL query: `%s`, args: %#v", sql, args)
- rows, err := r.DB.Queryx(sql, args...)
- if err != nil {
- return nil, 0, err
- }
-
- jobs := make([]*schema.Job, 0, 50)
- for rows.Next() {
- job, err := schema.ScanJob(rows)
- if err != nil {
- return nil, 0, err
- }
- jobs = append(jobs, job)
- }
-
- // count all jobs:
- query = sq.Select("count(*)").From("job")
- for _, f := range filters {
- query = buildWhereClause(f, query)
- }
- query = securityCheck(ctx, query)
- var count int
- if err := query.RunWith(r.DB).Scan(&count); err != nil {
- return nil, 0, err
- }
-
- return jobs, count, nil
-}
-
-func securityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder {
- user := auth.GetUser(ctx)
- if user == nil || user.HasRole(auth.RoleAdmin) {
- return query
- }
-
- return query.Where("job.user = ?", user.Username)
-}
-
-// Build a sq.SelectBuilder out of a schema.JobFilter.
-func buildWhereClause(filter *model.JobFilter, query sq.SelectBuilder) sq.SelectBuilder {
- if filter.Tags != nil {
- query = query.Join("jobtag ON jobtag.job_id = job.id").Where(sq.Eq{"jobtag.tag_id": filter.Tags})
- }
- if filter.JobID != nil {
- query = buildStringCondition("job.job_id", filter.JobID, query)
- }
- if filter.ArrayJobID != nil {
- query = query.Where("job.array_job_id = ?", *filter.ArrayJobID)
- }
- if filter.User != nil {
- query = buildStringCondition("job.user", filter.User, query)
- }
- if filter.Project != nil {
- query = buildStringCondition("job.project", filter.Project, query)
- }
- if filter.Cluster != nil {
- query = buildStringCondition("job.cluster", filter.Cluster, query)
- }
- if filter.Partition != nil {
- query = buildStringCondition("job.partition", filter.Partition, query)
- }
- if filter.StartTime != nil {
- query = buildTimeCondition("job.start_time", filter.StartTime, query)
- }
- if filter.Duration != nil {
- query = buildIntCondition("job.duration", filter.Duration, query)
- }
- if filter.State != nil {
- states := make([]string, len(filter.State))
- for i, val := range filter.State {
- states[i] = string(val)
- }
-
- query = query.Where(sq.Eq{"job.job_state": states})
- }
- if filter.NumNodes != nil {
- query = buildIntCondition("job.num_nodes", filter.NumNodes, query)
- }
- if filter.NumAccelerators != nil {
- query = buildIntCondition("job.num_acc", filter.NumAccelerators, query)
- }
- if filter.NumHWThreads != nil {
- query = buildIntCondition("job.num_hwthreads", filter.NumHWThreads, query)
- }
- if filter.FlopsAnyAvg != nil {
- query = buildFloatCondition("job.flops_any_avg", filter.FlopsAnyAvg, query)
- }
- if filter.MemBwAvg != nil {
- query = buildFloatCondition("job.mem_bw_avg", filter.MemBwAvg, query)
- }
- if filter.LoadAvg != nil {
- query = buildFloatCondition("job.load_avg", filter.LoadAvg, query)
- }
- if filter.MemUsedMax != nil {
- query = buildFloatCondition("job.mem_used_max", filter.MemUsedMax, query)
- }
- return query
-}
-
-func buildIntCondition(field string, cond *model.IntRange, query sq.SelectBuilder) sq.SelectBuilder {
- return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To)
-}
-
-func buildTimeCondition(field string, cond *model.TimeRange, query sq.SelectBuilder) sq.SelectBuilder {
- if cond.From != nil && cond.To != nil {
- return query.Where(field+" BETWEEN ? AND ?", cond.From.Unix(), cond.To.Unix())
- } else if cond.From != nil {
- return query.Where("? <= "+field, cond.From.Unix())
- } else if cond.To != nil {
- return query.Where(field+" <= ?", cond.To.Unix())
- } else {
- return query
- }
-}
-
-func buildFloatCondition(field string, cond *model.FloatRange, query sq.SelectBuilder) sq.SelectBuilder {
- return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To)
-}
-
-func buildStringCondition(field string, cond *model.StringInput, query sq.SelectBuilder) sq.SelectBuilder {
- if cond.Eq != nil {
- return query.Where(field+" = ?", *cond.Eq)
- }
- if cond.StartsWith != nil {
- return query.Where(field+" LIKE ?", fmt.Sprint(*cond.StartsWith, "%"))
- }
- if cond.EndsWith != nil {
- return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.EndsWith))
- }
- if cond.Contains != nil {
- return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.Contains, "%"))
- }
- return query
-}
-
-var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
-var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
-
-func toSnakeCase(str string) string {
- str = strings.ReplaceAll(str, "'", "")
- str = strings.ReplaceAll(str, "\\", "")
- snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
- snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
- return strings.ToLower(snake)
+ DB *sqlx.DB
+ Repo *repository.JobRepository
}
diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go
index bfe0542..bd5c5f9 100644
--- a/graph/schema.resolvers.go
+++ b/graph/schema.resolvers.go
@@ -40,12 +40,7 @@ func (r *jobResolver) Tags(ctx context.Context, obj *schema.Job) ([]*schema.Tag,
}
func (r *mutationResolver) CreateTag(ctx context.Context, typeArg string, name string) (*schema.Tag, error) {
- res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", typeArg, name)
- if err != nil {
- return nil, err
- }
-
- id, err := res.LastInsertId()
+ id, err := r.Repo.CreateTag(typeArg, name)
if err != nil {
return nil, err
}
@@ -59,18 +54,18 @@ func (r *mutationResolver) DeleteTag(ctx context.Context, id string) (string, er
}
func (r *mutationResolver) AddTagsToJob(ctx context.Context, job string, tagIds []string) ([]*schema.Tag, error) {
- jid, err := strconv.Atoi(job)
+ jid, err := strconv.ParseInt(job, 10, 64)
if err != nil {
return nil, err
}
for _, tagId := range tagIds {
- tid, err := strconv.Atoi(tagId)
+ tid, err := strconv.ParseInt(tagId, 10, 64)
if err != nil {
return nil, err
}
- if _, err := r.DB.Exec("INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)", jid, tid); err != nil {
+ if err := r.Repo.AddTag(jid, tid); err != nil {
return nil, err
}
}
@@ -148,14 +143,21 @@ func (r *queryResolver) Tags(ctx context.Context) ([]*schema.Tag, error) {
}
func (r *queryResolver) Job(ctx context.Context, id string) (*schema.Job, error) {
- // This query is very common (mostly called through other resolvers such as JobMetrics),
- // so we use prepared statements here.
- user := auth.GetUser(ctx)
- if user == nil || user.HasRole(auth.RoleAdmin) {
- return schema.ScanJob(r.findJobByIdStmt.QueryRowx(id))
+ numericId, err := strconv.ParseInt(id, 10, 64)
+ if err != nil {
+ return nil, err
}
- return schema.ScanJob(r.findJobByIdWithUserStmt.QueryRowx(id, user.Username))
+ job, err := r.Repo.FindById(numericId)
+ if err != nil {
+ return nil, err
+ }
+
+ if user := auth.GetUser(ctx); user != nil && !user.HasRole(auth.RoleAdmin) && job.User != user.Username {
+ return nil, errors.New("you are not allowed to see this job")
+ }
+
+ return job, nil
}
func (r *queryResolver) JobMetrics(ctx context.Context, id string, metrics []string, scopes []schema.MetricScope) ([]*model.JobMetricWithName, error) {
@@ -191,7 +193,12 @@ func (r *queryResolver) JobsFootprints(ctx context.Context, filter []*model.JobF
}
func (r *queryResolver) Jobs(ctx context.Context, filter []*model.JobFilter, page *model.PageRequest, order *model.OrderByInput) (*model.JobResultList, error) {
- jobs, count, err := r.queryJobs(ctx, filter, page, order)
+ jobs, err := r.Repo.QueryJobs(ctx, filter, page, order)
+ if err != nil {
+ return nil, err
+ }
+
+ count, err := r.Repo.CountJobs(ctx, filter)
if err != nil {
return nil, err
}
diff --git a/graph/stats.go b/graph/stats.go
index d979229..bca0afe 100644
--- a/graph/stats.go
+++ b/graph/stats.go
@@ -11,6 +11,7 @@ import (
"github.com/ClusterCockpit/cc-backend/config"
"github.com/ClusterCockpit/cc-backend/graph/model"
"github.com/ClusterCockpit/cc-backend/metricdata"
+ "github.com/ClusterCockpit/cc-backend/repository"
"github.com/ClusterCockpit/cc-backend/schema"
sq "github.com/Masterminds/squirrel"
)
@@ -53,9 +54,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
Where("job.cluster = ?", cluster.Name).
Where("job.partition = ?", partition.Name)
- query = securityCheck(ctx, query)
+ query = repository.SecurityCheck(ctx, query)
for _, f := range filter {
- query = buildWhereClause(f, query)
+ query = repository.BuildWhereClause(f, query)
}
rows, err := query.RunWith(r.DB).Query()
@@ -90,9 +91,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
if groupBy == nil {
query := sq.Select("COUNT(job.id)").From("job").Where("job.duration < 120")
- query = securityCheck(ctx, query)
+ query = repository.SecurityCheck(ctx, query)
for _, f := range filter {
- query = buildWhereClause(f, query)
+ query = repository.BuildWhereClause(f, query)
}
if err := query.RunWith(r.DB).QueryRow().Scan(&(stats[""].ShortJobs)); err != nil {
return nil, err
@@ -100,9 +101,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
} else {
col := groupBy2column[*groupBy]
query := sq.Select(col, "COUNT(job.id)").From("job").Where("job.duration < 120")
- query = securityCheck(ctx, query)
+ query = repository.SecurityCheck(ctx, query)
for _, f := range filter {
- query = buildWhereClause(f, query)
+ query = repository.BuildWhereClause(f, query)
}
rows, err := query.RunWith(r.DB).Query()
if err != nil {
@@ -162,9 +163,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
// to add a condition to the query of the kind "
= ".
func (r *queryResolver) jobsStatisticsHistogram(ctx context.Context, value string, filters []*model.JobFilter, id, col string) ([]*model.HistoPoint, error) {
query := sq.Select(value, "COUNT(job.id) AS count").From("job")
- query = securityCheck(ctx, query)
+ query = repository.SecurityCheck(ctx, query)
for _, f := range filters {
- query = buildWhereClause(f, query)
+ query = repository.BuildWhereClause(f, query)
}
if len(id) != 0 && len(col) != 0 {
@@ -188,14 +189,16 @@ func (r *queryResolver) jobsStatisticsHistogram(ctx context.Context, value strin
return points, nil
}
+const MAX_JOBS_FOR_ANALYSIS = 500
+
// Helper function for the rooflineHeatmap GraphQL query placed here so that schema.resolvers.go is not too full.
func (r *Resolver) rooflineHeatmap(ctx context.Context, filter []*model.JobFilter, rows int, cols int, minX float64, minY float64, maxX float64, maxY float64) ([][]float64, error) {
- jobs, count, err := r.queryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: 501}, nil)
+ jobs, err := r.Repo.QueryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: MAX_JOBS_FOR_ANALYSIS + 1}, nil)
if err != nil {
return nil, err
}
- if len(jobs) > 500 {
- return nil, fmt.Errorf("too many jobs matched (matched: %d, max: %d)", count, 500)
+ if len(jobs) > MAX_JOBS_FOR_ANALYSIS {
+ return nil, fmt.Errorf("too many jobs matched (max: %d)", MAX_JOBS_FOR_ANALYSIS)
}
fcols, frows := float64(cols), float64(rows)
@@ -250,12 +253,12 @@ func (r *Resolver) rooflineHeatmap(ctx context.Context, filter []*model.JobFilte
// Helper function for the jobsFootprints GraphQL query placed here so that schema.resolvers.go is not too full.
func (r *queryResolver) jobsFootprints(ctx context.Context, filter []*model.JobFilter, metrics []string) ([]*model.MetricFootprints, error) {
- jobs, count, err := r.queryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: 501}, nil)
+ jobs, err := r.Repo.QueryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: MAX_JOBS_FOR_ANALYSIS + 1}, nil)
if err != nil {
return nil, err
}
- if len(jobs) > 500 {
- return nil, fmt.Errorf("too many jobs matched (matched: %d, max: %d)", count, 500)
+ if len(jobs) > MAX_JOBS_FOR_ANALYSIS {
+ return nil, fmt.Errorf("too many jobs matched (max: %d)", MAX_JOBS_FOR_ANALYSIS)
}
avgs := make([][]schema.Float, len(metrics))
diff --git a/repository/job.go b/repository/job.go
index 9d01051..770074f 100644
--- a/repository/job.go
+++ b/repository/job.go
@@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
- "fmt"
"strconv"
"github.com/ClusterCockpit/cc-backend/auth"
@@ -42,8 +41,7 @@ func (r *JobRepository) Find(
return nil, err
}
- job, err := schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
- return job, err
+ return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
}
// FindById executes a SQL query to find a specific batch job.
@@ -58,8 +56,7 @@ func (r *JobRepository) FindById(
return nil, err
}
- job, err := schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
- return job, err
+ return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...))
}
// Start inserts a new job in the table, returning the unique job ID.
@@ -96,20 +93,9 @@ func (r *JobRepository) Stop(
return
}
-// CountJobs returns the number of jobs for the specified user (if a non-admin user is found in that context) and state.
+// CountJobsPerCluster returns the number of jobs for the specified user (if a non-admin user is found in that context) and state.
// The counts are grouped by cluster.
-func (r *JobRepository) CountJobs(ctx context.Context, state *schema.JobState) (map[string]int, error) {
- // q := sq.Select("count(*)").From("job")
- // if cluster != nil {
- // q = q.Where("job.cluster = ?", cluster)
- // }
- // if state != nil {
- // q = q.Where("job.job_state = ?", string(*state))
- // }
-
- // err = q.RunWith(r.DB).QueryRow().Scan(&count)
- // return
-
+func (r *JobRepository) CountJobsPerCluster(ctx context.Context, state *schema.JobState) (map[string]int, error) {
q := sq.Select("job.cluster, count(*)").From("job").GroupBy("job.cluster")
if state != nil {
q = q.Where("job.job_state = ?", string(*state))
@@ -137,13 +123,6 @@ func (r *JobRepository) CountJobs(ctx context.Context, state *schema.JobState) (
return counts, nil
}
-// func (r *JobRepository) Query(
-// filters []*model.JobFilter,
-// page *model.PageRequest,
-// order *model.OrderByInput) ([]*schema.Job, int, error) {
-
-// }
-
func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) {
stmt := sq.Update("job").
Set("monitoring_status", monitoringStatus).
@@ -186,91 +165,6 @@ func (r *JobRepository) Archive(
return nil
}
-// Add the tag with id `tagId` to the job with the database id `jobId`.
-func (r *JobRepository) AddTag(jobId int64, tagId int64) error {
- _, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES (?, ?)`, jobId, tagId)
- return err
-}
-
-// CreateTag creates a new tag with the specified type and name and returns its database id.
-func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
- res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
- if err != nil {
- return 0, err
- }
-
- return res.LastInsertId()
-}
-
-func (r *JobRepository) GetTags(user *string) (tags []schema.Tag, counts map[string]int, err error) {
- tags = make([]schema.Tag, 0, 100)
- xrows, err := r.DB.Queryx("SELECT * FROM tag")
- if err != nil {
- return nil, nil, err
- }
-
- for xrows.Next() {
- var t schema.Tag
- if err := xrows.StructScan(&t); err != nil {
- return nil, nil, err
- }
- tags = append(tags, t)
- }
-
- q := sq.Select("t.tag_name, count(jt.tag_id)").
- From("tag t").
- LeftJoin("jobtag jt ON t.id = jt.tag_id").
- GroupBy("t.tag_name")
- if user != nil {
- q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user)
- }
-
- rows, err := q.RunWith(r.DB).Query()
- if err != nil {
- return nil, nil, err
- }
-
- counts = make(map[string]int)
-
- for rows.Next() {
- var tagName string
- var count int
- err = rows.Scan(&tagName, &count)
- if err != nil {
- fmt.Println(err)
- }
- counts[tagName] = count
- }
- err = rows.Err()
-
- return
-}
-
-// AddTagOrCreate adds the tag with the specified type and name to the job with the database id `jobId`.
-// If such a tag does not yet exist, it is created.
-func (r *JobRepository) AddTagOrCreate(jobId int64, tagType string, tagName string) (tagId int64, err error) {
- tagId, exists := r.TagId(tagType, tagName)
- if !exists {
- tagId, err = r.CreateTag(tagType, tagName)
- if err != nil {
- return 0, err
- }
- }
-
- return tagId, r.AddTag(jobId, tagId)
-}
-
-// TagId returns the database id of the tag with the specified type and name.
-func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exists bool) {
- exists = true
- if err := sq.Select("id").From("tag").
- Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
- RunWith(r.DB).QueryRow().Scan(&tagId); err != nil {
- exists = false
- }
- return
-}
-
var ErrNotFound = errors.New("no such job or user")
// FindJobOrUser returns a job database ID or a username if a job or user machtes the search term.
diff --git a/repository/job_test.go b/repository/job_test.go
index 8a43617..2c8e75f 100644
--- a/repository/job_test.go
+++ b/repository/job_test.go
@@ -7,8 +7,6 @@ import (
"github.com/jmoiron/sqlx"
"github.com/ClusterCockpit/cc-backend/test"
- _ "github.com/go-sql-driver/mysql"
- _ "github.com/mattn/go-sqlite3"
)
var db *sqlx.DB
@@ -57,7 +55,7 @@ func TestFindById(t *testing.T) {
func TestGetTags(t *testing.T) {
r := setup(t)
- tags, counts, err := r.GetTags(nil)
+ tags, counts, err := r.CountTags(nil)
if err != nil {
t.Fatal(err)
}
diff --git a/repository/query.go b/repository/query.go
new file mode 100644
index 0000000..53a863e
--- /dev/null
+++ b/repository/query.go
@@ -0,0 +1,212 @@
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "regexp"
+ "strings"
+
+ "github.com/ClusterCockpit/cc-backend/auth"
+ "github.com/ClusterCockpit/cc-backend/graph/model"
+ "github.com/ClusterCockpit/cc-backend/log"
+ "github.com/ClusterCockpit/cc-backend/schema"
+ sq "github.com/Masterminds/squirrel"
+)
+
+// QueryJobs returns a list of jobs matching the provided filters. page and order are optional-
+func (r *JobRepository) QueryJobs(
+ ctx context.Context,
+ filters []*model.JobFilter,
+ page *model.PageRequest,
+ order *model.OrderByInput) ([]*schema.Job, error) {
+
+ query := sq.Select(schema.JobColumns...).From("job")
+ query = SecurityCheck(ctx, query)
+
+ if order != nil {
+ field := toSnakeCase(order.Field)
+ if order.Order == model.SortDirectionEnumAsc {
+ query = query.OrderBy(fmt.Sprintf("job.%s ASC", field))
+ } else if order.Order == model.SortDirectionEnumDesc {
+ query = query.OrderBy(fmt.Sprintf("job.%s DESC", field))
+ } else {
+ return nil, errors.New("invalid sorting order")
+ }
+ }
+
+ if page != nil {
+ limit := uint64(page.ItemsPerPage)
+ query = query.Offset((uint64(page.Page) - 1) * limit).Limit(limit)
+ } else {
+ query = query.Limit(50)
+ }
+
+ for _, f := range filters {
+ query = BuildWhereClause(f, query)
+ }
+
+ sql, args, err := query.ToSql()
+ if err != nil {
+ return nil, err
+ }
+
+ log.Debugf("SQL query: `%s`, args: %#v", sql, args)
+ rows, err := r.DB.Queryx(sql, args...)
+ if err != nil {
+ return nil, err
+ }
+
+ jobs := make([]*schema.Job, 0, 50)
+ for rows.Next() {
+ job, err := schema.ScanJob(rows)
+ if err != nil {
+ return nil, err
+ }
+ jobs = append(jobs, job)
+ }
+
+ return jobs, nil
+}
+
+// CountJobs counts the number of jobs matching the filters.
+func (r *JobRepository) CountJobs(
+ ctx context.Context,
+ filters []*model.JobFilter) (int, error) {
+
+ // count all jobs:
+ query := sq.Select("count(*)").From("job")
+ query = SecurityCheck(ctx, query)
+ for _, f := range filters {
+ query = BuildWhereClause(f, query)
+ }
+ var count int
+ if err := query.RunWith(r.DB).Scan(&count); err != nil {
+ return 0, err
+ }
+
+ return count, nil
+}
+
+func SecurityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder {
+ user := auth.GetUser(ctx)
+ if user == nil || user.HasRole(auth.RoleAdmin) {
+ return query
+ }
+
+ return query.Where("job.user = ?", user.Username)
+}
+
+// Build a sq.SelectBuilder out of a schema.JobFilter.
+func BuildWhereClause(filter *model.JobFilter, query sq.SelectBuilder) sq.SelectBuilder {
+ if filter.Tags != nil {
+ query = query.Join("jobtag ON jobtag.job_id = job.id").Where(sq.Eq{"jobtag.tag_id": filter.Tags})
+ }
+ if filter.JobID != nil {
+ query = buildStringCondition("job.job_id", filter.JobID, query)
+ }
+ if filter.ArrayJobID != nil {
+ query = query.Where("job.array_job_id = ?", *filter.ArrayJobID)
+ }
+ if filter.User != nil {
+ query = buildStringCondition("job.user", filter.User, query)
+ }
+ if filter.Project != nil {
+ query = buildStringCondition("job.project", filter.Project, query)
+ }
+ if filter.Cluster != nil {
+ query = buildStringCondition("job.cluster", filter.Cluster, query)
+ }
+ if filter.Partition != nil {
+ query = buildStringCondition("job.partition", filter.Partition, query)
+ }
+ if filter.StartTime != nil {
+ query = buildTimeCondition("job.start_time", filter.StartTime, query)
+ }
+ if filter.Duration != nil {
+ query = buildIntCondition("job.duration", filter.Duration, query)
+ }
+ if filter.State != nil {
+ states := make([]string, len(filter.State))
+ for i, val := range filter.State {
+ states[i] = string(val)
+ }
+
+ query = query.Where(sq.Eq{"job.job_state": states})
+ }
+ if filter.NumNodes != nil {
+ query = buildIntCondition("job.num_nodes", filter.NumNodes, query)
+ }
+ if filter.NumAccelerators != nil {
+ query = buildIntCondition("job.num_acc", filter.NumAccelerators, query)
+ }
+ if filter.NumHWThreads != nil {
+ query = buildIntCondition("job.num_hwthreads", filter.NumHWThreads, query)
+ }
+ if filter.FlopsAnyAvg != nil {
+ query = buildFloatCondition("job.flops_any_avg", filter.FlopsAnyAvg, query)
+ }
+ if filter.MemBwAvg != nil {
+ query = buildFloatCondition("job.mem_bw_avg", filter.MemBwAvg, query)
+ }
+ if filter.LoadAvg != nil {
+ query = buildFloatCondition("job.load_avg", filter.LoadAvg, query)
+ }
+ if filter.MemUsedMax != nil {
+ query = buildFloatCondition("job.mem_used_max", filter.MemUsedMax, query)
+ }
+ return query
+}
+
+func buildIntCondition(field string, cond *model.IntRange, query sq.SelectBuilder) sq.SelectBuilder {
+ return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To)
+}
+
+func buildTimeCondition(field string, cond *model.TimeRange, query sq.SelectBuilder) sq.SelectBuilder {
+ if cond.From != nil && cond.To != nil {
+ return query.Where(field+" BETWEEN ? AND ?", cond.From.Unix(), cond.To.Unix())
+ } else if cond.From != nil {
+ return query.Where("? <= "+field, cond.From.Unix())
+ } else if cond.To != nil {
+ return query.Where(field+" <= ?", cond.To.Unix())
+ } else {
+ return query
+ }
+}
+
+func buildFloatCondition(field string, cond *model.FloatRange, query sq.SelectBuilder) sq.SelectBuilder {
+ return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To)
+}
+
+func buildStringCondition(field string, cond *model.StringInput, query sq.SelectBuilder) sq.SelectBuilder {
+ if cond.Eq != nil {
+ return query.Where(field+" = ?", *cond.Eq)
+ }
+ if cond.StartsWith != nil {
+ return query.Where(field+" LIKE ?", fmt.Sprint(*cond.StartsWith, "%"))
+ }
+ if cond.EndsWith != nil {
+ return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.EndsWith))
+ }
+ if cond.Contains != nil {
+ return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.Contains, "%"))
+ }
+ return query
+}
+
+var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
+var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
+
+func toSnakeCase(str string) string {
+ for _, c := range str {
+ if c == '\'' || c == '\\' {
+ panic("A hacker (probably not)!!!")
+ }
+ }
+
+ str = strings.ReplaceAll(str, "'", "")
+ str = strings.ReplaceAll(str, "\\", "")
+ snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
+ snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
+ return strings.ToLower(snake)
+}
diff --git a/repository/tags.go b/repository/tags.go
new file mode 100644
index 0000000..c8d8310
--- /dev/null
+++ b/repository/tags.go
@@ -0,0 +1,93 @@
+package repository
+
+import (
+ "fmt"
+
+ "github.com/ClusterCockpit/cc-backend/schema"
+ sq "github.com/Masterminds/squirrel"
+)
+
+// Add the tag with id `tagId` to the job with the database id `jobId`.
+func (r *JobRepository) AddTag(jobId int64, tagId int64) error {
+ _, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES (?, ?)`, jobId, tagId)
+ return err
+}
+
+// CreateTag creates a new tag with the specified type and name and returns its database id.
+func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
+ res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName)
+ if err != nil {
+ return 0, err
+ }
+
+ return res.LastInsertId()
+}
+
+func (r *JobRepository) CountTags(user *string) (tags []schema.Tag, counts map[string]int, err error) {
+ tags = make([]schema.Tag, 0, 100)
+ xrows, err := r.DB.Queryx("SELECT * FROM tag")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ for xrows.Next() {
+ var t schema.Tag
+ if err := xrows.StructScan(&t); err != nil {
+ return nil, nil, err
+ }
+ tags = append(tags, t)
+ }
+
+ q := sq.Select("t.tag_name, count(jt.tag_id)").
+ From("tag t").
+ LeftJoin("jobtag jt ON t.id = jt.tag_id").
+ GroupBy("t.tag_name")
+ if user != nil {
+ q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user)
+ }
+
+ rows, err := q.RunWith(r.DB).Query()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ counts = make(map[string]int)
+
+ for rows.Next() {
+ var tagName string
+ var count int
+ err = rows.Scan(&tagName, &count)
+ if err != nil {
+ fmt.Println(err)
+ }
+ counts[tagName] = count
+ }
+ err = rows.Err()
+
+ return
+}
+
+// AddTagOrCreate adds the tag with the specified type and name to the job with the database id `jobId`.
+// If such a tag does not yet exist, it is created.
+func (r *JobRepository) AddTagOrCreate(jobId int64, tagType string, tagName string) (tagId int64, err error) {
+ tagId, exists := r.TagId(tagType, tagName)
+ if !exists {
+ tagId, err = r.CreateTag(tagType, tagName)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ return tagId, r.AddTag(jobId, tagId)
+}
+
+// TagId returns the database id of the tag with the specified type and name.
+func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exists bool) {
+ exists = true
+ if err := sq.Select("id").From("tag").
+ Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
+ RunWith(r.DB).QueryRow().Scan(&tagId); err != nil {
+ exists = false
+ }
+ return
+}
diff --git a/server.go b/server.go
index 5a78d4a..ba127c8 100644
--- a/server.go
+++ b/server.go
@@ -130,12 +130,12 @@ func setupHomeRoute(i InfoType, r *http.Request) InfoType {
}
state := schema.JobStateRunning
- runningJobs, err := jobRepo.CountJobs(r.Context(), &state)
+ runningJobs, err := jobRepo.CountJobsPerCluster(r.Context(), &state)
if err != nil {
log.Errorf("failed to count jobs: %s", err.Error())
runningJobs = map[string]int{}
}
- totalJobs, err := jobRepo.CountJobs(r.Context(), nil)
+ totalJobs, err := jobRepo.CountJobsPerCluster(r.Context(), nil)
if err != nil {
log.Errorf("failed to count jobs: %s", err.Error())
totalJobs = map[string]int{}
@@ -200,7 +200,7 @@ func setupTaglistRoute(i InfoType, r *http.Request) InfoType {
username = &user.Username
}
- tags, counts, err := jobRepo.GetTags(username)
+ tags, counts, err := jobRepo.CountTags(username)
tagMap := make(map[string][]map[string]interface{})
if err != nil {
log.Errorf("GetTags failed: %s", err.Error())
@@ -360,10 +360,7 @@ func main() {
// Build routes...
- resolver := &graph.Resolver{DB: db}
- if err := resolver.Init(); err != nil {
- log.Fatal(err)
- }
+ resolver := &graph.Resolver{DB: db, Repo: jobRepo}
graphQLEndpoint := handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: resolver}))
if os.Getenv("DEBUG") != "1" {
graphQLEndpoint.SetRecoverFunc(func(ctx context.Context, err interface{}) error {
diff --git a/test/db.go b/test/db.go
index aa42fb8..8553ef6 100644
--- a/test/db.go
+++ b/test/db.go
@@ -5,6 +5,7 @@ import (
"os"
"github.com/jmoiron/sqlx"
+ _ "github.com/mattn/go-sqlite3"
)
func InitDB() *sqlx.DB {