diff --git a/api_test.go b/api_test.go index e80cd21..e726ed8 100644 --- a/api_test.go +++ b/api_test.go @@ -109,12 +109,9 @@ func setup(t *testing.T) *api.RestApi { t.Fatal(err) } - resolver := &graph.Resolver{DB: db} - if err := resolver.Init(); err != nil { - t.Fatal(err) - } + resolver := &graph.Resolver{DB: db, Repo: &repository.JobRepository{DB: db}} return &api.RestApi{ - JobRepository: &repository.JobRepository{DB: db}, + JobRepository: resolver.Repo, Resolver: resolver, } } diff --git a/graph/resolver.go b/graph/resolver.go index a023b5b..ce08e33 100644 --- a/graph/resolver.go +++ b/graph/resolver.go @@ -1,17 +1,7 @@ package graph import ( - "context" - "errors" - "fmt" - "regexp" - "strings" - - "github.com/ClusterCockpit/cc-backend/auth" - "github.com/ClusterCockpit/cc-backend/graph/model" - "github.com/ClusterCockpit/cc-backend/log" - "github.com/ClusterCockpit/cc-backend/schema" - sq "github.com/Masterminds/squirrel" + "github.com/ClusterCockpit/cc-backend/repository" "github.com/jmoiron/sqlx" ) @@ -20,215 +10,6 @@ import ( // It serves as dependency injection for your app, add any dependencies you require here. type Resolver struct { - DB *sqlx.DB - - findJobByIdStmt *sqlx.Stmt - findJobByIdWithUserStmt *sqlx.Stmt -} - -func (r *Resolver) Init() error { - findJobById, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).ToSql() - if err != nil { - return err - } - - r.findJobByIdStmt, err = r.DB.Preparex(findJobById) - if err != nil { - return err - } - - findJobByIdWithUser, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).Where("job.user = ?").ToSql() - if err != nil { - return err - } - - r.findJobByIdWithUserStmt, err = r.DB.Preparex(findJobByIdWithUser) - if err != nil { - return err - } - - return nil -} - -// Helper function for the `jobs` GraphQL-Query. Is also used elsewhere when a list of jobs is needed. -func (r *Resolver) queryJobs( - ctx context.Context, - filters []*model.JobFilter, - page *model.PageRequest, - order *model.OrderByInput) ([]*schema.Job, int, error) { - - query := sq.Select(schema.JobColumns...).From("job") - query = securityCheck(ctx, query) - - if order != nil { - field := toSnakeCase(order.Field) - if order.Order == model.SortDirectionEnumAsc { - query = query.OrderBy(fmt.Sprintf("job.%s ASC", field)) - } else if order.Order == model.SortDirectionEnumDesc { - query = query.OrderBy(fmt.Sprintf("job.%s DESC", field)) - } else { - return nil, 0, errors.New("invalid sorting order") - } - } - - if page != nil { - limit := uint64(page.ItemsPerPage) - query = query.Offset((uint64(page.Page) - 1) * limit).Limit(limit) - } else { - query = query.Limit(50) - } - - for _, f := range filters { - query = buildWhereClause(f, query) - } - - sql, args, err := query.ToSql() - if err != nil { - return nil, 0, err - } - - log.Debugf("SQL query: `%s`, args: %#v", sql, args) - rows, err := r.DB.Queryx(sql, args...) - if err != nil { - return nil, 0, err - } - - jobs := make([]*schema.Job, 0, 50) - for rows.Next() { - job, err := schema.ScanJob(rows) - if err != nil { - return nil, 0, err - } - jobs = append(jobs, job) - } - - // count all jobs: - query = sq.Select("count(*)").From("job") - for _, f := range filters { - query = buildWhereClause(f, query) - } - query = securityCheck(ctx, query) - var count int - if err := query.RunWith(r.DB).Scan(&count); err != nil { - return nil, 0, err - } - - return jobs, count, nil -} - -func securityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder { - user := auth.GetUser(ctx) - if user == nil || user.HasRole(auth.RoleAdmin) { - return query - } - - return query.Where("job.user = ?", user.Username) -} - -// Build a sq.SelectBuilder out of a schema.JobFilter. -func buildWhereClause(filter *model.JobFilter, query sq.SelectBuilder) sq.SelectBuilder { - if filter.Tags != nil { - query = query.Join("jobtag ON jobtag.job_id = job.id").Where(sq.Eq{"jobtag.tag_id": filter.Tags}) - } - if filter.JobID != nil { - query = buildStringCondition("job.job_id", filter.JobID, query) - } - if filter.ArrayJobID != nil { - query = query.Where("job.array_job_id = ?", *filter.ArrayJobID) - } - if filter.User != nil { - query = buildStringCondition("job.user", filter.User, query) - } - if filter.Project != nil { - query = buildStringCondition("job.project", filter.Project, query) - } - if filter.Cluster != nil { - query = buildStringCondition("job.cluster", filter.Cluster, query) - } - if filter.Partition != nil { - query = buildStringCondition("job.partition", filter.Partition, query) - } - if filter.StartTime != nil { - query = buildTimeCondition("job.start_time", filter.StartTime, query) - } - if filter.Duration != nil { - query = buildIntCondition("job.duration", filter.Duration, query) - } - if filter.State != nil { - states := make([]string, len(filter.State)) - for i, val := range filter.State { - states[i] = string(val) - } - - query = query.Where(sq.Eq{"job.job_state": states}) - } - if filter.NumNodes != nil { - query = buildIntCondition("job.num_nodes", filter.NumNodes, query) - } - if filter.NumAccelerators != nil { - query = buildIntCondition("job.num_acc", filter.NumAccelerators, query) - } - if filter.NumHWThreads != nil { - query = buildIntCondition("job.num_hwthreads", filter.NumHWThreads, query) - } - if filter.FlopsAnyAvg != nil { - query = buildFloatCondition("job.flops_any_avg", filter.FlopsAnyAvg, query) - } - if filter.MemBwAvg != nil { - query = buildFloatCondition("job.mem_bw_avg", filter.MemBwAvg, query) - } - if filter.LoadAvg != nil { - query = buildFloatCondition("job.load_avg", filter.LoadAvg, query) - } - if filter.MemUsedMax != nil { - query = buildFloatCondition("job.mem_used_max", filter.MemUsedMax, query) - } - return query -} - -func buildIntCondition(field string, cond *model.IntRange, query sq.SelectBuilder) sq.SelectBuilder { - return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To) -} - -func buildTimeCondition(field string, cond *model.TimeRange, query sq.SelectBuilder) sq.SelectBuilder { - if cond.From != nil && cond.To != nil { - return query.Where(field+" BETWEEN ? AND ?", cond.From.Unix(), cond.To.Unix()) - } else if cond.From != nil { - return query.Where("? <= "+field, cond.From.Unix()) - } else if cond.To != nil { - return query.Where(field+" <= ?", cond.To.Unix()) - } else { - return query - } -} - -func buildFloatCondition(field string, cond *model.FloatRange, query sq.SelectBuilder) sq.SelectBuilder { - return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To) -} - -func buildStringCondition(field string, cond *model.StringInput, query sq.SelectBuilder) sq.SelectBuilder { - if cond.Eq != nil { - return query.Where(field+" = ?", *cond.Eq) - } - if cond.StartsWith != nil { - return query.Where(field+" LIKE ?", fmt.Sprint(*cond.StartsWith, "%")) - } - if cond.EndsWith != nil { - return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.EndsWith)) - } - if cond.Contains != nil { - return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.Contains, "%")) - } - return query -} - -var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") -var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") - -func toSnakeCase(str string) string { - str = strings.ReplaceAll(str, "'", "") - str = strings.ReplaceAll(str, "\\", "") - snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") - snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") - return strings.ToLower(snake) + DB *sqlx.DB + Repo *repository.JobRepository } diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index bfe0542..bd5c5f9 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -40,12 +40,7 @@ func (r *jobResolver) Tags(ctx context.Context, obj *schema.Job) ([]*schema.Tag, } func (r *mutationResolver) CreateTag(ctx context.Context, typeArg string, name string) (*schema.Tag, error) { - res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", typeArg, name) - if err != nil { - return nil, err - } - - id, err := res.LastInsertId() + id, err := r.Repo.CreateTag(typeArg, name) if err != nil { return nil, err } @@ -59,18 +54,18 @@ func (r *mutationResolver) DeleteTag(ctx context.Context, id string) (string, er } func (r *mutationResolver) AddTagsToJob(ctx context.Context, job string, tagIds []string) ([]*schema.Tag, error) { - jid, err := strconv.Atoi(job) + jid, err := strconv.ParseInt(job, 10, 64) if err != nil { return nil, err } for _, tagId := range tagIds { - tid, err := strconv.Atoi(tagId) + tid, err := strconv.ParseInt(tagId, 10, 64) if err != nil { return nil, err } - if _, err := r.DB.Exec("INSERT INTO jobtag (job_id, tag_id) VALUES ($1, $2)", jid, tid); err != nil { + if err := r.Repo.AddTag(jid, tid); err != nil { return nil, err } } @@ -148,14 +143,21 @@ func (r *queryResolver) Tags(ctx context.Context) ([]*schema.Tag, error) { } func (r *queryResolver) Job(ctx context.Context, id string) (*schema.Job, error) { - // This query is very common (mostly called through other resolvers such as JobMetrics), - // so we use prepared statements here. - user := auth.GetUser(ctx) - if user == nil || user.HasRole(auth.RoleAdmin) { - return schema.ScanJob(r.findJobByIdStmt.QueryRowx(id)) + numericId, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, err } - return schema.ScanJob(r.findJobByIdWithUserStmt.QueryRowx(id, user.Username)) + job, err := r.Repo.FindById(numericId) + if err != nil { + return nil, err + } + + if user := auth.GetUser(ctx); user != nil && !user.HasRole(auth.RoleAdmin) && job.User != user.Username { + return nil, errors.New("you are not allowed to see this job") + } + + return job, nil } func (r *queryResolver) JobMetrics(ctx context.Context, id string, metrics []string, scopes []schema.MetricScope) ([]*model.JobMetricWithName, error) { @@ -191,7 +193,12 @@ func (r *queryResolver) JobsFootprints(ctx context.Context, filter []*model.JobF } func (r *queryResolver) Jobs(ctx context.Context, filter []*model.JobFilter, page *model.PageRequest, order *model.OrderByInput) (*model.JobResultList, error) { - jobs, count, err := r.queryJobs(ctx, filter, page, order) + jobs, err := r.Repo.QueryJobs(ctx, filter, page, order) + if err != nil { + return nil, err + } + + count, err := r.Repo.CountJobs(ctx, filter) if err != nil { return nil, err } diff --git a/graph/stats.go b/graph/stats.go index d979229..bca0afe 100644 --- a/graph/stats.go +++ b/graph/stats.go @@ -11,6 +11,7 @@ import ( "github.com/ClusterCockpit/cc-backend/config" "github.com/ClusterCockpit/cc-backend/graph/model" "github.com/ClusterCockpit/cc-backend/metricdata" + "github.com/ClusterCockpit/cc-backend/repository" "github.com/ClusterCockpit/cc-backend/schema" sq "github.com/Masterminds/squirrel" ) @@ -53,9 +54,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF Where("job.cluster = ?", cluster.Name). Where("job.partition = ?", partition.Name) - query = securityCheck(ctx, query) + query = repository.SecurityCheck(ctx, query) for _, f := range filter { - query = buildWhereClause(f, query) + query = repository.BuildWhereClause(f, query) } rows, err := query.RunWith(r.DB).Query() @@ -90,9 +91,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF if groupBy == nil { query := sq.Select("COUNT(job.id)").From("job").Where("job.duration < 120") - query = securityCheck(ctx, query) + query = repository.SecurityCheck(ctx, query) for _, f := range filter { - query = buildWhereClause(f, query) + query = repository.BuildWhereClause(f, query) } if err := query.RunWith(r.DB).QueryRow().Scan(&(stats[""].ShortJobs)); err != nil { return nil, err @@ -100,9 +101,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF } else { col := groupBy2column[*groupBy] query := sq.Select(col, "COUNT(job.id)").From("job").Where("job.duration < 120") - query = securityCheck(ctx, query) + query = repository.SecurityCheck(ctx, query) for _, f := range filter { - query = buildWhereClause(f, query) + query = repository.BuildWhereClause(f, query) } rows, err := query.RunWith(r.DB).Query() if err != nil { @@ -162,9 +163,9 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF // to add a condition to the query of the kind " = ". func (r *queryResolver) jobsStatisticsHistogram(ctx context.Context, value string, filters []*model.JobFilter, id, col string) ([]*model.HistoPoint, error) { query := sq.Select(value, "COUNT(job.id) AS count").From("job") - query = securityCheck(ctx, query) + query = repository.SecurityCheck(ctx, query) for _, f := range filters { - query = buildWhereClause(f, query) + query = repository.BuildWhereClause(f, query) } if len(id) != 0 && len(col) != 0 { @@ -188,14 +189,16 @@ func (r *queryResolver) jobsStatisticsHistogram(ctx context.Context, value strin return points, nil } +const MAX_JOBS_FOR_ANALYSIS = 500 + // Helper function for the rooflineHeatmap GraphQL query placed here so that schema.resolvers.go is not too full. func (r *Resolver) rooflineHeatmap(ctx context.Context, filter []*model.JobFilter, rows int, cols int, minX float64, minY float64, maxX float64, maxY float64) ([][]float64, error) { - jobs, count, err := r.queryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: 501}, nil) + jobs, err := r.Repo.QueryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: MAX_JOBS_FOR_ANALYSIS + 1}, nil) if err != nil { return nil, err } - if len(jobs) > 500 { - return nil, fmt.Errorf("too many jobs matched (matched: %d, max: %d)", count, 500) + if len(jobs) > MAX_JOBS_FOR_ANALYSIS { + return nil, fmt.Errorf("too many jobs matched (max: %d)", MAX_JOBS_FOR_ANALYSIS) } fcols, frows := float64(cols), float64(rows) @@ -250,12 +253,12 @@ func (r *Resolver) rooflineHeatmap(ctx context.Context, filter []*model.JobFilte // Helper function for the jobsFootprints GraphQL query placed here so that schema.resolvers.go is not too full. func (r *queryResolver) jobsFootprints(ctx context.Context, filter []*model.JobFilter, metrics []string) ([]*model.MetricFootprints, error) { - jobs, count, err := r.queryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: 501}, nil) + jobs, err := r.Repo.QueryJobs(ctx, filter, &model.PageRequest{Page: 1, ItemsPerPage: MAX_JOBS_FOR_ANALYSIS + 1}, nil) if err != nil { return nil, err } - if len(jobs) > 500 { - return nil, fmt.Errorf("too many jobs matched (matched: %d, max: %d)", count, 500) + if len(jobs) > MAX_JOBS_FOR_ANALYSIS { + return nil, fmt.Errorf("too many jobs matched (max: %d)", MAX_JOBS_FOR_ANALYSIS) } avgs := make([][]schema.Float, len(metrics)) diff --git a/repository/job.go b/repository/job.go index 9d01051..770074f 100644 --- a/repository/job.go +++ b/repository/job.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "errors" - "fmt" "strconv" "github.com/ClusterCockpit/cc-backend/auth" @@ -42,8 +41,7 @@ func (r *JobRepository) Find( return nil, err } - job, err := schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...)) - return job, err + return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...)) } // FindById executes a SQL query to find a specific batch job. @@ -58,8 +56,7 @@ func (r *JobRepository) FindById( return nil, err } - job, err := schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...)) - return job, err + return schema.ScanJob(r.DB.QueryRowx(sqlQuery, args...)) } // Start inserts a new job in the table, returning the unique job ID. @@ -96,20 +93,9 @@ func (r *JobRepository) Stop( return } -// CountJobs returns the number of jobs for the specified user (if a non-admin user is found in that context) and state. +// CountJobsPerCluster returns the number of jobs for the specified user (if a non-admin user is found in that context) and state. // The counts are grouped by cluster. -func (r *JobRepository) CountJobs(ctx context.Context, state *schema.JobState) (map[string]int, error) { - // q := sq.Select("count(*)").From("job") - // if cluster != nil { - // q = q.Where("job.cluster = ?", cluster) - // } - // if state != nil { - // q = q.Where("job.job_state = ?", string(*state)) - // } - - // err = q.RunWith(r.DB).QueryRow().Scan(&count) - // return - +func (r *JobRepository) CountJobsPerCluster(ctx context.Context, state *schema.JobState) (map[string]int, error) { q := sq.Select("job.cluster, count(*)").From("job").GroupBy("job.cluster") if state != nil { q = q.Where("job.job_state = ?", string(*state)) @@ -137,13 +123,6 @@ func (r *JobRepository) CountJobs(ctx context.Context, state *schema.JobState) ( return counts, nil } -// func (r *JobRepository) Query( -// filters []*model.JobFilter, -// page *model.PageRequest, -// order *model.OrderByInput) ([]*schema.Job, int, error) { - -// } - func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) { stmt := sq.Update("job"). Set("monitoring_status", monitoringStatus). @@ -186,91 +165,6 @@ func (r *JobRepository) Archive( return nil } -// Add the tag with id `tagId` to the job with the database id `jobId`. -func (r *JobRepository) AddTag(jobId int64, tagId int64) error { - _, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES (?, ?)`, jobId, tagId) - return err -} - -// CreateTag creates a new tag with the specified type and name and returns its database id. -func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) { - res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName) - if err != nil { - return 0, err - } - - return res.LastInsertId() -} - -func (r *JobRepository) GetTags(user *string) (tags []schema.Tag, counts map[string]int, err error) { - tags = make([]schema.Tag, 0, 100) - xrows, err := r.DB.Queryx("SELECT * FROM tag") - if err != nil { - return nil, nil, err - } - - for xrows.Next() { - var t schema.Tag - if err := xrows.StructScan(&t); err != nil { - return nil, nil, err - } - tags = append(tags, t) - } - - q := sq.Select("t.tag_name, count(jt.tag_id)"). - From("tag t"). - LeftJoin("jobtag jt ON t.id = jt.tag_id"). - GroupBy("t.tag_name") - if user != nil { - q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user) - } - - rows, err := q.RunWith(r.DB).Query() - if err != nil { - return nil, nil, err - } - - counts = make(map[string]int) - - for rows.Next() { - var tagName string - var count int - err = rows.Scan(&tagName, &count) - if err != nil { - fmt.Println(err) - } - counts[tagName] = count - } - err = rows.Err() - - return -} - -// AddTagOrCreate adds the tag with the specified type and name to the job with the database id `jobId`. -// If such a tag does not yet exist, it is created. -func (r *JobRepository) AddTagOrCreate(jobId int64, tagType string, tagName string) (tagId int64, err error) { - tagId, exists := r.TagId(tagType, tagName) - if !exists { - tagId, err = r.CreateTag(tagType, tagName) - if err != nil { - return 0, err - } - } - - return tagId, r.AddTag(jobId, tagId) -} - -// TagId returns the database id of the tag with the specified type and name. -func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exists bool) { - exists = true - if err := sq.Select("id").From("tag"). - Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName). - RunWith(r.DB).QueryRow().Scan(&tagId); err != nil { - exists = false - } - return -} - var ErrNotFound = errors.New("no such job or user") // FindJobOrUser returns a job database ID or a username if a job or user machtes the search term. diff --git a/repository/job_test.go b/repository/job_test.go index 8a43617..2c8e75f 100644 --- a/repository/job_test.go +++ b/repository/job_test.go @@ -7,8 +7,6 @@ import ( "github.com/jmoiron/sqlx" "github.com/ClusterCockpit/cc-backend/test" - _ "github.com/go-sql-driver/mysql" - _ "github.com/mattn/go-sqlite3" ) var db *sqlx.DB @@ -57,7 +55,7 @@ func TestFindById(t *testing.T) { func TestGetTags(t *testing.T) { r := setup(t) - tags, counts, err := r.GetTags(nil) + tags, counts, err := r.CountTags(nil) if err != nil { t.Fatal(err) } diff --git a/repository/query.go b/repository/query.go new file mode 100644 index 0000000..53a863e --- /dev/null +++ b/repository/query.go @@ -0,0 +1,212 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + + "github.com/ClusterCockpit/cc-backend/auth" + "github.com/ClusterCockpit/cc-backend/graph/model" + "github.com/ClusterCockpit/cc-backend/log" + "github.com/ClusterCockpit/cc-backend/schema" + sq "github.com/Masterminds/squirrel" +) + +// QueryJobs returns a list of jobs matching the provided filters. page and order are optional- +func (r *JobRepository) QueryJobs( + ctx context.Context, + filters []*model.JobFilter, + page *model.PageRequest, + order *model.OrderByInput) ([]*schema.Job, error) { + + query := sq.Select(schema.JobColumns...).From("job") + query = SecurityCheck(ctx, query) + + if order != nil { + field := toSnakeCase(order.Field) + if order.Order == model.SortDirectionEnumAsc { + query = query.OrderBy(fmt.Sprintf("job.%s ASC", field)) + } else if order.Order == model.SortDirectionEnumDesc { + query = query.OrderBy(fmt.Sprintf("job.%s DESC", field)) + } else { + return nil, errors.New("invalid sorting order") + } + } + + if page != nil { + limit := uint64(page.ItemsPerPage) + query = query.Offset((uint64(page.Page) - 1) * limit).Limit(limit) + } else { + query = query.Limit(50) + } + + for _, f := range filters { + query = BuildWhereClause(f, query) + } + + sql, args, err := query.ToSql() + if err != nil { + return nil, err + } + + log.Debugf("SQL query: `%s`, args: %#v", sql, args) + rows, err := r.DB.Queryx(sql, args...) + if err != nil { + return nil, err + } + + jobs := make([]*schema.Job, 0, 50) + for rows.Next() { + job, err := schema.ScanJob(rows) + if err != nil { + return nil, err + } + jobs = append(jobs, job) + } + + return jobs, nil +} + +// CountJobs counts the number of jobs matching the filters. +func (r *JobRepository) CountJobs( + ctx context.Context, + filters []*model.JobFilter) (int, error) { + + // count all jobs: + query := sq.Select("count(*)").From("job") + query = SecurityCheck(ctx, query) + for _, f := range filters { + query = BuildWhereClause(f, query) + } + var count int + if err := query.RunWith(r.DB).Scan(&count); err != nil { + return 0, err + } + + return count, nil +} + +func SecurityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder { + user := auth.GetUser(ctx) + if user == nil || user.HasRole(auth.RoleAdmin) { + return query + } + + return query.Where("job.user = ?", user.Username) +} + +// Build a sq.SelectBuilder out of a schema.JobFilter. +func BuildWhereClause(filter *model.JobFilter, query sq.SelectBuilder) sq.SelectBuilder { + if filter.Tags != nil { + query = query.Join("jobtag ON jobtag.job_id = job.id").Where(sq.Eq{"jobtag.tag_id": filter.Tags}) + } + if filter.JobID != nil { + query = buildStringCondition("job.job_id", filter.JobID, query) + } + if filter.ArrayJobID != nil { + query = query.Where("job.array_job_id = ?", *filter.ArrayJobID) + } + if filter.User != nil { + query = buildStringCondition("job.user", filter.User, query) + } + if filter.Project != nil { + query = buildStringCondition("job.project", filter.Project, query) + } + if filter.Cluster != nil { + query = buildStringCondition("job.cluster", filter.Cluster, query) + } + if filter.Partition != nil { + query = buildStringCondition("job.partition", filter.Partition, query) + } + if filter.StartTime != nil { + query = buildTimeCondition("job.start_time", filter.StartTime, query) + } + if filter.Duration != nil { + query = buildIntCondition("job.duration", filter.Duration, query) + } + if filter.State != nil { + states := make([]string, len(filter.State)) + for i, val := range filter.State { + states[i] = string(val) + } + + query = query.Where(sq.Eq{"job.job_state": states}) + } + if filter.NumNodes != nil { + query = buildIntCondition("job.num_nodes", filter.NumNodes, query) + } + if filter.NumAccelerators != nil { + query = buildIntCondition("job.num_acc", filter.NumAccelerators, query) + } + if filter.NumHWThreads != nil { + query = buildIntCondition("job.num_hwthreads", filter.NumHWThreads, query) + } + if filter.FlopsAnyAvg != nil { + query = buildFloatCondition("job.flops_any_avg", filter.FlopsAnyAvg, query) + } + if filter.MemBwAvg != nil { + query = buildFloatCondition("job.mem_bw_avg", filter.MemBwAvg, query) + } + if filter.LoadAvg != nil { + query = buildFloatCondition("job.load_avg", filter.LoadAvg, query) + } + if filter.MemUsedMax != nil { + query = buildFloatCondition("job.mem_used_max", filter.MemUsedMax, query) + } + return query +} + +func buildIntCondition(field string, cond *model.IntRange, query sq.SelectBuilder) sq.SelectBuilder { + return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To) +} + +func buildTimeCondition(field string, cond *model.TimeRange, query sq.SelectBuilder) sq.SelectBuilder { + if cond.From != nil && cond.To != nil { + return query.Where(field+" BETWEEN ? AND ?", cond.From.Unix(), cond.To.Unix()) + } else if cond.From != nil { + return query.Where("? <= "+field, cond.From.Unix()) + } else if cond.To != nil { + return query.Where(field+" <= ?", cond.To.Unix()) + } else { + return query + } +} + +func buildFloatCondition(field string, cond *model.FloatRange, query sq.SelectBuilder) sq.SelectBuilder { + return query.Where(field+" BETWEEN ? AND ?", cond.From, cond.To) +} + +func buildStringCondition(field string, cond *model.StringInput, query sq.SelectBuilder) sq.SelectBuilder { + if cond.Eq != nil { + return query.Where(field+" = ?", *cond.Eq) + } + if cond.StartsWith != nil { + return query.Where(field+" LIKE ?", fmt.Sprint(*cond.StartsWith, "%")) + } + if cond.EndsWith != nil { + return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.EndsWith)) + } + if cond.Contains != nil { + return query.Where(field+" LIKE ?", fmt.Sprint("%", *cond.Contains, "%")) + } + return query +} + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +func toSnakeCase(str string) string { + for _, c := range str { + if c == '\'' || c == '\\' { + panic("A hacker (probably not)!!!") + } + } + + str = strings.ReplaceAll(str, "'", "") + str = strings.ReplaceAll(str, "\\", "") + snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +} diff --git a/repository/tags.go b/repository/tags.go new file mode 100644 index 0000000..c8d8310 --- /dev/null +++ b/repository/tags.go @@ -0,0 +1,93 @@ +package repository + +import ( + "fmt" + + "github.com/ClusterCockpit/cc-backend/schema" + sq "github.com/Masterminds/squirrel" +) + +// Add the tag with id `tagId` to the job with the database id `jobId`. +func (r *JobRepository) AddTag(jobId int64, tagId int64) error { + _, err := r.DB.Exec(`INSERT INTO jobtag (job_id, tag_id) VALUES (?, ?)`, jobId, tagId) + return err +} + +// CreateTag creates a new tag with the specified type and name and returns its database id. +func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) { + res, err := r.DB.Exec("INSERT INTO tag (tag_type, tag_name) VALUES ($1, $2)", tagType, tagName) + if err != nil { + return 0, err + } + + return res.LastInsertId() +} + +func (r *JobRepository) CountTags(user *string) (tags []schema.Tag, counts map[string]int, err error) { + tags = make([]schema.Tag, 0, 100) + xrows, err := r.DB.Queryx("SELECT * FROM tag") + if err != nil { + return nil, nil, err + } + + for xrows.Next() { + var t schema.Tag + if err := xrows.StructScan(&t); err != nil { + return nil, nil, err + } + tags = append(tags, t) + } + + q := sq.Select("t.tag_name, count(jt.tag_id)"). + From("tag t"). + LeftJoin("jobtag jt ON t.id = jt.tag_id"). + GroupBy("t.tag_name") + if user != nil { + q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user) + } + + rows, err := q.RunWith(r.DB).Query() + if err != nil { + return nil, nil, err + } + + counts = make(map[string]int) + + for rows.Next() { + var tagName string + var count int + err = rows.Scan(&tagName, &count) + if err != nil { + fmt.Println(err) + } + counts[tagName] = count + } + err = rows.Err() + + return +} + +// AddTagOrCreate adds the tag with the specified type and name to the job with the database id `jobId`. +// If such a tag does not yet exist, it is created. +func (r *JobRepository) AddTagOrCreate(jobId int64, tagType string, tagName string) (tagId int64, err error) { + tagId, exists := r.TagId(tagType, tagName) + if !exists { + tagId, err = r.CreateTag(tagType, tagName) + if err != nil { + return 0, err + } + } + + return tagId, r.AddTag(jobId, tagId) +} + +// TagId returns the database id of the tag with the specified type and name. +func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exists bool) { + exists = true + if err := sq.Select("id").From("tag"). + Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName). + RunWith(r.DB).QueryRow().Scan(&tagId); err != nil { + exists = false + } + return +} diff --git a/server.go b/server.go index 5a78d4a..ba127c8 100644 --- a/server.go +++ b/server.go @@ -130,12 +130,12 @@ func setupHomeRoute(i InfoType, r *http.Request) InfoType { } state := schema.JobStateRunning - runningJobs, err := jobRepo.CountJobs(r.Context(), &state) + runningJobs, err := jobRepo.CountJobsPerCluster(r.Context(), &state) if err != nil { log.Errorf("failed to count jobs: %s", err.Error()) runningJobs = map[string]int{} } - totalJobs, err := jobRepo.CountJobs(r.Context(), nil) + totalJobs, err := jobRepo.CountJobsPerCluster(r.Context(), nil) if err != nil { log.Errorf("failed to count jobs: %s", err.Error()) totalJobs = map[string]int{} @@ -200,7 +200,7 @@ func setupTaglistRoute(i InfoType, r *http.Request) InfoType { username = &user.Username } - tags, counts, err := jobRepo.GetTags(username) + tags, counts, err := jobRepo.CountTags(username) tagMap := make(map[string][]map[string]interface{}) if err != nil { log.Errorf("GetTags failed: %s", err.Error()) @@ -360,10 +360,7 @@ func main() { // Build routes... - resolver := &graph.Resolver{DB: db} - if err := resolver.Init(); err != nil { - log.Fatal(err) - } + resolver := &graph.Resolver{DB: db, Repo: jobRepo} graphQLEndpoint := handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: resolver})) if os.Getenv("DEBUG") != "1" { graphQLEndpoint.SetRecoverFunc(func(ctx context.Context, err interface{}) error { diff --git a/test/db.go b/test/db.go index aa42fb8..8553ef6 100644 --- a/test/db.go +++ b/test/db.go @@ -5,6 +5,7 @@ import ( "os" "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" ) func InitDB() *sqlx.DB {