make database schema mysql compatible; use prepared statements

This commit is contained in:
Lou Knauer 2022-01-20 10:00:55 +01:00
parent a64944f3c3
commit 9034cb90aa
10 changed files with 163 additions and 86 deletions

View File

@ -215,7 +215,7 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) {
} }
res, err := api.DB.NamedExec(`INSERT INTO job ( res, err := api.DB.NamedExec(`INSERT INTO job (
job_id, user, project, cluster, partition, array_job_id, num_nodes, num_hwthreads, num_acc, job_id, user, project, cluster, `+"`partition`"+`, array_job_id, num_nodes, num_hwthreads, num_acc,
exclusive, monitoring_status, smt, job_state, start_time, duration, resources, meta_data exclusive, monitoring_status, smt, job_state, start_time, duration, resources, meta_data
) VALUES ( ) VALUES (
:job_id, :user, :project, :cluster, :partition, :array_job_id, :num_nodes, :num_hwthreads, :num_acc, :job_id, :user, :project, :cluster, :partition, :array_job_id, :num_nodes, :num_hwthreads, :num_acc,

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -13,13 +12,15 @@ import (
"github.com/ClusterCockpit/cc-jobarchive/auth" "github.com/ClusterCockpit/cc-jobarchive/auth"
"github.com/ClusterCockpit/cc-jobarchive/graph/model" "github.com/ClusterCockpit/cc-jobarchive/graph/model"
"github.com/iamlouk/lrucache"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
) )
var db *sqlx.DB var db *sqlx.DB
var lookupConfigStmt *sqlx.Stmt
var lock sync.RWMutex var lock sync.RWMutex
var uiDefaults map[string]interface{} var uiDefaults map[string]interface{}
var cache lrucache.Cache = *lrucache.New(1024)
var Clusters []*model.Cluster var Clusters []*model.Cluster
func Init(usersdb *sqlx.DB, authEnabled bool, uiConfig map[string]interface{}, jobArchive string) error { func Init(usersdb *sqlx.DB, authEnabled bool, uiConfig map[string]interface{}, jobArchive string) error {
@ -57,13 +58,18 @@ func Init(usersdb *sqlx.DB, authEnabled bool, uiConfig map[string]interface{}, j
_, err := db.Exec(` _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS configuration ( CREATE TABLE IF NOT EXISTS configuration (
username varchar(255), username varchar(255),
key varchar(255), confkey varchar(255),
value varchar(255), value varchar(255),
PRIMARY KEY (username, key), PRIMARY KEY (username, confkey),
FOREIGN KEY (username) REFERENCES user (username) ON DELETE CASCADE ON UPDATE NO ACTION);`) FOREIGN KEY (username) REFERENCES user (username) ON DELETE CASCADE ON UPDATE NO ACTION);`)
if err != nil { if err != nil {
return err return err
} }
lookupConfigStmt, err = db.Preparex(`SELECT confkey, value FROM configuration WHERE configuration.username = ?`)
if err != nil {
return err
}
} }
return nil return nil
@ -72,38 +78,52 @@ func Init(usersdb *sqlx.DB, authEnabled bool, uiConfig map[string]interface{}, j
// Return the personalised UI config for the currently authenticated // Return the personalised UI config for the currently authenticated
// user or return the plain default config. // user or return the plain default config.
func GetUIConfig(r *http.Request) (map[string]interface{}, error) { func GetUIConfig(r *http.Request) (map[string]interface{}, error) {
user := auth.GetUser(r.Context())
if user == nil {
lock.RLock() lock.RLock()
copy := make(map[string]interface{}, len(uiDefaults))
for k, v := range uiDefaults {
copy[k] = v
}
lock.RUnlock()
return copy, nil
}
data := cache.Get(user.Username, func() (interface{}, time.Duration, int) {
config := make(map[string]interface{}, len(uiDefaults)) config := make(map[string]interface{}, len(uiDefaults))
for k, v := range uiDefaults { for k, v := range uiDefaults {
config[k] = v config[k] = v
} }
lock.RUnlock()
user := auth.GetUser(r.Context()) rows, err := lookupConfigStmt.Query(user.Username)
if user == nil {
return config, nil
}
rows, err := db.Query(`SELECT key, value FROM configuration WHERE configuration.username = ?`, user.Username)
if err != nil { if err != nil {
return nil, err return err, 0, 0
} }
size := 0
for rows.Next() { for rows.Next() {
var key, rawval string var key, rawval string
if err := rows.Scan(&key, &rawval); err != nil { if err := rows.Scan(&key, &rawval); err != nil {
return nil, err return err, 0, 0
} }
var val interface{} var val interface{}
if err := json.Unmarshal([]byte(rawval), &val); err != nil { if err := json.Unmarshal([]byte(rawval), &val); err != nil {
return nil, err return err, 0, 0
} }
size += len(key)
size += len(rawval)
config[key] = val config[key] = val
} }
return config, nil return config, 24 * time.Hour, size
})
if err, ok := data.(error); ok {
return nil, err
}
return data.(map[string]interface{}), nil
} }
// If the context does not have a user, update the global ui configuration without persisting it! // If the context does not have a user, update the global ui configuration without persisting it!
@ -111,21 +131,20 @@ func GetUIConfig(r *http.Request) (map[string]interface{}, error) {
func UpdateConfig(key, value string, ctx context.Context) error { func UpdateConfig(key, value string, ctx context.Context) error {
user := auth.GetUser(ctx) user := auth.GetUser(ctx)
if user == nil { if user == nil {
lock.RLock()
defer lock.RUnlock()
var val interface{} var val interface{}
if err := json.Unmarshal([]byte(value), &val); err != nil { if err := json.Unmarshal([]byte(value), &val); err != nil {
return err return err
} }
lock.Lock()
defer lock.Unlock()
uiDefaults[key] = val uiDefaults[key] = val
return nil return nil
} }
if _, err := db.Exec(`REPLACE INTO configuration (username, key, value) VALUES (?, ?, ?)`, cache.Del(user.Username)
if _, err := db.Exec(`REPLACE INTO configuration (username, confkey, value) VALUES (?, ?, ?)`,
user.Username, key, value); err != nil { user.Username, key, value); err != nil {
log.Printf("db.Exec: %s\n", err.Error())
return err return err
} }

@ -1 +1 @@
Subproject commit 68543017064707625d788d1e7f987434d0bb0714 Subproject commit 80650220a3d481b6fc82c305791073ed92fc9261

1
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/99designs/gqlgen v0.13.0 github.com/99designs/gqlgen v0.13.0
github.com/Masterminds/squirrel v1.5.1 github.com/Masterminds/squirrel v1.5.1
github.com/go-ldap/ldap/v3 v3.4.1 github.com/go-ldap/ldap/v3 v3.4.1
github.com/go-sql-driver/mysql v1.5.0
github.com/golang-jwt/jwt/v4 v4.1.0 github.com/golang-jwt/jwt/v4 v4.1.0
github.com/gorilla/handlers v1.5.1 github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log"
"regexp" "regexp"
"strings" "strings"
@ -20,6 +21,31 @@ import (
type Resolver struct { type Resolver struct {
DB *sqlx.DB DB *sqlx.DB
findJobByIdStmt *sqlx.Stmt
findJobByIdWithUserStmt *sqlx.Stmt
}
func (r *Resolver) Init() {
findJobById, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).ToSql()
if err != nil {
log.Fatal(err)
}
r.findJobByIdStmt, err = r.DB.Preparex(findJobById)
if err != nil {
log.Fatal(err)
}
findJobByIdWithUser, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).Where("job.user = ?").ToSql()
if err != nil {
log.Fatal(err)
}
r.findJobByIdWithUserStmt, err = r.DB.Preparex(findJobByIdWithUser)
if err != nil {
log.Fatal(err)
}
} }
// Helper function for the `jobs` GraphQL-Query. Is also used elsewhere when a list of jobs is needed. // Helper function for the `jobs` GraphQL-Query. Is also used elsewhere when a list of jobs is needed.
@ -82,17 +108,12 @@ func (r *Resolver) queryJobs(ctx context.Context, filters []*model.JobFilter, pa
} }
func securityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder { func securityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder {
val := ctx.Value(auth.ContextUserKey) user := auth.GetUser(ctx)
if val == nil { if user == nil || user.IsAdmin {
return query return query
} }
user := val.(*auth.User) return query.Where("job.user = ?", user.Username)
if user.IsAdmin {
return query
}
return query.Where("job.user_id = ?", user.Username)
} }
// Build a sq.SelectBuilder out of a schema.JobFilter. // Build a sq.SelectBuilder out of a schema.JobFilter.

View File

@ -148,14 +148,14 @@ func (r *queryResolver) Tags(ctx context.Context) ([]*schema.Tag, error) {
} }
func (r *queryResolver) Job(ctx context.Context, id string) (*schema.Job, error) { func (r *queryResolver) Job(ctx context.Context, id string) (*schema.Job, error) {
query := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", id) // This query is very common (mostly called through other resolvers such as JobMetrics),
query = securityCheck(ctx, query) // so we use prepared statements here.
sql, args, err := query.ToSql() user := auth.GetUser(ctx)
if err != nil { if user == nil || user.IsAdmin {
return nil, err return schema.ScanJob(r.findJobByIdStmt.QueryRowx(id))
} }
return schema.ScanJob(r.DB.QueryRowx(sql, args...)) return schema.ScanJob(r.findJobByIdWithUserStmt.QueryRowx(id, user.Username))
} }
func (r *queryResolver) JobMetrics(ctx context.Context, id string, metrics []string, scopes []schema.MetricScope) ([]*model.JobMetricWithName, error) { func (r *queryResolver) JobMetrics(ctx context.Context, id string, metrics []string, scopes []schema.MetricScope) ([]*model.JobMetricWithName, error) {

View File

@ -30,13 +30,13 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
// `socketsPerNode` and `coresPerSocket` can differ from cluster to cluster, so we need to explicitly loop over those. // `socketsPerNode` and `coresPerSocket` can differ from cluster to cluster, so we need to explicitly loop over those.
for _, cluster := range config.Clusters { for _, cluster := range config.Clusters {
for _, partition := range cluster.Partitions { for _, partition := range cluster.Partitions {
corehoursCol := fmt.Sprintf("SUM(job.duration * job.num_nodes * %d * %d) / 3600", partition.SocketsPerNode, partition.CoresPerSocket) corehoursCol := fmt.Sprintf("ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600)", partition.SocketsPerNode, partition.CoresPerSocket)
var query sq.SelectBuilder var query sq.SelectBuilder
if groupBy == nil { if groupBy == nil {
query = sq.Select( query = sq.Select(
"''", "''",
"COUNT(job.id)", "COUNT(job.id)",
"SUM(job.duration) / 3600", "ROUND(SUM(job.duration) / 3600)",
corehoursCol, corehoursCol,
).From("job") ).From("job")
} else { } else {
@ -44,7 +44,7 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF
query = sq.Select( query = sq.Select(
col, col,
"COUNT(job.id)", "COUNT(job.id)",
"SUM(job.duration) / 3600", "ROUND(SUM(job.duration) / 3600)",
corehoursCol, corehoursCol,
).From("job").GroupBy(col) ).From("job").GroupBy(col)
} }

View File

@ -13,32 +13,34 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
) )
// `AUTO_INCREMENT` is in a comment because of this hack:
// https://stackoverflow.com/a/41028314 (sqlite creates unique ids automatically)
const JOBS_DB_SCHEMA string = ` const JOBS_DB_SCHEMA string = `
DROP TABLE IF EXISTS jobtag;
DROP TABLE IF EXISTS job; DROP TABLE IF EXISTS job;
DROP TABLE IF EXISTS tag; DROP TABLE IF EXISTS tag;
DROP TABLE IF EXISTS jobtag;
CREATE TABLE job ( CREATE TABLE job (
id INTEGER PRIMARY KEY AUTOINCREMENT, -- Not needed in sqlite id INTEGER PRIMARY KEY /*!40101 AUTO_INCREMENT */,
job_id BIGINT NOT NULL, job_id BIGINT NOT NULL,
cluster VARCHAR(255) NOT NULL, cluster VARCHAR(255) NOT NULL,
start_time TIMESTAMP NOT NULL, start_time BIGINT NOT NULL, -- Unix timestamp
user VARCHAR(255) NOT NULL, user VARCHAR(255) NOT NULL,
project VARCHAR(255) NOT NULL, project VARCHAR(255) NOT NULL,
partition VARCHAR(255) NOT NULL, ` + "`partition`" + ` VARCHAR(255) NOT NULL, -- partition is a keyword in mysql -.-
array_job_id BIGINT NOT NULL, array_job_id BIGINT NOT NULL,
duration INT, duration INT,
job_state VARCHAR(255) CHECK(job_state IN ('running', 'completed', 'failed', 'canceled', 'stopped', 'timeout')) NOT NULL, job_state VARCHAR(255) NOT NULL CHECK(job_state IN ('running', 'completed', 'failed', 'canceled', 'stopped', 'timeout')),
meta_data TEXT, -- json, but sqlite has no json type meta_data TEXT, -- JSON
resources TEXT NOT NULL, -- json, but sqlite has no json type resources TEXT NOT NULL, -- JSON
num_nodes INT NOT NULL, num_nodes INT NOT NULL,
num_hwthreads INT NOT NULL, num_hwthreads INT NOT NULL,
num_acc INT NOT NULL, num_acc INT NOT NULL,
smt TINYINT CHECK(smt IN (0, 1 )) NOT NULL DEFAULT 1, smt TINYINT NOT NULL DEFAULT 1 CHECK(smt IN (0, 1 )),
exclusive TINYINT CHECK(exclusive IN (0, 1, 2)) NOT NULL DEFAULT 1, exclusive TINYINT NOT NULL DEFAULT 1 CHECK(exclusive IN (0, 1, 2)),
monitoring_status TINYINT CHECK(monitoring_status IN (0, 1 )) NOT NULL DEFAULT 1, monitoring_status TINYINT NOT NULL DEFAULT 1 CHECK(monitoring_status IN (0, 1 )),
mem_used_max REAL NOT NULL DEFAULT 0.0, mem_used_max REAL NOT NULL DEFAULT 0.0,
flops_any_avg REAL NOT NULL DEFAULT 0.0, flops_any_avg REAL NOT NULL DEFAULT 0.0,
@ -89,7 +91,7 @@ func initDB(db *sqlx.DB, archive string) error {
} }
stmt, err := tx.PrepareNamed(`INSERT INTO job ( stmt, err := tx.PrepareNamed(`INSERT INTO job (
job_id, user, project, cluster, partition, array_job_id, num_nodes, num_hwthreads, num_acc, job_id, user, project, cluster, ` + "`partition`" + `, array_job_id, num_nodes, num_hwthreads, num_acc,
exclusive, monitoring_status, smt, job_state, start_time, duration, resources, meta_data, exclusive, monitoring_status, smt, job_state, start_time, duration, resources, meta_data,
mem_used_max, flops_any_avg, mem_bw_avg, load_avg, net_bw_avg, net_data_vol_total, file_bw_avg, file_data_vol_total mem_used_max, flops_any_avg, mem_bw_avg, load_avg, net_bw_avg, net_data_vol_total, file_bw_avg, file_data_vol_total
) VALUES ( ) VALUES (
@ -201,6 +203,7 @@ func loadJob(tx *sqlx.Tx, stmt *sqlx.NamedStmt, tags map[string]int64, path stri
job := schema.Job{ job := schema.Job{
BaseJob: jobMeta.BaseJob, BaseJob: jobMeta.BaseJob,
StartTime: time.Unix(jobMeta.StartTime, 0), StartTime: time.Unix(jobMeta.StartTime, 0),
StartTimeUnix: jobMeta.StartTime,
} }
// TODO: Other metrics... // TODO: Other metrics...

View File

@ -35,7 +35,8 @@ type BaseJob struct {
type Job struct { type Job struct {
ID int64 `json:"id" db:"id"` ID int64 `json:"id" db:"id"`
BaseJob BaseJob
StartTime time.Time `json:"startTime" db:"start_time"` StartTimeUnix int64 `json:"-" db:"start_time"`
StartTime time.Time `json:"startTime"`
MemUsedMax float64 `json:"-" db:"mem_used_max"` MemUsedMax float64 `json:"-" db:"mem_used_max"`
FlopsAnyAvg float64 `json:"-" db:"flops_any_avg"` FlopsAnyAvg float64 `json:"-" db:"flops_any_avg"`
MemBwAvg float64 `json:"-" db:"mem_bw_avg"` MemBwAvg float64 `json:"-" db:"mem_bw_avg"`
@ -83,6 +84,7 @@ func ScanJob(row Scannable) (*Job, error) {
return nil, err return nil, err
} }
job.StartTime = time.Unix(job.StartTimeUnix, 0)
if job.Duration == 0 && job.State == JobStateRunning { if job.Duration == 0 && job.State == JobStateRunning {
job.Duration = int32(time.Since(job.StartTime).Seconds()) job.Duration = int32(time.Since(job.StartTime).Seconds())
} }

View File

@ -35,6 +35,8 @@ import (
"github.com/gorilla/handlers" "github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -55,7 +57,10 @@ type ProgramConfig struct {
// Folder where static assets can be found, will be served directly // Folder where static assets can be found, will be served directly
StaticFiles string `json:"static-files"` StaticFiles string `json:"static-files"`
// Currently only SQLite3 ist supported, so this should be a filename // 'sqlite3' or 'mysql' (mysql will work for mariadb as well)
DBDriver string `json:"db-driver"`
// For sqlite3 a filename, for mysql a DSN in this format: https://github.com/go-sql-driver/mysql#dsn-data-source-name (Without query parameters!).
DB string `json:"db"` DB string `json:"db"`
// Path to the job-archive // Path to the job-archive
@ -87,6 +92,7 @@ var programConfig ProgramConfig = ProgramConfig{
Addr: ":8080", Addr: ":8080",
DisableAuthentication: false, DisableAuthentication: false,
StaticFiles: "./frontend/public", StaticFiles: "./frontend/public",
DBDriver: "sqlite3",
DB: "./var/job.db", DB: "./var/job.db",
JobArchive: "./var/job-archive", JobArchive: "./var/job-archive",
AsyncArchiving: true, AsyncArchiving: true,
@ -116,7 +122,6 @@ var programConfig ProgramConfig = ProgramConfig{
"plot_view_showRoofline": true, "plot_view_showRoofline": true,
"plot_view_showStatTable": true, "plot_view_showStatTable": true,
}, },
MachineStateDir: "./var/machine-state",
} }
func main() { func main() {
@ -147,14 +152,25 @@ func main() {
} }
var err error var err error
// This might need to change for other databases: if programConfig.DBDriver == "sqlite3" {
db, err = sqlx.Open("sqlite3", fmt.Sprintf("%s?_foreign_keys=on", programConfig.DB)) db, err = sqlx.Open("sqlite3", fmt.Sprintf("%s?_foreign_keys=on", programConfig.DB))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// Only for sqlite, not needed for any other database:
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
} else if programConfig.DBDriver == "mysql" {
db, err = sqlx.Open("mysql", fmt.Sprintf("%s?multiStatements=true", programConfig.DB))
if err != nil {
log.Fatal(err)
}
db.SetConnMaxLifetime(time.Minute * 3)
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(10)
} else {
log.Fatalf("unsupported database driver: %s", programConfig.DBDriver)
}
// Initialize sub-modules... // Initialize sub-modules...
@ -220,18 +236,20 @@ func main() {
// Build routes... // Build routes...
resolver := &graph.Resolver{DB: db} resolver := &graph.Resolver{DB: db}
resolver.Init()
graphQLEndpoint := handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: resolver})) graphQLEndpoint := handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: resolver}))
if os.Getenv("DEBUG") != "1" {
graphQLEndpoint.SetRecoverFunc(func(ctx context.Context, err interface{}) error {
switch e := err.(type) {
case string:
return fmt.Errorf("panic: %s", e)
case error:
return fmt.Errorf("panic caused by: %w", e)
}
// graphQLEndpoint.SetRecoverFunc(func(ctx context.Context, err interface{}) error { return errors.New("internal server error (panic)")
// switch e := err.(type) { })
// case string: }
// return fmt.Errorf("panic: %s", e)
// case error:
// return fmt.Errorf("panic caused by: %w", e)
// }
// return errors.New("internal server error (panic)")
// })
graphQLPlayground := playground.Handler("GraphQL playground", "/query") graphQLPlayground := playground.Handler("GraphQL playground", "/query")
api := &api.RestApi{ api := &api.RestApi{
@ -388,6 +406,19 @@ func monitoringRoutes(router *mux.Router, resolver *graph.Resolver) {
} }
filterPresets["tags"] = tags filterPresets["tags"] = tags
} }
if query.Get("numNodes") != "" {
parts := strings.Split(query.Get("numNodes"), "-")
if len(parts) == 2 {
a, e1 := strconv.Atoi(parts[0])
b, e2 := strconv.Atoi(parts[1])
if e1 == nil && e2 == nil {
filterPresets["numNodes"] = map[string]int{"from": a, "to": b}
}
}
}
if query.Get("jobId") != "" {
filterPresets["jobId"] = query.Get("jobId")
}
return filterPresets return filterPresets
} }