diff --git a/internal/repository/query.go b/internal/repository/query.go index 0501fe1..84b8048 100644 --- a/internal/repository/query.go +++ b/internal/repository/query.go @@ -18,13 +18,17 @@ import ( sq "github.com/Masterminds/squirrel" ) -// SecurityCheck-less, private: Returns a list of jobs matching the provided filters. page and order are optional- -func (r *JobRepository) queryJobs( - query sq.SelectBuilder, +func (r *JobRepository) QueryJobs( + ctx context.Context, filters []*model.JobFilter, page *model.PageRequest, order *model.OrderByInput) ([]*schema.Job, error) { + query, qerr := SecurityCheck(ctx, sq.Select(jobColumns...).From("job")) + if qerr != nil { + return nil, qerr + } + if order != nil { field := toSnakeCase(order.Field) @@ -67,34 +71,15 @@ func (r *JobRepository) queryJobs( return jobs, nil } -// testFunction for queryJobs -func (r *JobRepository) testQueryJobs( - filters []*model.JobFilter, - page *model.PageRequest, - order *model.OrderByInput) ([]*schema.Job, error) { - - return r.queryJobs(sq.Select(jobColumns...).From("job"), filters, page, order) -} - -// Public function with added securityCheck, calls private queryJobs function above -func (r *JobRepository) QueryJobs( +func (r *JobRepository) CountJobs( ctx context.Context, - filters []*model.JobFilter, - page *model.PageRequest, - order *model.OrderByInput) ([]*schema.Job, error) { - - query, qerr := SecurityCheck(ctx, sq.Select(jobColumns...).From("job")) - if qerr != nil { - return nil, qerr - } - - return r.queryJobs(query, filters, page, order) -} - -// SecurityCheck-less, private: Returns the number of jobs matching the filters -func (r *JobRepository) countJobs(query sq.SelectBuilder, filters []*model.JobFilter) (int, error) { + query, qerr := SecurityCheck(ctx, sq.Select("count(*)").From("job")) + if qerr != nil { + return 0, qerr + } + for _, f := range filters { query = BuildWhereClause(f, query) } @@ -107,27 +92,6 @@ func (r *JobRepository) countJobs(query sq.SelectBuilder, return count, nil } -// testFunction for countJobs -func (r *JobRepository) testCountJobs( - filters []*model.JobFilter) (int, error) { - - return r.countJobs(sq.Select("count(*)").From("job"), filters) -} - -// Public function with added securityCheck, calls private countJobs function above -func (r *JobRepository) CountJobs( - ctx context.Context, - filters []*model.JobFilter) (int, error) { - - query, qerr := SecurityCheck(ctx, sq.Select("count(*)").From("job")) - - if qerr != nil { - return 0, qerr - } - - return r.countJobs(query, filters) -} - func SecurityCheck(ctx context.Context, query sq.SelectBuilder) (sq.SelectBuilder, error) { user := GetUserFromContext(ctx) if user == nil { diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index efb5395..48b692f 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -5,10 +5,12 @@ package repository import ( + "context" "testing" "github.com/ClusterCockpit/cc-backend/internal/graph/model" "github.com/ClusterCockpit/cc-backend/pkg/log" + "github.com/ClusterCockpit/cc-backend/pkg/schema" _ "github.com/mattn/go-sqlite3" ) @@ -94,7 +96,7 @@ func BenchmarkDB_CountJobs(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := db.testCountJobs([]*model.JobFilter{filter}) + _, err := db.CountJobs(getContext(b), []*model.JobFilter{filter}) noErr(b, err) } }) @@ -118,20 +120,37 @@ func BenchmarkDB_QueryJobs(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := db.testQueryJobs([]*model.JobFilter{filter}, page, order) + _, err := db.QueryJobs(getContext(b), []*model.JobFilter{filter}, page, order) noErr(b, err) } }) }) } +func getContext(tb testing.TB) context.Context { + tb.Helper() + + var roles []string + roles = append(roles, schema.GetRoleString(schema.RoleAdmin)) + projects := make([]string, 0) + + user := &schema.User{ + Username: "demo", + Name: "The man", + Roles: roles, + Projects: projects, + AuthSource: schema.AuthViaLDAP, + } + ctx := context.Background() + return context.WithValue(ctx, ContextUserKey, user) +} + func setup(tb testing.TB) *JobRepository { tb.Helper() log.Init("warn", true) dbfile := "testdata/job.db" err := MigrateDB("sqlite3", dbfile) noErr(tb, err) - Connect("sqlite3", dbfile) return GetJobRepository() } diff --git a/internal/repository/stats.go b/internal/repository/stats.go index 193eb06..3ac3ffd 100644 --- a/internal/repository/stats.go +++ b/internal/repository/stats.go @@ -233,10 +233,17 @@ func (r *JobRepository) JobsStatsGrouped( return stats, nil } -func (r *JobRepository) jobsStats( - query sq.SelectBuilder, +func (r *JobRepository) JobsStats( + ctx context.Context, filter []*model.JobFilter) ([]*model.JobsStatistics, error) { + start := time.Now() + query := r.buildStatsQuery(filter, "") + query, err := SecurityCheck(ctx, query) + if err != nil { + return nil, err + } + row := query.RunWith(r.DB).QueryRow() stats := make([]*model.JobsStatistics, 0, 1) @@ -267,29 +274,8 @@ func (r *JobRepository) jobsStats( TotalAccHours: totalAccHours}) } - return stats, nil -} - -func (r *JobRepository) testJobsStats( - filter []*model.JobFilter) ([]*model.JobsStatistics, error) { - - query := r.buildStatsQuery(filter, "") - return r.jobsStats(query, filter) -} - -func (r *JobRepository) JobsStats( - ctx context.Context, - filter []*model.JobFilter) ([]*model.JobsStatistics, error) { - - start := time.Now() - query := r.buildStatsQuery(filter, "") - query, err := SecurityCheck(ctx, query) - if err != nil { - return nil, err - } - log.Debugf("Timer JobStats %s", time.Since(start)) - return r.jobsStats(query, filter) + return stats, nil } func (r *JobRepository) JobCountGrouped( diff --git a/internal/repository/stats_test.go b/internal/repository/stats_test.go index 2672b3f..6ed3f72 100644 --- a/internal/repository/stats_test.go +++ b/internal/repository/stats_test.go @@ -26,12 +26,10 @@ func TestJobStats(t *testing.T) { r := setup(t) filter := &model.JobFilter{} - var err error - var stats []*model.JobsStatistics - stats, err = r.testJobsStats([]*model.JobFilter{filter}) + stats, err := r.JobsStats(getContext(t), []*model.JobFilter{filter}) noErr(t, err) - if stats[0].TotalJobs != 98 { + if stats[0].TotalJobs != 6 { t.Fatalf("Want 98, Got %d", stats[0].TotalJobs) } }