From becb2bfa3a5967daad12388e2bc63a287ccbbc92 Mon Sep 17 00:00:00 2001 From: Jan Eitzinger Date: Wed, 7 Jun 2023 11:58:58 +0200 Subject: [PATCH] Refactor Jobs stats resolver --- internal/graph/schema.resolvers.go | 58 ++++- internal/repository/stats.go | 382 ++++++++++++++++++----------- internal/repository/stats_test.go | 24 ++ 3 files changed, 317 insertions(+), 147 deletions(-) create mode 100644 internal/repository/stats_test.go diff --git a/internal/graph/schema.resolvers.go b/internal/graph/schema.resolvers.go index a24fac0..c2a2ed7 100644 --- a/internal/graph/schema.resolvers.go +++ b/internal/graph/schema.resolvers.go @@ -11,6 +11,7 @@ import ( "strconv" "time" + "github.com/99designs/gqlgen/graphql" "github.com/ClusterCockpit/cc-backend/internal/auth" "github.com/ClusterCockpit/cc-backend/internal/graph/generated" "github.com/ClusterCockpit/cc-backend/internal/graph/model" @@ -38,7 +39,6 @@ func (r *jobResolver) Tags(ctx context.Context, obj *schema.Job) ([]*schema.Tag, // ConcurrentJobs is the resolver for the concurrentJobs field. func (r *jobResolver) ConcurrentJobs(ctx context.Context, obj *schema.Job) (*model.JobLinkResultList, error) { - exc := int(obj.Exclusive) if exc != 1 { filter := []*model.JobFilter{} @@ -269,7 +269,43 @@ func (r *queryResolver) Jobs(ctx context.Context, filter []*model.JobFilter, pag // JobsStatistics is the resolver for the jobsStatistics field. func (r *queryResolver) JobsStatistics(ctx context.Context, filter []*model.JobFilter, groupBy *model.Aggregate) ([]*model.JobsStatistics, error) { - return r.Repo.JobsStatistics(ctx, filter, groupBy) + var err error + var stats []*model.JobsStatistics + + if requireField(ctx, "TotalJobs") { + if requireField(ctx, "TotalCoreHours") { + if groupBy == nil { + stats, err = r.Repo.JobsStatsPlain(ctx, filter) + } else { + stats, err = r.Repo.JobsStats(ctx, filter, groupBy) + } + } else { + if groupBy == nil { + stats, err = r.Repo.JobsStatsPlainNoCoreH(ctx, filter) + } else { + stats, err = r.Repo.JobsStatsNoCoreH(ctx, filter, groupBy) + } + } + } else { + stats = make([]*model.JobsStatistics, 0, 1) + } + + if err != nil { + return nil, err + } + + if requireField(ctx, "histDuration") || requireField(ctx, "histNumNodes") { + if groupBy == nil { + stats[0], err = r.Repo.AddHistograms(ctx, filter, stats[0]) + if err != nil { + return nil, err + } + } else { + return nil, errors.New("histograms only implemented without groupBy argument") + } + } + + return stats, nil } // JobsCount is the resolver for the jobsCount field. @@ -367,3 +403,21 @@ type jobResolver struct{ *Resolver } type mutationResolver struct{ *Resolver } type queryResolver struct{ *Resolver } type subClusterResolver struct{ *Resolver } + +// !!! WARNING !!! +// The code below was going to be deleted when updating resolvers. It has been copied here so you have +// one last chance to move it out of harms way if you want. There are two reasons this happens: +// - When renaming or deleting a resolver the old code will be put in here. You can safely delete +// it when you're done. +// - You have helper methods in this file. Move them out to keep these resolver files clean. +func requireField(ctx context.Context, name string) bool { + fields := graphql.CollectAllFields(ctx) + + for _, f := range fields { + if f == name { + return true + } + } + + return false +} diff --git a/internal/repository/stats.go b/internal/repository/stats.go index 5ea589f..9ec646f 100644 --- a/internal/repository/stats.go +++ b/internal/repository/stats.go @@ -10,9 +10,7 @@ import ( "fmt" "time" - "github.com/99designs/gqlgen/graphql" "github.com/ClusterCockpit/cc-backend/internal/auth" - "github.com/ClusterCockpit/cc-backend/internal/config" "github.com/ClusterCockpit/cc-backend/internal/graph/model" "github.com/ClusterCockpit/cc-backend/pkg/archive" "github.com/ClusterCockpit/cc-backend/pkg/log" @@ -26,14 +24,43 @@ var groupBy2column = map[model.Aggregate]string{ model.AggregateCluster: "job.cluster", } -// Helper function for the jobsStatistics GraphQL query placed here so that schema.resolvers.go is not too full. -func (r *JobRepository) JobsStatistics(ctx context.Context, +func (r *JobRepository) buildJobsStatsQuery( filter []*model.JobFilter, - groupBy *model.Aggregate) ([]*model.JobsStatistics, error) { + col string) sq.SelectBuilder { - start := time.Now() - // In case `groupBy` is nil (not used), the model.JobsStatistics used is at the key '' (empty string) - stats := map[string]*model.JobsStatistics{} + var query sq.SelectBuilder + castType := r.getCastType() + + if col != "" { + // Scan columns: id, totalJobs, totalWalltime + query = sq.Select(col, "COUNT(job.id)", + fmt.Sprintf("CAST(ROUND(SUM(job.duration) / 3600) as %s)", castType), + ).From("job").GroupBy(col) + } else { + // Scan columns: totalJobs, totalWalltime + query = sq.Select("COUNT(job.id)", + fmt.Sprintf("CAST(ROUND(SUM(job.duration) / 3600) as %s)", castType), + ).From("job") + } + + for _, f := range filter { + query = BuildWhereClause(f, query) + } + + return query +} + +func (r *JobRepository) getUserName(ctx context.Context, id string) string { + user := auth.GetUser(ctx) + name, _ := r.FindColumnValue(user, id, "user", "name", "username", false) + if name != "" { + return name + } else { + return "-" + } +} + +func (r *JobRepository) getCastType() string { var castType string switch r.driver { @@ -41,45 +68,175 @@ func (r *JobRepository) JobsStatistics(ctx context.Context, castType = "int" case "mysql": castType = "unsigned" + default: + castType = "" } - // `socketsPerNode` and `coresPerSocket` can differ from cluster to cluster, so we need to explicitly loop over those. + return castType +} + +// with groupBy and without coreHours +func (r *JobRepository) JobsStatsNoCoreH( + ctx context.Context, + filter []*model.JobFilter, + groupBy *model.Aggregate) ([]*model.JobsStatistics, error) { + + start := time.Now() + col := groupBy2column[*groupBy] + query := r.buildJobsStatsQuery(filter, col) + query, err := SecurityCheck(ctx, query) + if err != nil { + return nil, err + } + + rows, err := query.RunWith(r.DB).Query() + if err != nil { + log.Warn("Error while querying DB for job statistics") + return nil, err + } + + stats := make([]*model.JobsStatistics, 0, 100) + + for rows.Next() { + var id sql.NullString + var jobs, walltime sql.NullInt64 + if err := rows.Scan(&id, &jobs, &walltime); err != nil { + log.Warn("Error while scanning rows") + return nil, err + } + + if id.Valid { + if col == "job.user" { + name := r.getUserName(ctx, id.String) + stats = append(stats, + &model.JobsStatistics{ + ID: id.String, + Name: &name, + TotalJobs: int(jobs.Int64), + TotalWalltime: int(walltime.Int64)}) + } else { + stats = append(stats, + &model.JobsStatistics{ + ID: id.String, + TotalJobs: int(jobs.Int64), + TotalWalltime: int(walltime.Int64)}) + } + } + } + + log.Infof("Timer JobStatistics %s", time.Since(start)) + return stats, nil +} + +// without groupBy and without coreHours +func (r *JobRepository) JobsStatsPlainNoCoreH( + ctx context.Context, + filter []*model.JobFilter) ([]*model.JobsStatistics, error) { + + start := time.Now() + query := r.buildJobsStatsQuery(filter, "") + query, err := SecurityCheck(ctx, query) + if err != nil { + return nil, err + } + + row := query.RunWith(r.DB).QueryRow() + stats := make([]*model.JobsStatistics, 0, 1) + var jobs, walltime sql.NullInt64 + if err := row.Scan(&jobs, &walltime); err != nil { + log.Warn("Error while scanning rows") + return nil, err + } + + if jobs.Valid { + stats = append(stats, + &model.JobsStatistics{ + TotalJobs: int(jobs.Int64), + TotalWalltime: int(walltime.Int64)}) + } + + log.Infof("Timer JobStatistics %s", time.Since(start)) + return stats, nil +} + +// without groupBy and with coreHours +func (r *JobRepository) JobsStatsPlain( + ctx context.Context, + filter []*model.JobFilter) ([]*model.JobsStatistics, error) { + + start := time.Now() + query := r.buildJobsStatsQuery(filter, "") + query, err := SecurityCheck(ctx, query) + if err != nil { + return nil, err + } + + castType := r.getCastType() + var totalJobs, totalWalltime, totalCoreHours int64 + for _, cluster := range archive.Clusters { for _, subcluster := range cluster.SubClusters { - corehoursCol := fmt.Sprintf("CAST(ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600) as %s)", subcluster.SocketsPerNode, subcluster.CoresPerSocket, castType) - var rawQuery sq.SelectBuilder - if groupBy == nil { - rawQuery = sq.Select( - "''", - "COUNT(job.id)", - fmt.Sprintf("CAST(ROUND(SUM(job.duration) / 3600) as %s)", castType), - corehoursCol, - ).From("job") - } else { - col := groupBy2column[*groupBy] - rawQuery = sq.Select( - col, - "COUNT(job.id)", - fmt.Sprintf("CAST(ROUND(SUM(job.duration) / 3600) as %s)", castType), - corehoursCol, - ).From("job").GroupBy(col) - } - rawQuery = rawQuery. - Where("job.cluster = ?", cluster.Name). + scQuery := query.Column(fmt.Sprintf( + "CAST(ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600) as %s)", + subcluster.SocketsPerNode, subcluster.CoresPerSocket, castType)) + scQuery = scQuery.Where("job.cluster = ?", cluster.Name). Where("job.subcluster = ?", subcluster.Name) - query, qerr := SecurityCheck(ctx, rawQuery) - - if qerr != nil { - return nil, qerr + row := scQuery.RunWith(r.DB).QueryRow() + var jobs, walltime, corehours sql.NullInt64 + if err := row.Scan(&jobs, &walltime, &corehours); err != nil { + log.Warn("Error while scanning rows") + return nil, err } - for _, f := range filter { - query = BuildWhereClause(f, query) + if jobs.Valid { + totalJobs += jobs.Int64 + totalWalltime += walltime.Int64 + totalCoreHours += corehours.Int64 } + } + } + stats := make([]*model.JobsStatistics, 0, 1) + stats = append(stats, + &model.JobsStatistics{ + TotalJobs: int(totalJobs), + TotalWalltime: int(totalWalltime), + TotalCoreHours: int(totalCoreHours)}) - rows, err := query.RunWith(r.DB).Query() + log.Infof("Timer JobStatistics %s", time.Since(start)) + return stats, nil +} + +// with groupBy and with coreHours +func (r *JobRepository) JobsStats( + ctx context.Context, + filter []*model.JobFilter, + groupBy *model.Aggregate) ([]*model.JobsStatistics, error) { + + start := time.Now() + + stats := map[string]*model.JobsStatistics{} + col := groupBy2column[*groupBy] + query := r.buildJobsStatsQuery(filter, col) + query, err := SecurityCheck(ctx, query) + if err != nil { + return nil, err + } + + castType := r.getCastType() + + for _, cluster := range archive.Clusters { + for _, subcluster := range cluster.SubClusters { + + scQuery := query.Column(fmt.Sprintf( + "CAST(ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600) as %s)", + subcluster.SocketsPerNode, subcluster.CoresPerSocket, castType)) + + scQuery = scQuery.Where("job.cluster = ?", cluster.Name). + Where("job.subcluster = ?", subcluster.Name) + + rows, err := scQuery.RunWith(r.DB).Query() if err != nil { log.Warn("Error while querying DB for job statistics") return nil, err @@ -93,11 +250,20 @@ func (r *JobRepository) JobsStatistics(ctx context.Context, return nil, err } - if id.Valid { - if s, ok := stats[id.String]; ok { - s.TotalJobs += int(jobs.Int64) - s.TotalWalltime += int(walltime.Int64) - s.TotalCoreHours += int(corehours.Int64) + if s, ok := stats[id.String]; ok { + s.TotalJobs += int(jobs.Int64) + s.TotalWalltime += int(walltime.Int64) + s.TotalCoreHours += int(corehours.Int64) + } else { + if col == "job.user" { + name := r.getUserName(ctx, id.String) + stats[id.String] = &model.JobsStatistics{ + ID: id.String, + Name: &name, + TotalJobs: int(jobs.Int64), + TotalWalltime: int(walltime.Int64), + TotalCoreHours: int(corehours.Int64), + } } else { stats[id.String] = &model.JobsStatistics{ ID: id.String, @@ -108,120 +274,50 @@ func (r *JobRepository) JobsStatistics(ctx context.Context, } } } - - } - } - - if groupBy == nil { - - query := sq.Select("COUNT(job.id)").From("job").Where("job.duration < ?", config.Keys.ShortRunningJobsDuration) - query, qerr := SecurityCheck(ctx, query) - - if qerr != nil { - return nil, qerr - } - - for _, f := range filter { - query = BuildWhereClause(f, query) - } - if err := query.RunWith(r.DB).QueryRow().Scan(&(stats[""].ShortJobs)); err != nil { - log.Warn("Error while scanning rows for short job stats") - return nil, err - } - } else { - col := groupBy2column[*groupBy] - - query := sq.Select(col, "COUNT(job.id)").From("job").Where("job.duration < ?", config.Keys.ShortRunningJobsDuration) - - query, qerr := SecurityCheck(ctx, query) - - if qerr != nil { - return nil, qerr - } - - for _, f := range filter { - query = BuildWhereClause(f, query) - } - rows, err := query.RunWith(r.DB).Query() - if err != nil { - log.Warn("Error while querying jobs for short jobs") - return nil, err - } - - for rows.Next() { - var id sql.NullString - var shortJobs sql.NullInt64 - if err := rows.Scan(&id, &shortJobs); err != nil { - log.Warn("Error while scanning rows for short jobs") - return nil, err - } - - if id.Valid { - stats[id.String].ShortJobs = int(shortJobs.Int64) - } - } - - if col == "job.user" { - for id := range stats { - emptyDash := "-" - user := auth.GetUser(ctx) - name, _ := r.FindColumnValue(user, id, "user", "name", "username", false) - if name != "" { - stats[id].Name = &name - } else { - stats[id].Name = &emptyDash - } - } - } - } - - // Calculating the histogram data is expensive, so only do it if needed. - // An explicit resolver can not be used because we need to know the filters. - histogramsNeeded := false - fields := graphql.CollectFieldsCtx(ctx, nil) - for _, col := range fields { - if col.Name == "histDuration" || col.Name == "histNumNodes" { - histogramsNeeded = true } } res := make([]*model.JobsStatistics, 0, len(stats)) for _, stat := range stats { res = append(res, stat) - id, col := "", "" - if groupBy != nil { - id = stat.ID - col = groupBy2column[*groupBy] - } - - if histogramsNeeded { - var err error - value := fmt.Sprintf(`CAST(ROUND((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) / 3600) as %s) as value`, time.Now().Unix(), castType) - stat.HistDuration, err = r.jobsStatisticsHistogram(ctx, value, filter, id, col) - if err != nil { - log.Warn("Error while loading job statistics histogram: running jobs") - return nil, err - } - - stat.HistNumNodes, err = r.jobsStatisticsHistogram(ctx, "job.num_nodes as value", filter, id, col) - if err != nil { - log.Warn("Error while loading job statistics histogram: num nodes") - return nil, err - } - } } log.Infof("Timer JobStatistics %s", time.Since(start)) return res, nil } -// `value` must be the column grouped by, but renamed to "value". `id` and `col` can optionally be used -// to add a condition to the query of the kind " = ". -func (r *JobRepository) jobsStatisticsHistogram(ctx context.Context, - value string, filters []*model.JobFilter, id, col string) ([]*model.HistoPoint, error) { +func (r *JobRepository) AddHistograms( + ctx context.Context, + filter []*model.JobFilter, + stat *model.JobsStatistics) (*model.JobsStatistics, error) { + + castType := r.getCastType() + var err error + value := fmt.Sprintf(`CAST(ROUND((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) / 3600) as %s) as value`, time.Now().Unix(), castType) + stat.HistDuration, err = r.jobsStatisticsHistogram(ctx, value, filter) + if err != nil { + log.Warn("Error while loading job statistics histogram: running jobs") + return nil, err + } + + stat.HistNumNodes, err = r.jobsStatisticsHistogram(ctx, "job.num_nodes as value", filter) + if err != nil { + log.Warn("Error while loading job statistics histogram: num nodes") + return nil, err + } + + return stat, nil +} + +// `value` must be the column grouped by, but renamed to "value" +func (r *JobRepository) jobsStatisticsHistogram( + ctx context.Context, + value string, + filters []*model.JobFilter) ([]*model.HistoPoint, error) { start := time.Now() - query, qerr := SecurityCheck(ctx, sq.Select(value, "COUNT(job.id) AS count").From("job")) + query, qerr := SecurityCheck(ctx, + sq.Select(value, "COUNT(job.id) AS count").From("job")) if qerr != nil { return nil, qerr @@ -231,10 +327,6 @@ func (r *JobRepository) jobsStatisticsHistogram(ctx context.Context, query = BuildWhereClause(f, query) } - if len(id) != 0 && len(col) != 0 { - query = query.Where(col+" = ?", id) - } - rows, err := query.GroupBy("value").RunWith(r.DB).Query() if err != nil { log.Error("Error while running query") diff --git a/internal/repository/stats_test.go b/internal/repository/stats_test.go new file mode 100644 index 0000000..6ed485b --- /dev/null +++ b/internal/repository/stats_test.go @@ -0,0 +1,24 @@ +// Copyright (C) 2022 NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package repository + +import ( + "fmt" + "testing" +) + +func TestBuildJobStatsQuery(t *testing.T) { + r := setup(t) + q := r.buildJobsStatsQuery(nil, "USER") + + sql, _, err := q.ToSql() + noErr(t, err) + + fmt.Printf("SQL: %s\n", sql) + + if 1 != 5 { + t.Errorf("wrong summary for diagnostic 3\ngot: %d \nwant: 1366", 5) + } +}