mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2026-04-04 15:07:29 +02:00
Merge branch 'dev' into log-aggregator
This commit is contained in:
@@ -30,7 +30,7 @@ import (
|
||||
ccconf "github.com/ClusterCockpit/cc-lib/v2/ccConfig"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
@@ -216,9 +216,7 @@ func TestRestApi(t *testing.T) {
|
||||
return testData, nil
|
||||
}
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.PathPrefix("/api").Subrouter()
|
||||
r.StrictSlash(true)
|
||||
r := chi.NewRouter()
|
||||
restapi.MountAPIRoutes(r)
|
||||
|
||||
var TestJobID int64 = 123
|
||||
|
||||
@@ -36,9 +36,9 @@ type GetClustersAPIResponse struct {
|
||||
// @router /api/clusters/ [get]
|
||||
func (api *RestAPI) getClusters(rw http.ResponseWriter, r *http.Request) {
|
||||
if user := repository.GetUserFromContext(r.Context()); user != nil &&
|
||||
!user.HasRole(schema.RoleApi) {
|
||||
!user.HasRole(schema.RoleAPI) {
|
||||
|
||||
handleError(fmt.Errorf("missing role: %v", schema.GetRoleString(schema.RoleApi)), http.StatusForbidden, rw)
|
||||
handleError(fmt.Errorf("missing role: %v", schema.GetRoleString(schema.RoleAPI)), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/ClusterCockpit/cc-backend/pkg/archive"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -243,10 +243,10 @@ func (api *RestAPI) getJobs(rw http.ResponseWriter, r *http.Request) {
|
||||
// @router /api/jobs/{id} [get]
|
||||
func (api *RestAPI) getCompleteJobByID(rw http.ResponseWriter, r *http.Request) {
|
||||
// Fetch job from db
|
||||
id, ok := mux.Vars(r)["id"]
|
||||
id := chi.URLParam(r, "id")
|
||||
var job *schema.Job
|
||||
var err error
|
||||
if ok {
|
||||
if id != "" {
|
||||
id, e := strconv.ParseInt(id, 10, 64)
|
||||
if e != nil {
|
||||
handleError(fmt.Errorf("integer expected in path for id: %w", e), http.StatusBadRequest, rw)
|
||||
@@ -336,10 +336,10 @@ func (api *RestAPI) getCompleteJobByID(rw http.ResponseWriter, r *http.Request)
|
||||
// @router /api/jobs/{id} [post]
|
||||
func (api *RestAPI) getJobByID(rw http.ResponseWriter, r *http.Request) {
|
||||
// Fetch job from db
|
||||
id, ok := mux.Vars(r)["id"]
|
||||
id := chi.URLParam(r, "id")
|
||||
var job *schema.Job
|
||||
var err error
|
||||
if ok {
|
||||
if id != "" {
|
||||
id, e := strconv.ParseInt(id, 10, 64)
|
||||
if e != nil {
|
||||
handleError(fmt.Errorf("integer expected in path for id: %w", e), http.StatusBadRequest, rw)
|
||||
@@ -439,7 +439,7 @@ func (api *RestAPI) getJobByID(rw http.ResponseWriter, r *http.Request) {
|
||||
// @security ApiKeyAuth
|
||||
// @router /api/jobs/edit_meta/{id} [post]
|
||||
func (api *RestAPI) editMeta(rw http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(mux.Vars(r)["id"], 10, 64)
|
||||
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||
if err != nil {
|
||||
handleError(fmt.Errorf("parsing job ID failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
@@ -487,7 +487,7 @@ func (api *RestAPI) editMeta(rw http.ResponseWriter, r *http.Request) {
|
||||
// @security ApiKeyAuth
|
||||
// @router /api/jobs/tag_job/{id} [post]
|
||||
func (api *RestAPI) tagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(mux.Vars(r)["id"], 10, 64)
|
||||
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||
if err != nil {
|
||||
handleError(fmt.Errorf("parsing job ID failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
@@ -551,7 +551,7 @@ func (api *RestAPI) tagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
// @security ApiKeyAuth
|
||||
// @router /jobs/tag_job/{id} [delete]
|
||||
func (api *RestAPI) removeTagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(mux.Vars(r)["id"], 10, 64)
|
||||
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||
if err != nil {
|
||||
handleError(fmt.Errorf("parsing job ID failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
@@ -754,6 +754,7 @@ func (api *RestAPI) stopJobByRequest(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
isCached := false
|
||||
job, err = api.JobRepository.Find(req.JobID, req.Cluster, req.StartTime)
|
||||
if err != nil {
|
||||
// Try cached jobs if not found in main repository
|
||||
@@ -764,9 +765,10 @@ func (api *RestAPI) stopJobByRequest(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
job = cachedJob
|
||||
isCached = true
|
||||
}
|
||||
|
||||
api.checkAndHandleStopJob(rw, job, req)
|
||||
api.checkAndHandleStopJob(rw, job, req, isCached)
|
||||
}
|
||||
|
||||
// deleteJobByID godoc
|
||||
@@ -786,9 +788,9 @@ func (api *RestAPI) stopJobByRequest(rw http.ResponseWriter, r *http.Request) {
|
||||
// @router /api/jobs/delete_job/{id} [delete]
|
||||
func (api *RestAPI) deleteJobByID(rw http.ResponseWriter, r *http.Request) {
|
||||
// Fetch job (that will be stopped) from db
|
||||
id, ok := mux.Vars(r)["id"]
|
||||
id := chi.URLParam(r, "id")
|
||||
var err error
|
||||
if ok {
|
||||
if id != "" {
|
||||
id, e := strconv.ParseInt(id, 10, 64)
|
||||
if e != nil {
|
||||
handleError(fmt.Errorf("integer expected in path for id: %w", e), http.StatusBadRequest, rw)
|
||||
@@ -885,9 +887,9 @@ func (api *RestAPI) deleteJobByRequest(rw http.ResponseWriter, r *http.Request)
|
||||
func (api *RestAPI) deleteJobBefore(rw http.ResponseWriter, r *http.Request) {
|
||||
var cnt int
|
||||
// Fetch job (that will be stopped) from db
|
||||
id, ok := mux.Vars(r)["ts"]
|
||||
id := chi.URLParam(r, "ts")
|
||||
var err error
|
||||
if ok {
|
||||
if id != "" {
|
||||
ts, e := strconv.ParseInt(id, 10, 64)
|
||||
if e != nil {
|
||||
handleError(fmt.Errorf("integer expected in path for ts: %w", e), http.StatusBadRequest, rw)
|
||||
@@ -923,7 +925,7 @@ func (api *RestAPI) deleteJobBefore(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (api *RestAPI) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Job, req StopJobAPIRequest) {
|
||||
func (api *RestAPI) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Job, req StopJobAPIRequest, isCached bool) {
|
||||
// Sanity checks
|
||||
if job.State != schema.JobStateRunning {
|
||||
handleError(fmt.Errorf("jobId %d (id %d) on %s : job has already been stopped (state is: %s)", job.JobID, *job.ID, job.Cluster, job.State), http.StatusUnprocessableEntity, rw)
|
||||
@@ -948,11 +950,21 @@ func (api *RestAPI) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Jo
|
||||
api.JobRepository.Mutex.Lock()
|
||||
defer api.JobRepository.Mutex.Unlock()
|
||||
|
||||
if err := api.JobRepository.Stop(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
if err := api.JobRepository.StopCached(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
handleError(fmt.Errorf("jobId %d (id %d) on %s : marking job as '%s' (duration: %d) in DB failed: %w", job.JobID, *job.ID, job.Cluster, job.State, job.Duration, err), http.StatusInternalServerError, rw)
|
||||
// If the job is still in job_cache, transfer it to the job table first
|
||||
// so that job.ID always points to the job table for downstream code
|
||||
if isCached {
|
||||
newID, err := api.JobRepository.TransferCachedJobToMain(*job.ID)
|
||||
if err != nil {
|
||||
handleError(fmt.Errorf("jobId %d (id %d) on %s : transferring cached job failed: %w", job.JobID, *job.ID, job.Cluster, err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
cclog.Infof("transferred cached job to main table: old id %d -> new id %d (jobId=%d)", *job.ID, newID, job.JobID)
|
||||
job.ID = &newID
|
||||
}
|
||||
|
||||
if err := api.JobRepository.Stop(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
handleError(fmt.Errorf("jobId %d (id %d) on %s : marking job as '%s' (duration: %d) in DB failed: %w", job.JobID, *job.ID, job.Cluster, job.State, job.Duration, err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
cclog.Infof("archiving job... (dbid: %d): cluster=%s, jobId=%d, user=%s, startTime=%d, duration=%d, state=%s", *job.ID, job.Cluster, job.JobID, job.User, job.StartTime, job.Duration, job.State)
|
||||
@@ -976,7 +988,7 @@ func (api *RestAPI) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Jo
|
||||
}
|
||||
|
||||
func (api *RestAPI) getJobMetrics(rw http.ResponseWriter, r *http.Request) {
|
||||
id := mux.Vars(r)["id"]
|
||||
id := chi.URLParam(r, "id")
|
||||
metrics := r.URL.Query()["metric"]
|
||||
var scopes []schema.MetricScope
|
||||
for _, scope := range r.URL.Query()["scope"] {
|
||||
@@ -1042,8 +1054,8 @@ type GetUsedNodesAPIResponse struct {
|
||||
// @router /api/jobs/used_nodes [get]
|
||||
func (api *RestAPI) getUsedNodes(rw http.ResponseWriter, r *http.Request) {
|
||||
if user := repository.GetUserFromContext(r.Context()); user != nil &&
|
||||
!user.HasRole(schema.RoleApi) {
|
||||
handleError(fmt.Errorf("missing role: %v", schema.GetRoleString(schema.RoleApi)), http.StatusForbidden, rw)
|
||||
!user.HasRole(schema.RoleAPI) {
|
||||
handleError(fmt.Errorf("missing role: %v", schema.GetRoleString(schema.RoleAPI)), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -251,6 +251,7 @@ func (api *NatsAPI) handleStopJob(payload string) {
|
||||
return
|
||||
}
|
||||
|
||||
isCached := false
|
||||
job, err := api.JobRepository.Find(req.JobID, req.Cluster, req.StartTime)
|
||||
if err != nil {
|
||||
cachedJob, cachedErr := api.JobRepository.FindCached(req.JobID, req.Cluster, req.StartTime)
|
||||
@@ -260,6 +261,7 @@ func (api *NatsAPI) handleStopJob(payload string) {
|
||||
return
|
||||
}
|
||||
job = cachedJob
|
||||
isCached = true
|
||||
}
|
||||
|
||||
if job.State != schema.JobStateRunning {
|
||||
@@ -287,16 +289,26 @@ func (api *NatsAPI) handleStopJob(payload string) {
|
||||
api.JobRepository.Mutex.Lock()
|
||||
defer api.JobRepository.Mutex.Unlock()
|
||||
|
||||
if err := api.JobRepository.Stop(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
if err := api.JobRepository.StopCached(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
cclog.Errorf("NATS job stop: jobId %d (id %d) on %s: marking job as '%s' failed: %v",
|
||||
job.JobID, job.ID, job.Cluster, job.State, err)
|
||||
// If the job is still in job_cache, transfer it to the job table first
|
||||
if isCached {
|
||||
newID, err := api.JobRepository.TransferCachedJobToMain(*job.ID)
|
||||
if err != nil {
|
||||
cclog.Errorf("NATS job stop: jobId %d (id %d) on %s: transferring cached job failed: %v",
|
||||
job.JobID, *job.ID, job.Cluster, err)
|
||||
return
|
||||
}
|
||||
cclog.Infof("NATS: transferred cached job to main table: old id %d -> new id %d (jobId=%d)", *job.ID, newID, job.JobID)
|
||||
job.ID = &newID
|
||||
}
|
||||
|
||||
if err := api.JobRepository.Stop(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
cclog.Errorf("NATS job stop: jobId %d (id %d) on %s: marking job as '%s' failed: %v",
|
||||
job.JobID, *job.ID, job.Cluster, job.State, err)
|
||||
return
|
||||
}
|
||||
|
||||
cclog.Infof("NATS: archiving job (dbid: %d): cluster=%s, jobId=%d, user=%s, startTime=%d, duration=%d, state=%s",
|
||||
job.ID, job.Cluster, job.JobID, job.User, job.StartTime, job.Duration, job.State)
|
||||
*job.ID, job.Cluster, job.JobID, job.User, job.StartTime, job.Duration, job.State)
|
||||
|
||||
if job.MonitoringStatus == schema.MonitoringStatusDisabled {
|
||||
return
|
||||
|
||||
@@ -80,7 +80,7 @@ func (api *RestAPI) updateNodeStates(rw http.ResponseWriter, r *http.Request) {
|
||||
ms := metricstore.GetMemoryStore()
|
||||
|
||||
m := make(map[string][]string)
|
||||
healthStates := make(map[string]schema.MonitoringState)
|
||||
healthResults := make(map[string]metricstore.HealthCheckResult)
|
||||
|
||||
startMs := time.Now()
|
||||
|
||||
@@ -94,8 +94,8 @@ func (api *RestAPI) updateNodeStates(rw http.ResponseWriter, r *http.Request) {
|
||||
if sc != "" {
|
||||
metricList := archive.GetMetricConfigSubCluster(req.Cluster, sc)
|
||||
metricNames := metricListToNames(metricList)
|
||||
if states, err := ms.HealthCheck(req.Cluster, nl, metricNames); err == nil {
|
||||
maps.Copy(healthStates, states)
|
||||
if results, err := ms.HealthCheck(req.Cluster, nl, metricNames); err == nil {
|
||||
maps.Copy(healthResults, results)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -106,8 +106,10 @@ func (api *RestAPI) updateNodeStates(rw http.ResponseWriter, r *http.Request) {
|
||||
for _, node := range req.Nodes {
|
||||
state := determineState(node.States)
|
||||
healthState := schema.MonitoringStateFailed
|
||||
if hs, ok := healthStates[node.Hostname]; ok {
|
||||
healthState = hs
|
||||
var healthMetrics string
|
||||
if result, ok := healthResults[node.Hostname]; ok {
|
||||
healthState = result.State
|
||||
healthMetrics = result.HealthMetrics
|
||||
}
|
||||
nodeState := schema.NodeStateDB{
|
||||
TimeStamp: requestReceived,
|
||||
@@ -116,10 +118,14 @@ func (api *RestAPI) updateNodeStates(rw http.ResponseWriter, r *http.Request) {
|
||||
MemoryAllocated: node.MemoryAllocated,
|
||||
GpusAllocated: node.GpusAllocated,
|
||||
HealthState: healthState,
|
||||
HealthMetrics: healthMetrics,
|
||||
JobsRunning: node.JobsRunning,
|
||||
}
|
||||
|
||||
repo.UpdateNodeState(node.Hostname, req.Cluster, &nodeState)
|
||||
if err := repo.UpdateNodeState(node.Hostname, req.Cluster, &nodeState); err != nil {
|
||||
cclog.Errorf("updateNodeStates: updating node state for %s on %s failed: %v",
|
||||
node.Hostname, req.Cluster, err)
|
||||
}
|
||||
}
|
||||
|
||||
cclog.Debugf("Timer updateNodeStates, SQLite Inserts: %s", time.Since(startDB))
|
||||
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/util"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// @title ClusterCockpit REST API
|
||||
@@ -73,91 +73,95 @@ func New() *RestAPI {
|
||||
|
||||
// MountAPIRoutes registers REST API endpoints for job and cluster management.
|
||||
// These routes use JWT token authentication via the X-Auth-Token header.
|
||||
func (api *RestAPI) MountAPIRoutes(r *mux.Router) {
|
||||
r.StrictSlash(true)
|
||||
func (api *RestAPI) MountAPIRoutes(r chi.Router) {
|
||||
// REST API Uses TokenAuth
|
||||
// User List
|
||||
r.HandleFunc("/users/", api.getUsers).Methods(http.MethodGet)
|
||||
r.Get("/users/", api.getUsers)
|
||||
// Cluster List
|
||||
r.HandleFunc("/clusters/", api.getClusters).Methods(http.MethodGet)
|
||||
r.Get("/clusters/", api.getClusters)
|
||||
// Slurm node state
|
||||
r.HandleFunc("/nodestate/", api.updateNodeStates).Methods(http.MethodPost, http.MethodPut)
|
||||
r.Post("/nodestate/", api.updateNodeStates)
|
||||
r.Put("/nodestate/", api.updateNodeStates)
|
||||
// Job Handler
|
||||
if config.Keys.APISubjects == nil {
|
||||
cclog.Info("Enabling REST start/stop job API")
|
||||
r.HandleFunc("/jobs/start_job/", api.startJob).Methods(http.MethodPost, http.MethodPut)
|
||||
r.HandleFunc("/jobs/stop_job/", api.stopJobByRequest).Methods(http.MethodPost, http.MethodPut)
|
||||
r.Post("/jobs/start_job/", api.startJob)
|
||||
r.Put("/jobs/start_job/", api.startJob)
|
||||
r.Post("/jobs/stop_job/", api.stopJobByRequest)
|
||||
r.Put("/jobs/stop_job/", api.stopJobByRequest)
|
||||
}
|
||||
r.HandleFunc("/jobs/", api.getJobs).Methods(http.MethodGet)
|
||||
r.HandleFunc("/jobs/used_nodes", api.getUsedNodes).Methods(http.MethodGet)
|
||||
r.HandleFunc("/jobs/tag_job/{id}", api.tagJob).Methods(http.MethodPost, http.MethodPatch)
|
||||
r.HandleFunc("/jobs/tag_job/{id}", api.removeTagJob).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/jobs/edit_meta/{id}", api.editMeta).Methods(http.MethodPost, http.MethodPatch)
|
||||
r.HandleFunc("/jobs/metrics/{id}", api.getJobMetrics).Methods(http.MethodGet)
|
||||
r.HandleFunc("/jobs/delete_job/", api.deleteJobByRequest).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/jobs/delete_job/{id}", api.deleteJobByID).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/jobs/delete_job_before/{ts}", api.deleteJobBefore).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/jobs/{id}", api.getJobByID).Methods(http.MethodPost)
|
||||
r.HandleFunc("/jobs/{id}", api.getCompleteJobByID).Methods(http.MethodGet)
|
||||
r.Get("/jobs/", api.getJobs)
|
||||
r.Get("/jobs/used_nodes", api.getUsedNodes)
|
||||
r.Post("/jobs/tag_job/{id}", api.tagJob)
|
||||
r.Patch("/jobs/tag_job/{id}", api.tagJob)
|
||||
r.Delete("/jobs/tag_job/{id}", api.removeTagJob)
|
||||
r.Post("/jobs/edit_meta/{id}", api.editMeta)
|
||||
r.Patch("/jobs/edit_meta/{id}", api.editMeta)
|
||||
r.Get("/jobs/metrics/{id}", api.getJobMetrics)
|
||||
r.Delete("/jobs/delete_job/", api.deleteJobByRequest)
|
||||
r.Delete("/jobs/delete_job/{id}", api.deleteJobByID)
|
||||
r.Delete("/jobs/delete_job_before/{ts}", api.deleteJobBefore)
|
||||
r.Post("/jobs/{id}", api.getJobByID)
|
||||
r.Get("/jobs/{id}", api.getCompleteJobByID)
|
||||
|
||||
r.HandleFunc("/tags/", api.removeTags).Methods(http.MethodDelete)
|
||||
r.Delete("/tags/", api.removeTags)
|
||||
|
||||
if api.MachineStateDir != "" {
|
||||
r.HandleFunc("/machine_state/{cluster}/{host}", api.getMachineState).Methods(http.MethodGet)
|
||||
r.HandleFunc("/machine_state/{cluster}/{host}", api.putMachineState).Methods(http.MethodPut, http.MethodPost)
|
||||
r.Get("/machine_state/{cluster}/{host}", api.getMachineState)
|
||||
r.Put("/machine_state/{cluster}/{host}", api.putMachineState)
|
||||
r.Post("/machine_state/{cluster}/{host}", api.putMachineState)
|
||||
}
|
||||
}
|
||||
|
||||
// MountUserAPIRoutes registers user-accessible REST API endpoints.
|
||||
// These are limited endpoints for regular users with JWT token authentication.
|
||||
func (api *RestAPI) MountUserAPIRoutes(r *mux.Router) {
|
||||
r.StrictSlash(true)
|
||||
func (api *RestAPI) MountUserAPIRoutes(r chi.Router) {
|
||||
// REST API Uses TokenAuth
|
||||
r.HandleFunc("/jobs/", api.getJobs).Methods(http.MethodGet)
|
||||
r.HandleFunc("/jobs/{id}", api.getJobByID).Methods(http.MethodPost)
|
||||
r.HandleFunc("/jobs/{id}", api.getCompleteJobByID).Methods(http.MethodGet)
|
||||
r.HandleFunc("/jobs/metrics/{id}", api.getJobMetrics).Methods(http.MethodGet)
|
||||
r.Get("/jobs/", api.getJobs)
|
||||
r.Post("/jobs/{id}", api.getJobByID)
|
||||
r.Get("/jobs/{id}", api.getCompleteJobByID)
|
||||
r.Get("/jobs/metrics/{id}", api.getJobMetrics)
|
||||
}
|
||||
|
||||
// MountMetricStoreAPIRoutes registers metric storage API endpoints.
|
||||
// These endpoints handle metric data ingestion and health checks with JWT token authentication.
|
||||
func (api *RestAPI) MountMetricStoreAPIRoutes(r *mux.Router) {
|
||||
func (api *RestAPI) MountMetricStoreAPIRoutes(r chi.Router) {
|
||||
// REST API Uses TokenAuth
|
||||
// Note: StrictSlash handles trailing slash variations automatically
|
||||
r.HandleFunc("/free", freeMetrics).Methods(http.MethodPost)
|
||||
r.HandleFunc("/write", writeMetrics).Methods(http.MethodPost)
|
||||
r.HandleFunc("/debug", debugMetrics).Methods(http.MethodGet)
|
||||
r.HandleFunc("/healthcheck", api.updateNodeStates).Methods(http.MethodPost)
|
||||
r.Post("/free", freeMetrics)
|
||||
r.Post("/write", writeMetrics)
|
||||
r.Get("/debug", debugMetrics)
|
||||
r.Post("/healthcheck", api.updateNodeStates)
|
||||
// Same endpoints but with trailing slash
|
||||
r.HandleFunc("/free/", freeMetrics).Methods(http.MethodPost)
|
||||
r.HandleFunc("/write/", writeMetrics).Methods(http.MethodPost)
|
||||
r.HandleFunc("/debug/", debugMetrics).Methods(http.MethodGet)
|
||||
r.HandleFunc("/healthcheck/", api.updateNodeStates).Methods(http.MethodPost)
|
||||
r.Post("/free/", freeMetrics)
|
||||
r.Post("/write/", writeMetrics)
|
||||
r.Get("/debug/", debugMetrics)
|
||||
r.Post("/healthcheck/", api.updateNodeStates)
|
||||
}
|
||||
|
||||
// MountConfigAPIRoutes registers configuration and user management endpoints.
|
||||
// These routes use session-based authentication and require admin privileges.
|
||||
func (api *RestAPI) MountConfigAPIRoutes(r *mux.Router) {
|
||||
r.StrictSlash(true)
|
||||
// Routes use full paths (including /config prefix) to avoid conflicting with
|
||||
// the /config page route when registered via Group instead of Route.
|
||||
func (api *RestAPI) MountConfigAPIRoutes(r chi.Router) {
|
||||
// Settings Frontend Uses SessionAuth
|
||||
if api.Authentication != nil {
|
||||
r.HandleFunc("/roles/", api.getRoles).Methods(http.MethodGet)
|
||||
r.HandleFunc("/users/", api.createUser).Methods(http.MethodPost, http.MethodPut)
|
||||
r.HandleFunc("/users/", api.getUsers).Methods(http.MethodGet)
|
||||
r.HandleFunc("/users/", api.deleteUser).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/user/{id}", api.updateUser).Methods(http.MethodPost)
|
||||
r.HandleFunc("/notice/", api.editNotice).Methods(http.MethodPost)
|
||||
r.Get("/config/roles/", api.getRoles)
|
||||
r.Post("/config/users/", api.createUser)
|
||||
r.Put("/config/users/", api.createUser)
|
||||
r.Get("/config/users/", api.getUsers)
|
||||
r.Delete("/config/users/", api.deleteUser)
|
||||
r.Post("/config/user/{id}", api.updateUser)
|
||||
r.Post("/config/notice/", api.editNotice)
|
||||
}
|
||||
}
|
||||
|
||||
// MountFrontendAPIRoutes registers frontend-specific API endpoints.
|
||||
// These routes support JWT generation and user configuration updates with session authentication.
|
||||
func (api *RestAPI) MountFrontendAPIRoutes(r *mux.Router) {
|
||||
r.StrictSlash(true)
|
||||
func (api *RestAPI) MountFrontendAPIRoutes(r chi.Router) {
|
||||
// Settings Frontend Uses SessionAuth
|
||||
if api.Authentication != nil {
|
||||
r.HandleFunc("/jwt/", api.getJWT).Methods(http.MethodGet)
|
||||
r.HandleFunc("/configuration/", api.updateConfiguration).Methods(http.MethodPost)
|
||||
r.Get("/jwt/", api.getJWT)
|
||||
r.Post("/configuration/", api.updateConfiguration)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -381,9 +385,8 @@ func (api *RestAPI) putMachineState(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
cluster := vars["cluster"]
|
||||
host := vars["host"]
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
host := chi.URLParam(r, "host")
|
||||
|
||||
if err := validatePathComponent(cluster, "cluster name"); err != nil {
|
||||
handleError(err, http.StatusBadRequest, rw)
|
||||
@@ -434,9 +437,8 @@ func (api *RestAPI) getMachineState(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
cluster := vars["cluster"]
|
||||
host := vars["host"]
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
host := chi.URLParam(r, "host")
|
||||
|
||||
if err := validatePathComponent(cluster, "cluster name"); err != nil {
|
||||
handleError(err, http.StatusBadRequest, rw)
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type APIReturnedUser struct {
|
||||
@@ -91,7 +91,7 @@ func (api *RestAPI) updateUser(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Handle role updates
|
||||
if newrole != "" {
|
||||
if err := repository.GetUserRepository().AddRole(r.Context(), mux.Vars(r)["id"], newrole); err != nil {
|
||||
if err := repository.GetUserRepository().AddRole(r.Context(), chi.URLParam(r, "id"), newrole); err != nil {
|
||||
handleError(fmt.Errorf("adding role failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (api *RestAPI) updateUser(rw http.ResponseWriter, r *http.Request) {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
} else if delrole != "" {
|
||||
if err := repository.GetUserRepository().RemoveRole(r.Context(), mux.Vars(r)["id"], delrole); err != nil {
|
||||
if err := repository.GetUserRepository().RemoveRole(r.Context(), chi.URLParam(r, "id"), delrole); err != nil {
|
||||
handleError(fmt.Errorf("removing role failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
@@ -107,7 +107,7 @@ func (api *RestAPI) updateUser(rw http.ResponseWriter, r *http.Request) {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
} else if newproj != "" {
|
||||
if err := repository.GetUserRepository().AddProject(r.Context(), mux.Vars(r)["id"], newproj); err != nil {
|
||||
if err := repository.GetUserRepository().AddProject(r.Context(), chi.URLParam(r, "id"), newproj); err != nil {
|
||||
handleError(fmt.Errorf("adding project failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
@@ -115,7 +115,7 @@ func (api *RestAPI) updateUser(rw http.ResponseWriter, r *http.Request) {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
} else if delproj != "" {
|
||||
if err := repository.GetUserRepository().RemoveProject(r.Context(), mux.Vars(r)["id"], delproj); err != nil {
|
||||
if err := repository.GetUserRepository().RemoveProject(r.Context(), chi.URLParam(r, "id"), delproj); err != nil {
|
||||
handleError(fmt.Errorf("removing project failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
@@ -164,7 +164,7 @@ func (api *RestAPI) createUser(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(password) == 0 && role != schema.GetRoleString(schema.RoleApi) {
|
||||
if len(password) == 0 && role != schema.GetRoleString(schema.RoleAPI) {
|
||||
handleError(fmt.Errorf("only API users are allowed to have a blank password (login will be impossible)"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,7 +170,6 @@ All exported functions are safe for concurrent use:
|
||||
- `Start()` - Safe to call once
|
||||
- `TriggerArchiving()` - Safe from multiple goroutines
|
||||
- `Shutdown()` - Safe to call once
|
||||
- `WaitForArchiving()` - Deprecated, but safe
|
||||
|
||||
Internal state is protected by:
|
||||
- Channel synchronization (`archiveChannel`)
|
||||
|
||||
@@ -294,6 +294,11 @@ func handleOIDCUser(OIDCUser *schema.User) {
|
||||
handleUserSync(OIDCUser, Keys.OpenIDConfig.SyncUserOnLogin, Keys.OpenIDConfig.UpdateUserOnLogin)
|
||||
}
|
||||
|
||||
// handleLdapUser syncs LDAP user with database
|
||||
func handleLdapUser(ldapUser *schema.User) {
|
||||
handleUserSync(ldapUser, Keys.LdapConfig.SyncUserOnLogin, Keys.LdapConfig.UpdateUserOnLogin)
|
||||
}
|
||||
|
||||
func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request, user *schema.User) error {
|
||||
session, err := auth.sessionStore.New(r, "session")
|
||||
if err != nil {
|
||||
@@ -443,13 +448,13 @@ func (auth *Authentication) AuthAPI(
|
||||
if user != nil {
|
||||
switch {
|
||||
case len(user.Roles) == 1:
|
||||
if user.HasRole(schema.RoleApi) {
|
||||
if user.HasRole(schema.RoleAPI) {
|
||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
case len(user.Roles) >= 2:
|
||||
if user.HasAllRoles([]schema.Role{schema.RoleAdmin, schema.RoleApi}) {
|
||||
if user.HasAllRoles([]schema.Role{schema.RoleAdmin, schema.RoleAPI}) {
|
||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||
return
|
||||
@@ -479,13 +484,13 @@ func (auth *Authentication) AuthUserAPI(
|
||||
if user != nil {
|
||||
switch {
|
||||
case len(user.Roles) == 1:
|
||||
if user.HasRole(schema.RoleApi) {
|
||||
if user.HasRole(schema.RoleAPI) {
|
||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
case len(user.Roles) >= 2:
|
||||
if user.HasRole(schema.RoleApi) && user.HasAnyRole([]schema.Role{schema.RoleUser, schema.RoleManager, schema.RoleSupport, schema.RoleAdmin}) {
|
||||
if user.HasRole(schema.RoleAPI) && user.HasAnyRole([]schema.Role{schema.RoleUser, schema.RoleManager, schema.RoleSupport, schema.RoleAdmin}) {
|
||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||
return
|
||||
@@ -515,13 +520,13 @@ func (auth *Authentication) AuthMetricStoreAPI(
|
||||
if user != nil {
|
||||
switch {
|
||||
case len(user.Roles) == 1:
|
||||
if user.HasRole(schema.RoleApi) {
|
||||
if user.HasRole(schema.RoleAPI) {
|
||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
case len(user.Roles) >= 2:
|
||||
if user.HasRole(schema.RoleApi) && user.HasAnyRole([]schema.Role{schema.RoleUser, schema.RoleManager, schema.RoleAdmin}) {
|
||||
if user.HasRole(schema.RoleAPI) && user.HasAnyRole([]schema.Role{schema.RoleUser, schema.RoleManager, schema.RoleAdmin}) {
|
||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||
return
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
@@ -25,16 +26,19 @@ type LdapConfig struct {
|
||||
UserBind string `json:"user-bind"`
|
||||
UserFilter string `json:"user-filter"`
|
||||
UserAttr string `json:"username-attr"`
|
||||
UIDAttr string `json:"uid-attr"`
|
||||
SyncInterval string `json:"sync-interval"` // Parsed using time.ParseDuration.
|
||||
SyncDelOldUsers bool `json:"sync-del-old-users"`
|
||||
|
||||
// Should an non-existent user be added to the DB if user exists in ldap directory
|
||||
SyncUserOnLogin bool `json:"sync-user-on-login"`
|
||||
// Should a non-existent user be added to the DB if user exists in ldap directory
|
||||
SyncUserOnLogin bool `json:"sync-user-on-login"`
|
||||
UpdateUserOnLogin bool `json:"update-user-on-login"`
|
||||
}
|
||||
|
||||
type LdapAuthenticator struct {
|
||||
syncPassword string
|
||||
UserAttr string
|
||||
UIDAttr string
|
||||
}
|
||||
|
||||
var _ Authenticator = (*LdapAuthenticator)(nil)
|
||||
@@ -51,6 +55,12 @@ func (la *LdapAuthenticator) Init() error {
|
||||
la.UserAttr = "gecos"
|
||||
}
|
||||
|
||||
if Keys.LdapConfig.UIDAttr != "" {
|
||||
la.UIDAttr = Keys.LdapConfig.UIDAttr
|
||||
} else {
|
||||
la.UIDAttr = "uid"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -66,55 +76,44 @@ func (la *LdapAuthenticator) CanLogin(
|
||||
if user.AuthSource == schema.AuthViaLDAP {
|
||||
return user, true
|
||||
}
|
||||
} else {
|
||||
if lc.SyncUserOnLogin {
|
||||
l, err := la.getLdapConnection(true)
|
||||
if err != nil {
|
||||
cclog.Error("LDAP connection error")
|
||||
return nil, false
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
// Search for the given username
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
lc.UserBase,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&%s(uid=%s))", lc.UserFilter, username),
|
||||
[]string{"dn", "uid", la.UserAttr}, nil)
|
||||
|
||||
sr, err := l.Search(searchRequest)
|
||||
if err != nil {
|
||||
cclog.Warn(err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(sr.Entries) != 1 {
|
||||
cclog.Warn("LDAP: User does not exist or too many entries returned")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
entry := sr.Entries[0]
|
||||
name := entry.GetAttributeValue(la.UserAttr)
|
||||
var roles []string
|
||||
roles = append(roles, schema.GetRoleString(schema.RoleUser))
|
||||
projects := make([]string, 0)
|
||||
|
||||
user = &schema.User{
|
||||
Username: username,
|
||||
Name: name,
|
||||
Roles: roles,
|
||||
Projects: projects,
|
||||
AuthType: schema.AuthSession,
|
||||
AuthSource: schema.AuthViaLDAP,
|
||||
}
|
||||
|
||||
if err := repository.GetUserRepository().AddUser(user); err != nil {
|
||||
cclog.Errorf("User '%s' LDAP: Insert into DB failed", username)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return user, true
|
||||
} else if lc.SyncUserOnLogin {
|
||||
l, err := la.getLdapConnection(true)
|
||||
if err != nil {
|
||||
cclog.Error("LDAP connection error")
|
||||
return nil, false
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
// Search for the given username
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
lc.UserBase,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&%s(%s=%s))", lc.UserFilter, la.UIDAttr, ldap.EscapeFilter(username)),
|
||||
[]string{"dn", la.UIDAttr, la.UserAttr}, nil)
|
||||
|
||||
sr, err := l.Search(searchRequest)
|
||||
if err != nil {
|
||||
cclog.Warn(err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(sr.Entries) != 1 {
|
||||
cclog.Warn("LDAP: User does not exist or too many entries returned")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
entry := sr.Entries[0]
|
||||
user = &schema.User{
|
||||
Username: username,
|
||||
Name: entry.GetAttributeValue(la.UserAttr),
|
||||
Roles: []string{schema.GetRoleString(schema.RoleUser)},
|
||||
Projects: make([]string, 0),
|
||||
AuthType: schema.AuthSession,
|
||||
AuthSource: schema.AuthViaLDAP,
|
||||
}
|
||||
|
||||
handleLdapUser(user)
|
||||
return user, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
@@ -132,7 +131,7 @@ func (la *LdapAuthenticator) Login(
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
userDn := strings.ReplaceAll(Keys.LdapConfig.UserBind, "{username}", user.Username)
|
||||
userDn := strings.ReplaceAll(Keys.LdapConfig.UserBind, "{username}", ldap.EscapeDN(user.Username))
|
||||
if err := l.Bind(userDn, r.FormValue("password")); err != nil {
|
||||
cclog.Errorf("AUTH/LDAP > Authentication for user %s failed: %v",
|
||||
user.Username, err)
|
||||
@@ -170,7 +169,7 @@ func (la *LdapAuthenticator) Sync() error {
|
||||
lc.UserBase,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
lc.UserFilter,
|
||||
[]string{"dn", "uid", la.UserAttr}, nil))
|
||||
[]string{"dn", la.UIDAttr, la.UserAttr}, nil))
|
||||
if err != nil {
|
||||
cclog.Warn("LDAP search error")
|
||||
return err
|
||||
@@ -178,9 +177,9 @@ func (la *LdapAuthenticator) Sync() error {
|
||||
|
||||
newnames := map[string]string{}
|
||||
for _, entry := range ldapResults.Entries {
|
||||
username := entry.GetAttributeValue("uid")
|
||||
username := entry.GetAttributeValue(la.UIDAttr)
|
||||
if username == "" {
|
||||
return errors.New("no attribute 'uid'")
|
||||
return fmt.Errorf("no attribute '%s'", la.UIDAttr)
|
||||
}
|
||||
|
||||
_, ok := users[username]
|
||||
@@ -194,20 +193,19 @@ func (la *LdapAuthenticator) Sync() error {
|
||||
|
||||
for username, where := range users {
|
||||
if where == InDB && lc.SyncDelOldUsers {
|
||||
ur.DelUser(username)
|
||||
if err := ur.DelUser(username); err != nil {
|
||||
cclog.Errorf("User '%s' LDAP: Delete from DB failed: %v", username, err)
|
||||
return err
|
||||
}
|
||||
cclog.Debugf("sync: remove %v (does not show up in LDAP anymore)", username)
|
||||
} else if where == InLdap {
|
||||
name := newnames[username]
|
||||
|
||||
var roles []string
|
||||
roles = append(roles, schema.GetRoleString(schema.RoleUser))
|
||||
projects := make([]string, 0)
|
||||
|
||||
user := &schema.User{
|
||||
Username: username,
|
||||
Name: name,
|
||||
Roles: roles,
|
||||
Projects: projects,
|
||||
Roles: []string{schema.GetRoleString(schema.RoleUser)},
|
||||
Projects: make([]string, 0),
|
||||
AuthSource: schema.AuthViaLDAP,
|
||||
}
|
||||
|
||||
@@ -224,11 +222,13 @@ func (la *LdapAuthenticator) Sync() error {
|
||||
|
||||
func (la *LdapAuthenticator) getLdapConnection(admin bool) (*ldap.Conn, error) {
|
||||
lc := Keys.LdapConfig
|
||||
conn, err := ldap.DialURL(lc.URL)
|
||||
conn, err := ldap.DialURL(lc.URL,
|
||||
ldap.DialWithDialer(&net.Dialer{Timeout: 10 * time.Second}))
|
||||
if err != nil {
|
||||
cclog.Warn("LDAP URL dial failed")
|
||||
return nil, err
|
||||
}
|
||||
conn.SetTimeout(30 * time.Second)
|
||||
|
||||
if admin {
|
||||
if err := conn.Bind(lc.SearchDN, la.syncPassword); err != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@@ -50,6 +51,7 @@ func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value strin
|
||||
MaxAge: int(time.Hour.Seconds()),
|
||||
Secure: r.TLS != nil,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
@@ -77,8 +79,7 @@ func NewOIDC(a *Authentication) *OIDC {
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: "oidc-callback",
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile"},
|
||||
}
|
||||
|
||||
oa := &OIDC{provider: provider, client: client, clientID: clientID, authentication: a}
|
||||
@@ -86,7 +87,7 @@ func NewOIDC(a *Authentication) *OIDC {
|
||||
return oa
|
||||
}
|
||||
|
||||
func (oa *OIDC) RegisterEndpoints(r *mux.Router) {
|
||||
func (oa *OIDC) RegisterEndpoints(r chi.Router) {
|
||||
r.HandleFunc("/oidc-login", oa.OAuth2Login)
|
||||
r.HandleFunc("/oidc-callback", oa.OAuth2Callback)
|
||||
}
|
||||
@@ -122,54 +123,93 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
token, err := oa.client.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
cclog.Errorf("token exchange failed: %s", err.Error())
|
||||
http.Error(rw, "Authentication failed during token exchange", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user info from OIDC provider with same timeout
|
||||
userInfo, err := oa.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
|
||||
cclog.Errorf("failed to get userinfo: %s", err.Error())
|
||||
http.Error(rw, "Failed to retrieve user information", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// // Extract the ID Token from OAuth2 token.
|
||||
// rawIDToken, ok := token.Extra("id_token").(string)
|
||||
// if !ok {
|
||||
// http.Error(rw, "Cannot access idToken", http.StatusInternalServerError)
|
||||
// }
|
||||
//
|
||||
// verifier := oa.provider.Verifier(&oidc.Config{ClientID: oa.clientID})
|
||||
// // Parse and verify ID Token payload.
|
||||
// idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
// if err != nil {
|
||||
// http.Error(rw, "Failed to extract idToken: "+err.Error(), http.StatusInternalServerError)
|
||||
// }
|
||||
// Verify ID token and nonce to prevent replay attacks
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(rw, "ID token not found in response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nonceCookie, err := r.Cookie("nonce")
|
||||
if err != nil {
|
||||
http.Error(rw, "nonce cookie not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
verifier := oa.provider.Verifier(&oidc.Config{ClientID: oa.clientID})
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
cclog.Errorf("ID token verification failed: %s", err.Error())
|
||||
http.Error(rw, "ID token verification failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce != nonceCookie.Value {
|
||||
http.Error(rw, "Nonce mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
projects := make([]string, 0)
|
||||
|
||||
// Extract custom claims
|
||||
// Extract custom claims from userinfo
|
||||
var claims struct {
|
||||
Username string `json:"preferred_username"`
|
||||
Name string `json:"name"`
|
||||
Profile struct {
|
||||
// Keycloak realm-level roles
|
||||
RealmAccess struct {
|
||||
Roles []string `json:"roles"`
|
||||
} `json:"realm_access"`
|
||||
// Keycloak client-level roles
|
||||
ResourceAccess struct {
|
||||
Client struct {
|
||||
Roles []string `json:"roles"`
|
||||
} `json:"clustercockpit"`
|
||||
} `json:"resource_access"`
|
||||
}
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
http.Error(rw, "Failed to extract Claims: "+err.Error(), http.StatusInternalServerError)
|
||||
cclog.Errorf("failed to extract claims: %s", err.Error())
|
||||
http.Error(rw, "Failed to extract user claims", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Username == "" {
|
||||
http.Error(rw, "Username claim missing from OIDC provider", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Merge roles from both client-level and realm-level access
|
||||
oidcRoles := append(claims.ResourceAccess.Client.Roles, claims.RealmAccess.Roles...)
|
||||
|
||||
roleSet := make(map[string]bool)
|
||||
for _, r := range oidcRoles {
|
||||
switch r {
|
||||
case "user":
|
||||
roleSet[schema.GetRoleString(schema.RoleUser)] = true
|
||||
case "admin":
|
||||
roleSet[schema.GetRoleString(schema.RoleAdmin)] = true
|
||||
case "manager":
|
||||
roleSet[schema.GetRoleString(schema.RoleManager)] = true
|
||||
case "support":
|
||||
roleSet[schema.GetRoleString(schema.RoleSupport)] = true
|
||||
}
|
||||
}
|
||||
|
||||
var roles []string
|
||||
for _, r := range claims.Profile.Client.Roles {
|
||||
switch r {
|
||||
case "user":
|
||||
roles = append(roles, schema.GetRoleString(schema.RoleUser))
|
||||
case "admin":
|
||||
roles = append(roles, schema.GetRoleString(schema.RoleAdmin))
|
||||
}
|
||||
for role := range roleSet {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
if len(roles) == 0 {
|
||||
@@ -188,8 +228,12 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
||||
handleOIDCUser(user)
|
||||
}
|
||||
|
||||
oa.authentication.SaveSession(rw, r, user)
|
||||
cclog.Infof("login successfull: user: %#v (roles: %v, projects: %v)", user.Username, user.Roles, user.Projects)
|
||||
if err := oa.authentication.SaveSession(rw, r, user); err != nil {
|
||||
cclog.Errorf("session save failed for user %q: %s", user.Username, err.Error())
|
||||
http.Error(rw, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
cclog.Infof("login successful: user: %#v (roles: %v, projects: %v)", user.Username, user.Roles, user.Projects)
|
||||
userCtx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
http.RedirectHandler("/", http.StatusTemporaryRedirect).ServeHTTP(rw, r.WithContext(userCtx))
|
||||
}
|
||||
@@ -206,7 +250,24 @@ func (oa *OIDC) OAuth2Login(rw http.ResponseWriter, r *http.Request) {
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
setCallbackCookie(rw, r, "verifier", codeVerifier)
|
||||
|
||||
// Generate nonce for ID token replay protection
|
||||
nonce, err := randString(16)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
setCallbackCookie(rw, r, "nonce", nonce)
|
||||
|
||||
// Build redirect URL from the incoming request
|
||||
scheme := "https"
|
||||
if r.TLS == nil && r.Header.Get("X-Forwarded-Proto") != "https" {
|
||||
scheme = "http"
|
||||
}
|
||||
oa.client.RedirectURL = fmt.Sprintf("%s://%s/oidc-callback", scheme, r.Host)
|
||||
|
||||
// Redirect user to consent page to ask for permission
|
||||
url := oa.client.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
|
||||
url := oa.client.AuthCodeURL(state, oauth2.AccessTypeOffline,
|
||||
oauth2.S256ChallengeOption(codeVerifier),
|
||||
oidc.Nonce(nonce))
|
||||
http.Redirect(rw, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -92,9 +92,17 @@ var configSchema = `
|
||||
"description": "Delete obsolete users in database.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"uid-attr": {
|
||||
"description": "LDAP attribute used as login username. Default: uid",
|
||||
"type": "string"
|
||||
},
|
||||
"sync-user-on-login": {
|
||||
"description": "Add non-existent user to DB at login attempt if user exists in Ldap directory",
|
||||
"type": "boolean"
|
||||
},
|
||||
"update-user-on-login": {
|
||||
"description": "Should an existent user attributes in the DB be updated at login attempt with values from LDAP.",
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": ["url", "user-base", "search-dn", "user-bind", "user-filter"]
|
||||
|
||||
@@ -74,6 +74,23 @@ type ProgramConfig struct {
|
||||
|
||||
// Systemd unit name for log viewer (default: "clustercockpit")
|
||||
SystemdUnit string `json:"systemd-unit"`
|
||||
|
||||
// Node state retention configuration
|
||||
NodeStateRetention *NodeStateRetention `json:"nodestate-retention"`
|
||||
}
|
||||
|
||||
type NodeStateRetention struct {
|
||||
Policy string `json:"policy"` // "delete" or "parquet"
|
||||
Age int `json:"age"` // hours, default 24
|
||||
TargetKind string `json:"target-kind"` // "file" or "s3"
|
||||
TargetPath string `json:"target-path"`
|
||||
TargetEndpoint string `json:"target-endpoint"`
|
||||
TargetBucket string `json:"target-bucket"`
|
||||
TargetAccessKey string `json:"target-access-key"`
|
||||
TargetSecretKey string `json:"target-secret-key"`
|
||||
TargetRegion string `json:"target-region"`
|
||||
TargetUsePathStyle bool `json:"target-use-path-style"`
|
||||
MaxFileSizeMB int `json:"max-file-size-mb"`
|
||||
}
|
||||
|
||||
type ResampleConfig struct {
|
||||
|
||||
@@ -130,6 +130,59 @@ var configSchema = `
|
||||
}
|
||||
},
|
||||
"required": ["subject-job-event", "subject-node-state"]
|
||||
},
|
||||
"nodestate-retention": {
|
||||
"description": "Node state retention configuration for cleaning up old node_state rows.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"policy": {
|
||||
"description": "Retention policy: 'delete' to remove old rows, 'parquet' to archive then delete.",
|
||||
"type": "string",
|
||||
"enum": ["delete", "parquet"]
|
||||
},
|
||||
"age": {
|
||||
"description": "Retention age in hours (default: 24).",
|
||||
"type": "integer"
|
||||
},
|
||||
"target-kind": {
|
||||
"description": "Target kind for parquet archiving: 'file' or 's3'.",
|
||||
"type": "string",
|
||||
"enum": ["file", "s3"]
|
||||
},
|
||||
"target-path": {
|
||||
"description": "Filesystem path for parquet file target.",
|
||||
"type": "string"
|
||||
},
|
||||
"target-endpoint": {
|
||||
"description": "S3 endpoint URL.",
|
||||
"type": "string"
|
||||
},
|
||||
"target-bucket": {
|
||||
"description": "S3 bucket name.",
|
||||
"type": "string"
|
||||
},
|
||||
"target-access-key": {
|
||||
"description": "S3 access key.",
|
||||
"type": "string"
|
||||
},
|
||||
"target-secret-key": {
|
||||
"description": "S3 secret key.",
|
||||
"type": "string"
|
||||
},
|
||||
"target-region": {
|
||||
"description": "S3 region.",
|
||||
"type": "string"
|
||||
},
|
||||
"target-use-path-style": {
|
||||
"description": "Use path-style S3 addressing.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"max-file-size-mb": {
|
||||
"description": "Maximum parquet file size in MB (default: 128).",
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": ["policy"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -10245,7 +10245,7 @@ func (ec *executionContext) _Series_id(ctx context.Context, field graphql.Collec
|
||||
field,
|
||||
ec.fieldContext_Series_id,
|
||||
func(ctx context.Context) (any, error) {
|
||||
return obj.Id, nil
|
||||
return obj.ID, nil
|
||||
},
|
||||
nil,
|
||||
ec.marshalOString2ᚖstring,
|
||||
|
||||
@@ -552,7 +552,7 @@ func (r *queryResolver) ScopedJobStats(ctx context.Context, id string, metrics [
|
||||
for _, stat := range stats {
|
||||
mdlStats = append(mdlStats, &model.ScopedStats{
|
||||
Hostname: stat.Hostname,
|
||||
ID: stat.Id,
|
||||
ID: stat.ID,
|
||||
Data: stat.Data,
|
||||
})
|
||||
}
|
||||
@@ -824,6 +824,7 @@ func (r *queryResolver) NodeMetricsList(ctx context.Context, cluster string, sub
|
||||
}
|
||||
|
||||
nodeRepo := repository.GetNodeRepository()
|
||||
// nodes -> array hostname
|
||||
nodes, stateMap, countNodes, hasNextPage, nerr := nodeRepo.GetNodesForList(ctx, cluster, subCluster, stateFilter, nodeFilter, page)
|
||||
if nerr != nil {
|
||||
return nil, errors.New("could not retrieve node list required for resolving NodeMetricsList")
|
||||
@@ -835,6 +836,7 @@ func (r *queryResolver) NodeMetricsList(ctx context.Context, cluster string, sub
|
||||
}
|
||||
}
|
||||
|
||||
// data -> map hostname:jobdata
|
||||
data, err := metricdispatch.LoadNodeListData(cluster, subCluster, nodes, metrics, scopes, *resolution, from, to, ctx)
|
||||
if err != nil {
|
||||
cclog.Warn("error while loading node data (Resolver.NodeMetricsList")
|
||||
@@ -842,18 +844,18 @@ func (r *queryResolver) NodeMetricsList(ctx context.Context, cluster string, sub
|
||||
}
|
||||
|
||||
nodeMetricsList := make([]*model.NodeMetrics, 0, len(data))
|
||||
for hostname, metrics := range data {
|
||||
for _, hostname := range nodes {
|
||||
host := &model.NodeMetrics{
|
||||
Host: hostname,
|
||||
State: stateMap[hostname],
|
||||
Metrics: make([]*model.JobMetricWithName, 0, len(metrics)*len(scopes)),
|
||||
Metrics: make([]*model.JobMetricWithName, 0),
|
||||
}
|
||||
host.SubCluster, err = archive.GetSubClusterByNode(cluster, hostname)
|
||||
if err != nil {
|
||||
cclog.Warnf("error in nodeMetrics resolver: %s", err)
|
||||
}
|
||||
|
||||
for metric, scopedMetrics := range metrics {
|
||||
for metric, scopedMetrics := range data[hostname] {
|
||||
for scope, scopedMetric := range scopedMetrics {
|
||||
host.Metrics = append(host.Metrics, &model.JobMetricWithName{
|
||||
Name: metric,
|
||||
@@ -867,7 +869,8 @@ func (r *queryResolver) NodeMetricsList(ctx context.Context, cluster string, sub
|
||||
}
|
||||
|
||||
nodeMetricsListResult := &model.NodesResultList{
|
||||
Items: nodeMetricsList,
|
||||
Items: nodeMetricsList,
|
||||
// TotalNodes depends on sum of nodes grouped on latest timestamp, see repo/node.go:357
|
||||
TotalNodes: &countNodes,
|
||||
HasNextPage: &hasNextPage,
|
||||
}
|
||||
|
||||
@@ -499,7 +499,7 @@ func copyJobMetric(src *schema.JobMetric) *schema.JobMetric {
|
||||
func copySeries(src *schema.Series) schema.Series {
|
||||
dst := schema.Series{
|
||||
Hostname: src.Hostname,
|
||||
Id: src.Id,
|
||||
ID: src.ID,
|
||||
Statistics: src.Statistics,
|
||||
Data: make([]schema.Float, len(src.Data)),
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestDeepCopy(t *testing.T) {
|
||||
Series: []schema.Series{
|
||||
{
|
||||
Hostname: "node001",
|
||||
Id: &nodeId,
|
||||
ID: &nodeId,
|
||||
Data: []schema.Float{1.0, 2.0, 3.0},
|
||||
Statistics: schema.MetricStatistics{
|
||||
Min: 1.0,
|
||||
|
||||
@@ -267,7 +267,7 @@ func (ccms *CCMetricStore) LoadData(
|
||||
|
||||
jobMetric.Series = append(jobMetric.Series, schema.Series{
|
||||
Hostname: query.Hostname,
|
||||
Id: id,
|
||||
ID: id,
|
||||
Statistics: schema.MetricStatistics{
|
||||
Avg: float64(res.Avg),
|
||||
Min: float64(res.Min),
|
||||
@@ -419,7 +419,7 @@ func (ccms *CCMetricStore) LoadScopedStats(
|
||||
|
||||
scopedJobStats[metric][scope] = append(scopedJobStats[metric][scope], &schema.ScopedStats{
|
||||
Hostname: query.Hostname,
|
||||
Id: id,
|
||||
ID: id,
|
||||
Data: &schema.MetricStatistics{
|
||||
Avg: float64(res.Avg),
|
||||
Min: float64(res.Min),
|
||||
@@ -634,7 +634,7 @@ func (ccms *CCMetricStore) LoadNodeListData(
|
||||
|
||||
scopeData.Series = append(scopeData.Series, schema.Series{
|
||||
Hostname: query.Hostname,
|
||||
Id: id,
|
||||
ID: id,
|
||||
Statistics: schema.MetricStatistics{
|
||||
Avg: float64(res.Avg),
|
||||
Min: float64(res.Min),
|
||||
|
||||
@@ -71,8 +71,9 @@ func (r *JobRepository) SyncJobs() ([]*schema.Job, error) {
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
// Use INSERT OR IGNORE to skip jobs already transferred by the stop path
|
||||
_, err = r.DB.Exec(
|
||||
"INSERT INTO job (job_id, cluster, subcluster, start_time, hpc_user, project, cluster_partition, array_job_id, num_nodes, num_hwthreads, num_acc, shared, monitoring_status, smt, job_state, duration, walltime, footprint, energy, energy_footprint, resources, meta_data) SELECT job_id, cluster, subcluster, start_time, hpc_user, project, cluster_partition, array_job_id, num_nodes, num_hwthreads, num_acc, shared, monitoring_status, smt, job_state, duration, walltime, footprint, energy, energy_footprint, resources, meta_data FROM job_cache")
|
||||
"INSERT OR IGNORE INTO job (job_id, cluster, subcluster, start_time, hpc_user, project, cluster_partition, array_job_id, num_nodes, num_hwthreads, num_acc, shared, monitoring_status, smt, job_state, duration, walltime, footprint, energy, energy_footprint, resources, meta_data) SELECT job_id, cluster, subcluster, start_time, hpc_user, project, cluster_partition, array_job_id, num_nodes, num_hwthreads, num_acc, shared, monitoring_status, smt, job_state, duration, walltime, footprint, energy, energy_footprint, resources, meta_data FROM job_cache")
|
||||
if err != nil {
|
||||
cclog.Warnf("Error while Job sync: %v", err)
|
||||
return nil, err
|
||||
@@ -87,6 +88,29 @@ func (r *JobRepository) SyncJobs() ([]*schema.Job, error) {
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// TransferCachedJobToMain moves a job from job_cache to the job table.
|
||||
// Caller must hold r.Mutex. Returns the new job table ID.
|
||||
func (r *JobRepository) TransferCachedJobToMain(cacheID int64) (int64, error) {
|
||||
res, err := r.DB.Exec(
|
||||
"INSERT INTO job (job_id, cluster, subcluster, start_time, hpc_user, project, cluster_partition, array_job_id, num_nodes, num_hwthreads, num_acc, shared, monitoring_status, smt, job_state, duration, walltime, footprint, energy, energy_footprint, resources, meta_data) SELECT job_id, cluster, subcluster, start_time, hpc_user, project, cluster_partition, array_job_id, num_nodes, num_hwthreads, num_acc, shared, monitoring_status, smt, job_state, duration, walltime, footprint, energy, energy_footprint, resources, meta_data FROM job_cache WHERE id = ?",
|
||||
cacheID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("transferring cached job %d to main table failed: %w", cacheID, err)
|
||||
}
|
||||
|
||||
newID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting new job ID after transfer failed: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", cacheID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("deleting cached job %d after transfer failed: %w", cacheID, err)
|
||||
}
|
||||
|
||||
return newID, nil
|
||||
}
|
||||
|
||||
// Start inserts a new job in the table, returning the unique job ID.
|
||||
// Statistics are not transfered!
|
||||
func (r *JobRepository) Start(job *schema.Job) (id int64, err error) {
|
||||
@@ -129,20 +153,3 @@ func (r *JobRepository) Stop(
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *JobRepository) StopCached(
|
||||
jobID int64,
|
||||
duration int32,
|
||||
state schema.JobState,
|
||||
monitoringStatus int32,
|
||||
) (err error) {
|
||||
// Note: StopCached updates job_cache table, not the main job table
|
||||
// Cache invalidation happens when job is synced to main table
|
||||
stmt := sq.Update("job_cache").
|
||||
Set("job_state", state).
|
||||
Set("duration", duration).
|
||||
Set("monitoring_status", monitoringStatus).
|
||||
Where("job_cache.id = ?", jobID)
|
||||
|
||||
_, err = stmt.RunWith(r.stmtCache).Exec()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -331,58 +331,60 @@ func TestStop(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopCached(t *testing.T) {
|
||||
func TestTransferCachedJobToMain(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
t.Run("successful stop cached job", func(t *testing.T) {
|
||||
t.Run("successful transfer from cache to main", func(t *testing.T) {
|
||||
// Insert a job in job_cache
|
||||
job := createTestJob(999009, "testcluster")
|
||||
id, err := r.Start(job)
|
||||
cacheID, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the cached job
|
||||
duration := int32(3600)
|
||||
state := schema.JobStateCompleted
|
||||
monitoringStatus := int32(schema.MonitoringStatusArchivingSuccessful)
|
||||
// Transfer the cached job to the main table
|
||||
r.Mutex.Lock()
|
||||
newID, err := r.TransferCachedJobToMain(cacheID)
|
||||
r.Mutex.Unlock()
|
||||
require.NoError(t, err, "TransferCachedJobToMain should succeed")
|
||||
assert.NotEqual(t, cacheID, newID, "New ID should differ from cache ID")
|
||||
|
||||
err = r.StopCached(id, duration, state, monitoringStatus)
|
||||
require.NoError(t, err, "StopCached should succeed")
|
||||
|
||||
// Verify job was updated in job_cache table
|
||||
var retrievedDuration int32
|
||||
var retrievedState string
|
||||
var retrievedMonStatus int32
|
||||
err = r.DB.QueryRow(`SELECT duration, job_state, monitoring_status FROM job_cache WHERE id = ?`, id).Scan(
|
||||
&retrievedDuration, &retrievedState, &retrievedMonStatus)
|
||||
// Verify job exists in job table
|
||||
var count int
|
||||
err = r.DB.QueryRow(`SELECT COUNT(*) FROM job WHERE id = ?`, newID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, duration, retrievedDuration)
|
||||
assert.Equal(t, string(state), retrievedState)
|
||||
assert.Equal(t, monitoringStatus, retrievedMonStatus)
|
||||
assert.Equal(t, 1, count, "Job should exist in main table")
|
||||
|
||||
// Verify job was removed from job_cache
|
||||
err = r.DB.QueryRow(`SELECT COUNT(*) FROM job_cache WHERE id = ?`, cacheID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count, "Job should be removed from cache")
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
_, err = r.DB.Exec("DELETE FROM job WHERE id = ?", newID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("stop cached job does not affect job table", func(t *testing.T) {
|
||||
t.Run("transfer preserves job data", func(t *testing.T) {
|
||||
// Insert a job in job_cache
|
||||
job := createTestJob(999010, "testcluster")
|
||||
id, err := r.Start(job)
|
||||
cacheID, err := r.Start(job)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the cached job
|
||||
err = r.StopCached(id, 3600, schema.JobStateCompleted, int32(schema.MonitoringStatusArchivingSuccessful))
|
||||
// Transfer the cached job
|
||||
r.Mutex.Lock()
|
||||
newID, err := r.TransferCachedJobToMain(cacheID)
|
||||
r.Mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify job table was not affected
|
||||
var count int
|
||||
err = r.DB.QueryRow(`SELECT COUNT(*) FROM job WHERE job_id = ? AND cluster = ?`,
|
||||
job.JobID, job.Cluster).Scan(&count)
|
||||
// Verify the transferred job has the correct data
|
||||
var jobID int64
|
||||
var cluster string
|
||||
err = r.DB.QueryRow(`SELECT job_id, cluster FROM job WHERE id = ?`, newID).Scan(&jobID, &cluster)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count, "Job table should not be affected by StopCached")
|
||||
assert.Equal(t, job.JobID, jobID)
|
||||
assert.Equal(t, job.Cluster, cluster)
|
||||
|
||||
// Clean up
|
||||
_, err = r.DB.Exec("DELETE FROM job_cache WHERE id = ?", id)
|
||||
_, err = r.DB.Exec("DELETE FROM job WHERE id = ?", newID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func SecurityCheckWithUser(user *schema.User, query sq.SelectBuilder) (sq.Select
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(user.Roles) == 1 && user.HasRole(schema.RoleApi):
|
||||
case len(user.Roles) == 1 && user.HasRole(schema.RoleAPI):
|
||||
return query, nil
|
||||
case user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}):
|
||||
return query, nil
|
||||
|
||||
@@ -23,6 +23,7 @@ CREATE TABLE "node_state" (
|
||||
CHECK (health_state IN (
|
||||
'full', 'partial', 'failed'
|
||||
)),
|
||||
health_metrics TEXT, -- JSON array of strings
|
||||
node_id INTEGER,
|
||||
FOREIGN KEY (node_id) REFERENCES node (id)
|
||||
);
|
||||
@@ -37,6 +38,7 @@ CREATE INDEX IF NOT EXISTS nodestates_state_timestamp ON node_state (node_state,
|
||||
CREATE INDEX IF NOT EXISTS nodestates_health_timestamp ON node_state (health_state, time_stamp);
|
||||
CREATE INDEX IF NOT EXISTS nodestates_nodeid_state ON node_state (node_id, node_state);
|
||||
CREATE INDEX IF NOT EXISTS nodestates_nodeid_health ON node_state (node_id, health_state);
|
||||
CREATE INDEX IF NOT EXISTS nodestates_nodeid_timestamp ON node_state (node_id, time_stamp DESC);
|
||||
|
||||
-- Add NEW Indices For Increased Amounts of Tags
|
||||
CREATE INDEX IF NOT EXISTS tags_jobid ON jobtag (job_id);
|
||||
|
||||
@@ -52,6 +52,38 @@ func GetNodeRepository() *NodeRepository {
|
||||
return nodeRepoInstance
|
||||
}
|
||||
|
||||
// latestStateCondition returns a squirrel expression that restricts node_state
|
||||
// rows to the latest per node_id using a correlated subquery.
|
||||
// Requires the query to join node and node_state tables.
|
||||
func latestStateCondition() sq.Sqlizer {
|
||||
return sq.Expr(
|
||||
"node_state.id = (SELECT ns2.id FROM node_state ns2 WHERE ns2.node_id = node.id ORDER BY ns2.time_stamp DESC LIMIT 1)",
|
||||
)
|
||||
}
|
||||
|
||||
// applyNodeFilters applies common NodeFilter conditions to a query that joins
|
||||
// the node and node_state tables with latestStateCondition.
|
||||
func applyNodeFilters(query sq.SelectBuilder, filters []*model.NodeFilter) sq.SelectBuilder {
|
||||
for _, f := range filters {
|
||||
if f.Cluster != nil {
|
||||
query = buildStringCondition("node.cluster", f.Cluster, query)
|
||||
}
|
||||
if f.SubCluster != nil {
|
||||
query = buildStringCondition("node.subcluster", f.SubCluster, query)
|
||||
}
|
||||
if f.Hostname != nil {
|
||||
query = buildStringCondition("node.hostname", f.Hostname, query)
|
||||
}
|
||||
if f.SchedulerState != nil {
|
||||
query = query.Where("node_state.node_state = ?", f.SchedulerState)
|
||||
}
|
||||
if f.HealthState != nil {
|
||||
query = query.Where("node_state.health_state = ?", f.HealthState)
|
||||
}
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (r *NodeRepository) FetchMetadata(hostname string, cluster string) (map[string]string, error) {
|
||||
start := time.Now()
|
||||
|
||||
@@ -82,17 +114,16 @@ func (r *NodeRepository) FetchMetadata(hostname string, cluster string) (map[str
|
||||
|
||||
func (r *NodeRepository) GetNode(hostname string, cluster string, withMeta bool) (*schema.Node, error) {
|
||||
node := &schema.Node{}
|
||||
var timestamp int
|
||||
if err := sq.Select("node.hostname", "node.cluster", "node.subcluster", "node_state.node_state",
|
||||
"node_state.health_state", "MAX(node_state.time_stamp) as time").
|
||||
From("node_state").
|
||||
Join("node ON node_state.node_id = node.id").
|
||||
if err := sq.Select("node.hostname", "node.cluster", "node.subcluster",
|
||||
"node_state.node_state", "node_state.health_state").
|
||||
From("node").
|
||||
Join("node_state ON node_state.node_id = node.id").
|
||||
Where(latestStateCondition()).
|
||||
Where("node.hostname = ?", hostname).
|
||||
Where("node.cluster = ?", cluster).
|
||||
GroupBy("node_state.node_id").
|
||||
RunWith(r.DB).
|
||||
QueryRow().Scan(&node.Hostname, &node.Cluster, &node.SubCluster, &node.NodeState, &node.HealthState, ×tamp); err != nil {
|
||||
cclog.Warnf("Error while querying node '%s' at time '%d' from database: %v", hostname, timestamp, err)
|
||||
QueryRow().Scan(&node.Hostname, &node.Cluster, &node.SubCluster, &node.NodeState, &node.HealthState); err != nil {
|
||||
cclog.Warnf("Error while querying node '%s' from database: %v", hostname, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -111,16 +142,15 @@ func (r *NodeRepository) GetNode(hostname string, cluster string, withMeta bool)
|
||||
|
||||
func (r *NodeRepository) GetNodeByID(id int64, withMeta bool) (*schema.Node, error) {
|
||||
node := &schema.Node{}
|
||||
var timestamp int
|
||||
if err := sq.Select("node.hostname", "node.cluster", "node.subcluster", "node_state.node_state",
|
||||
"node_state.health_state", "MAX(node_state.time_stamp) as time").
|
||||
From("node_state").
|
||||
Join("node ON node_state.node_id = node.id").
|
||||
if err := sq.Select("node.hostname", "node.cluster", "node.subcluster",
|
||||
"node_state.node_state", "node_state.health_state").
|
||||
From("node").
|
||||
Join("node_state ON node_state.node_id = node.id").
|
||||
Where(latestStateCondition()).
|
||||
Where("node.id = ?", id).
|
||||
GroupBy("node_state.node_id").
|
||||
RunWith(r.DB).
|
||||
QueryRow().Scan(&node.Hostname, &node.Cluster, &node.SubCluster, &node.NodeState, &node.HealthState, ×tamp); err != nil {
|
||||
cclog.Warnf("Error while querying node ID '%d' at time '%d' from database: %v", id, timestamp, err)
|
||||
QueryRow().Scan(&node.Hostname, &node.Cluster, &node.SubCluster, &node.NodeState, &node.HealthState); err != nil {
|
||||
cclog.Warnf("Error while querying node ID '%d' from database: %v", id, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -169,9 +199,10 @@ func (r *NodeRepository) AddNode(node *schema.NodeDB) (int64, error) {
|
||||
}
|
||||
|
||||
const NamedNodeStateInsert string = `
|
||||
INSERT INTO node_state (time_stamp, node_state, health_state, cpus_allocated,
|
||||
memory_allocated, gpus_allocated, jobs_running, node_id)
|
||||
VALUES (:time_stamp, :node_state, :health_state, :cpus_allocated, :memory_allocated, :gpus_allocated, :jobs_running, :node_id);`
|
||||
INSERT INTO node_state (time_stamp, node_state, health_state, health_metrics,
|
||||
cpus_allocated, memory_allocated, gpus_allocated, jobs_running, node_id)
|
||||
VALUES (:time_stamp, :node_state, :health_state, :health_metrics,
|
||||
:cpus_allocated, :memory_allocated, :gpus_allocated, :jobs_running, :node_id);`
|
||||
|
||||
// TODO: Add real Monitoring Health State
|
||||
|
||||
@@ -224,6 +255,75 @@ func (r *NodeRepository) UpdateNodeState(hostname string, cluster string, nodeSt
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// NodeStateWithNode combines a node state row with denormalized node info.
|
||||
type NodeStateWithNode struct {
|
||||
ID int64 `db:"id"`
|
||||
TimeStamp int64 `db:"time_stamp"`
|
||||
NodeState string `db:"node_state"`
|
||||
HealthState string `db:"health_state"`
|
||||
HealthMetrics string `db:"health_metrics"`
|
||||
CpusAllocated int `db:"cpus_allocated"`
|
||||
MemoryAllocated int64 `db:"memory_allocated"`
|
||||
GpusAllocated int `db:"gpus_allocated"`
|
||||
JobsRunning int `db:"jobs_running"`
|
||||
Hostname string `db:"hostname"`
|
||||
Cluster string `db:"cluster"`
|
||||
SubCluster string `db:"subcluster"`
|
||||
}
|
||||
|
||||
// FindNodeStatesBefore returns all node_state rows with time_stamp < cutoff,
|
||||
// joined with node info for denormalized archiving.
|
||||
func (r *NodeRepository) FindNodeStatesBefore(cutoff int64) ([]NodeStateWithNode, error) {
|
||||
rows, err := sq.Select(
|
||||
"node_state.id", "node_state.time_stamp", "node_state.node_state",
|
||||
"node_state.health_state", "node_state.health_metrics",
|
||||
"node_state.cpus_allocated", "node_state.memory_allocated",
|
||||
"node_state.gpus_allocated", "node_state.jobs_running",
|
||||
"node.hostname", "node.cluster", "node.subcluster",
|
||||
).
|
||||
From("node_state").
|
||||
Join("node ON node_state.node_id = node.id").
|
||||
Where(sq.Lt{"node_state.time_stamp": cutoff}).
|
||||
Where("node_state.id NOT IN (SELECT ns2.id FROM node_state ns2 WHERE ns2.time_stamp = (SELECT MAX(ns3.time_stamp) FROM node_state ns3 WHERE ns3.node_id = ns2.node_id))").
|
||||
OrderBy("node_state.time_stamp ASC").
|
||||
RunWith(r.DB).Query()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []NodeStateWithNode
|
||||
for rows.Next() {
|
||||
var ns NodeStateWithNode
|
||||
if err := rows.Scan(&ns.ID, &ns.TimeStamp, &ns.NodeState,
|
||||
&ns.HealthState, &ns.HealthMetrics,
|
||||
&ns.CpusAllocated, &ns.MemoryAllocated,
|
||||
&ns.GpusAllocated, &ns.JobsRunning,
|
||||
&ns.Hostname, &ns.Cluster, &ns.SubCluster); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, ns)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteNodeStatesBefore removes node_state rows with time_stamp < cutoff,
|
||||
// but always preserves the row with the latest timestamp per node_id.
|
||||
func (r *NodeRepository) DeleteNodeStatesBefore(cutoff int64) (int64, error) {
|
||||
res, err := r.DB.Exec(
|
||||
`DELETE FROM node_state WHERE time_stamp < ?
|
||||
AND id NOT IN (
|
||||
SELECT id FROM node_state ns2
|
||||
WHERE ns2.time_stamp = (SELECT MAX(ns3.time_stamp) FROM node_state ns3 WHERE ns3.node_id = ns2.node_id)
|
||||
)`,
|
||||
cutoff,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *NodeRepository) DeleteNode(id int64) error {
|
||||
_, err := r.DB.Exec(`DELETE FROM node WHERE node.id = ?`, id)
|
||||
if err != nil {
|
||||
@@ -243,38 +343,17 @@ func (r *NodeRepository) QueryNodes(
|
||||
order *model.OrderByInput, // Currently unused!
|
||||
) ([]*schema.Node, error) {
|
||||
query, qerr := AccessCheck(ctx,
|
||||
sq.Select("hostname", "cluster", "subcluster", "node_state", "health_state", "MAX(time_stamp) as time").
|
||||
sq.Select("node.hostname", "node.cluster", "node.subcluster",
|
||||
"node_state.node_state", "node_state.health_state").
|
||||
From("node").
|
||||
Join("node_state ON node_state.node_id = node.id"))
|
||||
Join("node_state ON node_state.node_id = node.id").
|
||||
Where(latestStateCondition()))
|
||||
if qerr != nil {
|
||||
return nil, qerr
|
||||
}
|
||||
|
||||
for _, f := range filters {
|
||||
if f.Cluster != nil {
|
||||
query = buildStringCondition("cluster", f.Cluster, query)
|
||||
}
|
||||
if f.SubCluster != nil {
|
||||
query = buildStringCondition("subcluster", f.SubCluster, query)
|
||||
}
|
||||
if f.Hostname != nil {
|
||||
query = buildStringCondition("hostname", f.Hostname, query)
|
||||
}
|
||||
if f.SchedulerState != nil {
|
||||
query = query.Where("node_state = ?", f.SchedulerState)
|
||||
// Requires Additional time_stamp Filter: Else the last (past!) time_stamp with queried state will be returned
|
||||
now := time.Now().Unix()
|
||||
query = query.Where(sq.Gt{"time_stamp": (now - 60)})
|
||||
}
|
||||
if f.HealthState != nil {
|
||||
query = query.Where("health_state = ?", f.HealthState)
|
||||
// Requires Additional time_stamp Filter: Else the last (past!) time_stamp with queried state will be returned
|
||||
now := time.Now().Unix()
|
||||
query = query.Where(sq.Gt{"time_stamp": (now - 60)})
|
||||
}
|
||||
}
|
||||
|
||||
query = query.GroupBy("node_id").OrderBy("hostname ASC")
|
||||
query = applyNodeFilters(query, filters)
|
||||
query = query.OrderBy("node.hostname ASC")
|
||||
|
||||
if page != nil && page.ItemsPerPage != -1 {
|
||||
limit := uint64(page.ItemsPerPage)
|
||||
@@ -291,11 +370,10 @@ func (r *NodeRepository) QueryNodes(
|
||||
nodes := make([]*schema.Node, 0)
|
||||
for rows.Next() {
|
||||
node := schema.Node{}
|
||||
var timestamp int
|
||||
if err := rows.Scan(&node.Hostname, &node.Cluster, &node.SubCluster,
|
||||
&node.NodeState, &node.HealthState, ×tamp); err != nil {
|
||||
&node.NodeState, &node.HealthState); err != nil {
|
||||
rows.Close()
|
||||
cclog.Warnf("Error while scanning rows (QueryNodes) at time '%d'", timestamp)
|
||||
cclog.Warn("Error while scanning rows (QueryNodes)")
|
||||
return nil, err
|
||||
}
|
||||
nodes = append(nodes, &node)
|
||||
@@ -305,72 +383,39 @@ func (r *NodeRepository) QueryNodes(
|
||||
}
|
||||
|
||||
// CountNodes returns the total matched nodes based on a node filter. It always operates
|
||||
// on the last state (largest timestamp).
|
||||
// on the last state (largest timestamp) per node.
|
||||
func (r *NodeRepository) CountNodes(
|
||||
ctx context.Context,
|
||||
filters []*model.NodeFilter,
|
||||
) (int, error) {
|
||||
query, qerr := AccessCheck(ctx,
|
||||
sq.Select("time_stamp", "count(*) as countRes").
|
||||
sq.Select("COUNT(*)").
|
||||
From("node").
|
||||
Join("node_state ON node_state.node_id = node.id"))
|
||||
Join("node_state ON node_state.node_id = node.id").
|
||||
Where(latestStateCondition()))
|
||||
if qerr != nil {
|
||||
return 0, qerr
|
||||
}
|
||||
|
||||
for _, f := range filters {
|
||||
if f.Cluster != nil {
|
||||
query = buildStringCondition("cluster", f.Cluster, query)
|
||||
}
|
||||
if f.SubCluster != nil {
|
||||
query = buildStringCondition("subcluster", f.SubCluster, query)
|
||||
}
|
||||
if f.Hostname != nil {
|
||||
query = buildStringCondition("hostname", f.Hostname, query)
|
||||
}
|
||||
if f.SchedulerState != nil {
|
||||
query = query.Where("node_state = ?", f.SchedulerState)
|
||||
// Requires Additional time_stamp Filter: Else the last (past!) time_stamp with queried state will be returned
|
||||
now := time.Now().Unix()
|
||||
query = query.Where(sq.Gt{"time_stamp": (now - 60)})
|
||||
}
|
||||
if f.HealthState != nil {
|
||||
query = query.Where("health_state = ?", f.HealthState)
|
||||
// Requires Additional time_stamp Filter: Else the last (past!) time_stamp with queried state will be returned
|
||||
now := time.Now().Unix()
|
||||
query = query.Where(sq.Gt{"time_stamp": (now - 60)})
|
||||
}
|
||||
}
|
||||
query = applyNodeFilters(query, filters)
|
||||
|
||||
query = query.GroupBy("time_stamp").OrderBy("time_stamp DESC").Limit(1)
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
var count int
|
||||
if err := query.RunWith(r.stmtCache).QueryRow().Scan(&count); err != nil {
|
||||
queryString, queryVars, _ := query.ToSql()
|
||||
cclog.Errorf("Error while running query '%s' %v: %v", queryString, queryVars, err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var totalNodes int
|
||||
for rows.Next() {
|
||||
var timestamp int
|
||||
if err := rows.Scan(×tamp, &totalNodes); err != nil {
|
||||
rows.Close()
|
||||
cclog.Warnf("Error while scanning rows (CountNodes) at time '%d'", timestamp)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return totalNodes, nil
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *NodeRepository) ListNodes(cluster string) ([]*schema.Node, error) {
|
||||
q := sq.Select("node.hostname", "node.cluster", "node.subcluster", "node_state.node_state",
|
||||
"node_state.health_state", "MAX(node_state.time_stamp) as time").
|
||||
q := sq.Select("node.hostname", "node.cluster", "node.subcluster",
|
||||
"node_state.node_state", "node_state.health_state").
|
||||
From("node").
|
||||
Join("node_state ON node_state.node_id = node.id").
|
||||
Where(latestStateCondition()).
|
||||
Where("node.cluster = ?", cluster).
|
||||
GroupBy("node_state.node_id").
|
||||
OrderBy("node.hostname ASC")
|
||||
|
||||
rows, err := q.RunWith(r.DB).Query()
|
||||
@@ -382,10 +427,9 @@ func (r *NodeRepository) ListNodes(cluster string) ([]*schema.Node, error) {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
node := &schema.Node{}
|
||||
var timestamp int
|
||||
if err := rows.Scan(&node.Hostname, &node.Cluster,
|
||||
&node.SubCluster, &node.NodeState, &node.HealthState, ×tamp); err != nil {
|
||||
cclog.Warnf("Error while scanning node list (ListNodes) at time '%d'", timestamp)
|
||||
&node.SubCluster, &node.NodeState, &node.HealthState); err != nil {
|
||||
cclog.Warn("Error while scanning node list (ListNodes)")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -396,11 +440,11 @@ func (r *NodeRepository) ListNodes(cluster string) ([]*schema.Node, error) {
|
||||
}
|
||||
|
||||
func (r *NodeRepository) MapNodes(cluster string) (map[string]string, error) {
|
||||
q := sq.Select("node.hostname", "node_state.node_state", "MAX(node_state.time_stamp) as time").
|
||||
q := sq.Select("node.hostname", "node_state.node_state").
|
||||
From("node").
|
||||
Join("node_state ON node_state.node_id = node.id").
|
||||
Where(latestStateCondition()).
|
||||
Where("node.cluster = ?", cluster).
|
||||
GroupBy("node_state.node_id").
|
||||
OrderBy("node.hostname ASC")
|
||||
|
||||
rows, err := q.RunWith(r.DB).Query()
|
||||
@@ -413,9 +457,8 @@ func (r *NodeRepository) MapNodes(cluster string) (map[string]string, error) {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var hostname, nodestate string
|
||||
var timestamp int
|
||||
if err := rows.Scan(&hostname, &nodestate, ×tamp); err != nil {
|
||||
cclog.Warnf("Error while scanning node list (MapNodes) at time '%d'", timestamp)
|
||||
if err := rows.Scan(&hostname, &nodestate); err != nil {
|
||||
cclog.Warn("Error while scanning node list (MapNodes)")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -426,33 +469,16 @@ func (r *NodeRepository) MapNodes(cluster string) (map[string]string, error) {
|
||||
}
|
||||
|
||||
func (r *NodeRepository) CountStates(ctx context.Context, filters []*model.NodeFilter, column string) ([]*model.NodeStates, error) {
|
||||
query, qerr := AccessCheck(ctx, sq.Select("hostname", column, "MAX(time_stamp) as time").From("node"))
|
||||
query, qerr := AccessCheck(ctx,
|
||||
sq.Select(column).
|
||||
From("node").
|
||||
Join("node_state ON node_state.node_id = node.id").
|
||||
Where(latestStateCondition()))
|
||||
if qerr != nil {
|
||||
return nil, qerr
|
||||
}
|
||||
|
||||
query = query.Join("node_state ON node_state.node_id = node.id")
|
||||
|
||||
for _, f := range filters {
|
||||
if f.Hostname != nil {
|
||||
query = buildStringCondition("hostname", f.Hostname, query)
|
||||
}
|
||||
if f.Cluster != nil {
|
||||
query = buildStringCondition("cluster", f.Cluster, query)
|
||||
}
|
||||
if f.SubCluster != nil {
|
||||
query = buildStringCondition("subcluster", f.SubCluster, query)
|
||||
}
|
||||
if f.SchedulerState != nil {
|
||||
query = query.Where("node_state = ?", f.SchedulerState)
|
||||
}
|
||||
if f.HealthState != nil {
|
||||
query = query.Where("health_state = ?", f.HealthState)
|
||||
}
|
||||
}
|
||||
|
||||
// Add Group and Order
|
||||
query = query.GroupBy("hostname").OrderBy("hostname DESC")
|
||||
query = applyNodeFilters(query, filters)
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
@@ -463,12 +489,10 @@ func (r *NodeRepository) CountStates(ctx context.Context, filters []*model.NodeF
|
||||
|
||||
stateMap := map[string]int{}
|
||||
for rows.Next() {
|
||||
var hostname, state string
|
||||
var timestamp int
|
||||
|
||||
if err := rows.Scan(&hostname, &state, ×tamp); err != nil {
|
||||
var state string
|
||||
if err := rows.Scan(&state); err != nil {
|
||||
rows.Close()
|
||||
cclog.Warnf("Error while scanning rows (CountStates) at time '%d'", timestamp)
|
||||
cclog.Warn("Error while scanning rows (CountStates)")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -661,26 +685,14 @@ func (r *NodeRepository) GetNodesForList(
|
||||
}
|
||||
|
||||
} else {
|
||||
// DB Nodes: Count and Find Next Page
|
||||
// DB Nodes: Count and derive hasNextPage from count
|
||||
var cerr error
|
||||
countNodes, cerr = r.CountNodes(ctx, queryFilters)
|
||||
if cerr != nil {
|
||||
cclog.Warn("error while counting node database data (Resolver.NodeMetricsList)")
|
||||
return nil, nil, 0, false, cerr
|
||||
}
|
||||
|
||||
// Example Page 4 @ 10 IpP : Does item 41 exist?
|
||||
// Minimal Page 41 @ 1 IpP : If len(result) is 1, Page 5 exists.
|
||||
nextPage := &model.PageRequest{
|
||||
ItemsPerPage: 1,
|
||||
Page: ((page.Page * page.ItemsPerPage) + 1),
|
||||
}
|
||||
nextNodes, err := r.QueryNodes(ctx, queryFilters, nextPage, nil) // Order not Used
|
||||
if err != nil {
|
||||
cclog.Warn("Error while querying next nodes")
|
||||
return nil, nil, 0, false, err
|
||||
}
|
||||
hasNextPage = len(nextNodes) == 1
|
||||
hasNextPage = page.Page*page.ItemsPerPage < countNodes
|
||||
}
|
||||
|
||||
// Fallback for non-init'd node table in DB; Ignores stateFilter
|
||||
|
||||
@@ -139,6 +139,13 @@ func nodeTestSetup(t *testing.T) {
|
||||
}
|
||||
archiveCfg := fmt.Sprintf("{\"kind\": \"file\",\"path\": \"%s\"}", jobarchive)
|
||||
|
||||
if err := ResetConnection(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
ResetConnection()
|
||||
})
|
||||
|
||||
Connect(dbfilepath)
|
||||
|
||||
if err := archive.Init(json.RawMessage(archiveCfg)); err != nil {
|
||||
@@ -149,8 +156,12 @@ func nodeTestSetup(t *testing.T) {
|
||||
func TestUpdateNodeState(t *testing.T) {
|
||||
nodeTestSetup(t)
|
||||
|
||||
repo := GetNodeRepository()
|
||||
now := time.Now().Unix()
|
||||
|
||||
nodeState := schema.NodeStateDB{
|
||||
TimeStamp: time.Now().Unix(), NodeState: "allocated",
|
||||
TimeStamp: now,
|
||||
NodeState: "allocated",
|
||||
CpusAllocated: 72,
|
||||
MemoryAllocated: 480,
|
||||
GpusAllocated: 0,
|
||||
@@ -158,18 +169,152 @@ func TestUpdateNodeState(t *testing.T) {
|
||||
JobsRunning: 1,
|
||||
}
|
||||
|
||||
repo := GetNodeRepository()
|
||||
err := repo.UpdateNodeState("host124", "testcluster", &nodeState)
|
||||
if err != nil {
|
||||
return
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
node, err := repo.GetNode("host124", "testcluster", false)
|
||||
if err != nil {
|
||||
return
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if node.NodeState != "allocated" {
|
||||
t.Errorf("wrong node state\ngot: %s \nwant: allocated ", node.NodeState)
|
||||
}
|
||||
|
||||
t.Run("FindBeforeEmpty", func(t *testing.T) {
|
||||
// Only the current-timestamp row exists, so nothing should be found before now
|
||||
rows, err := repo.FindNodeStatesBefore(now)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rows) != 0 {
|
||||
t.Errorf("expected 0 rows, got %d", len(rows))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteOldRows", func(t *testing.T) {
|
||||
// Insert 2 more old rows for host124
|
||||
for i, ts := range []int64{now - 7200, now - 3600} {
|
||||
ns := schema.NodeStateDB{
|
||||
TimeStamp: ts,
|
||||
NodeState: "allocated",
|
||||
HealthState: schema.MonitoringStateFull,
|
||||
CpusAllocated: 72,
|
||||
MemoryAllocated: 480,
|
||||
JobsRunning: i,
|
||||
}
|
||||
if err := repo.UpdateNodeState("host124", "testcluster", &ns); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete rows older than 30 minutes
|
||||
cutoff := now - 1800
|
||||
cnt, err := repo.DeleteNodeStatesBefore(cutoff)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should delete the 2 old rows
|
||||
if cnt != 2 {
|
||||
t.Errorf("expected 2 deleted rows, got %d", cnt)
|
||||
}
|
||||
|
||||
// Latest row should still exist
|
||||
node, err := repo.GetNode("host124", "testcluster", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if node.NodeState != "allocated" {
|
||||
t.Errorf("expected node state 'allocated', got %s", node.NodeState)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreservesLatestPerNode", func(t *testing.T) {
|
||||
// Insert a single old row for host125 — it's the latest per node so it must survive
|
||||
ns := schema.NodeStateDB{
|
||||
TimeStamp: now - 7200,
|
||||
NodeState: "idle",
|
||||
HealthState: schema.MonitoringStateFull,
|
||||
CpusAllocated: 0,
|
||||
MemoryAllocated: 0,
|
||||
JobsRunning: 0,
|
||||
}
|
||||
if err := repo.UpdateNodeState("host125", "testcluster", &ns); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete everything older than now — the latest per node should be preserved
|
||||
_, err := repo.DeleteNodeStatesBefore(now)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The latest row for host125 must still exist
|
||||
node, err := repo.GetNode("host125", "testcluster", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if node.NodeState != "idle" {
|
||||
t.Errorf("expected node state 'idle', got %s", node.NodeState)
|
||||
}
|
||||
|
||||
// Verify exactly 1 row remains for host125
|
||||
var countAfter int
|
||||
if err := repo.DB.QueryRow(
|
||||
"SELECT COUNT(*) FROM node_state WHERE node_id = (SELECT id FROM node WHERE hostname = 'host125')").
|
||||
Scan(&countAfter); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if countAfter != 1 {
|
||||
t.Errorf("expected 1 row remaining for host125, got %d", countAfter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FindBeforeWithJoin", func(t *testing.T) {
|
||||
// Insert old and current rows for host123
|
||||
for _, ts := range []int64{now - 7200, now} {
|
||||
ns := schema.NodeStateDB{
|
||||
TimeStamp: ts,
|
||||
NodeState: "allocated",
|
||||
HealthState: schema.MonitoringStateFull,
|
||||
CpusAllocated: 8,
|
||||
MemoryAllocated: 1024,
|
||||
GpusAllocated: 1,
|
||||
JobsRunning: 1,
|
||||
}
|
||||
if err := repo.UpdateNodeState("host123", "testcluster", &ns); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Find rows older than 30 minutes, excluding latest per node
|
||||
cutoff := now - 1800
|
||||
rows, err := repo.FindNodeStatesBefore(cutoff)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should find the old host123 row
|
||||
found := false
|
||||
for _, row := range rows {
|
||||
if row.Hostname == "host123" && row.TimeStamp == now-7200 {
|
||||
found = true
|
||||
if row.Cluster != "testcluster" {
|
||||
t.Errorf("expected cluster 'testcluster', got %s", row.Cluster)
|
||||
}
|
||||
if row.SubCluster != "sc1" {
|
||||
t.Errorf("expected subcluster 'sc1', got %s", row.SubCluster)
|
||||
}
|
||||
if row.CpusAllocated != 8 {
|
||||
t.Errorf("expected cpus_allocated 8, got %d", row.CpusAllocated)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected to find old host123 row among %d results", len(rows))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/graph/model"
|
||||
@@ -148,8 +150,22 @@ func getContext(tb testing.TB) context.Context {
|
||||
func setup(tb testing.TB) *JobRepository {
|
||||
tb.Helper()
|
||||
cclog.Init("warn", true)
|
||||
dbfile := "testdata/job.db"
|
||||
err := MigrateDB(dbfile)
|
||||
|
||||
// Copy test DB to a temp file for test isolation
|
||||
srcData, err := os.ReadFile("testdata/job.db")
|
||||
noErr(tb, err)
|
||||
dbfile := filepath.Join(tb.TempDir(), "job.db")
|
||||
err = os.WriteFile(dbfile, srcData, 0o644)
|
||||
noErr(tb, err)
|
||||
|
||||
// Reset singletons so Connect uses the new temp DB
|
||||
err = ResetConnection()
|
||||
noErr(tb, err)
|
||||
tb.Cleanup(func() {
|
||||
ResetConnection()
|
||||
})
|
||||
|
||||
err = MigrateDB(dbfile)
|
||||
noErr(tb, err)
|
||||
Connect(dbfile)
|
||||
return GetJobRepository()
|
||||
|
||||
@@ -25,17 +25,11 @@ func TestBuildJobStatsQuery(t *testing.T) {
|
||||
func TestJobStats(t *testing.T) {
|
||||
r := setup(t)
|
||||
|
||||
// First, count the actual jobs in the database (excluding test jobs)
|
||||
var expectedCount int
|
||||
err := r.DB.QueryRow(`SELECT COUNT(*) FROM job WHERE cluster != 'testcluster'`).Scan(&expectedCount)
|
||||
err := r.DB.QueryRow(`SELECT COUNT(*) FROM job`).Scan(&expectedCount)
|
||||
noErr(t, err)
|
||||
|
||||
filter := &model.JobFilter{}
|
||||
// Exclude test jobs created by other tests
|
||||
testCluster := "testcluster"
|
||||
filter.Cluster = &model.StringInput{Neq: &testCluster}
|
||||
|
||||
stats, err := r.JobsStats(getContext(t), []*model.JobFilter{filter})
|
||||
stats, err := r.JobsStats(getContext(t), []*model.JobFilter{})
|
||||
noErr(t, err)
|
||||
|
||||
if stats[0].TotalJobs != expectedCount {
|
||||
|
||||
@@ -644,12 +644,12 @@ func (r *JobRepository) checkScopeAuth(user *schema.User, operation string, scop
|
||||
if user != nil {
|
||||
switch {
|
||||
case operation == "write" && scope == "admin":
|
||||
if user.HasRole(schema.RoleAdmin) || (len(user.Roles) == 1 && user.HasRole(schema.RoleApi)) {
|
||||
if user.HasRole(schema.RoleAdmin) || (len(user.Roles) == 1 && user.HasRole(schema.RoleAPI)) {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
case operation == "write" && scope == "global":
|
||||
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) || (len(user.Roles) == 1 && user.HasRole(schema.RoleApi)) {
|
||||
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) || (len(user.Roles) == 1 && user.HasRole(schema.RoleAPI)) {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
|
||||
@@ -31,8 +31,25 @@ func setupUserTest(t *testing.T) *UserCfgRepo {
|
||||
}`
|
||||
|
||||
cclog.Init("info", true)
|
||||
dbfilepath := "testdata/job.db"
|
||||
err := MigrateDB(dbfilepath)
|
||||
|
||||
// Copy test DB to a temp file for test isolation
|
||||
srcData, err := os.ReadFile("testdata/job.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dbfilepath := filepath.Join(t.TempDir(), "job.db")
|
||||
if err := os.WriteFile(dbfilepath, srcData, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := ResetConnection(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
ResetConnection()
|
||||
})
|
||||
|
||||
err = MigrateDB(dbfilepath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/util"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type InfoType map[string]interface{}
|
||||
@@ -97,7 +97,7 @@ func setupConfigRoute(i InfoType, r *http.Request) InfoType {
|
||||
}
|
||||
|
||||
func setupJobRoute(i InfoType, r *http.Request) InfoType {
|
||||
i["id"] = mux.Vars(r)["id"]
|
||||
i["id"] = chi.URLParam(r, "id")
|
||||
if config.Keys.EmissionConstant != 0 {
|
||||
i["emission"] = config.Keys.EmissionConstant
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func setupJobRoute(i InfoType, r *http.Request) InfoType {
|
||||
}
|
||||
|
||||
func setupUserRoute(i InfoType, r *http.Request) InfoType {
|
||||
username := mux.Vars(r)["id"]
|
||||
username := chi.URLParam(r, "id")
|
||||
i["id"] = username
|
||||
i["username"] = username
|
||||
// TODO: If forbidden (== err exists), redirect to error page
|
||||
@@ -117,33 +117,33 @@ func setupUserRoute(i InfoType, r *http.Request) InfoType {
|
||||
}
|
||||
|
||||
func setupClusterStatusRoute(i InfoType, r *http.Request) InfoType {
|
||||
vars := mux.Vars(r)
|
||||
i["id"] = vars["cluster"]
|
||||
i["cluster"] = vars["cluster"]
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
i["id"] = cluster
|
||||
i["cluster"] = cluster
|
||||
i["displayType"] = "DASHBOARD"
|
||||
return i
|
||||
}
|
||||
|
||||
func setupClusterDetailRoute(i InfoType, r *http.Request) InfoType {
|
||||
vars := mux.Vars(r)
|
||||
i["id"] = vars["cluster"]
|
||||
i["cluster"] = vars["cluster"]
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
i["id"] = cluster
|
||||
i["cluster"] = cluster
|
||||
i["displayType"] = "DETAILS"
|
||||
return i
|
||||
}
|
||||
|
||||
func setupDashboardRoute(i InfoType, r *http.Request) InfoType {
|
||||
vars := mux.Vars(r)
|
||||
i["id"] = vars["cluster"]
|
||||
i["cluster"] = vars["cluster"]
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
i["id"] = cluster
|
||||
i["cluster"] = cluster
|
||||
i["displayType"] = "PUBLIC" // Used in Main Template
|
||||
return i
|
||||
}
|
||||
|
||||
func setupClusterOverviewRoute(i InfoType, r *http.Request) InfoType {
|
||||
vars := mux.Vars(r)
|
||||
i["id"] = vars["cluster"]
|
||||
i["cluster"] = vars["cluster"]
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
i["id"] = cluster
|
||||
i["cluster"] = cluster
|
||||
i["displayType"] = "OVERVIEW"
|
||||
|
||||
from, to := r.URL.Query().Get("from"), r.URL.Query().Get("to")
|
||||
@@ -155,11 +155,12 @@ func setupClusterOverviewRoute(i InfoType, r *http.Request) InfoType {
|
||||
}
|
||||
|
||||
func setupClusterListRoute(i InfoType, r *http.Request) InfoType {
|
||||
vars := mux.Vars(r)
|
||||
i["id"] = vars["cluster"]
|
||||
i["cluster"] = vars["cluster"]
|
||||
i["sid"] = vars["subcluster"]
|
||||
i["subCluster"] = vars["subcluster"]
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
subcluster := chi.URLParam(r, "subcluster")
|
||||
i["id"] = cluster
|
||||
i["cluster"] = cluster
|
||||
i["sid"] = subcluster
|
||||
i["subCluster"] = subcluster
|
||||
i["displayType"] = "LIST"
|
||||
|
||||
from, to := r.URL.Query().Get("from"), r.URL.Query().Get("to")
|
||||
@@ -171,10 +172,11 @@ func setupClusterListRoute(i InfoType, r *http.Request) InfoType {
|
||||
}
|
||||
|
||||
func setupNodeRoute(i InfoType, r *http.Request) InfoType {
|
||||
vars := mux.Vars(r)
|
||||
i["cluster"] = vars["cluster"]
|
||||
i["hostname"] = vars["hostname"]
|
||||
i["id"] = fmt.Sprintf("%s (%s)", vars["cluster"], vars["hostname"])
|
||||
cluster := chi.URLParam(r, "cluster")
|
||||
hostname := chi.URLParam(r, "hostname")
|
||||
i["cluster"] = cluster
|
||||
i["hostname"] = hostname
|
||||
i["id"] = fmt.Sprintf("%s (%s)", cluster, hostname)
|
||||
from, to := r.URL.Query().Get("from"), r.URL.Query().Get("to")
|
||||
if from != "" && to != "" {
|
||||
i["from"] = from
|
||||
@@ -184,7 +186,7 @@ func setupNodeRoute(i InfoType, r *http.Request) InfoType {
|
||||
}
|
||||
|
||||
func setupAnalysisRoute(i InfoType, r *http.Request) InfoType {
|
||||
i["cluster"] = mux.Vars(r)["cluster"]
|
||||
i["cluster"] = chi.URLParam(r, "cluster")
|
||||
return i
|
||||
}
|
||||
|
||||
@@ -396,7 +398,7 @@ func buildFilterPresets(query url.Values) map[string]interface{} {
|
||||
return filterPresets
|
||||
}
|
||||
|
||||
func SetupRoutes(router *mux.Router, buildInfo web.Build) {
|
||||
func SetupRoutes(router chi.Router, buildInfo web.Build) {
|
||||
userCfgRepo := repository.GetUserCfgRepo()
|
||||
for _, route := range routes {
|
||||
route := route
|
||||
|
||||
120
internal/taskmanager/nodestateRetentionService.go
Normal file
120
internal/taskmanager/nodestateRetentionService.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package taskmanager
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/config"
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
pqarchive "github.com/ClusterCockpit/cc-backend/pkg/archive/parquet"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
)
|
||||
|
||||
func RegisterNodeStateRetentionDeleteService(ageHours int) {
|
||||
cclog.Info("Register node state retention delete service")
|
||||
|
||||
s.NewJob(gocron.DurationJob(1*time.Hour),
|
||||
gocron.NewTask(
|
||||
func() {
|
||||
cutoff := time.Now().Unix() - int64(ageHours*3600)
|
||||
nodeRepo := repository.GetNodeRepository()
|
||||
cnt, err := nodeRepo.DeleteNodeStatesBefore(cutoff)
|
||||
if err != nil {
|
||||
cclog.Errorf("NodeState retention: error deleting old rows: %v", err)
|
||||
} else if cnt > 0 {
|
||||
cclog.Infof("NodeState retention: deleted %d old rows", cnt)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func RegisterNodeStateRetentionParquetService(cfg *config.NodeStateRetention) {
|
||||
cclog.Info("Register node state retention parquet service")
|
||||
|
||||
maxFileSizeMB := cfg.MaxFileSizeMB
|
||||
if maxFileSizeMB <= 0 {
|
||||
maxFileSizeMB = 128
|
||||
}
|
||||
|
||||
ageHours := cfg.Age
|
||||
if ageHours <= 0 {
|
||||
ageHours = 24
|
||||
}
|
||||
|
||||
var target pqarchive.ParquetTarget
|
||||
var err error
|
||||
|
||||
switch cfg.TargetKind {
|
||||
case "s3":
|
||||
target, err = pqarchive.NewS3Target(pqarchive.S3TargetConfig{
|
||||
Endpoint: cfg.TargetEndpoint,
|
||||
Bucket: cfg.TargetBucket,
|
||||
AccessKey: cfg.TargetAccessKey,
|
||||
SecretKey: cfg.TargetSecretKey,
|
||||
Region: cfg.TargetRegion,
|
||||
UsePathStyle: cfg.TargetUsePathStyle,
|
||||
})
|
||||
default:
|
||||
target, err = pqarchive.NewFileTarget(cfg.TargetPath)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cclog.Errorf("NodeState parquet retention: failed to create target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.NewJob(gocron.DurationJob(1*time.Hour),
|
||||
gocron.NewTask(
|
||||
func() {
|
||||
cutoff := time.Now().Unix() - int64(ageHours*3600)
|
||||
nodeRepo := repository.GetNodeRepository()
|
||||
|
||||
rows, err := nodeRepo.FindNodeStatesBefore(cutoff)
|
||||
if err != nil {
|
||||
cclog.Errorf("NodeState parquet retention: error finding rows: %v", err)
|
||||
return
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cclog.Infof("NodeState parquet retention: archiving %d rows", len(rows))
|
||||
pw := pqarchive.NewNodeStateParquetWriter(target, maxFileSizeMB)
|
||||
|
||||
for _, ns := range rows {
|
||||
row := pqarchive.ParquetNodeStateRow{
|
||||
TimeStamp: ns.TimeStamp,
|
||||
NodeState: ns.NodeState,
|
||||
HealthState: ns.HealthState,
|
||||
HealthMetrics: ns.HealthMetrics,
|
||||
CpusAllocated: int32(ns.CpusAllocated),
|
||||
MemoryAllocated: ns.MemoryAllocated,
|
||||
GpusAllocated: int32(ns.GpusAllocated),
|
||||
JobsRunning: int32(ns.JobsRunning),
|
||||
Hostname: ns.Hostname,
|
||||
Cluster: ns.Cluster,
|
||||
SubCluster: ns.SubCluster,
|
||||
}
|
||||
if err := pw.AddRow(row); err != nil {
|
||||
cclog.Errorf("NodeState parquet retention: add row: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := pw.Close(); err != nil {
|
||||
cclog.Errorf("NodeState parquet retention: close writer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cnt, err := nodeRepo.DeleteNodeStatesBefore(cutoff)
|
||||
if err != nil {
|
||||
cclog.Errorf("NodeState parquet retention: error deleting rows: %v", err)
|
||||
} else {
|
||||
cclog.Infof("NodeState parquet retention: deleted %d rows from db", cnt)
|
||||
}
|
||||
}))
|
||||
}
|
||||
@@ -6,63 +6,329 @@
|
||||
package taskmanager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/pkg/archive"
|
||||
pqarchive "github.com/ClusterCockpit/cc-backend/pkg/archive/parquet"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/v2/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/v2/schema"
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
)
|
||||
|
||||
func RegisterRetentionDeleteService(age int, includeDB bool, omitTagged bool) {
|
||||
// createParquetTarget creates a ParquetTarget (file or S3) from the retention config.
|
||||
func createParquetTarget(cfg Retention) (pqarchive.ParquetTarget, error) {
|
||||
switch cfg.TargetKind {
|
||||
case "s3":
|
||||
return pqarchive.NewS3Target(pqarchive.S3TargetConfig{
|
||||
Endpoint: cfg.TargetEndpoint,
|
||||
Bucket: cfg.TargetBucket,
|
||||
AccessKey: cfg.TargetAccessKey,
|
||||
SecretKey: cfg.TargetSecretKey,
|
||||
Region: cfg.TargetRegion,
|
||||
UsePathStyle: cfg.TargetUsePathStyle,
|
||||
})
|
||||
default:
|
||||
return pqarchive.NewFileTarget(cfg.TargetPath)
|
||||
}
|
||||
}
|
||||
|
||||
// createTargetBackend creates a secondary archive backend (file or S3) for JSON copy/move.
|
||||
func createTargetBackend(cfg Retention) (archive.ArchiveBackend, error) {
|
||||
var raw json.RawMessage
|
||||
var err error
|
||||
|
||||
switch cfg.TargetKind {
|
||||
case "s3":
|
||||
raw, err = json.Marshal(map[string]interface{}{
|
||||
"kind": "s3",
|
||||
"endpoint": cfg.TargetEndpoint,
|
||||
"bucket": cfg.TargetBucket,
|
||||
"access-key": cfg.TargetAccessKey,
|
||||
"secret-key": cfg.TargetSecretKey,
|
||||
"region": cfg.TargetRegion,
|
||||
"use-path-style": cfg.TargetUsePathStyle,
|
||||
})
|
||||
default:
|
||||
raw, err = json.Marshal(map[string]string{
|
||||
"kind": "file",
|
||||
"path": cfg.TargetPath,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal target config: %w", err)
|
||||
}
|
||||
return archive.InitBackend(raw)
|
||||
}
|
||||
|
||||
// transferJobsJSON copies job data from source archive to target backend in JSON format.
|
||||
func transferJobsJSON(jobs []*schema.Job, src archive.ArchiveBackend, dst archive.ArchiveBackend) error {
|
||||
// Transfer cluster configs for all clusters referenced by jobs
|
||||
clustersDone := make(map[string]bool)
|
||||
for _, job := range jobs {
|
||||
if clustersDone[job.Cluster] {
|
||||
continue
|
||||
}
|
||||
clusterCfg, err := src.LoadClusterCfg(job.Cluster)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: load cluster config %q: %v", job.Cluster, err)
|
||||
} else {
|
||||
if err := dst.StoreClusterCfg(job.Cluster, clusterCfg); err != nil {
|
||||
cclog.Warnf("Retention: store cluster config %q: %v", job.Cluster, err)
|
||||
}
|
||||
}
|
||||
clustersDone[job.Cluster] = true
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
meta, err := src.LoadJobMeta(job)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: load meta for job %d: %v", job.JobID, err)
|
||||
continue
|
||||
}
|
||||
data, err := src.LoadJobData(job)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: load data for job %d: %v", job.JobID, err)
|
||||
continue
|
||||
}
|
||||
if err := dst.ImportJob(meta, &data); err != nil {
|
||||
cclog.Warnf("Retention: import job %d: %v", job.JobID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// transferJobsParquet converts jobs to Parquet format, organized by cluster.
|
||||
func transferJobsParquet(jobs []*schema.Job, src archive.ArchiveBackend, target pqarchive.ParquetTarget, maxSizeMB int) error {
|
||||
cw := pqarchive.NewClusterAwareParquetWriter(target, maxSizeMB)
|
||||
|
||||
// Set cluster configs for all clusters referenced by jobs
|
||||
clustersDone := make(map[string]bool)
|
||||
for _, job := range jobs {
|
||||
if clustersDone[job.Cluster] {
|
||||
continue
|
||||
}
|
||||
clusterCfg, err := src.LoadClusterCfg(job.Cluster)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: load cluster config %q: %v", job.Cluster, err)
|
||||
} else {
|
||||
cw.SetClusterConfig(job.Cluster, clusterCfg)
|
||||
}
|
||||
clustersDone[job.Cluster] = true
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
meta, err := src.LoadJobMeta(job)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: load meta for job %d: %v", job.JobID, err)
|
||||
continue
|
||||
}
|
||||
data, err := src.LoadJobData(job)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: load data for job %d: %v", job.JobID, err)
|
||||
continue
|
||||
}
|
||||
row, err := pqarchive.JobToParquetRow(meta, &data)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: convert job %d: %v", job.JobID, err)
|
||||
continue
|
||||
}
|
||||
if err := cw.AddJob(*row); err != nil {
|
||||
cclog.Errorf("Retention: add job %d to writer: %v", job.JobID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return cw.Close()
|
||||
}
|
||||
|
||||
// cleanupAfterTransfer removes jobs from archive and optionally from DB.
|
||||
func cleanupAfterTransfer(jobs []*schema.Job, startTime int64, includeDB bool, omitTagged bool) {
|
||||
archive.GetHandle().CleanUp(jobs)
|
||||
|
||||
if includeDB {
|
||||
cnt, err := jobRepo.DeleteJobsBefore(startTime, omitTagged)
|
||||
if err != nil {
|
||||
cclog.Errorf("Retention: delete jobs from db: %v", err)
|
||||
} else {
|
||||
cclog.Infof("Retention: removed %d jobs from db", cnt)
|
||||
}
|
||||
if err = jobRepo.Optimize(); err != nil {
|
||||
cclog.Errorf("Retention: db optimization error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readCopyMarker reads the last-processed timestamp from a copy marker file.
|
||||
func readCopyMarker(cfg Retention) int64 {
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
switch cfg.TargetKind {
|
||||
case "s3":
|
||||
// For S3 we store the marker locally alongside the config
|
||||
data, err = os.ReadFile(copyMarkerPath(cfg))
|
||||
default:
|
||||
data, err = os.ReadFile(filepath.Join(cfg.TargetPath, ".copy-marker"))
|
||||
}
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
ts, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
// writeCopyMarker writes the last-processed timestamp to a copy marker file.
|
||||
func writeCopyMarker(cfg Retention, ts int64) {
|
||||
content := []byte(strconv.FormatInt(ts, 10))
|
||||
var err error
|
||||
|
||||
switch cfg.TargetKind {
|
||||
case "s3":
|
||||
err = os.WriteFile(copyMarkerPath(cfg), content, 0o640)
|
||||
default:
|
||||
err = os.WriteFile(filepath.Join(cfg.TargetPath, ".copy-marker"), content, 0o640)
|
||||
}
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention: write copy marker: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func copyMarkerPath(cfg Retention) string {
|
||||
// For S3 targets, store the marker in a local temp-style path derived from the bucket name
|
||||
return filepath.Join(os.TempDir(), fmt.Sprintf("cc-copy-marker-%s", cfg.TargetBucket))
|
||||
}
|
||||
|
||||
func RegisterRetentionDeleteService(cfg Retention) {
|
||||
cclog.Info("Register retention delete service")
|
||||
|
||||
s.NewJob(gocron.DailyJob(1, gocron.NewAtTimes(gocron.NewAtTime(3, 0, 0))),
|
||||
gocron.NewTask(
|
||||
func() {
|
||||
startTime := time.Now().Unix() - int64(age*24*3600)
|
||||
jobs, err := jobRepo.FindJobsBetween(0, startTime, omitTagged)
|
||||
startTime := time.Now().Unix() - int64(cfg.Age*24*3600)
|
||||
jobs, err := jobRepo.FindJobsBetween(0, startTime, cfg.OmitTagged)
|
||||
if err != nil {
|
||||
cclog.Warnf("Error while looking for retention jobs: %s", err.Error())
|
||||
cclog.Warnf("Retention delete: error finding jobs: %v", err)
|
||||
return
|
||||
}
|
||||
if len(jobs) == 0 {
|
||||
return
|
||||
}
|
||||
archive.GetHandle().CleanUp(jobs)
|
||||
|
||||
if includeDB {
|
||||
cnt, err := jobRepo.DeleteJobsBefore(startTime, omitTagged)
|
||||
if err != nil {
|
||||
cclog.Errorf("Error while deleting retention jobs from db: %s", err.Error())
|
||||
} else {
|
||||
cclog.Infof("Retention: Removed %d jobs from db", cnt)
|
||||
}
|
||||
if err = jobRepo.Optimize(); err != nil {
|
||||
cclog.Errorf("Error occured in db optimization: %s", err.Error())
|
||||
}
|
||||
}
|
||||
cclog.Infof("Retention delete: processing %d jobs", len(jobs))
|
||||
cleanupAfterTransfer(jobs, startTime, cfg.IncludeDB, cfg.OmitTagged)
|
||||
}))
|
||||
}
|
||||
|
||||
func RegisterRetentionMoveService(age int, includeDB bool, location string, omitTagged bool) {
|
||||
cclog.Info("Register retention move service")
|
||||
func RegisterRetentionCopyService(cfg Retention) {
|
||||
cclog.Infof("Register retention copy service (format=%s, target=%s)", cfg.Format, cfg.TargetKind)
|
||||
|
||||
maxFileSizeMB := cfg.MaxFileSizeMB
|
||||
if maxFileSizeMB <= 0 {
|
||||
maxFileSizeMB = 512
|
||||
}
|
||||
|
||||
s.NewJob(gocron.DailyJob(1, gocron.NewAtTimes(gocron.NewAtTime(4, 0, 0))),
|
||||
gocron.NewTask(
|
||||
func() {
|
||||
startTime := time.Now().Unix() - int64(age*24*3600)
|
||||
jobs, err := jobRepo.FindJobsBetween(0, startTime, omitTagged)
|
||||
if err != nil {
|
||||
cclog.Warnf("Error while looking for retention jobs: %s", err.Error())
|
||||
}
|
||||
archive.GetHandle().Move(jobs, location)
|
||||
cutoff := time.Now().Unix() - int64(cfg.Age*24*3600)
|
||||
lastProcessed := readCopyMarker(cfg)
|
||||
|
||||
if includeDB {
|
||||
cnt, err := jobRepo.DeleteJobsBefore(startTime, omitTagged)
|
||||
jobs, err := jobRepo.FindJobsBetween(lastProcessed, cutoff, cfg.OmitTagged)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention copy: error finding jobs: %v", err)
|
||||
return
|
||||
}
|
||||
if len(jobs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cclog.Infof("Retention copy: processing %d jobs", len(jobs))
|
||||
ar := archive.GetHandle()
|
||||
|
||||
switch cfg.Format {
|
||||
case "parquet":
|
||||
target, err := createParquetTarget(cfg)
|
||||
if err != nil {
|
||||
cclog.Errorf("Error while deleting retention jobs from db: %v", err)
|
||||
} else {
|
||||
cclog.Infof("Retention: Removed %d jobs from db", cnt)
|
||||
cclog.Errorf("Retention copy: create parquet target: %v", err)
|
||||
return
|
||||
}
|
||||
if err = jobRepo.Optimize(); err != nil {
|
||||
cclog.Errorf("Error occured in db optimization: %v", err)
|
||||
if err := transferJobsParquet(jobs, ar, target, maxFileSizeMB); err != nil {
|
||||
cclog.Errorf("Retention copy: parquet transfer: %v", err)
|
||||
return
|
||||
}
|
||||
default: // json
|
||||
dst, err := createTargetBackend(cfg)
|
||||
if err != nil {
|
||||
cclog.Errorf("Retention copy: create target backend: %v", err)
|
||||
return
|
||||
}
|
||||
if err := transferJobsJSON(jobs, ar, dst); err != nil {
|
||||
cclog.Errorf("Retention copy: json transfer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writeCopyMarker(cfg, cutoff)
|
||||
}))
|
||||
}
|
||||
|
||||
func RegisterRetentionMoveService(cfg Retention) {
|
||||
cclog.Infof("Register retention move service (format=%s, target=%s)", cfg.Format, cfg.TargetKind)
|
||||
|
||||
maxFileSizeMB := cfg.MaxFileSizeMB
|
||||
if maxFileSizeMB <= 0 {
|
||||
maxFileSizeMB = 512
|
||||
}
|
||||
|
||||
s.NewJob(gocron.DailyJob(1, gocron.NewAtTimes(gocron.NewAtTime(5, 0, 0))),
|
||||
gocron.NewTask(
|
||||
func() {
|
||||
startTime := time.Now().Unix() - int64(cfg.Age*24*3600)
|
||||
jobs, err := jobRepo.FindJobsBetween(0, startTime, cfg.OmitTagged)
|
||||
if err != nil {
|
||||
cclog.Warnf("Retention move: error finding jobs: %v", err)
|
||||
return
|
||||
}
|
||||
if len(jobs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cclog.Infof("Retention move: processing %d jobs", len(jobs))
|
||||
ar := archive.GetHandle()
|
||||
|
||||
switch cfg.Format {
|
||||
case "parquet":
|
||||
target, err := createParquetTarget(cfg)
|
||||
if err != nil {
|
||||
cclog.Errorf("Retention move: create parquet target: %v", err)
|
||||
return
|
||||
}
|
||||
if err := transferJobsParquet(jobs, ar, target, maxFileSizeMB); err != nil {
|
||||
cclog.Errorf("Retention move: parquet transfer: %v", err)
|
||||
return
|
||||
}
|
||||
default: // json
|
||||
dst, err := createTargetBackend(cfg)
|
||||
if err != nil {
|
||||
cclog.Errorf("Retention move: create target backend: %v", err)
|
||||
return
|
||||
}
|
||||
if err := transferJobsJSON(jobs, ar, dst); err != nil {
|
||||
cclog.Errorf("Retention move: json transfer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cleanupAfterTransfer(jobs, startTime, cfg.IncludeDB, cfg.OmitTagged)
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -23,11 +23,20 @@ const (
|
||||
|
||||
// Retention defines the configuration for job retention policies.
|
||||
type Retention struct {
|
||||
Policy string `json:"policy"`
|
||||
Location string `json:"location"`
|
||||
Age int `json:"age"`
|
||||
IncludeDB bool `json:"includeDB"`
|
||||
OmitTagged bool `json:"omitTagged"`
|
||||
Policy string `json:"policy"`
|
||||
Format string `json:"format"`
|
||||
Age int `json:"age"`
|
||||
IncludeDB bool `json:"includeDB"`
|
||||
OmitTagged bool `json:"omitTagged"`
|
||||
TargetKind string `json:"target-kind"`
|
||||
TargetPath string `json:"target-path"`
|
||||
TargetEndpoint string `json:"target-endpoint"`
|
||||
TargetBucket string `json:"target-bucket"`
|
||||
TargetAccessKey string `json:"target-access-key"`
|
||||
TargetSecretKey string `json:"target-secret-key"`
|
||||
TargetRegion string `json:"target-region"`
|
||||
TargetUsePathStyle bool `json:"target-use-path-style"`
|
||||
MaxFileSizeMB int `json:"max-file-size-mb"`
|
||||
}
|
||||
|
||||
// CronFrequency defines the execution intervals for various background workers.
|
||||
@@ -77,16 +86,11 @@ func initArchiveServices(config json.RawMessage) {
|
||||
|
||||
switch cfg.Retention.Policy {
|
||||
case "delete":
|
||||
RegisterRetentionDeleteService(
|
||||
cfg.Retention.Age,
|
||||
cfg.Retention.IncludeDB,
|
||||
cfg.Retention.OmitTagged)
|
||||
RegisterRetentionDeleteService(cfg.Retention)
|
||||
case "copy":
|
||||
RegisterRetentionCopyService(cfg.Retention)
|
||||
case "move":
|
||||
RegisterRetentionMoveService(
|
||||
cfg.Retention.Age,
|
||||
cfg.Retention.IncludeDB,
|
||||
cfg.Retention.Location,
|
||||
cfg.Retention.OmitTagged)
|
||||
RegisterRetentionMoveService(cfg.Retention)
|
||||
}
|
||||
|
||||
if cfg.Compression > 0 {
|
||||
@@ -133,9 +137,30 @@ func Start(cronCfg, archiveConfig json.RawMessage) {
|
||||
RegisterUpdateDurationWorker()
|
||||
RegisterCommitJobService()
|
||||
|
||||
if config.Keys.NodeStateRetention != nil && config.Keys.NodeStateRetention.Policy != "" {
|
||||
initNodeStateRetention()
|
||||
}
|
||||
|
||||
s.Start()
|
||||
}
|
||||
|
||||
func initNodeStateRetention() {
|
||||
cfg := config.Keys.NodeStateRetention
|
||||
age := cfg.Age
|
||||
if age <= 0 {
|
||||
age = 24
|
||||
}
|
||||
|
||||
switch cfg.Policy {
|
||||
case "delete":
|
||||
RegisterNodeStateRetentionDeleteService(age)
|
||||
case "parquet":
|
||||
RegisterNodeStateRetentionParquetService(cfg)
|
||||
default:
|
||||
cclog.Warnf("Unknown nodestate-retention policy: %s", cfg.Policy)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown stops the task manager and its scheduler.
|
||||
func Shutdown() {
|
||||
if s != nil {
|
||||
|
||||
Reference in New Issue
Block a user