mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2025-01-27 03:39:05 +01:00
Refactor repository tests
Add context to tests. Remove special test routines
This commit is contained in:
parent
59c749a164
commit
9533f06eaf
@ -18,13 +18,17 @@ import (
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
)
|
||||
|
||||
// SecurityCheck-less, private: Returns a list of jobs matching the provided filters. page and order are optional-
|
||||
func (r *JobRepository) queryJobs(
|
||||
query sq.SelectBuilder,
|
||||
func (r *JobRepository) QueryJobs(
|
||||
ctx context.Context,
|
||||
filters []*model.JobFilter,
|
||||
page *model.PageRequest,
|
||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
||||
|
||||
query, qerr := SecurityCheck(ctx, sq.Select(jobColumns...).From("job"))
|
||||
if qerr != nil {
|
||||
return nil, qerr
|
||||
}
|
||||
|
||||
if order != nil {
|
||||
field := toSnakeCase(order.Field)
|
||||
|
||||
@ -67,34 +71,15 @@ func (r *JobRepository) queryJobs(
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// testFunction for queryJobs
|
||||
func (r *JobRepository) testQueryJobs(
|
||||
filters []*model.JobFilter,
|
||||
page *model.PageRequest,
|
||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
||||
|
||||
return r.queryJobs(sq.Select(jobColumns...).From("job"), filters, page, order)
|
||||
}
|
||||
|
||||
// Public function with added securityCheck, calls private queryJobs function above
|
||||
func (r *JobRepository) QueryJobs(
|
||||
func (r *JobRepository) CountJobs(
|
||||
ctx context.Context,
|
||||
filters []*model.JobFilter,
|
||||
page *model.PageRequest,
|
||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
||||
|
||||
query, qerr := SecurityCheck(ctx, sq.Select(jobColumns...).From("job"))
|
||||
if qerr != nil {
|
||||
return nil, qerr
|
||||
}
|
||||
|
||||
return r.queryJobs(query, filters, page, order)
|
||||
}
|
||||
|
||||
// SecurityCheck-less, private: Returns the number of jobs matching the filters
|
||||
func (r *JobRepository) countJobs(query sq.SelectBuilder,
|
||||
filters []*model.JobFilter) (int, error) {
|
||||
|
||||
query, qerr := SecurityCheck(ctx, sq.Select("count(*)").From("job"))
|
||||
if qerr != nil {
|
||||
return 0, qerr
|
||||
}
|
||||
|
||||
for _, f := range filters {
|
||||
query = BuildWhereClause(f, query)
|
||||
}
|
||||
@ -107,27 +92,6 @@ func (r *JobRepository) countJobs(query sq.SelectBuilder,
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// testFunction for countJobs
|
||||
func (r *JobRepository) testCountJobs(
|
||||
filters []*model.JobFilter) (int, error) {
|
||||
|
||||
return r.countJobs(sq.Select("count(*)").From("job"), filters)
|
||||
}
|
||||
|
||||
// Public function with added securityCheck, calls private countJobs function above
|
||||
func (r *JobRepository) CountJobs(
|
||||
ctx context.Context,
|
||||
filters []*model.JobFilter) (int, error) {
|
||||
|
||||
query, qerr := SecurityCheck(ctx, sq.Select("count(*)").From("job"))
|
||||
|
||||
if qerr != nil {
|
||||
return 0, qerr
|
||||
}
|
||||
|
||||
return r.countJobs(query, filters)
|
||||
}
|
||||
|
||||
func SecurityCheck(ctx context.Context, query sq.SelectBuilder) (sq.SelectBuilder, error) {
|
||||
user := GetUserFromContext(ctx)
|
||||
if user == nil {
|
||||
|
@ -5,10 +5,12 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/graph/model"
|
||||
"github.com/ClusterCockpit/cc-backend/pkg/log"
|
||||
"github.com/ClusterCockpit/cc-backend/pkg/schema"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
@ -94,7 +96,7 @@ func BenchmarkDB_CountJobs(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := db.testCountJobs([]*model.JobFilter{filter})
|
||||
_, err := db.CountJobs(getContext(b), []*model.JobFilter{filter})
|
||||
noErr(b, err)
|
||||
}
|
||||
})
|
||||
@ -118,20 +120,37 @@ func BenchmarkDB_QueryJobs(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := db.testQueryJobs([]*model.JobFilter{filter}, page, order)
|
||||
_, err := db.QueryJobs(getContext(b), []*model.JobFilter{filter}, page, order)
|
||||
noErr(b, err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func getContext(tb testing.TB) context.Context {
|
||||
tb.Helper()
|
||||
|
||||
var roles []string
|
||||
roles = append(roles, schema.GetRoleString(schema.RoleAdmin))
|
||||
projects := make([]string, 0)
|
||||
|
||||
user := &schema.User{
|
||||
Username: "demo",
|
||||
Name: "The man",
|
||||
Roles: roles,
|
||||
Projects: projects,
|
||||
AuthSource: schema.AuthViaLDAP,
|
||||
}
|
||||
ctx := context.Background()
|
||||
return context.WithValue(ctx, ContextUserKey, user)
|
||||
}
|
||||
|
||||
func setup(tb testing.TB) *JobRepository {
|
||||
tb.Helper()
|
||||
log.Init("warn", true)
|
||||
dbfile := "testdata/job.db"
|
||||
err := MigrateDB("sqlite3", dbfile)
|
||||
noErr(tb, err)
|
||||
|
||||
Connect("sqlite3", dbfile)
|
||||
return GetJobRepository()
|
||||
}
|
||||
|
@ -233,10 +233,17 @@ func (r *JobRepository) JobsStatsGrouped(
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *JobRepository) jobsStats(
|
||||
query sq.SelectBuilder,
|
||||
func (r *JobRepository) JobsStats(
|
||||
ctx context.Context,
|
||||
filter []*model.JobFilter) ([]*model.JobsStatistics, error) {
|
||||
|
||||
start := time.Now()
|
||||
query := r.buildStatsQuery(filter, "")
|
||||
query, err := SecurityCheck(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
row := query.RunWith(r.DB).QueryRow()
|
||||
stats := make([]*model.JobsStatistics, 0, 1)
|
||||
|
||||
@ -267,29 +274,8 @@ func (r *JobRepository) jobsStats(
|
||||
TotalAccHours: totalAccHours})
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *JobRepository) testJobsStats(
|
||||
filter []*model.JobFilter) ([]*model.JobsStatistics, error) {
|
||||
|
||||
query := r.buildStatsQuery(filter, "")
|
||||
return r.jobsStats(query, filter)
|
||||
}
|
||||
|
||||
func (r *JobRepository) JobsStats(
|
||||
ctx context.Context,
|
||||
filter []*model.JobFilter) ([]*model.JobsStatistics, error) {
|
||||
|
||||
start := time.Now()
|
||||
query := r.buildStatsQuery(filter, "")
|
||||
query, err := SecurityCheck(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("Timer JobStats %s", time.Since(start))
|
||||
return r.jobsStats(query, filter)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *JobRepository) JobCountGrouped(
|
||||
|
@ -26,12 +26,10 @@ func TestJobStats(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
filter := &model.JobFilter{}
|
||||
var err error
|
||||
var stats []*model.JobsStatistics
|
||||
stats, err = r.testJobsStats([]*model.JobFilter{filter})
|
||||
stats, err := r.JobsStats(getContext(t), []*model.JobFilter{filter})
|
||||
noErr(t, err)
|
||||
|
||||
if stats[0].TotalJobs != 98 {
|
||||
if stats[0].TotalJobs != 6 {
|
||||
t.Fatalf("Want 98, Got %d", stats[0].TotalJobs)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user