db pgx backend

This commit is contained in:
Pay Giesselmann 2024-07-16 17:58:01 +02:00
parent 93c515098c
commit a8aa92ad9d
31 changed files with 339 additions and 104 deletions

5
go.mod
View File

@ -17,6 +17,7 @@ require (
github.com/gorilla/mux v1.8.1
github.com/gorilla/sessions v1.3.0
github.com/influxdata/influxdb-client-go/v2 v2.13.0
github.com/jackc/pgx/v5 v5.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/prometheus/client_golang v1.19.1
@ -54,11 +55,15 @@ require (
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/influxdata/line-protocol v0.0.0-20210922203350-b1ad95c89adf // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/jpillora/backoff v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect

9
go.sum
View File

@ -119,6 +119,14 @@ github.com/influxdata/influxdb-client-go/v2 v2.13.0 h1:ioBbLmR5NMbAjP4UVA5r9b5xG
github.com/influxdata/influxdb-client-go/v2 v2.13.0/go.mod h1:k+spCbt9hcvqvUiz0sr5D8LolXHqAAOfPw9v/RIRHl4=
github.com/influxdata/line-protocol v0.0.0-20210922203350-b1ad95c89adf h1:7JTmneyiNEwVBOHSjoMxiWAqB992atOeepeFYegn5RU=
github.com/influxdata/line-protocol v0.0.0-20210922203350-b1ad95c89adf/go.mod h1:xaLFMmpvUxqXtVkUJfg9QmT88cDaCJ3ZKgdZ78oO8Qo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
@ -219,6 +227,7 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=

View File

@ -97,7 +97,7 @@ func InitDB() error {
continue
}
id, err := r.TransactionAdd(t, job)
id, err := r.TransactionAdd(t, &job)
if err != nil {
log.Errorf("repository initDB(): %v", err)
errorOccured++

View File

@ -10,6 +10,8 @@ import (
"time"
"github.com/ClusterCockpit/cc-backend/pkg/log"
sqrl "github.com/Masterminds/squirrel"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
"github.com/qustavo/sqlhooks/v2"
@ -22,6 +24,7 @@ var (
type DBConnection struct {
DB *sqlx.DB
SQ sqrl.StatementBuilderType
Driver string
}
@ -46,6 +49,8 @@ func Connect(driver string, db string) {
ConnectionMaxIdleTime: time.Hour,
}
sq := sqrl.StatementBuilderType{}
switch driver {
case "sqlite3":
// - Set WAL mode (not strictly necessary each time because it's persisted in the database, but good for first run)
@ -68,6 +73,13 @@ func Connect(driver string, db string) {
if err != nil {
log.Fatalf("sqlx.Open() error: %v", err)
}
case "postgres":
opts.URL += ""
dbHandle, err = sqlx.Open("pgx", opts.URL)
sq = sqrl.StatementBuilder.PlaceholderFormat(sqrl.Dollar)
if err != nil {
log.Fatalf("sqlx.Open() error: %v", err)
}
default:
log.Fatalf("unsupported database driver: %s", driver)
}
@ -77,7 +89,10 @@ func Connect(driver string, db string) {
dbHandle.SetConnMaxLifetime(opts.ConnectionMaxLifetime)
dbHandle.SetConnMaxIdleTime(opts.ConnectionMaxIdleTime)
dbConnInstance = &DBConnection{DB: dbHandle, Driver: driver}
dbConnInstance = &DBConnection{
DB: dbHandle,
SQ: sq,
Driver: driver}
err = checkDBVersion(driver, dbHandle.DB)
if err != nil {
log.Fatal(err)

View File

@ -31,6 +31,7 @@ var (
type JobRepository struct {
DB *sqlx.DB
SQ sq.StatementBuilderType
stmtCache *sq.StmtCache
cache *lrucache.Cache
archiveChannel chan *schema.Job
@ -44,6 +45,7 @@ func GetJobRepository() *JobRepository {
jobRepoInstance = &JobRepository{
DB: db.DB,
SQ: db.SQ,
driver: db.Driver,
stmtCache: sq.NewStmtCache(db.DB),
@ -107,6 +109,10 @@ func (r *JobRepository) Optimize() error {
}
case "mysql":
log.Info("Optimize currently not supported for mysql driver")
case "postgres":
if _, err = r.DB.Exec(`VACUUM`); err != nil {
return err
}
}
return nil
@ -142,6 +148,16 @@ func (r *JobRepository) Flush() error {
if _, err = r.DB.Exec(`SET FOREIGN_KEY_CHECKS = 1`); err != nil {
return err
}
case "postgres":
if _, err = r.DB.Exec(`DELETE FROM jobtag`); err != nil {
return err
}
if _, err = r.DB.Exec(`DELETE FROM tag`); err != nil {
return err
}
if _, err = r.DB.Exec(`DELETE FROM job`); err != nil {
return err
}
}
return nil
@ -166,7 +182,7 @@ func (r *JobRepository) FetchMetadata(job *schema.Job) (map[string]string, error
return job.MetaData, nil
}
if err := sq.Select("job.meta_data").From("job").Where("job.id = ?", job.ID).
if err := r.SQ.Select("job.meta_data").From("job").Where("job.id = ?", job.ID).
RunWith(r.stmtCache).QueryRow().Scan(&job.RawMetaData); err != nil {
log.Warn("Error while scanning for job metadata")
return nil, err
@ -212,7 +228,7 @@ func (r *JobRepository) UpdateMetadata(job *schema.Job, key, val string) (err er
return err
}
if _, err = sq.Update("job").Set("meta_data", job.RawMetaData).Where("job.id = ?", job.ID).RunWith(r.stmtCache).Exec(); err != nil {
if _, err = r.SQ.Update("job").Set("meta_data", job.RawMetaData).Where("job.id = ?", job.ID).RunWith(r.stmtCache).Exec(); err != nil {
log.Warnf("Error while updating metadata for job, DB ID '%v'", job.ID)
return err
}
@ -229,7 +245,7 @@ func (r *JobRepository) FetchFootprint(job *schema.Job) (map[string]float64, err
return job.Footprint, nil
}
if err := sq.Select("job.footprint").From("job").Where("job.id = ?", job.ID).
if err := r.SQ.Select("job.footprint").From("job").Where("job.id = ?", job.ID).
RunWith(r.stmtCache).QueryRow().Scan(&job.RawFootprint); err != nil {
log.Warn("Error while scanning for job footprint")
return nil, err
@ -251,9 +267,9 @@ func (r *JobRepository) FetchFootprint(job *schema.Job) (map[string]float64, err
func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
var cnt int
q := sq.Select("count(*)").From("job").Where("job.start_time < ?", startTime)
q := r.SQ.Select("count(*)").From("job").Where("job.start_time < ?", startTime)
q.RunWith(r.DB).QueryRow().Scan(cnt)
qd := sq.Delete("job").Where("job.start_time < ?", startTime)
qd := r.SQ.Delete("job").Where("job.start_time < ?", startTime)
_, err := qd.RunWith(r.DB).Exec()
if err != nil {
@ -266,7 +282,7 @@ func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
}
func (r *JobRepository) DeleteJobById(id int64) error {
qd := sq.Delete("job").Where("job.id = ?", id)
qd := r.SQ.Delete("job").Where("job.id = ?", id)
_, err := qd.RunWith(r.DB).Exec()
if err != nil {
@ -279,7 +295,7 @@ func (r *JobRepository) DeleteJobById(id int64) error {
}
func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) {
stmt := sq.Update("job").
stmt := r.SQ.Update("job").
Set("monitoring_status", monitoringStatus).
Where("job.id = ?", job)
@ -292,7 +308,7 @@ func (r *JobRepository) MarkArchived(
jobMeta *schema.JobMeta,
monitoringStatus int32,
) error {
stmt := sq.Update("job").
stmt := r.SQ.Update("job").
Set("monitoring_status", monitoringStatus).
Where("job.id = ?", jobMeta.JobID)
@ -412,7 +428,7 @@ func (r *JobRepository) FindColumnValue(user *schema.User, searchterm string, ta
query = "%" + searchterm + "%"
}
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport, schema.RoleManager}) {
theQuery := sq.Select(table+"."+selectColumn).Distinct().From(table).
theQuery := r.SQ.Select(table+"."+selectColumn).Distinct().From(table).
Where(table+"."+whereColumn+compareStr, query)
// theSql, args, theErr := theQuery.ToSql()
@ -439,7 +455,7 @@ func (r *JobRepository) FindColumnValue(user *schema.User, searchterm string, ta
func (r *JobRepository) FindColumnValues(user *schema.User, query string, table string, selectColumn string, whereColumn string) (results []string, err error) {
emptyResult := make([]string, 0)
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport, schema.RoleManager}) {
rows, err := sq.Select(table+"."+selectColumn).Distinct().From(table).
rows, err := r.SQ.Select(table+"."+selectColumn).Distinct().From(table).
Where(table+"."+whereColumn+" LIKE ?", fmt.Sprint("%", query, "%")).
RunWith(r.stmtCache).Query()
if err != nil && err != sql.ErrNoRows {
@ -488,7 +504,7 @@ func (r *JobRepository) Partitions(cluster string) ([]string, error) {
func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]int, error) {
start := time.Now()
subclusters := make(map[string]map[string]int)
rows, err := sq.Select("resources", "subcluster").From("job").
rows, err := r.SQ.Select("resources", "subcluster").From("job").
Where("job.job_state = 'running'").
Where("job.cluster = ?", cluster).
RunWith(r.stmtCache).Query()
@ -529,7 +545,7 @@ func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]in
func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error {
start := time.Now()
res, err := sq.Update("job").
res, err := r.SQ.Update("job").
Set("monitoring_status", schema.MonitoringStatusArchivingFailed).
Set("duration", 0).
Set("job_state", schema.JobStateFailed).
@ -564,11 +580,11 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
if startTimeBegin == 0 {
log.Infof("Find jobs before %d", startTimeEnd)
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
query = r.SQ.Select(jobColumns...).From("job").Where(fmt.Sprintf(
"job.start_time < %d", startTimeEnd))
} else {
log.Infof("Find jobs between %d and %d", startTimeBegin, startTimeEnd)
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
query = r.SQ.Select(jobColumns...).From("job").Where(fmt.Sprintf(
"job.start_time BETWEEN %d AND %d", startTimeBegin, startTimeEnd))
}

View File

@ -7,29 +7,44 @@ package repository
import (
"encoding/json"
"fmt"
"time"
"github.com/ClusterCockpit/cc-backend/pkg/log"
"github.com/ClusterCockpit/cc-backend/pkg/schema"
sq "github.com/Masterminds/squirrel"
)
// TODO conditional on r.driver
// ` + "`partition`" + `
const NamedJobInsert string = `INSERT INTO job (
job_id, user, project, cluster, subcluster, ` + "`partition`" + `, array_job_id, num_nodes, num_hwthreads, num_acc,
job_id, "user", project, cluster, subcluster, "partition", array_job_id, num_nodes, num_hwthreads, num_acc,
exclusive, monitoring_status, smt, job_state, start_time, duration, walltime, footprint, resources, meta_data
) VALUES (
:job_id, :user, :project, :cluster, :subcluster, :partition, :array_job_id, :num_nodes, :num_hwthreads, :num_acc,
:exclusive, :monitoring_status, :smt, :job_state, :start_time, :duration, :walltime, :footprint, :resources, :meta_data
);`
func (r *JobRepository) InsertJob(job *schema.JobMeta) (int64, error) {
res, err := r.DB.NamedExec(NamedJobInsert, job)
func (r *JobRepository) InsertJob(jobMeta *schema.JobMeta) (int64, error) {
//res, err := r.DB.NamedExec(NamedJobInsert, job)
job := schema.Job{
BaseJob: jobMeta.BaseJob,
StartTime: time.Unix(jobMeta.StartTime, 0),
StartTimeUnix: jobMeta.StartTime,
}
t, err := r.TransactionInit()
if err != nil {
log.Warn("Error while initializing SQL transactions")
return 0, err
}
id, err := r.TransactionAdd(t, &job)
if err != nil {
log.Warn("Error while NamedJobInsert")
return 0, err
}
id, err := res.LastInsertId()
err = r.TransactionEnd(t)
//id, err := res.LastInsertId()
if err != nil {
log.Warn("Error while getting last insert ID")
return 0, err
}
@ -64,7 +79,7 @@ func (r *JobRepository) Stop(
state schema.JobState,
monitoringStatus int32,
) (err error) {
stmt := sq.Update("job").
stmt := r.SQ.Update("job").
Set("job_state", state).
Set("duration", duration).
Set("monitoring_status", monitoringStatus).

View File

@ -13,7 +13,6 @@ import (
"github.com/ClusterCockpit/cc-backend/internal/graph/model"
"github.com/ClusterCockpit/cc-backend/pkg/log"
"github.com/ClusterCockpit/cc-backend/pkg/schema"
sq "github.com/Masterminds/squirrel"
)
// Find executes a SQL query to find a specific batch job.
@ -27,7 +26,7 @@ func (r *JobRepository) Find(
startTime *int64,
) (*schema.Job, error) {
start := time.Now()
q := sq.Select(jobColumns...).From("job").
q := r.SQ.Select(jobColumns...).From("job").
Where("job.job_id = ?", *jobId)
if cluster != nil {
@ -53,7 +52,7 @@ func (r *JobRepository) FindAll(
startTime *int64,
) ([]*schema.Job, error) {
start := time.Now()
q := sq.Select(jobColumns...).From("job").
q := r.SQ.Select(jobColumns...).From("job").
Where("job.job_id = ?", *jobId)
if cluster != nil {
@ -87,10 +86,10 @@ func (r *JobRepository) FindAll(
// It returns a pointer to a schema.Job data structure and an error variable.
// To check if no job was found test err == sql.ErrNoRows
func (r *JobRepository) FindById(ctx context.Context, jobId int64) (*schema.Job, error) {
q := sq.Select(jobColumns...).
q := r.SQ.Select(jobColumns...).
From("job").Where("job.id = ?", jobId)
q, qerr := SecurityCheck(ctx, q)
q, qerr := r.SecurityCheck(ctx, q)
if qerr != nil {
return nil, qerr
}
@ -103,7 +102,7 @@ func (r *JobRepository) FindById(ctx context.Context, jobId int64) (*schema.Job,
// It returns a pointer to a schema.Job data structure and an error variable.
// To check if no job was found test err == sql.ErrNoRows
func (r *JobRepository) FindByIdDirect(jobId int64) (*schema.Job, error) {
q := sq.Select(jobColumns...).
q := r.SQ.Select(jobColumns...).
From("job").Where("job.id = ?", jobId)
return scanJob(q.RunWith(r.stmtCache).QueryRow())
}
@ -113,13 +112,13 @@ func (r *JobRepository) FindByIdDirect(jobId int64) (*schema.Job, error) {
// It returns a pointer to a schema.Job data structure and an error variable.
// To check if no job was found test err == sql.ErrNoRows
func (r *JobRepository) FindByJobId(ctx context.Context, jobId int64, startTime int64, cluster string) (*schema.Job, error) {
q := sq.Select(jobColumns...).
q := r.SQ.Select(jobColumns...).
From("job").
Where("job.job_id = ?", jobId).
Where("job.cluster = ?", cluster).
Where("job.start_time = ?", startTime)
q, qerr := SecurityCheck(ctx, q)
q, qerr := r.SecurityCheck(ctx, q)
if qerr != nil {
return nil, qerr
}
@ -132,7 +131,7 @@ func (r *JobRepository) FindByJobId(ctx context.Context, jobId int64, startTime
// It returns a bool.
// If job was found, user is owner: test err != sql.ErrNoRows
func (r *JobRepository) IsJobOwner(jobId int64, startTime int64, user string, cluster string) bool {
q := sq.Select("id").
q := r.SQ.Select("id").
From("job").
Where("job.job_id = ?", jobId).
Where("job.user = ?", user).
@ -151,7 +150,7 @@ func (r *JobRepository) FindConcurrentJobs(
return nil, nil
}
query, qerr := SecurityCheck(ctx, sq.Select("job.id", "job.job_id", "job.start_time").From("job"))
query, qerr := r.SecurityCheck(ctx, r.SQ.Select("job.id", "job.job_id", "job.start_time").From("job"))
if qerr != nil {
return nil, qerr
}

View File

@ -24,7 +24,7 @@ func (r *JobRepository) QueryJobs(
page *model.PageRequest,
order *model.OrderByInput,
) ([]*schema.Job, error) {
query, qerr := SecurityCheck(ctx, sq.Select(jobColumns...).From("job"))
query, qerr := r.SecurityCheck(ctx, r.SQ.Select(jobColumns...).From("job"))
if qerr != nil {
return nil, qerr
}
@ -75,7 +75,7 @@ func (r *JobRepository) CountJobs(
ctx context.Context,
filters []*model.JobFilter,
) (int, error) {
query, qerr := SecurityCheck(ctx, sq.Select("count(*)").From("job"))
query, qerr := r.SecurityCheck(ctx, r.SQ.Select("count(*)").From("job"))
if qerr != nil {
return 0, qerr
}
@ -92,11 +92,11 @@ func (r *JobRepository) CountJobs(
return count, nil
}
func SecurityCheck(ctx context.Context, query sq.SelectBuilder) (sq.SelectBuilder, error) {
func (r *JobRepository) SecurityCheck(ctx context.Context, query sq.SelectBuilder) (sq.SelectBuilder, error) {
user := GetUserFromContext(ctx)
if user == nil {
var qnil sq.SelectBuilder
return qnil, fmt.Errorf("user context is nil")
//var qnil sq.SelectBuilder
return r.SQ.Select(), fmt.Errorf("user context is nil")
}
switch {
@ -114,8 +114,8 @@ func SecurityCheck(ctx context.Context, query sq.SelectBuilder) (sq.SelectBuilde
case user.HasRole(schema.RoleUser): // User : Only personal jobs
return query.Where("job.user = ?", user.Username), nil
default: // No known Role, return error
var qnil sq.SelectBuilder
return qnil, fmt.Errorf("user has no or unknown roles")
//var qnil sq.SelectBuilder
return r.SQ.Select(), fmt.Errorf("user has no or unknown roles")
}
}

View File

@ -12,6 +12,7 @@ import (
"github.com/ClusterCockpit/cc-backend/pkg/log"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/mysql"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
)
@ -53,6 +54,20 @@ func checkDBVersion(backend string, db *sql.DB) error {
if err != nil {
return err
}
case "postgres":
driver, err := postgres.WithInstance(db, &postgres.Config{})
if err != nil {
return err
}
d, err := iofs.New(migrationFiles, "migrations/postgres")
if err != nil {
return err
}
m, err = migrate.NewWithInstance("iofs", d, "postgres", driver)
if err != nil {
return err
}
default:
log.Fatalf("unsupported database backend: %s", backend)
}
@ -101,6 +116,16 @@ func getMigrateInstance(backend string, db string) (m *migrate.Migrate, err erro
if err != nil {
return m, err
}
case "postgres":
d, err := iofs.New(migrationFiles, "migrations/postgres")
if err != nil {
return m, err
}
m, err = migrate.NewWithSourceInstance("iofs", d, db)
if err != nil {
return m, err
}
default:
log.Fatalf("unsupported database backend: %s", backend)
}

View File

@ -0,0 +1,6 @@
DROP TABLE IF EXISTS job_meta;
DROP TABLE IF EXISTS configuration;
DROP TABLE IF EXISTS jobtag;
DROP TABLE IF EXISTS tag;
DROP TABLE IF EXISTS "user";
DROP TABLE IF EXISTS schema_migrations;

View File

@ -0,0 +1,68 @@
CREATE TABLE IF NOT EXISTS job (
id SERIAL PRIMARY KEY,
job_id BIGINT NOT NULL,
cluster VARCHAR(255) NOT NULL,
subcluster VARCHAR(255) NOT NULL,
start_time BIGINT NOT NULL, -- Unix timestamp
"user" VARCHAR(255) NOT NULL,
project VARCHAR(255) NOT NULL,
"partition" VARCHAR(255) NOT NULL,
array_job_id BIGINT NOT NULL,
duration INT NOT NULL DEFAULT 0,
walltime INT NOT NULL DEFAULT 0,
job_state VARCHAR(255) NOT NULL
CHECK (job_state IN ('running', 'completed', 'failed', 'cancelled',
'stopped', 'timeout', 'preempted', 'out_of_memory')),
meta_data TEXT, -- JSON
resources TEXT NOT NULL, -- JSON
num_nodes INT NOT NULL,
num_hwthreads INT NOT NULL,
num_acc INT NOT NULL,
smt SMALLINT NOT NULL DEFAULT 1 CHECK (smt IN (0, 1)),
exclusive SMALLINT NOT NULL DEFAULT 1 CHECK (exclusive IN (0, 1, 2)),
monitoring_status SMALLINT NOT NULL DEFAULT 1 CHECK (monitoring_status IN (0, 1, 2, 3)),
mem_used_max REAL NOT NULL DEFAULT 0.0,
flops_any_avg REAL NOT NULL DEFAULT 0.0,
mem_bw_avg REAL NOT NULL DEFAULT 0.0,
load_avg REAL NOT NULL DEFAULT 0.0,
net_bw_avg REAL NOT NULL DEFAULT 0.0,
net_data_vol_total REAL NOT NULL DEFAULT 0.0,
file_bw_avg REAL NOT NULL DEFAULT 0.0,
file_data_vol_total REAL NOT NULL DEFAULT 0.0,
UNIQUE (job_id, cluster, start_time)
);
CREATE TABLE IF NOT EXISTS tag (
id SERIAL PRIMARY KEY,
tag_type VARCHAR(255) NOT NULL,
tag_name VARCHAR(255) NOT NULL,
UNIQUE (tag_type, tag_name)
);
CREATE TABLE IF NOT EXISTS jobtag (
job_id INTEGER,
tag_id INTEGER,
PRIMARY KEY (job_id, tag_id),
FOREIGN KEY (job_id) REFERENCES job (id) ON DELETE CASCADE,
FOREIGN KEY (tag_id) REFERENCES tag (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS "user" (
username VARCHAR(255) PRIMARY KEY NOT NULL,
password VARCHAR(255) DEFAULT NULL,
ldap SMALLINT NOT NULL DEFAULT 0, -- "ldap" for historic reasons, fills the "AuthSource"
name VARCHAR(255) DEFAULT NULL,
roles VARCHAR(255) NOT NULL DEFAULT '[]',
email VARCHAR(255) DEFAULT NULL
);
CREATE TABLE IF NOT EXISTS configuration (
username VARCHAR(255),
confkey VARCHAR(255),
value VARCHAR(255),
PRIMARY KEY (username, confkey),
FOREIGN KEY (username) REFERENCES "user" (username) ON DELETE CASCADE ON UPDATE NO ACTION
);

View File

@ -0,0 +1,8 @@
DROP INDEX IF EXISTS job_stats;
DROP INDEX IF EXISTS job_by_user;
DROP INDEX IF EXISTS job_by_starttime;
DROP INDEX IF EXISTS job_by_job_id;
DROP INDEX IF EXISTS job_list;
DROP INDEX IF EXISTS job_list_user;
DROP INDEX IF EXISTS job_list_users;
DROP INDEX IF EXISTS job_list_users_start;

View File

@ -0,0 +1,8 @@
CREATE INDEX IF NOT EXISTS job_stats ON job (cluster, subcluster, "user");
CREATE INDEX IF NOT EXISTS job_by_user ON job ("user");
CREATE INDEX IF NOT EXISTS job_by_starttime ON job (start_time);
CREATE INDEX IF NOT EXISTS job_by_job_id ON job (job_id);
CREATE INDEX IF NOT EXISTS job_list ON job (cluster, job_state);
CREATE INDEX IF NOT EXISTS job_list_user ON job ("user", cluster, job_state);
CREATE INDEX IF NOT EXISTS job_list_users ON job ("user", job_state);
CREATE INDEX IF NOT EXISTS job_list_users_start ON job (start_time, "user", job_state);

View File

@ -0,0 +1 @@
ALTER TABLE user DROP COLUMN projects;

View File

@ -0,0 +1 @@
ALTER TABLE "user" ADD COLUMN projects VARCHAR(255) NOT NULL DEFAULT '[]';

View File

@ -0,0 +1,5 @@
ALTER TABLE job
MODIFY `partition` VARCHAR(255) NOT NULL,
MODIFY array_job_id BIGINT NOT NULL,
MODIFY num_hwthreads INT NOT NULL,
MODIFY num_acc INT NOT NULL;

View File

@ -0,0 +1,5 @@
ALTER TABLE job
ALTER COLUMN partition TYPE VARCHAR(255),
ALTER COLUMN array_job_id TYPE BIGINT,
ALTER COLUMN num_hwthreads TYPE INT,
ALTER COLUMN num_acc TYPE INT;

View File

@ -0,0 +1,2 @@
ALTER TABLE tag DROP COLUMN insert_time;
ALTER TABLE jobtag DROP COLUMN insert_time;

View File

@ -0,0 +1,2 @@
ALTER TABLE tag ADD COLUMN insert_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
ALTER TABLE jobtag ADD COLUMN insert_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP;

View File

@ -0,0 +1 @@
ALTER TABLE configuration MODIFY value VARCHAR(255);

View File

@ -0,0 +1,2 @@
ALTER TABLE configuration
ALTER COLUMN value TYPE TEXT;

View File

@ -0,0 +1,3 @@
SET FOREIGN_KEY_CHECKS = 0;
ALTER TABLE tag MODIFY id INTEGER;
SET FOREIGN_KEY_CHECKS = 1;

View File

@ -0,0 +1,3 @@
-- SET FOREIGN_KEY_CHECKS = 0;
-- ALTER TABLE tag MODIFY id INTEGER AUTO_INCREMENT;
-- SET FOREIGN_KEY_CHECKS = 1;

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS job_meta;

View File

@ -0,0 +1,20 @@
ALTER TABLE job ADD COLUMN energy REAL NOT NULL DEFAULT 0.0;
ALTER TABLE job ADD COLUMN footprint TEXT DEFAULT NULL;
ALTER TABLE job DROP flops_any_avg;
ALTER TABLE job DROP mem_bw_avg;
ALTER TABLE job DROP mem_used_max;
ALTER TABLE job DROP load_avg;
ALTER TABLE "user" RENAME TO users;
CREATE TABLE IF NOT EXISTS job_meta (
id SERIAL PRIMARY KEY,
job_id BIGINT NOT NULL,
cluster VARCHAR(255) NOT NULL,
start_time BIGINT NOT NULL, -- Unix timestamp
meta_data JSONB, -- JSON
metric_data JSONB, -- JSON
UNIQUE (job_id, cluster, start_time)
);

View File

@ -47,10 +47,10 @@ func (r *JobRepository) buildCountQuery(
if col != "" {
// Scan columns: id, cnt
query = sq.Select(col, "COUNT(job.id)").From("job").GroupBy(col)
query = r.SQ.Select(col, "COUNT(job.id)").From("job").GroupBy(col)
} else {
// Scan columns: cnt
query = sq.Select("COUNT(job.id)").From("job")
query = r.SQ.Select("COUNT(job.id)").From("job")
}
switch kind {
@ -78,25 +78,25 @@ func (r *JobRepository) buildStatsQuery(
if col != "" {
// Scan columns: id, totalJobs, totalWalltime, totalNodes, totalNodeHours, totalCores, totalCoreHours, totalAccs, totalAccHours
query = sq.Select(col, "COUNT(job.id) as totalJobs",
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END)) / 3600) as %s) as totalWalltime`, time.Now().Unix(), castType),
query = r.SQ.Select(col, "COUNT(job.id) as totalJobs",
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END)) / 3600) as %s) as totalWalltime`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(SUM(job.num_nodes) as %s) as totalNodes`, castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) * job.num_nodes) / 3600) as %s) as totalNodeHours`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END) * job.num_nodes) / 3600) as %s) as totalNodeHours`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(SUM(job.num_hwthreads) as %s) as totalCores`, castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) * job.num_hwthreads) / 3600) as %s) as totalCoreHours`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END) * job.num_hwthreads) / 3600) as %s) as totalCoreHours`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(SUM(job.num_acc) as %s) as totalAccs`, castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) * job.num_acc) / 3600) as %s) as totalAccHours`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END) * job.num_acc) / 3600) as %s) as totalAccHours`, time.Now().Unix(), castType),
).From("job").GroupBy(col)
} else {
// Scan columns: totalJobs, totalWalltime, totalNodes, totalNodeHours, totalCores, totalCoreHours, totalAccs, totalAccHours
query = sq.Select("COUNT(job.id)",
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END)) / 3600) as %s)`, time.Now().Unix(), castType),
query = r.SQ.Select("COUNT(job.id)",
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END)) / 3600) as %s)`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(SUM(job.num_nodes) as %s)`, castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) * job.num_nodes) / 3600) as %s)`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END) * job.num_nodes) / 3600) as %s)`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(SUM(job.num_hwthreads) as %s)`, castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) * job.num_hwthreads) / 3600) as %s)`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END) * job.num_hwthreads) / 3600) as %s)`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(SUM(job.num_acc) as %s)`, castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) * job.num_acc) / 3600) as %s)`, time.Now().Unix(), castType),
fmt.Sprintf(`CAST(ROUND(SUM((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END) * job.num_acc) / 3600) as %s)`, time.Now().Unix(), castType),
).From("job")
}
@ -109,7 +109,7 @@ func (r *JobRepository) buildStatsQuery(
func (r *JobRepository) getUserName(ctx context.Context, id string) string {
user := GetUserFromContext(ctx)
name, _ := r.FindColumnValue(user, id, "user", "name", "username", false)
name, _ := r.FindColumnValue(user, id, "users", "name", "username", false)
if name != "" {
return name
} else {
@ -125,6 +125,8 @@ func (r *JobRepository) getCastType() string {
castType = "int"
case "mysql":
castType = "unsigned"
case "postgres":
castType = "int"
default:
castType = ""
}
@ -143,7 +145,7 @@ func (r *JobRepository) JobsStatsGrouped(
col := groupBy2column[*groupBy]
query := r.buildStatsQuery(filter, col)
query, err := SecurityCheck(ctx, query)
query, err := r.SecurityCheck(ctx, query)
if err != nil {
return nil, err
}
@ -246,7 +248,7 @@ func (r *JobRepository) JobsStats(
) ([]*model.JobsStatistics, error) {
start := time.Now()
query := r.buildStatsQuery(filter, "")
query, err := SecurityCheck(ctx, query)
query, err := r.SecurityCheck(ctx, query)
if err != nil {
return nil, err
}
@ -307,7 +309,7 @@ func (r *JobRepository) JobCountGrouped(
start := time.Now()
col := groupBy2column[*groupBy]
query := r.buildCountQuery(filter, "", col)
query, err := SecurityCheck(ctx, query)
query, err := r.SecurityCheck(ctx, query)
if err != nil {
return nil, err
}
@ -349,7 +351,7 @@ func (r *JobRepository) AddJobCountGrouped(
start := time.Now()
col := groupBy2column[*groupBy]
query := r.buildCountQuery(filter, kind, col)
query, err := SecurityCheck(ctx, query)
query, err := r.SecurityCheck(ctx, query)
if err != nil {
return nil, err
}
@ -396,7 +398,7 @@ func (r *JobRepository) AddJobCount(
) ([]*model.JobsStatistics, error) {
start := time.Now()
query := r.buildCountQuery(filter, kind, "")
query, err := SecurityCheck(ctx, query)
query, err := r.SecurityCheck(ctx, query)
if err != nil {
return nil, err
}
@ -442,7 +444,7 @@ func (r *JobRepository) AddHistograms(
castType := r.getCastType()
var err error
value := fmt.Sprintf(`CAST(ROUND((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) / 3600) as %s) as value`, time.Now().Unix(), castType)
value := fmt.Sprintf(`CAST(ROUND((CASE WHEN job.job_state = 'running' THEN %d - job.start_time ELSE job.duration END) / 3600) as %s) as value`, time.Now().Unix(), castType)
stat.HistDuration, err = r.jobsStatisticsHistogram(ctx, value, filter)
if err != nil {
log.Warn("Error while loading job statistics histogram: running jobs")
@ -512,8 +514,8 @@ func (r *JobRepository) jobsStatisticsHistogram(
filters []*model.JobFilter,
) ([]*model.HistoPoint, error) {
start := time.Now()
query, qerr := SecurityCheck(ctx,
sq.Select(value, "COUNT(job.id) AS count").From("job"))
query, qerr := r.SecurityCheck(ctx,
r.SQ.Select(value, "COUNT(job.id) AS count").From("job"))
if qerr != nil {
return nil, qerr
@ -583,7 +585,7 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
start := time.Now()
jm := fmt.Sprintf(`json_extract(footprint, "$.%s")`, metric)
crossJoinQuery := sq.Select(
crossJoinQuery := r.SQ.Select(
fmt.Sprintf(`max(%s) as max`, jm),
fmt.Sprintf(`min(%s) as min`, jm),
).From("job").Where(
@ -592,7 +594,7 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
fmt.Sprintf(`%s <= %f`, jm, peak),
)
crossJoinQuery, cjqerr := SecurityCheck(ctx, crossJoinQuery)
crossJoinQuery, cjqerr := r.SecurityCheck(ctx, crossJoinQuery)
if cjqerr != nil {
return nil, cjqerr
@ -612,7 +614,7 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
then value.max*0.999999999 else %s end - value.min) / (value.max -
value.min) * %d as INTEGER )`, jm, jm, bins)
mainQuery := sq.Select(
mainQuery := r.SQ.Select(
fmt.Sprintf(`%s + 1 as bin`, binQuery),
fmt.Sprintf(`count(%s) as count`, jm),
fmt.Sprintf(`CAST(((value.max / %d) * (%s )) as INTEGER ) as min`, bins, binQuery),
@ -621,7 +623,7 @@ func (r *JobRepository) jobsMetricStatisticsHistogram(
fmt.Sprintf(`(%s) as value`, crossJoinQuerySql), crossJoinQueryArgs...,
).Where(fmt.Sprintf(`%s is not null and %s <= %f`, jm, jm, peak))
mainQuery, qerr := SecurityCheck(ctx, mainQuery)
mainQuery, qerr := r.SecurityCheck(ctx, mainQuery)
if qerr != nil {
return nil, qerr

View File

@ -10,12 +10,11 @@ import (
"github.com/ClusterCockpit/cc-backend/internal/archive"
"github.com/ClusterCockpit/cc-backend/pkg/log"
"github.com/ClusterCockpit/cc-backend/pkg/schema"
sq "github.com/Masterminds/squirrel"
)
// Add the tag with id `tagId` to the job with the database id `jobId`.
func (r *JobRepository) AddTag(job int64, tag int64) ([]*schema.Tag, error) {
q := sq.Insert("jobtag").Columns("job_id", "tag_id").Values(job, tag)
q := r.SQ.Insert("jobtag").Columns("job_id", "tag_id").Values(job, tag)
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
s, _, _ := q.ToSql()
@ -40,7 +39,7 @@ func (r *JobRepository) AddTag(job int64, tag int64) ([]*schema.Tag, error) {
// Removes a tag from a job
func (r *JobRepository) RemoveTag(job, tag int64) ([]*schema.Tag, error) {
q := sq.Delete("jobtag").Where("jobtag.job_id = ?", job).Where("jobtag.tag_id = ?", tag)
q := r.SQ.Delete("jobtag").Where("jobtag.job_id = ?", job).Where("jobtag.tag_id = ?", tag)
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
s, _, _ := q.ToSql()
@ -65,7 +64,7 @@ func (r *JobRepository) RemoveTag(job, tag int64) ([]*schema.Tag, error) {
// CreateTag creates a new tag with the specified type and name and returns its database id.
func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, err error) {
q := sq.Insert("tag").Columns("tag_type", "tag_name").Values(tagType, tagName)
q := r.SQ.Insert("tag").Columns("tag_type", "tag_name").Values(tagType, tagName)
res, err := q.RunWith(r.stmtCache).Exec()
if err != nil {
@ -92,7 +91,7 @@ func (r *JobRepository) CountTags(user *schema.User) (tags []schema.Tag, counts
tags = append(tags, t)
}
q := sq.Select("t.tag_name, count(jt.tag_id)").
q := r.SQ.Select("t.tag_name, count(jt.tag_id)").
From("tag t").
LeftJoin("jobtag jt ON t.id = jt.tag_id").
GroupBy("t.tag_name")
@ -147,7 +146,7 @@ func (r *JobRepository) AddTagOrCreate(jobId int64, tagType string, tagName stri
// TagId returns the database id of the tag with the specified type and name.
func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exists bool) {
exists = true
if err := sq.Select("id").From("tag").
if err := r.SQ.Select("id").From("tag").
Where("tag.tag_type = ?", tagType).Where("tag.tag_name = ?", tagName).
RunWith(r.stmtCache).QueryRow().Scan(&tagId); err != nil {
exists = false
@ -157,7 +156,7 @@ func (r *JobRepository) TagId(tagType string, tagName string) (tagId int64, exis
// GetTags returns a list of all tags if job is nil or of the tags that the job with that database ID has.
func (r *JobRepository) GetTags(job *int64) ([]*schema.Tag, error) {
q := sq.Select("id", "tag_type", "tag_name").From("tag")
q := r.SQ.Select("id", "tag_type", "tag_name").From("tag")
if job != nil {
q = q.Join("jobtag ON jobtag.tag_id = tag.id").Where("jobtag.job_id = ?", *job)
}

View File

@ -56,25 +56,29 @@ func (r *JobRepository) TransactionCommit(t *Transaction) error {
func (r *JobRepository) TransactionEnd(t *Transaction) error {
if err := t.tx.Commit(); err != nil {
log.Warn("Error while committing SQL transactions")
log.Warn("Error while ending SQL transactions")
return err
}
return nil
}
func (r *JobRepository) TransactionAdd(t *Transaction, job schema.Job) (int64, error) {
res, err := t.stmt.Exec(job)
func (r *JobRepository) TransactionAdd(t *Transaction, job *schema.Job) (int64, error) {
var id int64
_, err := t.stmt.Exec(job)
if err != nil {
log.Errorf("repository initDB(): %v", err)
log.Errorf("Error while adding SQL transactions: %v", err)
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
log.Errorf("repository initDB(): %v", err)
return 0, err
}
//id, err := res.LastInsertId()
// err = t.stmt.QueryRowx(job).Scan(&id)
id = 0
// if err != nil {
// log.Errorf("Error while getting last insert ID: %v", err)
// log.Debugf("Insert job %d, %s, %d", job.JobID, job.Cluster, job.StartTimeUnix)
// return 0, err
// }
return id, nil
}

View File

@ -28,6 +28,7 @@ var (
type UserRepository struct {
DB *sqlx.DB
SQ sq.StatementBuilderType
driver string
}
@ -37,6 +38,7 @@ func GetUserRepository() *UserRepository {
userRepoInstance = &UserRepository{
DB: db.DB,
SQ: db.SQ,
driver: db.Driver,
}
})
@ -46,8 +48,8 @@ func GetUserRepository() *UserRepository {
func (r *UserRepository) GetUser(username string) (*schema.User, error) {
user := &schema.User{Username: username}
var hashedPassword, name, rawRoles, email, rawProjects sql.NullString
if err := sq.Select("password", "ldap", "name", "roles", "email", "projects").From("user").
Where("user.username = ?", username).RunWith(r.DB).
if err := r.SQ.Select("password", "ldap", "name", "roles", "email", "projects").From("users").
Where("users.username = ?", username).RunWith(r.DB).
QueryRow().Scan(&hashedPassword, &user.AuthSource, &name, &rawRoles, &email, &rawProjects); err != nil {
log.Warnf("Error while querying user '%v' from database", username)
return nil, err
@ -73,7 +75,7 @@ func (r *UserRepository) GetUser(username string) (*schema.User, error) {
func (r *UserRepository) GetLdapUsernames() ([]string, error) {
var users []string
rows, err := r.DB.Query(`SELECT username FROM user WHERE user.ldap = 1`)
rows, err := r.DB.Query(`SELECT username FROM users WHERE user.ldap = 1`)
if err != nil {
log.Warn("Error while querying usernames")
return nil, err
@ -121,7 +123,7 @@ func (r *UserRepository) AddUser(user *schema.User) error {
vals = append(vals, int(user.AuthSource))
}
if _, err := sq.Insert("user").Columns(cols...).Values(vals...).RunWith(r.DB).Exec(); err != nil {
if _, err := r.SQ.Insert("users").Columns(cols...).Values(vals...).RunWith(r.DB).Exec(); err != nil {
log.Errorf("Error while inserting new user '%v' into DB", user.Username)
return err
}
@ -131,7 +133,7 @@ func (r *UserRepository) AddUser(user *schema.User) error {
}
func (r *UserRepository) DelUser(username string) error {
_, err := r.DB.Exec(`DELETE FROM user WHERE user.username = ?`, username)
_, err := r.DB.Exec(`DELETE FROM users WHERE users.username = ?`, username)
if err != nil {
log.Errorf("Error while deleting user '%s' from DB", username)
return err
@ -141,7 +143,7 @@ func (r *UserRepository) DelUser(username string) error {
}
func (r *UserRepository) ListUsers(specialsOnly bool) ([]*schema.User, error) {
q := sq.Select("username", "name", "email", "roles", "projects").From("user")
q := r.SQ.Select("username", "name", "email", "roles", "projects").From("users")
if specialsOnly {
q = q.Where("(roles != '[\"user\"]' AND roles != '[]')")
}
@ -202,7 +204,7 @@ func (r *UserRepository) AddRole(
}
roles, _ := json.Marshal(append(user.Roles, newRole))
if _, err := sq.Update("user").Set("roles", roles).Where("user.username = ?", username).RunWith(r.DB).Exec(); err != nil {
if _, err := r.SQ.Update("users").Set("roles", roles).Where("users.username = ?", username).RunWith(r.DB).Exec(); err != nil {
log.Errorf("error while adding new role for user '%s'", user.Username)
return err
}
@ -238,7 +240,7 @@ func (r *UserRepository) RemoveRole(ctx context.Context, username string, queryr
}
mroles, _ := json.Marshal(newroles)
if _, err := sq.Update("user").Set("roles", mroles).Where("user.username = ?", username).RunWith(r.DB).Exec(); err != nil {
if _, err := r.SQ.Update("users").Set("roles", mroles).Where("users.username = ?", username).RunWith(r.DB).Exec(); err != nil {
log.Errorf("Error while removing role for user '%s'", user.Username)
return err
}
@ -264,7 +266,7 @@ func (r *UserRepository) AddProject(
}
projects, _ := json.Marshal(append(user.Projects, project))
if _, err := sq.Update("user").Set("projects", projects).Where("user.username = ?", username).RunWith(r.DB).Exec(); err != nil {
if _, err := r.SQ.Update("users").Set("projects", projects).Where("users.username = ?", username).RunWith(r.DB).Exec(); err != nil {
return err
}
@ -302,7 +304,7 @@ func (r *UserRepository) RemoveProject(ctx context.Context, username string, pro
} else {
result, _ = json.Marshal(newprojects)
}
if _, err := sq.Update("user").Set("projects", result).Where("user.username = ?", username).RunWith(r.DB).Exec(); err != nil {
if _, err := r.SQ.Update("users").Set("projects", result).Where("users.username = ?", username).RunWith(r.DB).Exec(); err != nil {
return err
}
return nil
@ -333,7 +335,7 @@ func (r *UserRepository) FetchUserInCtx(ctx context.Context, username string) (*
user := &model.User{Username: username}
var name, email sql.NullString
if err := sq.Select("name", "email").From("user").Where("user.username = ?", username).
if err := r.SQ.Select("name", "email").From("users").Where("users.username = ?", username).
RunWith(r.DB).QueryRow().Scan(&name, &email); err != nil {
if err == sql.ErrNoRows {
/* This warning will be logged *often* for non-local users, i.e. users mentioned only in job-table or archive, */

View File

@ -13,6 +13,7 @@ import (
"github.com/ClusterCockpit/cc-backend/pkg/log"
"github.com/ClusterCockpit/cc-backend/pkg/lrucache"
"github.com/ClusterCockpit/cc-backend/pkg/schema"
sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
)
@ -23,7 +24,7 @@ var (
type UserCfgRepo struct {
DB *sqlx.DB
Lookup *sqlx.Stmt
SQ sq.StatementBuilderType
uiDefaults map[string]interface{}
cache *lrucache.Cache
lock sync.RWMutex
@ -33,14 +34,9 @@ func GetUserCfgRepo() *UserCfgRepo {
userCfgRepoOnce.Do(func() {
db := GetConnection()
lookupConfigStmt, err := db.DB.Preparex(`SELECT confkey, value FROM configuration WHERE configuration.username = ?`)
if err != nil {
log.Fatalf("db.DB.Preparex() error: %v", err)
}
userCfgRepoInstance = &UserCfgRepo{
DB: db.DB,
Lookup: lookupConfigStmt,
SQ: db.SQ,
uiDefaults: config.Keys.UiDefaults,
cache: lrucache.New(1024),
}
@ -68,7 +64,9 @@ func (uCfg *UserCfgRepo) GetUIConfig(user *schema.User) (map[string]interface{},
uiconfig[k] = v
}
rows, err := uCfg.Lookup.Query(user.Username)
rows, err := uCfg.SQ.Select("confkey", "value").
From("configuration").Where("configuration.username = ?", user.Username).
RunWith(uCfg.DB).Query()
if err != nil {
log.Warnf("Error while looking up user uiconfig for user '%v'", user.Username)
return err, 0, 0
@ -127,9 +125,18 @@ func (uCfg *UserCfgRepo) UpdateConfig(
return nil
}
if _, err := uCfg.DB.Exec(`REPLACE INTO configuration (username, confkey, value) VALUES (?, ?, ?)`, user.Username, key, value); err != nil {
log.Warnf("Error while replacing user config in DB for user '%v'", user.Username)
return err
// REPLACE is SQlite specific, use generic insert or update pattern
if _, err := uCfg.SQ.Insert("configuration").
Columns("username", "confkey", "value").
Values(user.Username, key, value).RunWith(uCfg.DB).Exec(); err != nil {
// insert failed, update key
if _, err = uCfg.SQ.Update("configuration").
Set("username", user.Username).
Set("confkey", key).
Set("value", value).RunWith(uCfg.DB).Exec(); err != nil {
log.Warnf("Error while replacing user config in DB for user '%v': %v", user.Username, err)
return err
}
}
uCfg.cache.Del(user.Username)

View File

@ -33,7 +33,8 @@
"type": "string",
"enum": [
"sqlite3",
"mysql"
"mysql",
"postgres"
]
},
"db": {