From 9034cb90aa75abb8dbc0735f9dcad46ceede946e Mon Sep 17 00:00:00 2001 From: Lou Knauer Date: Thu, 20 Jan 2022 10:00:55 +0100 Subject: [PATCH] make database schema mysql compatible; use prepared statements --- api/rest.go | 2 +- config/config.go | 87 ++++++++++++++++++++++++--------------- frontend | 2 +- go.mod | 1 + graph/resolver.go | 37 +++++++++++++---- graph/schema.resolvers.go | 12 +++--- graph/stats.go | 6 +-- init-db.go | 29 +++++++------ schema/job.go | 4 +- server.go | 69 ++++++++++++++++++++++--------- 10 files changed, 163 insertions(+), 86 deletions(-) diff --git a/api/rest.go b/api/rest.go index 9526252..65388eb 100644 --- a/api/rest.go +++ b/api/rest.go @@ -215,7 +215,7 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) { } res, err := api.DB.NamedExec(`INSERT INTO job ( - job_id, user, project, cluster, partition, array_job_id, num_nodes, num_hwthreads, num_acc, + job_id, user, project, cluster, `+"`partition`"+`, array_job_id, num_nodes, num_hwthreads, num_acc, exclusive, monitoring_status, smt, job_state, start_time, duration, resources, meta_data ) VALUES ( :job_id, :user, :project, :cluster, :partition, :array_job_id, :num_nodes, :num_hwthreads, :num_acc, diff --git a/config/config.go b/config/config.go index e13e101..a60a42f 100644 --- a/config/config.go +++ b/config/config.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "log" "net/http" "os" "path/filepath" @@ -13,13 +12,15 @@ import ( "github.com/ClusterCockpit/cc-jobarchive/auth" "github.com/ClusterCockpit/cc-jobarchive/graph/model" + "github.com/iamlouk/lrucache" "github.com/jmoiron/sqlx" ) var db *sqlx.DB +var lookupConfigStmt *sqlx.Stmt var lock sync.RWMutex var uiDefaults map[string]interface{} - +var cache lrucache.Cache = *lrucache.New(1024) var Clusters []*model.Cluster func Init(usersdb *sqlx.DB, authEnabled bool, uiConfig map[string]interface{}, jobArchive string) error { @@ -57,13 +58,18 @@ func Init(usersdb *sqlx.DB, authEnabled bool, uiConfig map[string]interface{}, j _, err := db.Exec(` CREATE TABLE IF NOT EXISTS configuration ( username varchar(255), - key varchar(255), + confkey varchar(255), value varchar(255), - PRIMARY KEY (username, key), + PRIMARY KEY (username, confkey), FOREIGN KEY (username) REFERENCES user (username) ON DELETE CASCADE ON UPDATE NO ACTION);`) if err != nil { return err } + + lookupConfigStmt, err = db.Preparex(`SELECT confkey, value FROM configuration WHERE configuration.username = ?`) + if err != nil { + return err + } } return nil @@ -72,38 +78,52 @@ func Init(usersdb *sqlx.DB, authEnabled bool, uiConfig map[string]interface{}, j // Return the personalised UI config for the currently authenticated // user or return the plain default config. func GetUIConfig(r *http.Request) (map[string]interface{}, error) { - lock.RLock() - config := make(map[string]interface{}, len(uiDefaults)) - for k, v := range uiDefaults { - config[k] = v - } - lock.RUnlock() - user := auth.GetUser(r.Context()) if user == nil { - return config, nil + lock.RLock() + copy := make(map[string]interface{}, len(uiDefaults)) + for k, v := range uiDefaults { + copy[k] = v + } + lock.RUnlock() + return copy, nil } - rows, err := db.Query(`SELECT key, value FROM configuration WHERE configuration.username = ?`, user.Username) - if err != nil { + data := cache.Get(user.Username, func() (interface{}, time.Duration, int) { + config := make(map[string]interface{}, len(uiDefaults)) + for k, v := range uiDefaults { + config[k] = v + } + + rows, err := lookupConfigStmt.Query(user.Username) + if err != nil { + return err, 0, 0 + } + + size := 0 + for rows.Next() { + var key, rawval string + if err := rows.Scan(&key, &rawval); err != nil { + return err, 0, 0 + } + + var val interface{} + if err := json.Unmarshal([]byte(rawval), &val); err != nil { + return err, 0, 0 + } + + size += len(key) + size += len(rawval) + config[key] = val + } + + return config, 24 * time.Hour, size + }) + if err, ok := data.(error); ok { return nil, err } - for rows.Next() { - var key, rawval string - if err := rows.Scan(&key, &rawval); err != nil { - return nil, err - } - - var val interface{} - if err := json.Unmarshal([]byte(rawval), &val); err != nil { - return nil, err - } - - config[key] = val - } - - return config, nil + return data.(map[string]interface{}), nil } // If the context does not have a user, update the global ui configuration without persisting it! @@ -111,21 +131,20 @@ func GetUIConfig(r *http.Request) (map[string]interface{}, error) { func UpdateConfig(key, value string, ctx context.Context) error { user := auth.GetUser(ctx) if user == nil { - lock.RLock() - defer lock.RUnlock() - var val interface{} if err := json.Unmarshal([]byte(value), &val); err != nil { return err } + lock.Lock() + defer lock.Unlock() uiDefaults[key] = val return nil } - if _, err := db.Exec(`REPLACE INTO configuration (username, key, value) VALUES (?, ?, ?)`, + cache.Del(user.Username) + if _, err := db.Exec(`REPLACE INTO configuration (username, confkey, value) VALUES (?, ?, ?)`, user.Username, key, value); err != nil { - log.Printf("db.Exec: %s\n", err.Error()) return err } diff --git a/frontend b/frontend index 6854301..8065022 160000 --- a/frontend +++ b/frontend @@ -1 +1 @@ -Subproject commit 68543017064707625d788d1e7f987434d0bb0714 +Subproject commit 80650220a3d481b6fc82c305791073ed92fc9261 diff --git a/go.mod b/go.mod index 8aca229..a222d3d 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/99designs/gqlgen v0.13.0 github.com/Masterminds/squirrel v1.5.1 github.com/go-ldap/ldap/v3 v3.4.1 + github.com/go-sql-driver/mysql v1.5.0 github.com/golang-jwt/jwt/v4 v4.1.0 github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 diff --git a/graph/resolver.go b/graph/resolver.go index ee90752..2e5f373 100644 --- a/graph/resolver.go +++ b/graph/resolver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "regexp" "strings" @@ -20,6 +21,31 @@ import ( type Resolver struct { DB *sqlx.DB + + findJobByIdStmt *sqlx.Stmt + findJobByIdWithUserStmt *sqlx.Stmt +} + +func (r *Resolver) Init() { + findJobById, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).ToSql() + if err != nil { + log.Fatal(err) + } + + r.findJobByIdStmt, err = r.DB.Preparex(findJobById) + if err != nil { + log.Fatal(err) + } + + findJobByIdWithUser, _, err := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", nil).Where("job.user = ?").ToSql() + if err != nil { + log.Fatal(err) + } + + r.findJobByIdWithUserStmt, err = r.DB.Preparex(findJobByIdWithUser) + if err != nil { + log.Fatal(err) + } } // Helper function for the `jobs` GraphQL-Query. Is also used elsewhere when a list of jobs is needed. @@ -82,17 +108,12 @@ func (r *Resolver) queryJobs(ctx context.Context, filters []*model.JobFilter, pa } func securityCheck(ctx context.Context, query sq.SelectBuilder) sq.SelectBuilder { - val := ctx.Value(auth.ContextUserKey) - if val == nil { + user := auth.GetUser(ctx) + if user == nil || user.IsAdmin { return query } - user := val.(*auth.User) - if user.IsAdmin { - return query - } - - return query.Where("job.user_id = ?", user.Username) + return query.Where("job.user = ?", user.Username) } // Build a sq.SelectBuilder out of a schema.JobFilter. diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 20ba974..dd14179 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -148,14 +148,14 @@ func (r *queryResolver) Tags(ctx context.Context) ([]*schema.Tag, error) { } func (r *queryResolver) Job(ctx context.Context, id string) (*schema.Job, error) { - query := sq.Select(schema.JobColumns...).From("job").Where("job.id = ?", id) - query = securityCheck(ctx, query) - sql, args, err := query.ToSql() - if err != nil { - return nil, err + // This query is very common (mostly called through other resolvers such as JobMetrics), + // so we use prepared statements here. + user := auth.GetUser(ctx) + if user == nil || user.IsAdmin { + return schema.ScanJob(r.findJobByIdStmt.QueryRowx(id)) } - return schema.ScanJob(r.DB.QueryRowx(sql, args...)) + return schema.ScanJob(r.findJobByIdWithUserStmt.QueryRowx(id, user.Username)) } func (r *queryResolver) JobMetrics(ctx context.Context, id string, metrics []string, scopes []schema.MetricScope) ([]*model.JobMetricWithName, error) { diff --git a/graph/stats.go b/graph/stats.go index 2bb0505..6794d80 100644 --- a/graph/stats.go +++ b/graph/stats.go @@ -30,13 +30,13 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF // `socketsPerNode` and `coresPerSocket` can differ from cluster to cluster, so we need to explicitly loop over those. for _, cluster := range config.Clusters { for _, partition := range cluster.Partitions { - corehoursCol := fmt.Sprintf("SUM(job.duration * job.num_nodes * %d * %d) / 3600", partition.SocketsPerNode, partition.CoresPerSocket) + corehoursCol := fmt.Sprintf("ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600)", partition.SocketsPerNode, partition.CoresPerSocket) var query sq.SelectBuilder if groupBy == nil { query = sq.Select( "''", "COUNT(job.id)", - "SUM(job.duration) / 3600", + "ROUND(SUM(job.duration) / 3600)", corehoursCol, ).From("job") } else { @@ -44,7 +44,7 @@ func (r *queryResolver) jobsStatistics(ctx context.Context, filter []*model.JobF query = sq.Select( col, "COUNT(job.id)", - "SUM(job.duration) / 3600", + "ROUND(SUM(job.duration) / 3600)", corehoursCol, ).From("job").GroupBy(col) } diff --git a/init-db.go b/init-db.go index 502c5f5..85a3bb3 100644 --- a/init-db.go +++ b/init-db.go @@ -13,32 +13,34 @@ import ( "github.com/jmoiron/sqlx" ) +// `AUTO_INCREMENT` is in a comment because of this hack: +// https://stackoverflow.com/a/41028314 (sqlite creates unique ids automatically) const JOBS_DB_SCHEMA string = ` + DROP TABLE IF EXISTS jobtag; DROP TABLE IF EXISTS job; DROP TABLE IF EXISTS tag; - DROP TABLE IF EXISTS jobtag; CREATE TABLE job ( - id INTEGER PRIMARY KEY AUTOINCREMENT, -- Not needed in sqlite + id INTEGER PRIMARY KEY /*!40101 AUTO_INCREMENT */, job_id BIGINT NOT NULL, cluster VARCHAR(255) NOT NULL, - start_time TIMESTAMP NOT NULL, + start_time BIGINT NOT NULL, -- Unix timestamp user VARCHAR(255) NOT NULL, project VARCHAR(255) NOT NULL, - partition VARCHAR(255) NOT NULL, + ` + "`partition`" + ` VARCHAR(255) NOT NULL, -- partition is a keyword in mysql -.- array_job_id BIGINT NOT NULL, duration INT, - job_state VARCHAR(255) CHECK(job_state IN ('running', 'completed', 'failed', 'canceled', 'stopped', 'timeout')) NOT NULL, - meta_data TEXT, -- json, but sqlite has no json type - resources TEXT NOT NULL, -- json, but sqlite has no json type + job_state VARCHAR(255) NOT NULL CHECK(job_state IN ('running', 'completed', 'failed', 'canceled', 'stopped', 'timeout')), + meta_data TEXT, -- JSON + resources TEXT NOT NULL, -- JSON num_nodes INT NOT NULL, num_hwthreads INT NOT NULL, num_acc INT NOT NULL, - smt TINYINT CHECK(smt IN (0, 1 )) NOT NULL DEFAULT 1, - exclusive TINYINT CHECK(exclusive IN (0, 1, 2)) NOT NULL DEFAULT 1, - monitoring_status TINYINT CHECK(monitoring_status IN (0, 1 )) NOT NULL DEFAULT 1, + smt TINYINT NOT NULL DEFAULT 1 CHECK(smt IN (0, 1 )), + exclusive TINYINT NOT NULL DEFAULT 1 CHECK(exclusive IN (0, 1, 2)), + monitoring_status TINYINT NOT NULL DEFAULT 1 CHECK(monitoring_status IN (0, 1 )), mem_used_max REAL NOT NULL DEFAULT 0.0, flops_any_avg REAL NOT NULL DEFAULT 0.0, @@ -89,7 +91,7 @@ func initDB(db *sqlx.DB, archive string) error { } stmt, err := tx.PrepareNamed(`INSERT INTO job ( - job_id, user, project, cluster, partition, array_job_id, num_nodes, num_hwthreads, num_acc, + job_id, user, project, cluster, ` + "`partition`" + `, array_job_id, num_nodes, num_hwthreads, num_acc, exclusive, monitoring_status, smt, job_state, start_time, duration, resources, meta_data, mem_used_max, flops_any_avg, mem_bw_avg, load_avg, net_bw_avg, net_data_vol_total, file_bw_avg, file_data_vol_total ) VALUES ( @@ -199,8 +201,9 @@ func loadJob(tx *sqlx.Tx, stmt *sqlx.NamedStmt, tags map[string]int64, path stri } job := schema.Job{ - BaseJob: jobMeta.BaseJob, - StartTime: time.Unix(jobMeta.StartTime, 0), + BaseJob: jobMeta.BaseJob, + StartTime: time.Unix(jobMeta.StartTime, 0), + StartTimeUnix: jobMeta.StartTime, } // TODO: Other metrics... diff --git a/schema/job.go b/schema/job.go index 8781776..a80acf8 100644 --- a/schema/job.go +++ b/schema/job.go @@ -35,7 +35,8 @@ type BaseJob struct { type Job struct { ID int64 `json:"id" db:"id"` BaseJob - StartTime time.Time `json:"startTime" db:"start_time"` + StartTimeUnix int64 `json:"-" db:"start_time"` + StartTime time.Time `json:"startTime"` MemUsedMax float64 `json:"-" db:"mem_used_max"` FlopsAnyAvg float64 `json:"-" db:"flops_any_avg"` MemBwAvg float64 `json:"-" db:"mem_bw_avg"` @@ -83,6 +84,7 @@ func ScanJob(row Scannable) (*Job, error) { return nil, err } + job.StartTime = time.Unix(job.StartTimeUnix, 0) if job.Duration == 0 && job.State == JobStateRunning { job.Duration = int32(time.Since(job.StartTime).Seconds()) } diff --git a/server.go b/server.go index 1f1bc2f..ffde0bd 100644 --- a/server.go +++ b/server.go @@ -35,6 +35,8 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/mux" "github.com/jmoiron/sqlx" + + _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" ) @@ -55,7 +57,10 @@ type ProgramConfig struct { // Folder where static assets can be found, will be served directly StaticFiles string `json:"static-files"` - // Currently only SQLite3 ist supported, so this should be a filename + // 'sqlite3' or 'mysql' (mysql will work for mariadb as well) + DBDriver string `json:"db-driver"` + + // For sqlite3 a filename, for mysql a DSN in this format: https://github.com/go-sql-driver/mysql#dsn-data-source-name (Without query parameters!). DB string `json:"db"` // Path to the job-archive @@ -87,6 +92,7 @@ var programConfig ProgramConfig = ProgramConfig{ Addr: ":8080", DisableAuthentication: false, StaticFiles: "./frontend/public", + DBDriver: "sqlite3", DB: "./var/job.db", JobArchive: "./var/job-archive", AsyncArchiving: true, @@ -116,7 +122,6 @@ var programConfig ProgramConfig = ProgramConfig{ "plot_view_showRoofline": true, "plot_view_showStatTable": true, }, - MachineStateDir: "./var/machine-state", } func main() { @@ -147,14 +152,25 @@ func main() { } var err error - // This might need to change for other databases: - db, err = sqlx.Open("sqlite3", fmt.Sprintf("%s?_foreign_keys=on", programConfig.DB)) - if err != nil { - log.Fatal(err) - } + if programConfig.DBDriver == "sqlite3" { + db, err = sqlx.Open("sqlite3", fmt.Sprintf("%s?_foreign_keys=on", programConfig.DB)) + if err != nil { + log.Fatal(err) + } - // Only for sqlite, not needed for any other database: - db.SetMaxOpenConns(1) + db.SetMaxOpenConns(1) + } else if programConfig.DBDriver == "mysql" { + db, err = sqlx.Open("mysql", fmt.Sprintf("%s?multiStatements=true", programConfig.DB)) + if err != nil { + log.Fatal(err) + } + + db.SetConnMaxLifetime(time.Minute * 3) + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(10) + } else { + log.Fatalf("unsupported database driver: %s", programConfig.DBDriver) + } // Initialize sub-modules... @@ -220,18 +236,20 @@ func main() { // Build routes... resolver := &graph.Resolver{DB: db} + resolver.Init() graphQLEndpoint := handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: resolver})) + if os.Getenv("DEBUG") != "1" { + graphQLEndpoint.SetRecoverFunc(func(ctx context.Context, err interface{}) error { + switch e := err.(type) { + case string: + return fmt.Errorf("panic: %s", e) + case error: + return fmt.Errorf("panic caused by: %w", e) + } - // graphQLEndpoint.SetRecoverFunc(func(ctx context.Context, err interface{}) error { - // switch e := err.(type) { - // case string: - // return fmt.Errorf("panic: %s", e) - // case error: - // return fmt.Errorf("panic caused by: %w", e) - // } - - // return errors.New("internal server error (panic)") - // }) + return errors.New("internal server error (panic)") + }) + } graphQLPlayground := playground.Handler("GraphQL playground", "/query") api := &api.RestApi{ @@ -388,6 +406,19 @@ func monitoringRoutes(router *mux.Router, resolver *graph.Resolver) { } filterPresets["tags"] = tags } + if query.Get("numNodes") != "" { + parts := strings.Split(query.Get("numNodes"), "-") + if len(parts) == 2 { + a, e1 := strconv.Atoi(parts[0]) + b, e2 := strconv.Atoi(parts[1]) + if e1 == nil && e2 == nil { + filterPresets["numNodes"] = map[string]int{"from": a, "to": b} + } + } + } + if query.Get("jobId") != "" { + filterPresets["jobId"] = query.Get("jobId") + } return filterPresets }