mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2024-12-26 05:19:05 +01:00
Refactor repository tests
Add context to tests. Remove special test routines
This commit is contained in:
parent
59c749a164
commit
9533f06eaf
@ -18,13 +18,17 @@ import (
|
|||||||
sq "github.com/Masterminds/squirrel"
|
sq "github.com/Masterminds/squirrel"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecurityCheck-less, private: Returns a list of jobs matching the provided filters. page and order are optional-
|
func (r *JobRepository) QueryJobs(
|
||||||
func (r *JobRepository) queryJobs(
|
ctx context.Context,
|
||||||
query sq.SelectBuilder,
|
|
||||||
filters []*model.JobFilter,
|
filters []*model.JobFilter,
|
||||||
page *model.PageRequest,
|
page *model.PageRequest,
|
||||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
order *model.OrderByInput) ([]*schema.Job, error) {
|
||||||
|
|
||||||
|
query, qerr := SecurityCheck(ctx, sq.Select(jobColumns...).From("job"))
|
||||||
|
if qerr != nil {
|
||||||
|
return nil, qerr
|
||||||
|
}
|
||||||
|
|
||||||
if order != nil {
|
if order != nil {
|
||||||
field := toSnakeCase(order.Field)
|
field := toSnakeCase(order.Field)
|
||||||
|
|
||||||
@ -67,34 +71,15 @@ func (r *JobRepository) queryJobs(
|
|||||||
return jobs, nil
|
return jobs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// testFunction for queryJobs
|
func (r *JobRepository) CountJobs(
|
||||||
func (r *JobRepository) testQueryJobs(
|
|
||||||
filters []*model.JobFilter,
|
|
||||||
page *model.PageRequest,
|
|
||||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
|
||||||
|
|
||||||
return r.queryJobs(sq.Select(jobColumns...).From("job"), filters, page, order)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Public function with added securityCheck, calls private queryJobs function above
|
|
||||||
func (r *JobRepository) QueryJobs(
|
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
filters []*model.JobFilter,
|
|
||||||
page *model.PageRequest,
|
|
||||||
order *model.OrderByInput) ([]*schema.Job, error) {
|
|
||||||
|
|
||||||
query, qerr := SecurityCheck(ctx, sq.Select(jobColumns...).From("job"))
|
|
||||||
if qerr != nil {
|
|
||||||
return nil, qerr
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.queryJobs(query, filters, page, order)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SecurityCheck-less, private: Returns the number of jobs matching the filters
|
|
||||||
func (r *JobRepository) countJobs(query sq.SelectBuilder,
|
|
||||||
filters []*model.JobFilter) (int, error) {
|
filters []*model.JobFilter) (int, error) {
|
||||||
|
|
||||||
|
query, qerr := SecurityCheck(ctx, sq.Select("count(*)").From("job"))
|
||||||
|
if qerr != nil {
|
||||||
|
return 0, qerr
|
||||||
|
}
|
||||||
|
|
||||||
for _, f := range filters {
|
for _, f := range filters {
|
||||||
query = BuildWhereClause(f, query)
|
query = BuildWhereClause(f, query)
|
||||||
}
|
}
|
||||||
@ -107,27 +92,6 @@ func (r *JobRepository) countJobs(query sq.SelectBuilder,
|
|||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// testFunction for countJobs
|
|
||||||
func (r *JobRepository) testCountJobs(
|
|
||||||
filters []*model.JobFilter) (int, error) {
|
|
||||||
|
|
||||||
return r.countJobs(sq.Select("count(*)").From("job"), filters)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Public function with added securityCheck, calls private countJobs function above
|
|
||||||
func (r *JobRepository) CountJobs(
|
|
||||||
ctx context.Context,
|
|
||||||
filters []*model.JobFilter) (int, error) {
|
|
||||||
|
|
||||||
query, qerr := SecurityCheck(ctx, sq.Select("count(*)").From("job"))
|
|
||||||
|
|
||||||
if qerr != nil {
|
|
||||||
return 0, qerr
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.countJobs(query, filters)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SecurityCheck(ctx context.Context, query sq.SelectBuilder) (sq.SelectBuilder, error) {
|
func SecurityCheck(ctx context.Context, query sq.SelectBuilder) (sq.SelectBuilder, error) {
|
||||||
user := GetUserFromContext(ctx)
|
user := GetUserFromContext(ctx)
|
||||||
if user == nil {
|
if user == nil {
|
||||||
|
@ -5,10 +5,12 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ClusterCockpit/cc-backend/internal/graph/model"
|
"github.com/ClusterCockpit/cc-backend/internal/graph/model"
|
||||||
"github.com/ClusterCockpit/cc-backend/pkg/log"
|
"github.com/ClusterCockpit/cc-backend/pkg/log"
|
||||||
|
"github.com/ClusterCockpit/cc-backend/pkg/schema"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,7 +96,7 @@ func BenchmarkDB_CountJobs(b *testing.B) {
|
|||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
_, err := db.testCountJobs([]*model.JobFilter{filter})
|
_, err := db.CountJobs(getContext(b), []*model.JobFilter{filter})
|
||||||
noErr(b, err)
|
noErr(b, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -118,20 +120,37 @@ func BenchmarkDB_QueryJobs(b *testing.B) {
|
|||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
_, err := db.testQueryJobs([]*model.JobFilter{filter}, page, order)
|
_, err := db.QueryJobs(getContext(b), []*model.JobFilter{filter}, page, order)
|
||||||
noErr(b, err)
|
noErr(b, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getContext(tb testing.TB) context.Context {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
var roles []string
|
||||||
|
roles = append(roles, schema.GetRoleString(schema.RoleAdmin))
|
||||||
|
projects := make([]string, 0)
|
||||||
|
|
||||||
|
user := &schema.User{
|
||||||
|
Username: "demo",
|
||||||
|
Name: "The man",
|
||||||
|
Roles: roles,
|
||||||
|
Projects: projects,
|
||||||
|
AuthSource: schema.AuthViaLDAP,
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
return context.WithValue(ctx, ContextUserKey, user)
|
||||||
|
}
|
||||||
|
|
||||||
func setup(tb testing.TB) *JobRepository {
|
func setup(tb testing.TB) *JobRepository {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
log.Init("warn", true)
|
log.Init("warn", true)
|
||||||
dbfile := "testdata/job.db"
|
dbfile := "testdata/job.db"
|
||||||
err := MigrateDB("sqlite3", dbfile)
|
err := MigrateDB("sqlite3", dbfile)
|
||||||
noErr(tb, err)
|
noErr(tb, err)
|
||||||
|
|
||||||
Connect("sqlite3", dbfile)
|
Connect("sqlite3", dbfile)
|
||||||
return GetJobRepository()
|
return GetJobRepository()
|
||||||
}
|
}
|
||||||
|
@ -233,10 +233,17 @@ func (r *JobRepository) JobsStatsGrouped(
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *JobRepository) jobsStats(
|
func (r *JobRepository) JobsStats(
|
||||||
query sq.SelectBuilder,
|
ctx context.Context,
|
||||||
filter []*model.JobFilter) ([]*model.JobsStatistics, error) {
|
filter []*model.JobFilter) ([]*model.JobsStatistics, error) {
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
query := r.buildStatsQuery(filter, "")
|
||||||
|
query, err := SecurityCheck(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
row := query.RunWith(r.DB).QueryRow()
|
row := query.RunWith(r.DB).QueryRow()
|
||||||
stats := make([]*model.JobsStatistics, 0, 1)
|
stats := make([]*model.JobsStatistics, 0, 1)
|
||||||
|
|
||||||
@ -267,29 +274,8 @@ func (r *JobRepository) jobsStats(
|
|||||||
TotalAccHours: totalAccHours})
|
TotalAccHours: totalAccHours})
|
||||||
}
|
}
|
||||||
|
|
||||||
return stats, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *JobRepository) testJobsStats(
|
|
||||||
filter []*model.JobFilter) ([]*model.JobsStatistics, error) {
|
|
||||||
|
|
||||||
query := r.buildStatsQuery(filter, "")
|
|
||||||
return r.jobsStats(query, filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *JobRepository) JobsStats(
|
|
||||||
ctx context.Context,
|
|
||||||
filter []*model.JobFilter) ([]*model.JobsStatistics, error) {
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
query := r.buildStatsQuery(filter, "")
|
|
||||||
query, err := SecurityCheck(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Timer JobStats %s", time.Since(start))
|
log.Debugf("Timer JobStats %s", time.Since(start))
|
||||||
return r.jobsStats(query, filter)
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *JobRepository) JobCountGrouped(
|
func (r *JobRepository) JobCountGrouped(
|
||||||
|
@ -26,12 +26,10 @@ func TestJobStats(t *testing.T) {
|
|||||||
r := setup(t)
|
r := setup(t)
|
||||||
|
|
||||||
filter := &model.JobFilter{}
|
filter := &model.JobFilter{}
|
||||||
var err error
|
stats, err := r.JobsStats(getContext(t), []*model.JobFilter{filter})
|
||||||
var stats []*model.JobsStatistics
|
|
||||||
stats, err = r.testJobsStats([]*model.JobFilter{filter})
|
|
||||||
noErr(t, err)
|
noErr(t, err)
|
||||||
|
|
||||||
if stats[0].TotalJobs != 98 {
|
if stats[0].TotalJobs != 6 {
|
||||||
t.Fatalf("Want 98, Got %d", stats[0].TotalJobs)
|
t.Fatalf("Want 98, Got %d", stats[0].TotalJobs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user