diff --git a/internal/repository/dbConnection.go b/internal/repository/dbConnection.go index 36d3e32..7ea24ef 100644 --- a/internal/repository/dbConnection.go +++ b/internal/repository/dbConnection.go @@ -22,7 +22,8 @@ var ( ) type DBConnection struct { - DB *sqlx.DB + DB *sqlx.DB + Driver string } func Connect(driver string, db string) { @@ -54,7 +55,7 @@ func Connect(driver string, db string) { log.Fatalf("unsupported database driver: %s", driver) } - dbConnInstance = &DBConnection{DB: dbHandle} + dbConnInstance = &DBConnection{DB: dbHandle, Driver: driver} checkDBVersion(driver, dbHandle.DB) }) } diff --git a/internal/repository/job.go b/internal/repository/job.go index d67fa2a..de21176 100644 --- a/internal/repository/job.go +++ b/internal/repository/job.go @@ -32,7 +32,8 @@ var ( ) type JobRepository struct { - DB *sqlx.DB + DB *sqlx.DB + driver string stmtCache *sq.StmtCache cache *lrucache.Cache @@ -47,6 +48,7 @@ func GetJobRepository() *JobRepository { jobRepoInstance = &JobRepository{ DB: db.DB, + driver: db.Driver, stmtCache: sq.NewStmtCache(db.DB), cache: lrucache.New(1024 * 1024), archiveChannel: make(chan *schema.Job, 128), @@ -674,17 +676,24 @@ func (r *JobRepository) JobsStatistics(ctx context.Context, start := time.Now() // In case `groupBy` is nil (not used), the model.JobsStatistics used is at the key '' (empty string) stats := map[string]*model.JobsStatistics{} + var castType string + + if r.driver == "sqlite3" { + castType = "int" + } else if r.driver == "mysql" { + castType = "unsigned" + } // `socketsPerNode` and `coresPerSocket` can differ from cluster to cluster, so we need to explicitly loop over those. for _, cluster := range archive.Clusters { for _, subcluster := range cluster.SubClusters { - corehoursCol := fmt.Sprintf("CAST(ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600) as int)", subcluster.SocketsPerNode, subcluster.CoresPerSocket) + corehoursCol := fmt.Sprintf("CAST(ROUND(SUM(job.duration * job.num_nodes * %d * %d) / 3600) as %s)", subcluster.SocketsPerNode, subcluster.CoresPerSocket, castType) var query sq.SelectBuilder if groupBy == nil { query = sq.Select( "''", "COUNT(job.id)", - "CAST(ROUND(SUM(job.duration) / 3600) as int)", + fmt.Sprintf("CAST(ROUND(SUM(job.duration) / 3600) as %s)", castType), corehoursCol, ).From("job") } else { @@ -692,7 +701,7 @@ func (r *JobRepository) JobsStatistics(ctx context.Context, query = sq.Select( col, "COUNT(job.id)", - "CAST(ROUND(SUM(job.duration) / 3600) as int)", + fmt.Sprintf("CAST(ROUND(SUM(job.duration) / 3600) as %s)", castType), corehoursCol, ).From("job").GroupBy(col) } @@ -797,7 +806,7 @@ func (r *JobRepository) JobsStatistics(ctx context.Context, if histogramsNeeded { var err error - value := fmt.Sprintf(`CAST(ROUND((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) / 3600) as int) as value`, time.Now().Unix()) + value := fmt.Sprintf(`CAST(ROUND((CASE WHEN job.job_state = "running" THEN %d - job.start_time ELSE job.duration END) / 3600) as %s) as value`, time.Now().Unix(), castType) stat.HistDuration, err = r.jobsStatisticsHistogram(ctx, value, filter, id, col) if err != nil { log.Warn("Error while loading job statistics histogram: running jobs")