mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2025-12-15 03:36:16 +01:00
Merge branch 'ai-review' into dev
This commit is contained in:
@@ -65,7 +65,7 @@ func setup(t *testing.T) *api.RestApi {
|
||||
}
|
||||
]
|
||||
}`
|
||||
const testclusterJson = `{
|
||||
const testclusterJSON = `{
|
||||
"name": "testcluster",
|
||||
"subClusters": [
|
||||
{
|
||||
@@ -128,7 +128,7 @@ func setup(t *testing.T) *api.RestApi {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "version.txt"), fmt.Appendf(nil, "%d", 2), 0o666); err != nil {
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "version.txt"), fmt.Appendf(nil, "%d", 3), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func setup(t *testing.T) *api.RestApi {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "testcluster", "cluster.json"), []byte(testclusterJson), 0o666); err != nil {
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "testcluster", "cluster.json"), []byte(testclusterJSON), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -29,9 +30,15 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// StopJobApiRequest model
|
||||
type StopJobApiRequest struct {
|
||||
JobId *int64 `json:"jobId" example:"123000"`
|
||||
const (
|
||||
// secondsPerDay is the number of seconds in 24 hours.
|
||||
// Used for duplicate job detection within a day window.
|
||||
secondsPerDay = 86400
|
||||
)
|
||||
|
||||
// StopJobAPIRequest model
|
||||
type StopJobAPIRequest struct {
|
||||
JobID *int64 `json:"jobId" example:"123000"`
|
||||
Cluster *string `json:"cluster" example:"fritz"`
|
||||
StartTime *int64 `json:"startTime" example:"1649723812"`
|
||||
State schema.JobState `json:"jobState" validate:"required" example:"completed"`
|
||||
@@ -40,7 +47,7 @@ type StopJobApiRequest struct {
|
||||
|
||||
// DeleteJobApiRequest model
|
||||
type DeleteJobApiRequest struct {
|
||||
JobId *int64 `json:"jobId" validate:"required" example:"123000"` // Cluster Job ID of job
|
||||
JobID *int64 `json:"jobId" validate:"required" example:"123000"` // Cluster Job ID of job
|
||||
Cluster *string `json:"cluster" example:"fritz"` // Cluster of job
|
||||
StartTime *int64 `json:"startTime" example:"1649723812"` // Start Time of job as epoch
|
||||
}
|
||||
@@ -113,7 +120,8 @@ func (api *RestApi) getJobs(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
for key, vals := range r.URL.Query() {
|
||||
switch key {
|
||||
// TODO: add project filter
|
||||
case "project":
|
||||
filter.Project = &model.StringInput{Eq: &vals[0]}
|
||||
case "state":
|
||||
for _, s := range vals {
|
||||
state := schema.JobState(s)
|
||||
@@ -363,7 +371,7 @@ func (api *RestApi) getJobById(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var metrics GetJobApiRequest
|
||||
if err = decode(r.Body, &metrics); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("decoding request failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -434,30 +442,32 @@ func (api *RestApi) getJobById(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *RestApi) editMeta(rw http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(mux.Vars(r)["id"], 10, 64)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("parsing job ID failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
job, err := api.JobRepository.FindById(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusNotFound)
|
||||
handleError(fmt.Errorf("finding job failed: %w", err), http.StatusNotFound, rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req EditMetaRequest
|
||||
if err := decode(r.Body, &req); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("decoding request failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.JobRepository.UpdateMetadata(job, req.Key, req.Value); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("updating metadata failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(rw).Encode(job)
|
||||
if err := json.NewEncoder(rw).Encode(job); err != nil {
|
||||
cclog.Errorf("Failed to encode job response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tagJob godoc
|
||||
@@ -480,32 +490,32 @@ func (api *RestApi) editMeta(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *RestApi) tagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(mux.Vars(r)["id"], 10, 64)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("parsing job ID failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
job, err := api.JobRepository.FindById(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusNotFound)
|
||||
handleError(fmt.Errorf("finding job failed: %w", err), http.StatusNotFound, rw)
|
||||
return
|
||||
}
|
||||
|
||||
job.Tags, err = api.JobRepository.GetTags(repository.GetUserFromContext(r.Context()), job.ID)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("getting tags failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req TagJobApiRequest
|
||||
if err := decode(r.Body, &req); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("decoding request failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tag := range req {
|
||||
tagId, err := api.JobRepository.AddTagOrCreate(repository.GetUserFromContext(r.Context()), *job.ID, tag.Type, tag.Name, tag.Scope)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("adding tag failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -519,7 +529,9 @@ func (api *RestApi) tagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(rw).Encode(job)
|
||||
if err := json.NewEncoder(rw).Encode(job); err != nil {
|
||||
cclog.Errorf("Failed to encode job response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// removeTagJob godoc
|
||||
@@ -542,25 +554,25 @@ func (api *RestApi) tagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *RestApi) removeTagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(mux.Vars(r)["id"], 10, 64)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("parsing job ID failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
job, err := api.JobRepository.FindById(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusNotFound)
|
||||
handleError(fmt.Errorf("finding job failed: %w", err), http.StatusNotFound, rw)
|
||||
return
|
||||
}
|
||||
|
||||
job.Tags, err = api.JobRepository.GetTags(repository.GetUserFromContext(r.Context()), job.ID)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("getting tags failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req TagJobApiRequest
|
||||
if err := decode(r.Body, &req); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("decoding request failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -573,7 +585,7 @@ func (api *RestApi) removeTagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
remainingTags, err := api.JobRepository.RemoveJobTagByRequest(repository.GetUserFromContext(r.Context()), *job.ID, rtag.Type, rtag.Name, rtag.Scope)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("removing tag failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -582,7 +594,9 @@ func (api *RestApi) removeTagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(rw).Encode(job)
|
||||
if err := json.NewEncoder(rw).Encode(job); err != nil {
|
||||
cclog.Errorf("Failed to encode job response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// removeTags godoc
|
||||
@@ -604,7 +618,7 @@ func (api *RestApi) removeTagJob(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *RestApi) removeTags(rw http.ResponseWriter, r *http.Request) {
|
||||
var req TagJobApiRequest
|
||||
if err := decode(r.Body, &req); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("decoding request failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -619,11 +633,10 @@ func (api *RestApi) removeTags(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err := api.JobRepository.RemoveTagByRequest(rtag.Type, rtag.Name, rtag.Scope)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("removing tag failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
} else {
|
||||
currentCount++
|
||||
}
|
||||
currentCount++
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
@@ -674,9 +687,11 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
handleError(fmt.Errorf("checking for duplicate failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
} else if err == nil {
|
||||
}
|
||||
if err == nil {
|
||||
for _, job := range jobs {
|
||||
if (req.StartTime - job.StartTime) < 86400 {
|
||||
// Check if jobs are within the same day (prevent duplicates)
|
||||
if (req.StartTime - job.StartTime) < secondsPerDay {
|
||||
handleError(fmt.Errorf("a job with that jobId, cluster and startTime already exists: dbid: %d, jobid: %d", job.ID, job.JobID), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
@@ -693,7 +708,6 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
for _, tag := range req.Tags {
|
||||
if _, err := api.JobRepository.AddTagOrCreate(repository.GetUserFromContext(r.Context()), id, tag.Type, tag.Name, tag.Scope); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("adding tag to new job %d failed: %w", id, err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
@@ -702,9 +716,11 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) {
|
||||
cclog.Printf("new job (id: %d): cluster=%s, jobId=%d, user=%s, startTime=%d", id, req.Cluster, req.JobID, req.User, req.StartTime)
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
Message: "success",
|
||||
})
|
||||
}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stopJobByRequest godoc
|
||||
@@ -725,7 +741,7 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) {
|
||||
// @router /api/jobs/stop_job/ [post]
|
||||
func (api *RestApi) stopJobByRequest(rw http.ResponseWriter, r *http.Request) {
|
||||
// Parse request body
|
||||
req := StopJobApiRequest{}
|
||||
req := StopJobAPIRequest{}
|
||||
if err := decode(r.Body, &req); err != nil {
|
||||
handleError(fmt.Errorf("parsing request body failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
@@ -734,20 +750,22 @@ func (api *RestApi) stopJobByRequest(rw http.ResponseWriter, r *http.Request) {
|
||||
// Fetch job (that will be stopped) from db
|
||||
var job *schema.Job
|
||||
var err error
|
||||
if req.JobId == nil {
|
||||
if req.JobID == nil {
|
||||
handleError(errors.New("the field 'jobId' is required"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
// cclog.Printf("loading db job for stopJobByRequest... : stopJobApiRequest=%v", req)
|
||||
job, err = api.JobRepository.Find(req.JobId, req.Cluster, req.StartTime)
|
||||
job, err = api.JobRepository.Find(req.JobID, req.Cluster, req.StartTime)
|
||||
if err != nil {
|
||||
job, err = api.JobRepository.FindCached(req.JobId, req.Cluster, req.StartTime)
|
||||
// FIXME: Previous error is hidden
|
||||
if err != nil {
|
||||
handleError(fmt.Errorf("finding job failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
// Try cached jobs if not found in main repository
|
||||
cachedJob, cachedErr := api.JobRepository.FindCached(req.JobID, req.Cluster, req.StartTime)
|
||||
if cachedErr != nil {
|
||||
// Combine both errors for better debugging
|
||||
handleError(fmt.Errorf("finding job failed: %w (cached lookup also failed: %v)", err, cachedErr), http.StatusNotFound, rw)
|
||||
return
|
||||
}
|
||||
job = cachedJob
|
||||
}
|
||||
|
||||
api.checkAndHandleStopJob(rw, job, req)
|
||||
@@ -790,9 +808,11 @@ func (api *RestApi) deleteJobById(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
Message: fmt.Sprintf("Successfully deleted job %s", id),
|
||||
})
|
||||
}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteJobByRequest godoc
|
||||
@@ -822,12 +842,12 @@ func (api *RestApi) deleteJobByRequest(rw http.ResponseWriter, r *http.Request)
|
||||
// Fetch job (that will be deleted) from db
|
||||
var job *schema.Job
|
||||
var err error
|
||||
if req.JobId == nil {
|
||||
if req.JobID == nil {
|
||||
handleError(errors.New("the field 'jobId' is required"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
job, err = api.JobRepository.Find(req.JobId, req.Cluster, req.StartTime)
|
||||
job, err = api.JobRepository.Find(req.JobID, req.Cluster, req.StartTime)
|
||||
if err != nil {
|
||||
handleError(fmt.Errorf("finding job failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
@@ -841,9 +861,11 @@ func (api *RestApi) deleteJobByRequest(rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
Message: fmt.Sprintf("Successfully deleted job %d", job.ID),
|
||||
})
|
||||
}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteJobBefore godoc
|
||||
@@ -885,19 +907,21 @@ func (api *RestApi) deleteJobBefore(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{
|
||||
Message: fmt.Sprintf("Successfully deleted %d jobs", cnt),
|
||||
})
|
||||
}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *RestApi) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Job, req StopJobApiRequest) {
|
||||
func (api *RestApi) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Job, req StopJobAPIRequest) {
|
||||
// Sanity checks
|
||||
if job.State != schema.JobStateRunning {
|
||||
handleError(fmt.Errorf("jobId %d (id %d) on %s : job has already been stopped (state is: %s)", job.JobID, job.ID, job.Cluster, job.State), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
|
||||
if job == nil || job.StartTime > req.StopTime {
|
||||
if job.StartTime > req.StopTime {
|
||||
handleError(fmt.Errorf("jobId %d (id %d) on %s : stopTime %d must be larger/equal than startTime %d", job.JobID, job.ID, job.Cluster, req.StopTime, job.StartTime), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
@@ -913,14 +937,14 @@ func (api *RestApi) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Jo
|
||||
job.Duration = int32(req.StopTime - job.StartTime)
|
||||
job.State = req.State
|
||||
api.JobRepository.Mutex.Lock()
|
||||
defer api.JobRepository.Mutex.Unlock()
|
||||
|
||||
if err := api.JobRepository.Stop(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
if err := api.JobRepository.StopCached(*job.ID, job.Duration, job.State, job.MonitoringStatus); err != nil {
|
||||
api.JobRepository.Mutex.Unlock()
|
||||
handleError(fmt.Errorf("jobId %d (id %d) on %s : marking job as '%s' (duration: %d) in DB failed: %w", job.JobID, job.ID, job.Cluster, job.State, job.Duration, err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
}
|
||||
api.JobRepository.Mutex.Unlock()
|
||||
|
||||
cclog.Printf("archiving job... (dbid: %d): cluster=%s, jobId=%d, user=%s, startTime=%d, duration=%d, state=%s", job.ID, job.Cluster, job.JobID, job.User, job.StartTime, job.Duration, job.State)
|
||||
|
||||
@@ -929,7 +953,9 @@ func (api *RestApi) checkAndHandleStopJob(rw http.ResponseWriter, job *schema.Jo
|
||||
// writing to the filesystem fails, the client will not know.
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(rw).Encode(job)
|
||||
if err := json.NewEncoder(rw).Encode(job); err != nil {
|
||||
cclog.Errorf("Failed to encode job response: %v", err)
|
||||
}
|
||||
|
||||
// Monitoring is disabled...
|
||||
if job.MonitoringStatus == schema.MonitoringStatusDisabled {
|
||||
@@ -947,7 +973,7 @@ func (api *RestApi) getJobMetrics(rw http.ResponseWriter, r *http.Request) {
|
||||
for _, scope := range r.URL.Query()["scope"] {
|
||||
var s schema.MetricScope
|
||||
if err := s.UnmarshalGQL(scope); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("unmarshaling scope failed: %w", err), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
scopes = append(scopes, s)
|
||||
@@ -956,7 +982,7 @@ func (api *RestApi) getJobMetrics(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
type Respone struct {
|
||||
type Response struct {
|
||||
Data *struct {
|
||||
JobMetrics []*model.JobMetricWithName `json:"jobMetrics"`
|
||||
} `json:"data"`
|
||||
@@ -968,17 +994,21 @@ func (api *RestApi) getJobMetrics(rw http.ResponseWriter, r *http.Request) {
|
||||
resolver := graph.GetResolverInstance()
|
||||
data, err := resolver.Query().JobMetrics(r.Context(), id, metrics, scopes, nil)
|
||||
if err != nil {
|
||||
json.NewEncoder(rw).Encode(Respone{
|
||||
if err := json.NewEncoder(rw).Encode(Response{
|
||||
Error: &struct {
|
||||
Message string "json:\"message\""
|
||||
Message string `json:"message"`
|
||||
}{Message: err.Error()},
|
||||
})
|
||||
}); err != nil {
|
||||
cclog.Errorf("Failed to encode error response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(rw).Encode(Respone{
|
||||
if err := json.NewEncoder(rw).Encode(Response{
|
||||
Data: &struct {
|
||||
JobMetrics []*model.JobMetricWithName "json:\"jobMetrics\""
|
||||
JobMetrics []*model.JobMetricWithName `json:"jobMetrics"`
|
||||
}{JobMetrics: data},
|
||||
})
|
||||
}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,13 +50,6 @@ func freeMetrics(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// // TODO: lastCheckpoint might be modified by different go-routines.
|
||||
// // Load it using the sync/atomic package?
|
||||
// freeUpTo := lastCheckpoint.Unix()
|
||||
// if to < freeUpTo {
|
||||
// freeUpTo = to
|
||||
// }
|
||||
|
||||
bodyDec := json.NewDecoder(r.Body)
|
||||
var selectors [][]string
|
||||
err = bodyDec.Decode(&selectors)
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package api provides the REST API layer for ClusterCockpit.
|
||||
// It handles HTTP requests for job management, user administration,
|
||||
// cluster queries, node state updates, and metrics storage operations.
|
||||
// The API supports both JWT token authentication and session-based authentication.
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -11,6 +16,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/auth"
|
||||
@@ -39,10 +45,19 @@ import (
|
||||
// @in header
|
||||
// @name X-Auth-Token
|
||||
|
||||
const (
|
||||
noticeFilePath = "./var/notice.txt"
|
||||
noticeFilePerms = 0o644
|
||||
)
|
||||
|
||||
type RestApi struct {
|
||||
JobRepository *repository.JobRepository
|
||||
Authentication *auth.Authentication
|
||||
MachineStateDir string
|
||||
// RepositoryMutex protects job creation operations from race conditions
|
||||
// when checking for duplicate jobs during startJob API calls.
|
||||
// It prevents concurrent job starts with the same jobId/cluster/startTime
|
||||
// from creating duplicate entries in the database.
|
||||
RepositoryMutex sync.Mutex
|
||||
}
|
||||
|
||||
@@ -66,7 +81,6 @@ func (api *RestApi) MountApiRoutes(r *mux.Router) {
|
||||
// Job Handler
|
||||
r.HandleFunc("/jobs/start_job/", api.startJob).Methods(http.MethodPost, http.MethodPut)
|
||||
r.HandleFunc("/jobs/stop_job/", api.stopJobByRequest).Methods(http.MethodPost, http.MethodPut)
|
||||
// r.HandleFunc("/jobs/import/", api.importJob).Methods(http.MethodPost, http.MethodPut)
|
||||
r.HandleFunc("/jobs/", api.getJobs).Methods(http.MethodGet)
|
||||
r.HandleFunc("/jobs/{id}", api.getJobById).Methods(http.MethodPost)
|
||||
r.HandleFunc("/jobs/{id}", api.getCompleteJobById).Methods(http.MethodGet)
|
||||
@@ -97,6 +111,7 @@ func (api *RestApi) MountUserApiRoutes(r *mux.Router) {
|
||||
|
||||
func (api *RestApi) MountMetricStoreApiRoutes(r *mux.Router) {
|
||||
// REST API Uses TokenAuth
|
||||
// Note: StrictSlash handles trailing slash variations automatically
|
||||
r.HandleFunc("/api/free", freeMetrics).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/write", writeMetrics).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/debug", debugMetrics).Methods(http.MethodGet)
|
||||
@@ -146,10 +161,12 @@ func handleError(err error, statusCode int, rw http.ResponseWriter) {
|
||||
cclog.Warnf("REST ERROR : %s", err.Error())
|
||||
rw.Header().Add("Content-Type", "application/json")
|
||||
rw.WriteHeader(statusCode)
|
||||
json.NewEncoder(rw).Encode(ErrorResponse{
|
||||
if err := json.NewEncoder(rw).Encode(ErrorResponse{
|
||||
Status: http.StatusText(statusCode),
|
||||
Error: err.Error(),
|
||||
})
|
||||
}); err != nil {
|
||||
cclog.Errorf("Failed to encode error response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func decode(r io.Reader, val any) error {
|
||||
@@ -162,41 +179,41 @@ func (api *RestApi) editNotice(rw http.ResponseWriter, r *http.Request) {
|
||||
// SecuredCheck() only worked with TokenAuth: Removed
|
||||
|
||||
if user := repository.GetUserFromContext(r.Context()); !user.HasRole(schema.RoleAdmin) {
|
||||
http.Error(rw, "Only admins are allowed to update the notice.txt file", http.StatusForbidden)
|
||||
handleError(fmt.Errorf("only admins are allowed to update the notice.txt file"), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
// Get Value
|
||||
newContent := r.FormValue("new-content")
|
||||
|
||||
// Check FIle
|
||||
noticeExists := util.CheckFileExists("./var/notice.txt")
|
||||
// Validate content length to prevent DoS
|
||||
if len(newContent) > 10000 {
|
||||
handleError(fmt.Errorf("notice content exceeds maximum length of 10000 characters"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
// Check File
|
||||
noticeExists := util.CheckFileExists(noticeFilePath)
|
||||
if !noticeExists {
|
||||
ntxt, err := os.Create("./var/notice.txt")
|
||||
ntxt, err := os.Create(noticeFilePath)
|
||||
if err != nil {
|
||||
cclog.Errorf("Creating ./var/notice.txt failed: %s", err.Error())
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("creating notice file failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
ntxt.Close()
|
||||
}
|
||||
|
||||
if err := os.WriteFile(noticeFilePath, []byte(newContent), noticeFilePerms); err != nil {
|
||||
handleError(fmt.Errorf("writing to notice file failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "text/plain")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
if newContent != "" {
|
||||
if err := os.WriteFile("./var/notice.txt", []byte(newContent), 0o666); err != nil {
|
||||
cclog.Errorf("Writing to ./var/notice.txt failed: %s", err.Error())
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
return
|
||||
} else {
|
||||
rw.Write([]byte("Update Notice Content Success"))
|
||||
}
|
||||
rw.Write([]byte("Update Notice Content Success"))
|
||||
} else {
|
||||
if err := os.WriteFile("./var/notice.txt", []byte(""), 0o666); err != nil {
|
||||
cclog.Errorf("Writing to ./var/notice.txt failed: %s", err.Error())
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
return
|
||||
} else {
|
||||
rw.Write([]byte("Empty Notice Content Success"))
|
||||
}
|
||||
rw.Write([]byte("Empty Notice Content Success"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,21 +223,20 @@ func (api *RestApi) getJWT(rw http.ResponseWriter, r *http.Request) {
|
||||
me := repository.GetUserFromContext(r.Context())
|
||||
if !me.HasRole(schema.RoleAdmin) {
|
||||
if username != me.Username {
|
||||
http.Error(rw, "Only admins are allowed to sign JWTs not for themselves",
|
||||
http.StatusForbidden)
|
||||
handleError(fmt.Errorf("only admins are allowed to sign JWTs not for themselves"), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
user, err := repository.GetUserRepository().GetUser(username)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("getting user failed: %w", err), http.StatusNotFound, rw)
|
||||
return
|
||||
}
|
||||
|
||||
jwt, err := api.Authentication.JwtAuth.ProvideJWT(user)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("providing JWT failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,17 +249,20 @@ func (api *RestApi) getRoles(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
user := repository.GetUserFromContext(r.Context())
|
||||
if !user.HasRole(schema.RoleAdmin) {
|
||||
http.Error(rw, "only admins are allowed to fetch a list of roles", http.StatusForbidden)
|
||||
handleError(fmt.Errorf("only admins are allowed to fetch a list of roles"), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
roles, err := schema.GetValidRoles(user)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("getting valid roles failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(rw).Encode(roles)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(rw).Encode(roles); err != nil {
|
||||
cclog.Errorf("Failed to encode roles response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *RestApi) updateConfiguration(rw http.ResponseWriter, r *http.Request) {
|
||||
@@ -251,38 +270,50 @@ func (api *RestApi) updateConfiguration(rw http.ResponseWriter, r *http.Request)
|
||||
key, value := r.FormValue("key"), r.FormValue("value")
|
||||
|
||||
if err := repository.GetUserCfgRepo().UpdateConfig(key, value, repository.GetUserFromContext(r.Context())); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("updating configuration failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.Write([]byte("success"))
|
||||
}
|
||||
|
||||
func (api *RestApi) putMachineState(rw http.ResponseWriter, r *http.Request) {
|
||||
if api.MachineStateDir == "" {
|
||||
http.Error(rw, "REST > machine state not enabled", http.StatusNotFound)
|
||||
handleError(fmt.Errorf("machine state not enabled"), http.StatusNotFound, rw)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
cluster := vars["cluster"]
|
||||
host := vars["host"]
|
||||
|
||||
// Validate cluster and host to prevent path traversal attacks
|
||||
if strings.Contains(cluster, "..") || strings.Contains(cluster, "/") || strings.Contains(cluster, "\\") {
|
||||
handleError(fmt.Errorf("invalid cluster name"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
if strings.Contains(host, "..") || strings.Contains(host, "/") || strings.Contains(host, "\\") {
|
||||
handleError(fmt.Errorf("invalid host name"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
dir := filepath.Join(api.MachineStateDir, cluster)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("creating directory failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
filename := filepath.Join(dir, fmt.Sprintf("%s.json", host))
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("creating file failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(f, r.Body); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("writing file failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -291,12 +322,25 @@ func (api *RestApi) putMachineState(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (api *RestApi) getMachineState(rw http.ResponseWriter, r *http.Request) {
|
||||
if api.MachineStateDir == "" {
|
||||
http.Error(rw, "REST > machine state not enabled", http.StatusNotFound)
|
||||
handleError(fmt.Errorf("machine state not enabled"), http.StatusNotFound, rw)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
filename := filepath.Join(api.MachineStateDir, vars["cluster"], fmt.Sprintf("%s.json", vars["host"]))
|
||||
cluster := vars["cluster"]
|
||||
host := vars["host"]
|
||||
|
||||
// Validate cluster and host to prevent path traversal attacks
|
||||
if strings.Contains(cluster, "..") || strings.Contains(cluster, "/") || strings.Contains(cluster, "\\") {
|
||||
handleError(fmt.Errorf("invalid cluster name"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
if strings.Contains(host, "..") || strings.Contains(host, "/") || strings.Contains(host, "\\") {
|
||||
handleError(fmt.Errorf("invalid host name"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
filename := filepath.Join(api.MachineStateDir, cluster, fmt.Sprintf("%s.json", host))
|
||||
|
||||
// Sets the content-type and 'Last-Modified' Header and so on automatically
|
||||
http.ServeFile(rw, r, filename)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -10,11 +11,12 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
type ApiReturnedUser struct {
|
||||
type APIReturnedUser struct {
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Roles []string `json:"roles"`
|
||||
@@ -40,24 +42,42 @@ func (api *RestApi) getUsers(rw http.ResponseWriter, r *http.Request) {
|
||||
// SecuredCheck() only worked with TokenAuth: Removed
|
||||
|
||||
if user := repository.GetUserFromContext(r.Context()); !user.HasRole(schema.RoleAdmin) {
|
||||
http.Error(rw, "Only admins are allowed to fetch a list of users", http.StatusForbidden)
|
||||
handleError(fmt.Errorf("only admins are allowed to fetch a list of users"), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
users, err := repository.GetUserRepository().ListUsers(r.URL.Query().Get("not-just-user") == "true")
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("listing users failed: %w", err), http.StatusInternalServerError, rw)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(rw).Encode(users)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(rw).Encode(users); err != nil {
|
||||
cclog.Errorf("Failed to encode users response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// updateUser godoc
|
||||
// @summary Update user roles and projects
|
||||
// @tags User
|
||||
// @description Allows admins to add/remove roles and projects for a user
|
||||
// @produce plain
|
||||
// @param id path string true "Username"
|
||||
// @param add-role formData string false "Role to add"
|
||||
// @param remove-role formData string false "Role to remove"
|
||||
// @param add-project formData string false "Project to add"
|
||||
// @param remove-project formData string false "Project to remove"
|
||||
// @success 200 {string} string "Success message"
|
||||
// @failure 403 {object} api.ErrorResponse "Forbidden"
|
||||
// @failure 422 {object} api.ErrorResponse "Unprocessable Entity"
|
||||
// @security ApiKeyAuth
|
||||
// @router /api/user/{id} [post]
|
||||
func (api *RestApi) updateUser(rw http.ResponseWriter, r *http.Request) {
|
||||
// SecuredCheck() only worked with TokenAuth: Removed
|
||||
|
||||
if user := repository.GetUserFromContext(r.Context()); !user.HasRole(schema.RoleAdmin) {
|
||||
http.Error(rw, "Only admins are allowed to update a user", http.StatusForbidden)
|
||||
handleError(fmt.Errorf("only admins are allowed to update a user"), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -67,43 +87,70 @@ func (api *RestApi) updateUser(rw http.ResponseWriter, r *http.Request) {
|
||||
newproj := r.FormValue("add-project")
|
||||
delproj := r.FormValue("remove-project")
|
||||
|
||||
// TODO: Handle anything but roles...
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Handle role updates
|
||||
if newrole != "" {
|
||||
if err := repository.GetUserRepository().AddRole(r.Context(), mux.Vars(r)["id"], newrole); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("adding role failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
rw.Write([]byte("Add Role Success"))
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{Message: "Add Role Success"}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
} else if delrole != "" {
|
||||
if err := repository.GetUserRepository().RemoveRole(r.Context(), mux.Vars(r)["id"], delrole); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("removing role failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
rw.Write([]byte("Remove Role Success"))
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{Message: "Remove Role Success"}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
} else if newproj != "" {
|
||||
if err := repository.GetUserRepository().AddProject(r.Context(), mux.Vars(r)["id"], newproj); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("adding project failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
rw.Write([]byte("Add Project Success"))
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{Message: "Add Project Success"}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
} else if delproj != "" {
|
||||
if err := repository.GetUserRepository().RemoveProject(r.Context(), mux.Vars(r)["id"], delproj); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("removing project failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
rw.Write([]byte("Remove Project Success"))
|
||||
if err := json.NewEncoder(rw).Encode(DefaultApiResponse{Message: "Remove Project Success"}); err != nil {
|
||||
cclog.Errorf("Failed to encode response: %v", err)
|
||||
}
|
||||
} else {
|
||||
http.Error(rw, "Not Add or Del [role|project]?", http.StatusInternalServerError)
|
||||
handleError(fmt.Errorf("no operation specified: must provide add-role, remove-role, add-project, or remove-project"), http.StatusBadRequest, rw)
|
||||
}
|
||||
}
|
||||
|
||||
// createUser godoc
|
||||
// @summary Create a new user
|
||||
// @tags User
|
||||
// @description Creates a new user with specified credentials and role
|
||||
// @produce plain
|
||||
// @param username formData string true "Username"
|
||||
// @param password formData string false "Password (not required for API users)"
|
||||
// @param role formData string true "User role"
|
||||
// @param name formData string false "Full name"
|
||||
// @param email formData string false "Email address"
|
||||
// @param project formData string false "Project (required for managers)"
|
||||
// @success 200 {string} string "Success message"
|
||||
// @failure 400 {object} api.ErrorResponse "Bad Request"
|
||||
// @failure 403 {object} api.ErrorResponse "Forbidden"
|
||||
// @failure 422 {object} api.ErrorResponse "Unprocessable Entity"
|
||||
// @security ApiKeyAuth
|
||||
// @router /api/users/ [post]
|
||||
func (api *RestApi) createUser(rw http.ResponseWriter, r *http.Request) {
|
||||
// SecuredCheck() only worked with TokenAuth: Removed
|
||||
|
||||
rw.Header().Set("Content-Type", "text/plain")
|
||||
me := repository.GetUserFromContext(r.Context())
|
||||
if !me.HasRole(schema.RoleAdmin) {
|
||||
http.Error(rw, "Only admins are allowed to create new users", http.StatusForbidden)
|
||||
handleError(fmt.Errorf("only admins are allowed to create new users"), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -111,18 +158,22 @@ func (api *RestApi) createUser(rw http.ResponseWriter, r *http.Request) {
|
||||
r.FormValue("password"), r.FormValue("role"), r.FormValue("name"),
|
||||
r.FormValue("email"), r.FormValue("project")
|
||||
|
||||
// Validate username length
|
||||
if len(username) == 0 || len(username) > 100 {
|
||||
handleError(fmt.Errorf("username must be between 1 and 100 characters"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
if len(password) == 0 && role != schema.GetRoleString(schema.RoleApi) {
|
||||
http.Error(rw, "Only API users are allowed to have a blank password (login will be impossible)", http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("only API users are allowed to have a blank password (login will be impossible)"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
if len(project) != 0 && role != schema.GetRoleString(schema.RoleManager) {
|
||||
http.Error(rw, "only managers require a project (can be changed later)",
|
||||
http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("only managers require a project (can be changed later)"), http.StatusBadRequest, rw)
|
||||
return
|
||||
} else if len(project) == 0 && role == schema.GetRoleString(schema.RoleManager) {
|
||||
http.Error(rw, "managers require a project to manage (can be changed later)",
|
||||
http.StatusBadRequest)
|
||||
handleError(fmt.Errorf("managers require a project to manage (can be changed later)"), http.StatusBadRequest, rw)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -134,24 +185,35 @@ func (api *RestApi) createUser(rw http.ResponseWriter, r *http.Request) {
|
||||
Projects: []string{project},
|
||||
Roles: []string{role},
|
||||
}); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("adding user failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(rw, "User %v successfully created!\n", username)
|
||||
}
|
||||
|
||||
// deleteUser godoc
|
||||
// @summary Delete a user
|
||||
// @tags User
|
||||
// @description Deletes a user from the system
|
||||
// @produce plain
|
||||
// @param username formData string true "Username to delete"
|
||||
// @success 200 {string} string "Success"
|
||||
// @failure 403 {object} api.ErrorResponse "Forbidden"
|
||||
// @failure 422 {object} api.ErrorResponse "Unprocessable Entity"
|
||||
// @security ApiKeyAuth
|
||||
// @router /api/users/ [delete]
|
||||
func (api *RestApi) deleteUser(rw http.ResponseWriter, r *http.Request) {
|
||||
// SecuredCheck() only worked with TokenAuth: Removed
|
||||
|
||||
if user := repository.GetUserFromContext(r.Context()); !user.HasRole(schema.RoleAdmin) {
|
||||
http.Error(rw, "Only admins are allowed to delete a user", http.StatusForbidden)
|
||||
handleError(fmt.Errorf("only admins are allowed to delete a user"), http.StatusForbidden, rw)
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
if err := repository.GetUserRepository().DelUser(username); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusUnprocessableEntity)
|
||||
handleError(fmt.Errorf("deleting user failed: %w", err), http.StatusUnprocessableEntity, rw)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -32,8 +31,19 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// Authenticator is the interface for all authentication methods.
|
||||
// Each authenticator determines if it can handle a login request (CanLogin)
|
||||
// and performs the actual authentication (Login).
|
||||
type Authenticator interface {
|
||||
// CanLogin determines if this authenticator can handle the login request.
|
||||
// It returns the user object if available and a boolean indicating if this
|
||||
// authenticator should attempt the login. This method should not perform
|
||||
// expensive operations or actual authentication.
|
||||
CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool)
|
||||
|
||||
// Login performs the actually authentication for the user.
|
||||
// It returns the authenticated user or an error if authentication fails.
|
||||
// The user parameter may be nil if the user doesn't exist in the database yet.
|
||||
Login(user *schema.User, rw http.ResponseWriter, r *http.Request) (*schema.User, error)
|
||||
}
|
||||
|
||||
@@ -42,27 +52,70 @@ var (
|
||||
authInstance *Authentication
|
||||
)
|
||||
|
||||
var ipUserLimiters sync.Map
|
||||
|
||||
func getIPUserLimiter(ip, username string) *rate.Limiter {
|
||||
key := ip + ":" + username
|
||||
limiter, ok := ipUserLimiters.Load(key)
|
||||
if !ok {
|
||||
newLimiter := rate.NewLimiter(rate.Every(time.Hour/10), 10)
|
||||
ipUserLimiters.Store(key, newLimiter)
|
||||
return newLimiter
|
||||
}
|
||||
return limiter.(*rate.Limiter)
|
||||
// rateLimiterEntry tracks a rate limiter and its last use time for cleanup
|
||||
type rateLimiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
var ipUserLimiters sync.Map
|
||||
|
||||
// getIPUserLimiter returns a rate limiter for the given IP and username combination.
|
||||
// Rate limiters are created on demand and track 5 attempts per 15 minutes.
|
||||
func getIPUserLimiter(ip, username string) *rate.Limiter {
|
||||
key := ip + ":" + username
|
||||
now := time.Now()
|
||||
|
||||
if entry, ok := ipUserLimiters.Load(key); ok {
|
||||
rle := entry.(*rateLimiterEntry)
|
||||
rle.lastUsed = now
|
||||
return rle.limiter
|
||||
}
|
||||
|
||||
// More aggressive rate limiting: 5 attempts per 15 minutes
|
||||
newLimiter := rate.NewLimiter(rate.Every(15*time.Minute/5), 5)
|
||||
ipUserLimiters.Store(key, &rateLimiterEntry{
|
||||
limiter: newLimiter,
|
||||
lastUsed: now,
|
||||
})
|
||||
return newLimiter
|
||||
}
|
||||
|
||||
// cleanupOldRateLimiters removes rate limiters that haven't been used recently
|
||||
func cleanupOldRateLimiters(olderThan time.Time) {
|
||||
ipUserLimiters.Range(func(key, value any) bool {
|
||||
entry := value.(*rateLimiterEntry)
|
||||
if entry.lastUsed.Before(olderThan) {
|
||||
ipUserLimiters.Delete(key)
|
||||
cclog.Debugf("Cleaned up rate limiter for %v", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// startRateLimiterCleanup starts a background goroutine to clean up old rate limiters
|
||||
func startRateLimiterCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
// Clean up limiters not used in the last 24 hours
|
||||
cleanupOldRateLimiters(time.Now().Add(-24 * time.Hour))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// AuthConfig contains configuration for all authentication methods
|
||||
type AuthConfig struct {
|
||||
LdapConfig *LdapConfig `json:"ldap"`
|
||||
JwtConfig *JWTAuthConfig `json:"jwts"`
|
||||
OpenIDConfig *OpenIDConfig `json:"oidc"`
|
||||
}
|
||||
|
||||
// Keys holds the global authentication configuration
|
||||
var Keys AuthConfig
|
||||
|
||||
// Authentication manages all authentication methods and session handling
|
||||
type Authentication struct {
|
||||
sessionStore *sessions.CookieStore
|
||||
LdapAuth *LdapAuthenticator
|
||||
@@ -86,10 +139,31 @@ func (auth *Authentication) AuthViaSession(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO: Check if session keys exist
|
||||
username, _ := session.Values["username"].(string)
|
||||
projects, _ := session.Values["projects"].([]string)
|
||||
roles, _ := session.Values["roles"].([]string)
|
||||
// Validate session data with proper type checking
|
||||
username, ok := session.Values["username"].(string)
|
||||
if !ok || username == "" {
|
||||
cclog.Warn("Invalid session: missing or invalid username")
|
||||
// Invalidate the corrupted session
|
||||
session.Options.MaxAge = -1
|
||||
_ = auth.sessionStore.Save(r, rw, session)
|
||||
return nil, errors.New("invalid session data")
|
||||
}
|
||||
|
||||
projects, ok := session.Values["projects"].([]string)
|
||||
if !ok {
|
||||
cclog.Warn("Invalid session: projects not found or invalid type, using empty list")
|
||||
projects = []string{}
|
||||
}
|
||||
|
||||
roles, ok := session.Values["roles"].([]string)
|
||||
if !ok || len(roles) == 0 {
|
||||
cclog.Warn("Invalid session: missing or invalid roles")
|
||||
// Invalidate the corrupted session
|
||||
session.Options.MaxAge = -1
|
||||
_ = auth.sessionStore.Save(r, rw, session)
|
||||
return nil, errors.New("invalid session data")
|
||||
}
|
||||
|
||||
return &schema.User{
|
||||
Username: username,
|
||||
Projects: projects,
|
||||
@@ -102,6 +176,9 @@ func (auth *Authentication) AuthViaSession(
|
||||
func Init(authCfg *json.RawMessage) {
|
||||
initOnce.Do(func() {
|
||||
authInstance = &Authentication{}
|
||||
|
||||
// Start background cleanup of rate limiters
|
||||
startRateLimiterCleanup()
|
||||
|
||||
sessKey := os.Getenv("SESSION_KEY")
|
||||
if sessKey == "" {
|
||||
@@ -185,38 +262,36 @@ func GetAuthInstance() *Authentication {
|
||||
return authInstance
|
||||
}
|
||||
|
||||
func handleTokenUser(tokenUser *schema.User) {
|
||||
// handleUserSync syncs or updates a user in the database based on configuration.
|
||||
// This is used for both JWT and OIDC authentication when syncUserOnLogin or updateUserOnLogin is enabled.
|
||||
func handleUserSync(user *schema.User, syncUserOnLogin, updateUserOnLogin bool) {
|
||||
r := repository.GetUserRepository()
|
||||
dbUser, err := r.GetUser(tokenUser.Username)
|
||||
dbUser, err := r.GetUser(user.Username)
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
cclog.Errorf("Error while loading user '%s': %v", tokenUser.Username, err)
|
||||
} else if err == sql.ErrNoRows && Keys.JwtConfig.SyncUserOnLogin { // Adds New User
|
||||
if err := r.AddUser(tokenUser); err != nil {
|
||||
cclog.Errorf("Error while adding user '%s' to DB: %v", tokenUser.Username, err)
|
||||
cclog.Errorf("Error while loading user '%s': %v", user.Username, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err == sql.ErrNoRows && syncUserOnLogin { // Add new user
|
||||
if err := r.AddUser(user); err != nil {
|
||||
cclog.Errorf("Error while adding user '%s' to DB: %v", user.Username, err)
|
||||
}
|
||||
} else if err == nil && Keys.JwtConfig.UpdateUserOnLogin { // Update Existing User
|
||||
if err := r.UpdateUser(dbUser, tokenUser); err != nil {
|
||||
cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err)
|
||||
} else if err == nil && updateUserOnLogin { // Update existing user
|
||||
if err := r.UpdateUser(dbUser, user); err != nil {
|
||||
cclog.Errorf("Error while updating user '%s' in DB: %v", dbUser.Username, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleOIDCUser(OIDCUser *schema.User) {
|
||||
r := repository.GetUserRepository()
|
||||
dbUser, err := r.GetUser(OIDCUser.Username)
|
||||
// handleTokenUser syncs JWT token user with database
|
||||
func handleTokenUser(tokenUser *schema.User) {
|
||||
handleUserSync(tokenUser, Keys.JwtConfig.SyncUserOnLogin, Keys.JwtConfig.UpdateUserOnLogin)
|
||||
}
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
cclog.Errorf("Error while loading user '%s': %v", OIDCUser.Username, err)
|
||||
} else if err == sql.ErrNoRows && Keys.OpenIDConfig.SyncUserOnLogin { // Adds New User
|
||||
if err := r.AddUser(OIDCUser); err != nil {
|
||||
cclog.Errorf("Error while adding user '%s' to DB: %v", OIDCUser.Username, err)
|
||||
}
|
||||
} else if err == nil && Keys.OpenIDConfig.UpdateUserOnLogin { // Update Existing User
|
||||
if err := r.UpdateUser(dbUser, OIDCUser); err != nil {
|
||||
cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err)
|
||||
}
|
||||
}
|
||||
// handleOIDCUser syncs OIDC user with database
|
||||
func handleOIDCUser(OIDCUser *schema.User) {
|
||||
handleUserSync(OIDCUser, Keys.OpenIDConfig.SyncUserOnLogin, Keys.OpenIDConfig.UpdateUserOnLogin)
|
||||
}
|
||||
|
||||
func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request, user *schema.User) error {
|
||||
@@ -231,6 +306,7 @@ func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request,
|
||||
session.Options.MaxAge = int(auth.SessionMaxAge.Seconds())
|
||||
}
|
||||
if config.Keys.HTTPSCertFile == "" && config.Keys.HTTPSKeyFile == "" {
|
||||
cclog.Warn("HTTPS not configured - session cookies will not have Secure flag set (insecure for production)")
|
||||
session.Options.Secure = false
|
||||
}
|
||||
session.Options.SameSite = http.SameSiteStrictMode
|
||||
@@ -532,10 +608,13 @@ func securedCheck(user *schema.User, r *http.Request) error {
|
||||
IPAddress = r.RemoteAddr
|
||||
}
|
||||
|
||||
// FIXME: IPV6 not handled
|
||||
if strings.Contains(IPAddress, ":") {
|
||||
IPAddress = strings.Split(IPAddress, ":")[0]
|
||||
// Handle both IPv4 and IPv6 addresses properly
|
||||
// For IPv6, this will strip the port and brackets
|
||||
// For IPv4, this will strip the port
|
||||
if host, _, err := net.SplitHostPort(IPAddress); err == nil {
|
||||
IPAddress = host
|
||||
}
|
||||
// If SplitHostPort fails, IPAddress is already just a host (no port)
|
||||
|
||||
// If nothing declared in config: deny all request to this api endpoint
|
||||
if len(config.Keys.APIAllowedIPs) == 0 {
|
||||
|
||||
176
internal/auth/auth_test.go
Normal file
176
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestGetIPUserLimiter tests the rate limiter creation and retrieval
|
||||
func TestGetIPUserLimiter(t *testing.T) {
|
||||
ip := "192.168.1.1"
|
||||
username := "testuser"
|
||||
|
||||
// Get limiter for the first time
|
||||
limiter1 := getIPUserLimiter(ip, username)
|
||||
if limiter1 == nil {
|
||||
t.Fatal("Expected limiter to be created")
|
||||
}
|
||||
|
||||
// Get the same limiter again
|
||||
limiter2 := getIPUserLimiter(ip, username)
|
||||
if limiter1 != limiter2 {
|
||||
t.Error("Expected to get the same limiter instance")
|
||||
}
|
||||
|
||||
// Get a different limiter for different user
|
||||
limiter3 := getIPUserLimiter(ip, "otheruser")
|
||||
if limiter1 == limiter3 {
|
||||
t.Error("Expected different limiter for different user")
|
||||
}
|
||||
|
||||
// Get a different limiter for different IP
|
||||
limiter4 := getIPUserLimiter("192.168.1.2", username)
|
||||
if limiter1 == limiter4 {
|
||||
t.Error("Expected different limiter for different IP")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiterBehavior tests that rate limiting works correctly
|
||||
func TestRateLimiterBehavior(t *testing.T) {
|
||||
ip := "10.0.0.1"
|
||||
username := "ratelimituser"
|
||||
|
||||
limiter := getIPUserLimiter(ip, username)
|
||||
|
||||
// Should allow first 5 attempts
|
||||
for i := 0; i < 5; i++ {
|
||||
if !limiter.Allow() {
|
||||
t.Errorf("Request %d should be allowed within rate limit", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 6th attempt should be blocked
|
||||
if limiter.Allow() {
|
||||
t.Error("Request 6 should be blocked by rate limiter")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanupOldRateLimiters tests the cleanup function
|
||||
func TestCleanupOldRateLimiters(t *testing.T) {
|
||||
// Clear all existing limiters first to avoid interference from other tests
|
||||
cleanupOldRateLimiters(time.Now().Add(24 * time.Hour))
|
||||
|
||||
// Create some new rate limiters
|
||||
limiter1 := getIPUserLimiter("1.1.1.1", "user1")
|
||||
limiter2 := getIPUserLimiter("2.2.2.2", "user2")
|
||||
|
||||
if limiter1 == nil || limiter2 == nil {
|
||||
t.Fatal("Failed to create test limiters")
|
||||
}
|
||||
|
||||
// Cleanup limiters older than 1 second from now (should keep both)
|
||||
time.Sleep(10 * time.Millisecond) // Small delay to ensure timestamp difference
|
||||
cleanupOldRateLimiters(time.Now().Add(-1 * time.Second))
|
||||
|
||||
// Verify they still exist (should get same instance)
|
||||
if getIPUserLimiter("1.1.1.1", "user1") != limiter1 {
|
||||
t.Error("Limiter 1 was incorrectly cleaned up")
|
||||
}
|
||||
if getIPUserLimiter("2.2.2.2", "user2") != limiter2 {
|
||||
t.Error("Limiter 2 was incorrectly cleaned up")
|
||||
}
|
||||
|
||||
// Cleanup limiters older than 1 hour from now (should remove both)
|
||||
cleanupOldRateLimiters(time.Now().Add(2 * time.Hour))
|
||||
|
||||
// Getting them again should create new instances
|
||||
newLimiter1 := getIPUserLimiter("1.1.1.1", "user1")
|
||||
if newLimiter1 == limiter1 {
|
||||
t.Error("Old limiter should have been cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIPv4Extraction tests extracting IPv4 addresses
|
||||
func TestIPv4Extraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"IPv4 with port", "192.168.1.1:8080", "192.168.1.1"},
|
||||
{"IPv4 without port", "192.168.1.1", "192.168.1.1"},
|
||||
{"Localhost with port", "127.0.0.1:3000", "127.0.0.1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input
|
||||
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||
result = host
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIPv6Extraction tests extracting IPv6 addresses
|
||||
func TestIPv6Extraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"IPv6 with port", "[2001:db8::1]:8080", "2001:db8::1"},
|
||||
{"IPv6 localhost with port", "[::1]:3000", "::1"},
|
||||
{"IPv6 without port", "2001:db8::1", "2001:db8::1"},
|
||||
{"IPv6 localhost", "::1", "::1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input
|
||||
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||
result = host
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIPExtractionEdgeCases tests edge cases for IP extraction
|
||||
func TestIPExtractionEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Hostname without port", "example.com", "example.com"},
|
||||
{"Empty string", "", ""},
|
||||
{"Just port", ":8080", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input
|
||||
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||
result = host
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -102,38 +101,21 @@ func (ja *JWTAuthenticator) AuthViaJWT(
|
||||
|
||||
// Token is valid, extract payload
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
sub, _ := claims["sub"].(string)
|
||||
|
||||
var roles []string
|
||||
|
||||
// Validate user + roles from JWT against database?
|
||||
if Keys.JwtConfig.ValidateUser {
|
||||
ur := repository.GetUserRepository()
|
||||
user, err := ur.GetUser(sub)
|
||||
// Deny any logins for unknown usernames
|
||||
if err != nil {
|
||||
cclog.Warn("Could not find user from JWT in internal database.")
|
||||
return nil, errors.New("unknown user")
|
||||
}
|
||||
// Take user roles from database instead of trusting the JWT
|
||||
roles = user.Roles
|
||||
} else {
|
||||
// Extract roles from JWT (if present)
|
||||
if rawroles, ok := claims["roles"].([]any); ok {
|
||||
for _, rr := range rawroles {
|
||||
if r, ok := rr.(string); ok {
|
||||
roles = append(roles, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use shared helper to get user from JWT claims
|
||||
var user *schema.User
|
||||
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthToken, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &schema.User{
|
||||
Username: sub,
|
||||
Roles: roles,
|
||||
AuthType: schema.AuthToken,
|
||||
AuthSource: -1,
|
||||
}, nil
|
||||
|
||||
// If not validating user, we only get roles from JWT (no projects for this auth method)
|
||||
if !Keys.JwtConfig.ValidateUser {
|
||||
user.Roles = extractRolesFromClaims(claims, false)
|
||||
user.Projects = nil // Standard JWT auth doesn't include projects
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ProvideJWT generates a new JWT that can be used for authentication
|
||||
|
||||
@@ -7,14 +7,11 @@ package auth
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -149,57 +146,16 @@ func (ja *JWTCookieSessionAuthenticator) Login(
|
||||
}
|
||||
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
sub, _ := claims["sub"].(string)
|
||||
|
||||
var roles []string
|
||||
projects := make([]string, 0)
|
||||
|
||||
if jc.ValidateUser {
|
||||
var err error
|
||||
user, err = repository.GetUserRepository().GetUser(sub)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
cclog.Errorf("Error while loading user '%v'", sub)
|
||||
}
|
||||
|
||||
// Deny any logins for unknown usernames
|
||||
if user == nil {
|
||||
cclog.Warn("Could not find user from JWT in internal database.")
|
||||
return nil, errors.New("unknown user")
|
||||
}
|
||||
} else {
|
||||
var name string
|
||||
if wrap, ok := claims["name"].(map[string]any); ok {
|
||||
if vals, ok := wrap["values"].([]any); ok {
|
||||
if len(vals) != 0 {
|
||||
name = fmt.Sprintf("%v", vals[0])
|
||||
|
||||
for i := 1; i < len(vals); i++ {
|
||||
name += fmt.Sprintf(" %v", vals[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles from JWT (if present)
|
||||
if rawroles, ok := claims["roles"].([]any); ok {
|
||||
for _, rr := range rawroles {
|
||||
if r, ok := rr.(string); ok {
|
||||
roles = append(roles, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
user = &schema.User{
|
||||
Username: sub,
|
||||
Name: name,
|
||||
Roles: roles,
|
||||
Projects: projects,
|
||||
AuthType: schema.AuthSession,
|
||||
AuthSource: schema.AuthViaToken,
|
||||
}
|
||||
|
||||
if jc.SyncUserOnLogin || jc.UpdateUserOnLogin {
|
||||
handleTokenUser(user)
|
||||
}
|
||||
|
||||
// Use shared helper to get user from JWT claims
|
||||
user, err = getUserFromJWT(claims, jc.ValidateUser, schema.AuthSession, schema.AuthViaToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sync or update user if configured
|
||||
if !jc.ValidateUser && (jc.SyncUserOnLogin || jc.UpdateUserOnLogin) {
|
||||
handleTokenUser(user)
|
||||
}
|
||||
|
||||
// (Ask browser to) Delete JWT cookie
|
||||
|
||||
136
internal/auth/jwtHelpers.go
Normal file
136
internal/auth/jwtHelpers.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// extractStringFromClaims extracts a string value from JWT claims
|
||||
func extractStringFromClaims(claims jwt.MapClaims, key string) string {
|
||||
if val, ok := claims[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractRolesFromClaims extracts roles from JWT claims
|
||||
// If validateRoles is true, only valid roles are returned
|
||||
func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string {
|
||||
var roles []string
|
||||
|
||||
if rawroles, ok := claims["roles"].([]any); ok {
|
||||
for _, rr := range rawroles {
|
||||
if r, ok := rr.(string); ok {
|
||||
if validateRoles {
|
||||
if schema.IsValidRole(r) {
|
||||
roles = append(roles, r)
|
||||
}
|
||||
} else {
|
||||
roles = append(roles, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return roles
|
||||
}
|
||||
|
||||
// extractProjectsFromClaims extracts projects from JWT claims
|
||||
func extractProjectsFromClaims(claims jwt.MapClaims) []string {
|
||||
projects := make([]string, 0)
|
||||
|
||||
if rawprojs, ok := claims["projects"].([]any); ok {
|
||||
for _, pp := range rawprojs {
|
||||
if p, ok := pp.(string); ok {
|
||||
projects = append(projects, p)
|
||||
}
|
||||
}
|
||||
} else if rawprojs, ok := claims["projects"]; ok {
|
||||
if projSlice, ok := rawprojs.([]string); ok {
|
||||
projects = append(projects, projSlice...)
|
||||
}
|
||||
}
|
||||
|
||||
return projects
|
||||
}
|
||||
|
||||
// extractNameFromClaims extracts name from JWT claims
|
||||
// Handles both simple string and complex nested structure
|
||||
func extractNameFromClaims(claims jwt.MapClaims) string {
|
||||
// Try simple string first
|
||||
if name, ok := claims["name"].(string); ok {
|
||||
return name
|
||||
}
|
||||
|
||||
// Try nested structure: {name: {values: [...]}}
|
||||
if wrap, ok := claims["name"].(map[string]any); ok {
|
||||
if vals, ok := wrap["values"].([]any); ok {
|
||||
if len(vals) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%v", vals[0])
|
||||
for i := 1; i < len(vals); i++ {
|
||||
name += fmt.Sprintf(" %v", vals[i])
|
||||
}
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getUserFromJWT creates or retrieves a user based on JWT claims
|
||||
// If validateUser is true, the user must exist in the database
|
||||
// Otherwise, a new user object is created from claims
|
||||
// authSource should be a schema.AuthSource constant (like schema.AuthViaToken)
|
||||
func getUserFromJWT(claims jwt.MapClaims, validateUser bool, authType schema.AuthType, authSource schema.AuthSource) (*schema.User, error) {
|
||||
sub := extractStringFromClaims(claims, "sub")
|
||||
if sub == "" {
|
||||
return nil, errors.New("missing 'sub' claim in JWT")
|
||||
}
|
||||
|
||||
if validateUser {
|
||||
// Validate user against database
|
||||
ur := repository.GetUserRepository()
|
||||
user, err := ur.GetUser(sub)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
cclog.Errorf("Error while loading user '%v': %v", sub, err)
|
||||
return nil, fmt.Errorf("database error: %w", err)
|
||||
}
|
||||
|
||||
// Deny any logins for unknown usernames
|
||||
if user == nil || err == sql.ErrNoRows {
|
||||
cclog.Warn("Could not find user from JWT in internal database.")
|
||||
return nil, errors.New("unknown user")
|
||||
}
|
||||
|
||||
// Return database user (with database roles)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Create user from JWT claims
|
||||
name := extractNameFromClaims(claims)
|
||||
roles := extractRolesFromClaims(claims, true) // Validate roles
|
||||
projects := extractProjectsFromClaims(claims)
|
||||
|
||||
return &schema.User{
|
||||
Username: sub,
|
||||
Name: name,
|
||||
Roles: roles,
|
||||
Projects: projects,
|
||||
AuthType: authType,
|
||||
AuthSource: authSource,
|
||||
}, nil
|
||||
}
|
||||
281
internal/auth/jwtHelpers_test.go
Normal file
281
internal/auth/jwtHelpers_test.go
Normal file
@@ -0,0 +1,281 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// TestExtractStringFromClaims tests extracting string values from JWT claims
|
||||
func TestExtractStringFromClaims(t *testing.T) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "testuser",
|
||||
"email": "test@example.com",
|
||||
"age": 25, // not a string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{"Existing string", "sub", "testuser"},
|
||||
{"Another string", "email", "test@example.com"},
|
||||
{"Non-existent key", "missing", ""},
|
||||
{"Non-string value", "age", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractStringFromClaims(claims, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractRolesFromClaims tests role extraction and validation
|
||||
func TestExtractRolesFromClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims jwt.MapClaims
|
||||
validateRoles bool
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Valid roles without validation",
|
||||
claims: jwt.MapClaims{
|
||||
"roles": []any{"admin", "user", "invalid_role"},
|
||||
},
|
||||
validateRoles: false,
|
||||
expected: []string{"admin", "user", "invalid_role"},
|
||||
},
|
||||
{
|
||||
name: "Valid roles with validation",
|
||||
claims: jwt.MapClaims{
|
||||
"roles": []any{"admin", "user", "api"},
|
||||
},
|
||||
validateRoles: true,
|
||||
expected: []string{"admin", "user", "api"},
|
||||
},
|
||||
{
|
||||
name: "Invalid roles with validation",
|
||||
claims: jwt.MapClaims{
|
||||
"roles": []any{"invalid_role", "fake_role"},
|
||||
},
|
||||
validateRoles: true,
|
||||
expected: []string{}, // Should filter out invalid roles
|
||||
},
|
||||
{
|
||||
name: "No roles claim",
|
||||
claims: jwt.MapClaims{},
|
||||
validateRoles: false,
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Non-array roles",
|
||||
claims: jwt.MapClaims{
|
||||
"roles": "admin",
|
||||
},
|
||||
validateRoles: false,
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractRolesFromClaims(tt.claims, tt.validateRoles)
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d roles, got %d", len(tt.expected), len(result))
|
||||
return
|
||||
}
|
||||
|
||||
for i, role := range result {
|
||||
if i >= len(tt.expected) || role != tt.expected[i] {
|
||||
t.Errorf("Expected role %s at position %d, got %s", tt.expected[i], i, role)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractProjectsFromClaims tests project extraction from claims
|
||||
func TestExtractProjectsFromClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims jwt.MapClaims
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Projects as array of interfaces",
|
||||
claims: jwt.MapClaims{
|
||||
"projects": []any{"project1", "project2", "project3"},
|
||||
},
|
||||
expected: []string{"project1", "project2", "project3"},
|
||||
},
|
||||
{
|
||||
name: "Projects as string array",
|
||||
claims: jwt.MapClaims{
|
||||
"projects": []string{"projectA", "projectB"},
|
||||
},
|
||||
expected: []string{"projectA", "projectB"},
|
||||
},
|
||||
{
|
||||
name: "No projects claim",
|
||||
claims: jwt.MapClaims{},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Mixed types in projects array",
|
||||
claims: jwt.MapClaims{
|
||||
"projects": []any{"project1", 123, "project2"},
|
||||
},
|
||||
expected: []string{"project1", "project2"}, // Should skip non-strings
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractProjectsFromClaims(tt.claims)
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d projects, got %d", len(tt.expected), len(result))
|
||||
return
|
||||
}
|
||||
|
||||
for i, project := range result {
|
||||
if i >= len(tt.expected) || project != tt.expected[i] {
|
||||
t.Errorf("Expected project %s at position %d, got %s", tt.expected[i], i, project)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractNameFromClaims tests name extraction from various formats
|
||||
func TestExtractNameFromClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims jwt.MapClaims
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple string name",
|
||||
claims: jwt.MapClaims{
|
||||
"name": "John Doe",
|
||||
},
|
||||
expected: "John Doe",
|
||||
},
|
||||
{
|
||||
name: "Nested name structure",
|
||||
claims: jwt.MapClaims{
|
||||
"name": map[string]any{
|
||||
"values": []any{"John", "Doe"},
|
||||
},
|
||||
},
|
||||
expected: "John Doe",
|
||||
},
|
||||
{
|
||||
name: "Nested name with single value",
|
||||
claims: jwt.MapClaims{
|
||||
"name": map[string]any{
|
||||
"values": []any{"Alice"},
|
||||
},
|
||||
},
|
||||
expected: "Alice",
|
||||
},
|
||||
{
|
||||
name: "No name claim",
|
||||
claims: jwt.MapClaims{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty nested values",
|
||||
claims: jwt.MapClaims{
|
||||
"name": map[string]any{
|
||||
"values": []any{},
|
||||
},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Nested with non-string values",
|
||||
claims: jwt.MapClaims{
|
||||
"name": map[string]any{
|
||||
"values": []any{123, "Smith"},
|
||||
},
|
||||
},
|
||||
expected: "123 Smith", // Should convert to string
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractNameFromClaims(tt.claims)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserFromJWT_NoValidation tests getUserFromJWT without database validation
|
||||
func TestGetUserFromJWT_NoValidation(t *testing.T) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "testuser",
|
||||
"name": "Test User",
|
||||
"roles": []any{"user", "admin"},
|
||||
"projects": []any{"project1", "project2"},
|
||||
}
|
||||
|
||||
user, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if user.Username != "testuser" {
|
||||
t.Errorf("Expected username 'testuser', got '%s'", user.Username)
|
||||
}
|
||||
|
||||
if user.Name != "Test User" {
|
||||
t.Errorf("Expected name 'Test User', got '%s'", user.Name)
|
||||
}
|
||||
|
||||
if len(user.Roles) != 2 {
|
||||
t.Errorf("Expected 2 roles, got %d", len(user.Roles))
|
||||
}
|
||||
|
||||
if len(user.Projects) != 2 {
|
||||
t.Errorf("Expected 2 projects, got %d", len(user.Projects))
|
||||
}
|
||||
|
||||
if user.AuthType != schema.AuthToken {
|
||||
t.Errorf("Expected AuthType %v, got %v", schema.AuthToken, user.AuthType)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserFromJWT_MissingSub tests error when sub claim is missing
|
||||
func TestGetUserFromJWT_MissingSub(t *testing.T) {
|
||||
claims := jwt.MapClaims{
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
_, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing sub claim")
|
||||
}
|
||||
|
||||
if err.Error() != "missing 'sub' claim in JWT" {
|
||||
t.Errorf("Expected specific error message, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -77,70 +75,16 @@ func (ja *JWTSessionAuthenticator) Login(
|
||||
}
|
||||
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
sub, _ := claims["sub"].(string)
|
||||
|
||||
var roles []string
|
||||
projects := make([]string, 0)
|
||||
|
||||
if Keys.JwtConfig.ValidateUser {
|
||||
var err error
|
||||
user, err = repository.GetUserRepository().GetUser(sub)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
cclog.Errorf("Error while loading user '%v'", sub)
|
||||
}
|
||||
|
||||
// Deny any logins for unknown usernames
|
||||
if user == nil {
|
||||
cclog.Warn("Could not find user from JWT in internal database.")
|
||||
return nil, errors.New("unknown user")
|
||||
}
|
||||
} else {
|
||||
var name string
|
||||
if wrap, ok := claims["name"].(map[string]any); ok {
|
||||
if vals, ok := wrap["values"].([]any); ok {
|
||||
if len(vals) != 0 {
|
||||
name = fmt.Sprintf("%v", vals[0])
|
||||
|
||||
for i := 1; i < len(vals); i++ {
|
||||
name += fmt.Sprintf(" %v", vals[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles from JWT (if present)
|
||||
if rawroles, ok := claims["roles"].([]any); ok {
|
||||
for _, rr := range rawroles {
|
||||
if r, ok := rr.(string); ok {
|
||||
if schema.IsValidRole(r) {
|
||||
roles = append(roles, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rawprojs, ok := claims["projects"].([]any); ok {
|
||||
for _, pp := range rawprojs {
|
||||
if p, ok := pp.(string); ok {
|
||||
projects = append(projects, p)
|
||||
}
|
||||
}
|
||||
} else if rawprojs, ok := claims["projects"]; ok {
|
||||
projects = append(projects, rawprojs.([]string)...)
|
||||
}
|
||||
|
||||
user = &schema.User{
|
||||
Username: sub,
|
||||
Name: name,
|
||||
Roles: roles,
|
||||
Projects: projects,
|
||||
AuthType: schema.AuthSession,
|
||||
AuthSource: schema.AuthViaToken,
|
||||
}
|
||||
|
||||
if Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin {
|
||||
handleTokenUser(user)
|
||||
}
|
||||
|
||||
// Use shared helper to get user from JWT claims
|
||||
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthSession, schema.AuthViaToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sync or update user if configured
|
||||
if !Keys.JwtConfig.ValidateUser && (Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin) {
|
||||
handleTokenUser(user)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
@@ -54,8 +54,13 @@ func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value strin
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authenticator with the configured provider
|
||||
func NewOIDC(a *Authentication) *OIDC {
|
||||
provider, err := oidc.NewProvider(context.Background(), Keys.OpenIDConfig.Provider)
|
||||
// Use context with timeout for provider initialization
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, Keys.OpenIDConfig.Provider)
|
||||
if err != nil {
|
||||
cclog.Fatal(err)
|
||||
}
|
||||
@@ -111,13 +116,18 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
||||
http.Error(rw, "Code not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
token, err := oa.client.Exchange(context.Background(), code, oauth2.VerifierOption(codeVerifier))
|
||||
// Exchange authorization code for token with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
token, err := oa.client.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo, err := oa.provider.UserInfo(context.Background(), oauth2.StaticTokenSource(token))
|
||||
// Get user info from OIDC provider with same timeout
|
||||
userInfo, err := oa.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -180,8 +190,8 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
oa.authentication.SaveSession(rw, r, user)
|
||||
cclog.Infof("login successfull: user: %#v (roles: %v, projects: %v)", user.Username, user.Roles, user.Projects)
|
||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
http.RedirectHandler("/", http.StatusTemporaryRedirect).ServeHTTP(rw, r.WithContext(ctx))
|
||||
userCtx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||
http.RedirectHandler("/", http.StatusTemporaryRedirect).ServeHTTP(rw, r.WithContext(userCtx))
|
||||
}
|
||||
|
||||
func (oa *OIDC) OAuth2Login(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@ package graph
|
||||
|
||||
// This file will be automatically regenerated based on the schema, any resolver implementations
|
||||
// will be copied through when generating and any unknown code will be moved to the end.
|
||||
// Code generated by github.com/99designs/gqlgen version v0.17.78
|
||||
// Code generated by github.com/99designs/gqlgen version v0.17.81
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
@@ -81,14 +81,14 @@ func setup(t *testing.T) *repository.JobRepository {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
jobarchive := filepath.Join(tmpdir, "job-archive")
|
||||
if err := os.Mkdir(jobarchive, 0777); err != nil {
|
||||
if err := os.Mkdir(jobarchive, 0o777); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "version.txt"), []byte(fmt.Sprintf("%d", 2)), 0666); err != nil {
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "version.txt"), fmt.Appendf(nil, "%d", 3), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fritzArchive := filepath.Join(tmpdir, "job-archive", "fritz")
|
||||
if err := os.Mkdir(fritzArchive, 0777); err != nil {
|
||||
if err := os.Mkdir(fritzArchive, 0o777); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := copyFile(filepath.Join("testdata", "cluster-fritz.json"),
|
||||
@@ -103,7 +103,7 @@ func setup(t *testing.T) *repository.JobRepository {
|
||||
}
|
||||
|
||||
cfgFilePath := filepath.Join(tmpdir, "config.json")
|
||||
if err := os.WriteFile(cfgFilePath, []byte(testconfig), 0666); err != nil {
|
||||
if err := os.WriteFile(cfgFilePath, []byte(testconfig), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,18 @@ func InitDB() error {
|
||||
|
||||
// Bundle 100 inserts into one transaction for better performance
|
||||
if i%100 == 0 {
|
||||
r.TransactionCommit(t)
|
||||
if i > 0 {
|
||||
if err := t.Commit(); err != nil {
|
||||
cclog.Errorf("transaction commit error: %v", err)
|
||||
return err
|
||||
}
|
||||
// Start a new transaction for the next batch
|
||||
t, err = r.TransactionInit()
|
||||
if err != nil {
|
||||
cclog.Errorf("transaction init error: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
fmt.Printf("%d jobs inserted...\r", i)
|
||||
}
|
||||
|
||||
|
||||
@@ -75,10 +75,10 @@ func ArchiveCheckpoints(checkpointsDir, archiveDir string, from int64, deleteIns
|
||||
|
||||
var wg sync.WaitGroup
|
||||
n, errs := int32(0), int32(0)
|
||||
work := make(chan workItem, NumWorkers)
|
||||
work := make(chan workItem, Keys.NumWorkers)
|
||||
|
||||
wg.Add(NumWorkers)
|
||||
for worker := 0; worker < NumWorkers; worker++ {
|
||||
wg.Add(Keys.NumWorkers)
|
||||
for worker := 0; worker < Keys.NumWorkers; worker++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for workItem := range work {
|
||||
@@ -116,7 +116,7 @@ func ArchiveCheckpoints(checkpointsDir, archiveDir string, from int64, deleteIns
|
||||
}
|
||||
|
||||
if errs > 0 {
|
||||
return int(n), fmt.Errorf("%d errors happend while archiving (%d successes)", errs, n)
|
||||
return int(n), fmt.Errorf("%d errors happened while archiving (%d successes)", errs, n)
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
@@ -147,11 +147,11 @@ func archiveCheckpoints(dir string, archiveDir string, from int64, deleteInstead
|
||||
}
|
||||
|
||||
filename := filepath.Join(archiveDir, fmt.Sprintf("%d.zip", from))
|
||||
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, CheckpointFilePerms)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
err = os.MkdirAll(archiveDir, 0o755)
|
||||
err = os.MkdirAll(archiveDir, CheckpointDirPerms)
|
||||
if err == nil {
|
||||
f, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
f, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, CheckpointFilePerms)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -105,46 +105,6 @@ func (b *buffer) firstWrite() int64 {
|
||||
|
||||
func (b *buffer) close() {}
|
||||
|
||||
/*
|
||||
func (b *buffer) close() {
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
b.closed = true
|
||||
n, sum, min, max := 0, 0., math.MaxFloat64, -math.MaxFloat64
|
||||
for _, x := range b.data {
|
||||
if x.IsNaN() {
|
||||
continue
|
||||
}
|
||||
|
||||
n += 1
|
||||
f := float64(x)
|
||||
sum += f
|
||||
min = math.Min(min, f)
|
||||
max = math.Max(max, f)
|
||||
}
|
||||
|
||||
b.statisticts.samples = n
|
||||
if n > 0 {
|
||||
b.statisticts.avg = Float(sum / float64(n))
|
||||
b.statisticts.min = Float(min)
|
||||
b.statisticts.max = Float(max)
|
||||
} else {
|
||||
b.statisticts.avg = NaN
|
||||
b.statisticts.min = NaN
|
||||
b.statisticts.max = NaN
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// func interpolate(idx int, data []Float) Float {
|
||||
// if idx == 0 || idx+1 == len(data) {
|
||||
// return NaN
|
||||
// }
|
||||
// return (data[idx-1] + data[idx+1]) / 2.0
|
||||
// }
|
||||
|
||||
// Return all known values from `from` to `to`. Gaps of information are represented as NaN.
|
||||
// Simple linear interpolation is done between the two neighboring cells if possible.
|
||||
// If values at the start or end are missing, instead of NaN values, the second and thrid
|
||||
|
||||
@@ -28,6 +28,17 @@ import (
|
||||
"github.com/linkedin/goavro/v2"
|
||||
)
|
||||
|
||||
// File operation constants
|
||||
const (
|
||||
// CheckpointFilePerms defines default permissions for checkpoint files
|
||||
CheckpointFilePerms = 0o644
|
||||
// CheckpointDirPerms defines default permissions for checkpoint directories
|
||||
CheckpointDirPerms = 0o755
|
||||
// GCTriggerInterval determines how often GC is forced during checkpoint loading
|
||||
// GC is triggered every GCTriggerInterval*NumWorkers loaded hosts
|
||||
GCTriggerInterval = 100
|
||||
)
|
||||
|
||||
// Whenever changed, update MarshalJSON as well!
|
||||
type CheckpointMetrics struct {
|
||||
Data []schema.Float `json:"data"`
|
||||
@@ -171,9 +182,9 @@ func (m *MemoryStore) ToCheckpoint(dir string, from, to int64) (int, error) {
|
||||
n, errs := int32(0), int32(0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(NumWorkers)
|
||||
work := make(chan workItem, NumWorkers*2)
|
||||
for worker := 0; worker < NumWorkers; worker++ {
|
||||
wg.Add(Keys.NumWorkers)
|
||||
work := make(chan workItem, Keys.NumWorkers*2)
|
||||
for worker := 0; worker < Keys.NumWorkers; worker++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -205,7 +216,7 @@ func (m *MemoryStore) ToCheckpoint(dir string, from, to int64) (int, error) {
|
||||
wg.Wait()
|
||||
|
||||
if errs > 0 {
|
||||
return int(n), fmt.Errorf("[METRICSTORE]> %d errors happend while creating checkpoints (%d successes)", errs, n)
|
||||
return int(n), fmt.Errorf("[METRICSTORE]> %d errors happened while creating checkpoints (%d successes)", errs, n)
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
@@ -285,11 +296,11 @@ func (l *Level) toCheckpoint(dir string, from, to int64, m *MemoryStore) error {
|
||||
}
|
||||
|
||||
filepath := path.Join(dir, fmt.Sprintf("%d.json", from))
|
||||
f, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
f, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY, CheckpointFilePerms)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
err = os.MkdirAll(dir, 0o755)
|
||||
err = os.MkdirAll(dir, CheckpointDirPerms)
|
||||
if err == nil {
|
||||
f, err = os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
f, err = os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY, CheckpointFilePerms)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -307,11 +318,11 @@ func (l *Level) toCheckpoint(dir string, from, to int64, m *MemoryStore) error {
|
||||
|
||||
func (m *MemoryStore) FromCheckpoint(dir string, from int64, extension string) (int, error) {
|
||||
var wg sync.WaitGroup
|
||||
work := make(chan [2]string, NumWorkers)
|
||||
work := make(chan [2]string, Keys.NumWorkers)
|
||||
n, errs := int32(0), int32(0)
|
||||
|
||||
wg.Add(NumWorkers)
|
||||
for worker := 0; worker < NumWorkers; worker++ {
|
||||
wg.Add(Keys.NumWorkers)
|
||||
for worker := 0; worker < Keys.NumWorkers; worker++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for host := range work {
|
||||
@@ -347,7 +358,7 @@ func (m *MemoryStore) FromCheckpoint(dir string, from int64, extension string) (
|
||||
}
|
||||
|
||||
i++
|
||||
if i%NumWorkers == 0 && i > 100 {
|
||||
if i%Keys.NumWorkers == 0 && i > GCTriggerInterval {
|
||||
// Forcing garbage collection runs here regulary during the loading of checkpoints
|
||||
// will decrease the total heap size after loading everything back to memory is done.
|
||||
// While loading data, the heap will grow fast, so the GC target size will double
|
||||
@@ -368,7 +379,7 @@ done:
|
||||
}
|
||||
|
||||
if errs > 0 {
|
||||
return int(n), fmt.Errorf("[METRICSTORE]> %d errors happend while creating checkpoints (%d successes)", errs, n)
|
||||
return int(n), fmt.Errorf("[METRICSTORE]> %d errors happened while creating checkpoints (%d successes)", errs, n)
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
@@ -379,7 +390,7 @@ done:
|
||||
func (m *MemoryStore) FromCheckpointFiles(dir string, from int64) (int, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
// The directory does not exist, so create it using os.MkdirAll()
|
||||
err := os.MkdirAll(dir, 0o755) // 0755 sets the permissions for the directory
|
||||
err := os.MkdirAll(dir, CheckpointDirPerms) // CheckpointDirPerms sets the permissions for the directory
|
||||
if err != nil {
|
||||
cclog.Fatalf("[METRICSTORE]> Error creating directory: %#v\n", err)
|
||||
}
|
||||
@@ -464,7 +475,7 @@ func (l *Level) loadAvroFile(m *MemoryStore, f *os.File, from int64) error {
|
||||
// Create a new OCF reader from the buffered reader
|
||||
ocfReader, err := goavro.NewOCFReader(br)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("[METRICSTORE]> error creating OCF reader: %w", err)
|
||||
}
|
||||
|
||||
metricsData := make(map[string]schema.FloatArray)
|
||||
@@ -477,7 +488,7 @@ func (l *Level) loadAvroFile(m *MemoryStore, f *os.File, from int64) error {
|
||||
|
||||
record, ok := datum.(map[string]any)
|
||||
if !ok {
|
||||
panic("[METRICSTORE]> failed to assert datum as map[string]interface{}")
|
||||
return fmt.Errorf("[METRICSTORE]> failed to assert datum as map[string]interface{}")
|
||||
}
|
||||
|
||||
for key, value := range record {
|
||||
@@ -559,7 +570,7 @@ func (l *Level) createBuffer(m *MemoryStore, metricName string, floatArray schem
|
||||
l.metrics[minfo.offset] = b
|
||||
} else {
|
||||
if prev.start > b.start {
|
||||
return errors.New("wooops")
|
||||
return fmt.Errorf("[METRICSTORE]> buffer start time %d is before previous buffer start %d", b.start, prev.start)
|
||||
}
|
||||
|
||||
b.prev = prev
|
||||
@@ -623,7 +634,7 @@ func (l *Level) loadFile(cf *CheckpointFile, m *MemoryStore) error {
|
||||
l.metrics[minfo.offset] = b
|
||||
} else {
|
||||
if prev.start > b.start {
|
||||
return errors.New("wooops")
|
||||
return fmt.Errorf("[METRICSTORE]> buffer start time %d is before previous buffer start %d", b.start, prev.start)
|
||||
}
|
||||
|
||||
b.prev = prev
|
||||
@@ -700,13 +711,17 @@ func (l *Level) fromCheckpoint(m *MemoryStore, dir string, from int64, extension
|
||||
loader := loaders[extension]
|
||||
|
||||
for _, filename := range files {
|
||||
f, err := os.Open(path.Join(dir, filename))
|
||||
if err != nil {
|
||||
return filesLoaded, err
|
||||
}
|
||||
defer f.Close()
|
||||
// Use a closure to ensure file is closed immediately after use
|
||||
err := func() error {
|
||||
f, err := os.Open(path.Join(dir, filename))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err = loader(m, f, from); err != nil {
|
||||
return loader(m, f, from)
|
||||
}()
|
||||
if err != nil {
|
||||
return filesLoaded, err
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
var InternalCCMSFlag bool = false
|
||||
|
||||
type MetricStoreConfig struct {
|
||||
// Number of concurrent workers for checkpoint and archive operations.
|
||||
// If not set or 0, defaults to min(runtime.NumCPU()/2+1, 10)
|
||||
NumWorkers int `json:"num-workers"`
|
||||
Checkpoints struct {
|
||||
FileFormat string `json:"file-format"`
|
||||
Interval string `json:"interval"`
|
||||
@@ -62,7 +65,7 @@ const (
|
||||
AvgAggregation
|
||||
)
|
||||
|
||||
func AssignAggregationStratergy(str string) (AggregationStrategy, error) {
|
||||
func AssignAggregationStrategy(str string) (AggregationStrategy, error) {
|
||||
switch str {
|
||||
case "":
|
||||
return NoAggregation, nil
|
||||
|
||||
@@ -39,7 +39,7 @@ func (l *Level) findLevelOrCreate(selector []string, nMetrics int) *Level {
|
||||
// Children map needs to be created...
|
||||
l.lock.RUnlock()
|
||||
} else {
|
||||
child, ok := l.children[selector[0]]
|
||||
child, ok = l.children[selector[0]]
|
||||
l.lock.RUnlock()
|
||||
if ok {
|
||||
return child.findLevelOrCreate(selector[1:], nMetrics)
|
||||
|
||||
@@ -3,6 +3,20 @@
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package memorystore provides an efficient in-memory time-series metric storage system
|
||||
// with support for hierarchical data organization, checkpointing, and archiving.
|
||||
//
|
||||
// The package organizes metrics in a tree structure (cluster → host → component) and
|
||||
// provides concurrent read/write access to metric data with configurable aggregation strategies.
|
||||
// Background goroutines handle periodic checkpointing (JSON or Avro format), archiving old data,
|
||||
// and enforcing retention policies.
|
||||
//
|
||||
// Key features:
|
||||
// - In-memory metric storage with configurable retention
|
||||
// - Hierarchical data organization (selectors)
|
||||
// - Concurrent checkpoint/archive workers
|
||||
// - Support for sum and average aggregation
|
||||
// - NATS integration for metric ingestion
|
||||
package memorystore
|
||||
|
||||
import (
|
||||
@@ -10,18 +24,14 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/config"
|
||||
"github.com/ClusterCockpit/cc-backend/pkg/archive"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"github.com/ClusterCockpit/cc-lib/resampler"
|
||||
"github.com/ClusterCockpit/cc-lib/runtimeEnv"
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/ClusterCockpit/cc-lib/util"
|
||||
)
|
||||
@@ -29,14 +39,12 @@ import (
|
||||
var (
|
||||
singleton sync.Once
|
||||
msInstance *MemoryStore
|
||||
// shutdownFunc stores the context cancellation function created in Init
|
||||
// and is called during Shutdown to cancel all background goroutines
|
||||
shutdownFunc context.CancelFunc
|
||||
)
|
||||
|
||||
var NumWorkers int = 4
|
||||
|
||||
func init() {
|
||||
maxWorkers := 10
|
||||
NumWorkers = min(runtime.NumCPU()/2+1, maxWorkers)
|
||||
}
|
||||
|
||||
type Metric struct {
|
||||
Name string
|
||||
@@ -61,30 +69,34 @@ func Init(rawConfig json.RawMessage, wg *sync.WaitGroup) {
|
||||
}
|
||||
}
|
||||
|
||||
// Set NumWorkers from config or use default
|
||||
if Keys.NumWorkers <= 0 {
|
||||
maxWorkers := 10
|
||||
Keys.NumWorkers = min(runtime.NumCPU()/2+1, maxWorkers)
|
||||
}
|
||||
cclog.Debugf("[METRICSTORE]> Using %d workers for checkpoint/archive operations\n", Keys.NumWorkers)
|
||||
|
||||
// Helper function to add metric configuration
|
||||
addMetricConfig := func(mc schema.MetricConfig) {
|
||||
agg, err := AssignAggregationStrategy(mc.Aggregation)
|
||||
if err != nil {
|
||||
cclog.Warnf("Could not find aggregation strategy for metric config '%s': %s", mc.Name, err.Error())
|
||||
}
|
||||
|
||||
AddMetric(mc.Name, MetricConfig{
|
||||
Frequency: int64(mc.Timestep),
|
||||
Aggregation: agg,
|
||||
})
|
||||
}
|
||||
|
||||
for _, c := range archive.Clusters {
|
||||
for _, mc := range c.MetricConfig {
|
||||
agg, err := AssignAggregationStratergy(mc.Aggregation)
|
||||
if err != nil {
|
||||
cclog.Warnf("Could not find aggregation stratergy for metric config '%s': %s", mc.Name, err.Error())
|
||||
}
|
||||
|
||||
AddMetric(mc.Name, MetricConfig{
|
||||
Frequency: int64(mc.Timestep),
|
||||
Aggregation: agg,
|
||||
})
|
||||
addMetricConfig(*mc)
|
||||
}
|
||||
|
||||
for _, sc := range c.SubClusters {
|
||||
for _, mc := range sc.MetricConfig {
|
||||
agg, err := AssignAggregationStratergy(mc.Aggregation)
|
||||
if err != nil {
|
||||
cclog.Warnf("Could not find aggregation stratergy for metric config '%s': %s", mc.Name, err.Error())
|
||||
}
|
||||
|
||||
AddMetric(mc.Name, MetricConfig{
|
||||
Frequency: int64(mc.Timestep),
|
||||
Aggregation: agg,
|
||||
})
|
||||
addMetricConfig(mc)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -126,15 +138,11 @@ func Init(rawConfig json.RawMessage, wg *sync.WaitGroup) {
|
||||
Archiving(wg, ctx)
|
||||
DataStaging(wg, ctx)
|
||||
|
||||
wg.Add(1)
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-sigs
|
||||
runtimeEnv.SystemdNotifiy(false, "[METRICSTORE]> Shutting down ...")
|
||||
shutdown()
|
||||
}()
|
||||
// Note: Signal handling has been removed from this function.
|
||||
// The caller is responsible for handling shutdown signals and calling
|
||||
// the shutdown() function when appropriate.
|
||||
// Store the shutdown function for later use by Shutdown()
|
||||
shutdownFunc = shutdown
|
||||
|
||||
if Keys.Nats != nil {
|
||||
for _, natsConf := range Keys.Nats {
|
||||
@@ -190,6 +198,11 @@ func GetMemoryStore() *MemoryStore {
|
||||
}
|
||||
|
||||
func Shutdown() {
|
||||
// Cancel the context to signal all background goroutines to stop
|
||||
if shutdownFunc != nil {
|
||||
shutdownFunc()
|
||||
}
|
||||
|
||||
cclog.Infof("[METRICSTORE]> Writing to '%s'...\n", Keys.Checkpoints.RootDir)
|
||||
var files int
|
||||
var err error
|
||||
@@ -207,70 +220,8 @@ func Shutdown() {
|
||||
cclog.Errorf("[METRICSTORE]> Writing checkpoint failed: %s\n", err.Error())
|
||||
}
|
||||
cclog.Infof("[METRICSTORE]> Done! (%d files written)\n", files)
|
||||
|
||||
// ms.PrintHeirarchy()
|
||||
}
|
||||
|
||||
// func (m *MemoryStore) PrintHeirarchy() {
|
||||
// m.root.lock.Lock()
|
||||
// defer m.root.lock.Unlock()
|
||||
|
||||
// fmt.Printf("Root : \n")
|
||||
|
||||
// for lvl1, sel1 := range m.root.children {
|
||||
// fmt.Printf("\t%s\n", lvl1)
|
||||
// for lvl2, sel2 := range sel1.children {
|
||||
// fmt.Printf("\t\t%s\n", lvl2)
|
||||
// if lvl1 == "fritz" && lvl2 == "f0201" {
|
||||
|
||||
// for name, met := range m.Metrics {
|
||||
// mt := sel2.metrics[met.Offset]
|
||||
|
||||
// fmt.Printf("\t\t\t\t%s\n", name)
|
||||
// fmt.Printf("\t\t\t\t")
|
||||
|
||||
// for mt != nil {
|
||||
// // if name == "cpu_load" {
|
||||
// fmt.Printf("%d(%d) -> %#v", mt.start, len(mt.data), mt.data)
|
||||
// // }
|
||||
// mt = mt.prev
|
||||
// }
|
||||
// fmt.Printf("\n")
|
||||
|
||||
// }
|
||||
// }
|
||||
// for lvl3, sel3 := range sel2.children {
|
||||
// if lvl1 == "fritz" && lvl2 == "f0201" && lvl3 == "hwthread70" {
|
||||
|
||||
// fmt.Printf("\t\t\t\t\t%s\n", lvl3)
|
||||
|
||||
// for name, met := range m.Metrics {
|
||||
// mt := sel3.metrics[met.Offset]
|
||||
|
||||
// fmt.Printf("\t\t\t\t\t\t%s\n", name)
|
||||
|
||||
// fmt.Printf("\t\t\t\t\t\t")
|
||||
|
||||
// for mt != nil {
|
||||
// // if name == "clock" {
|
||||
// fmt.Printf("%d(%d) -> %#v", mt.start, len(mt.data), mt.data)
|
||||
|
||||
// mt = mt.prev
|
||||
// }
|
||||
// fmt.Printf("\n")
|
||||
|
||||
// }
|
||||
|
||||
// // for i, _ := range sel3.metrics {
|
||||
// // fmt.Printf("\t\t\t\t\t%s\n", getName(configmetrics, i))
|
||||
// // }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// }
|
||||
|
||||
func getName(m *MemoryStore, i int) string {
|
||||
for key, val := range m.Metrics {
|
||||
if val.offset == i {
|
||||
|
||||
156
internal/memorystore/memorystore_test.go
Normal file
156
internal/memorystore/memorystore_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package memorystore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
)
|
||||
|
||||
func TestAssignAggregationStrategy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected AggregationStrategy
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty string", "", NoAggregation, false},
|
||||
{"sum", "sum", SumAggregation, false},
|
||||
{"avg", "avg", AvgAggregation, false},
|
||||
{"invalid", "invalid", NoAggregation, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := AssignAggregationStrategy(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AssignAggregationStrategy(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("AssignAggregationStrategy(%q) = %v, want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddMetric(t *testing.T) {
|
||||
// Reset Metrics before test
|
||||
Metrics = make(map[string]MetricConfig)
|
||||
|
||||
err := AddMetric("test_metric", MetricConfig{
|
||||
Frequency: 60,
|
||||
Aggregation: SumAggregation,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddMetric() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := Metrics["test_metric"]; !ok {
|
||||
t.Error("AddMetric() did not add metric to Metrics map")
|
||||
}
|
||||
|
||||
// Test updating with higher frequency
|
||||
err = AddMetric("test_metric", MetricConfig{
|
||||
Frequency: 120,
|
||||
Aggregation: SumAggregation,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddMetric() error = %v", err)
|
||||
}
|
||||
|
||||
if Metrics["test_metric"].Frequency != 120 {
|
||||
t.Errorf("AddMetric() frequency = %d, want 120", Metrics["test_metric"].Frequency)
|
||||
}
|
||||
|
||||
// Test updating with lower frequency (should not update)
|
||||
err = AddMetric("test_metric", MetricConfig{
|
||||
Frequency: 30,
|
||||
Aggregation: SumAggregation,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddMetric() error = %v", err)
|
||||
}
|
||||
|
||||
if Metrics["test_metric"].Frequency != 120 {
|
||||
t.Errorf("AddMetric() frequency = %d, want 120 (should not downgrade)", Metrics["test_metric"].Frequency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetricFrequency(t *testing.T) {
|
||||
// Reset Metrics before test
|
||||
Metrics = map[string]MetricConfig{
|
||||
"test_metric": {
|
||||
Frequency: 60,
|
||||
Aggregation: SumAggregation,
|
||||
},
|
||||
}
|
||||
|
||||
freq, err := GetMetricFrequency("test_metric")
|
||||
if err != nil {
|
||||
t.Errorf("GetMetricFrequency() error = %v", err)
|
||||
}
|
||||
if freq != 60 {
|
||||
t.Errorf("GetMetricFrequency() = %d, want 60", freq)
|
||||
}
|
||||
|
||||
_, err = GetMetricFrequency("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("GetMetricFrequency() expected error for nonexistent metric")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferWrite(t *testing.T) {
|
||||
b := newBuffer(100, 10)
|
||||
|
||||
// Test writing value
|
||||
nb, err := b.write(100, schema.Float(42.0))
|
||||
if err != nil {
|
||||
t.Errorf("buffer.write() error = %v", err)
|
||||
}
|
||||
if nb != b {
|
||||
t.Error("buffer.write() created new buffer unexpectedly")
|
||||
}
|
||||
if len(b.data) != 1 {
|
||||
t.Errorf("buffer.write() len(data) = %d, want 1", len(b.data))
|
||||
}
|
||||
if b.data[0] != schema.Float(42.0) {
|
||||
t.Errorf("buffer.write() data[0] = %v, want 42.0", b.data[0])
|
||||
}
|
||||
|
||||
// Test writing value from past (should error)
|
||||
_, err = b.write(50, schema.Float(10.0))
|
||||
if err == nil {
|
||||
t.Error("buffer.write() expected error for past timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferRead(t *testing.T) {
|
||||
b := newBuffer(100, 10)
|
||||
|
||||
// Write some test data
|
||||
b.write(100, schema.Float(1.0))
|
||||
b.write(110, schema.Float(2.0))
|
||||
b.write(120, schema.Float(3.0))
|
||||
|
||||
// Read data
|
||||
data := make([]schema.Float, 3)
|
||||
result, from, to, err := b.read(100, 130, data)
|
||||
if err != nil {
|
||||
t.Errorf("buffer.read() error = %v", err)
|
||||
}
|
||||
// Buffer read should return from as firstWrite (start + freq/2)
|
||||
if from != 100 {
|
||||
t.Errorf("buffer.read() from = %d, want 100", from)
|
||||
}
|
||||
if to != 130 {
|
||||
t.Errorf("buffer.read() to = %d, want 130", to)
|
||||
}
|
||||
if len(result) != 3 {
|
||||
t.Errorf("buffer.read() len(result) = %d, want 3", len(result))
|
||||
}
|
||||
}
|
||||
68
internal/repository/config.go
Normal file
68
internal/repository/config.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package repository
|
||||
|
||||
import "time"
|
||||
|
||||
// RepositoryConfig holds configuration for repository operations.
|
||||
// All fields have sensible defaults, so this configuration is optional.
|
||||
type RepositoryConfig struct {
|
||||
// CacheSize is the LRU cache size in bytes for job metadata and energy footprints.
|
||||
// Default: 1MB (1024 * 1024 bytes)
|
||||
CacheSize int
|
||||
|
||||
// MaxOpenConnections is the maximum number of open database connections.
|
||||
// Default: 4
|
||||
MaxOpenConnections int
|
||||
|
||||
// MaxIdleConnections is the maximum number of idle database connections.
|
||||
// Default: 4
|
||||
MaxIdleConnections int
|
||||
|
||||
// ConnectionMaxLifetime is the maximum amount of time a connection may be reused.
|
||||
// Default: 1 hour
|
||||
ConnectionMaxLifetime time.Duration
|
||||
|
||||
// ConnectionMaxIdleTime is the maximum amount of time a connection may be idle.
|
||||
// Default: 1 hour
|
||||
ConnectionMaxIdleTime time.Duration
|
||||
|
||||
// MinRunningJobDuration is the minimum duration in seconds for a job to be
|
||||
// considered in "running jobs" queries. This filters out very short jobs.
|
||||
// Default: 600 seconds (10 minutes)
|
||||
MinRunningJobDuration int
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default repository configuration.
|
||||
// These values are optimized for typical deployments.
|
||||
func DefaultConfig() *RepositoryConfig {
|
||||
return &RepositoryConfig{
|
||||
CacheSize: 1 * 1024 * 1024, // 1MB
|
||||
MaxOpenConnections: 4,
|
||||
MaxIdleConnections: 4,
|
||||
ConnectionMaxLifetime: time.Hour,
|
||||
ConnectionMaxIdleTime: time.Hour,
|
||||
MinRunningJobDuration: 600, // 10 minutes
|
||||
}
|
||||
}
|
||||
|
||||
// repoConfig is the package-level configuration instance.
|
||||
// It is initialized with defaults and can be overridden via SetConfig.
|
||||
var repoConfig *RepositoryConfig = DefaultConfig()
|
||||
|
||||
// SetConfig sets the repository configuration.
|
||||
// This must be called before any repository initialization (Connect, GetJobRepository, etc.).
|
||||
// If not called, default values from DefaultConfig() are used.
|
||||
func SetConfig(cfg *RepositoryConfig) {
|
||||
if cfg != nil {
|
||||
repoConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig returns the current repository configuration.
|
||||
func GetConfig() *RepositoryConfig {
|
||||
return repoConfig
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
@@ -35,21 +36,15 @@ type DatabaseOptions struct {
|
||||
ConnectionMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func setupSqlite(db *sql.DB) (err error) {
|
||||
func setupSqlite(db *sql.DB) error {
|
||||
pragmas := []string{
|
||||
// "journal_mode = WAL",
|
||||
// "busy_timeout = 5000",
|
||||
// "synchronous = NORMAL",
|
||||
// "cache_size = 1000000000", // 1GB
|
||||
// "foreign_keys = true",
|
||||
"temp_store = memory",
|
||||
// "mmap_size = 3000000000",
|
||||
}
|
||||
|
||||
for _, pragma := range pragmas {
|
||||
_, err = db.Exec("PRAGMA " + pragma)
|
||||
_, err := db.Exec("PRAGMA " + pragma)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,24 +58,24 @@ func Connect(driver string, db string) {
|
||||
dbConnOnce.Do(func() {
|
||||
opts := DatabaseOptions{
|
||||
URL: db,
|
||||
MaxOpenConnections: 4,
|
||||
MaxIdleConnections: 4,
|
||||
ConnectionMaxLifetime: time.Hour,
|
||||
ConnectionMaxIdleTime: time.Hour,
|
||||
MaxOpenConnections: repoConfig.MaxOpenConnections,
|
||||
MaxIdleConnections: repoConfig.MaxIdleConnections,
|
||||
ConnectionMaxLifetime: repoConfig.ConnectionMaxLifetime,
|
||||
ConnectionMaxIdleTime: repoConfig.ConnectionMaxIdleTime,
|
||||
}
|
||||
|
||||
switch driver {
|
||||
case "sqlite3":
|
||||
// TODO: Have separate DB handles for Writes and Reads
|
||||
// Optimize SQLite connection: https://kerkour.com/sqlite-for-servers
|
||||
connectionUrlParams := make(url.Values)
|
||||
connectionUrlParams.Add("_txlock", "immediate")
|
||||
connectionUrlParams.Add("_journal_mode", "WAL")
|
||||
connectionUrlParams.Add("_busy_timeout", "5000")
|
||||
connectionUrlParams.Add("_synchronous", "NORMAL")
|
||||
connectionUrlParams.Add("_cache_size", "1000000000")
|
||||
connectionUrlParams.Add("_foreign_keys", "true")
|
||||
opts.URL = fmt.Sprintf("file:%s?%s", opts.URL, connectionUrlParams.Encode())
|
||||
connectionURLParams := make(url.Values)
|
||||
connectionURLParams.Add("_txlock", "immediate")
|
||||
connectionURLParams.Add("_journal_mode", "WAL")
|
||||
connectionURLParams.Add("_busy_timeout", "5000")
|
||||
connectionURLParams.Add("_synchronous", "NORMAL")
|
||||
connectionURLParams.Add("_cache_size", "1000000000")
|
||||
connectionURLParams.Add("_foreign_keys", "true")
|
||||
opts.URL = fmt.Sprintf("file:%s?%s", opts.URL, connectionURLParams.Encode())
|
||||
|
||||
if cclog.Loglevel() == "debug" {
|
||||
sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
|
||||
@@ -89,7 +84,10 @@ func Connect(driver string, db string) {
|
||||
dbHandle, err = sqlx.Open("sqlite3", opts.URL)
|
||||
}
|
||||
|
||||
setupSqlite(dbHandle.DB)
|
||||
err = setupSqlite(dbHandle.DB)
|
||||
if err != nil {
|
||||
cclog.Abortf("Failed sqlite db setup.\nError: %s\n", err.Error())
|
||||
}
|
||||
case "mysql":
|
||||
opts.URL += "?multiStatements=true"
|
||||
dbHandle, err = sqlx.Open("mysql", opts.URL)
|
||||
|
||||
@@ -2,6 +2,63 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package repository provides the data access layer for cc-backend using the repository pattern.
|
||||
//
|
||||
// The repository pattern abstracts database operations and provides a clean interface for
|
||||
// data access. Each major entity (Job, User, Node, Tag) has its own repository with CRUD
|
||||
// operations and specialized queries.
|
||||
//
|
||||
// # Database Connection
|
||||
//
|
||||
// Initialize the database connection before using any repository:
|
||||
//
|
||||
// repository.Connect("sqlite3", "./var/job.db")
|
||||
// // or for MySQL:
|
||||
// repository.Connect("mysql", "user:password@tcp(localhost:3306)/dbname")
|
||||
//
|
||||
// # Configuration
|
||||
//
|
||||
// Optional: Configure repository settings before initialization:
|
||||
//
|
||||
// repository.SetConfig(&repository.RepositoryConfig{
|
||||
// CacheSize: 2 * 1024 * 1024, // 2MB cache
|
||||
// MaxOpenConnections: 8, // Connection pool size
|
||||
// MinRunningJobDuration: 300, // Filter threshold
|
||||
// })
|
||||
//
|
||||
// If not configured, sensible defaults are used automatically.
|
||||
//
|
||||
// # Repositories
|
||||
//
|
||||
// - JobRepository: Job lifecycle management and querying
|
||||
// - UserRepository: User management and authentication
|
||||
// - NodeRepository: Cluster node state tracking
|
||||
// - Tags: Job tagging and categorization
|
||||
//
|
||||
// # Caching
|
||||
//
|
||||
// Repositories use LRU caching to improve performance. Cache keys are constructed
|
||||
// as "type:id" (e.g., "metadata:123"). Cache is automatically invalidated on
|
||||
// mutations to maintain consistency.
|
||||
//
|
||||
// # Transaction Support
|
||||
//
|
||||
// For batch operations, use transactions:
|
||||
//
|
||||
// t, err := jobRepo.TransactionInit()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer t.Rollback() // Rollback if not committed
|
||||
//
|
||||
// // Perform operations...
|
||||
// jobRepo.TransactionAdd(t, query, args...)
|
||||
//
|
||||
// // Commit when done
|
||||
// if err := t.Commit(); err != nil {
|
||||
// return err
|
||||
// }
|
||||
package repository
|
||||
|
||||
import (
|
||||
@@ -45,7 +102,7 @@ func GetJobRepository() *JobRepository {
|
||||
driver: db.Driver,
|
||||
|
||||
stmtCache: sq.NewStmtCache(db.DB),
|
||||
cache: lrucache.New(1024 * 1024),
|
||||
cache: lrucache.New(repoConfig.CacheSize),
|
||||
}
|
||||
})
|
||||
return jobRepoInstance
|
||||
@@ -267,7 +324,31 @@ func (r *JobRepository) FetchEnergyFootprint(job *schema.Job) (map[string]float6
|
||||
func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
|
||||
var cnt int
|
||||
q := sq.Select("count(*)").From("job").Where("job.start_time < ?", startTime)
|
||||
q.RunWith(r.DB).QueryRow().Scan(cnt)
|
||||
if err := q.RunWith(r.DB).QueryRow().Scan(&cnt); err != nil {
|
||||
cclog.Errorf("Error counting jobs before %d: %v", startTime, err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Invalidate cache for jobs being deleted (get job IDs first)
|
||||
if cnt > 0 {
|
||||
var jobIds []int64
|
||||
rows, err := sq.Select("id").From("job").Where("job.start_time < ?", startTime).RunWith(r.DB).Query()
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err == nil {
|
||||
jobIds = append(jobIds, id)
|
||||
}
|
||||
}
|
||||
// Invalidate cache entries
|
||||
for _, id := range jobIds {
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", id))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
qd := sq.Delete("job").Where("job.start_time < ?", startTime)
|
||||
_, err := qd.RunWith(r.DB).Exec()
|
||||
|
||||
@@ -281,6 +362,10 @@ func (r *JobRepository) DeleteJobsBefore(startTime int64) (int, error) {
|
||||
}
|
||||
|
||||
func (r *JobRepository) DeleteJobById(id int64) error {
|
||||
// Invalidate cache entries before deletion
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", id))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", id))
|
||||
|
||||
qd := sq.Delete("job").Where("job.id = ?", id)
|
||||
_, err := qd.RunWith(r.DB).Exec()
|
||||
|
||||
@@ -450,13 +535,14 @@ func (r *JobRepository) AllocatedNodes(cluster string) (map[string]map[string]in
|
||||
// FIXME: Set duration to requested walltime?
|
||||
func (r *JobRepository) StopJobsExceedingWalltimeBy(seconds int) error {
|
||||
start := time.Now()
|
||||
currentTime := time.Now().Unix()
|
||||
res, err := sq.Update("job").
|
||||
Set("monitoring_status", schema.MonitoringStatusArchivingFailed).
|
||||
Set("duration", 0).
|
||||
Set("job_state", schema.JobStateFailed).
|
||||
Where("job.job_state = 'running'").
|
||||
Where("job.walltime > 0").
|
||||
Where(fmt.Sprintf("(%d - job.start_time) > (job.walltime + %d)", time.Now().Unix(), seconds)).
|
||||
Where("(? - job.start_time) > (job.walltime + ?)", currentTime, seconds).
|
||||
RunWith(r.DB).Exec()
|
||||
if err != nil {
|
||||
cclog.Warn("Error while stopping jobs exceeding walltime")
|
||||
@@ -505,21 +591,21 @@ func (r *JobRepository) FindJobIdsByTag(tagId int64) ([]int64, error) {
|
||||
// FIXME: Reconsider filtering short jobs with harcoded threshold
|
||||
func (r *JobRepository) FindRunningJobs(cluster string) ([]*schema.Job, error) {
|
||||
query := sq.Select(jobColumns...).From("job").
|
||||
Where(fmt.Sprintf("job.cluster = '%s'", cluster)).
|
||||
Where("job.cluster = ?", cluster).
|
||||
Where("job.job_state = 'running'").
|
||||
Where("job.duration > 600")
|
||||
Where("job.duration > ?", repoConfig.MinRunningJobDuration)
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 50)
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -552,12 +638,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
|
||||
|
||||
if startTimeBegin == 0 {
|
||||
cclog.Infof("Find jobs before %d", startTimeEnd)
|
||||
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
|
||||
"job.start_time < %d", startTimeEnd))
|
||||
query = sq.Select(jobColumns...).From("job").Where("job.start_time < ?", startTimeEnd)
|
||||
} else {
|
||||
cclog.Infof("Find jobs between %d and %d", startTimeBegin, startTimeEnd)
|
||||
query = sq.Select(jobColumns...).From("job").Where(fmt.Sprintf(
|
||||
"job.start_time BETWEEN %d AND %d", startTimeBegin, startTimeEnd))
|
||||
query = sq.Select(jobColumns...).From("job").Where("job.start_time BETWEEN ? AND ?", startTimeBegin, startTimeEnd)
|
||||
}
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
@@ -565,12 +649,12 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 50)
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -582,6 +666,10 @@ func (r *JobRepository) FindJobsBetween(startTimeBegin int64, startTimeEnd int64
|
||||
}
|
||||
|
||||
func (r *JobRepository) UpdateMonitoringStatus(job int64, monitoringStatus int32) (err error) {
|
||||
// Invalidate cache entries as monitoring status affects job state
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", job))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", job))
|
||||
|
||||
stmt := sq.Update("job").
|
||||
Set("monitoring_status", monitoringStatus).
|
||||
Where("job.id = ?", job)
|
||||
|
||||
@@ -31,8 +31,9 @@ const NamedJobInsert string = `INSERT INTO job (
|
||||
|
||||
func (r *JobRepository) InsertJob(job *schema.Job) (int64, error) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
|
||||
res, err := r.DB.NamedExec(NamedJobCacheInsert, job)
|
||||
r.Mutex.Unlock()
|
||||
if err != nil {
|
||||
cclog.Warn("Error while NamedJobInsert")
|
||||
return 0, err
|
||||
@@ -57,12 +58,12 @@ func (r *JobRepository) SyncJobs() ([]*schema.Job, error) {
|
||||
cclog.Errorf("Error while running query %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 50)
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,6 +114,10 @@ func (r *JobRepository) Stop(
|
||||
state schema.JobState,
|
||||
monitoringStatus int32,
|
||||
) (err error) {
|
||||
// Invalidate cache entries as job state is changing
|
||||
r.cache.Del(fmt.Sprintf("metadata:%d", jobId))
|
||||
r.cache.Del(fmt.Sprintf("energyFootprint:%d", jobId))
|
||||
|
||||
stmt := sq.Update("job").
|
||||
Set("job_state", state).
|
||||
Set("duration", duration).
|
||||
@@ -129,11 +134,13 @@ func (r *JobRepository) StopCached(
|
||||
state schema.JobState,
|
||||
monitoringStatus int32,
|
||||
) (err error) {
|
||||
// Note: StopCached updates job_cache table, not the main job table
|
||||
// Cache invalidation happens when job is synced to main table
|
||||
stmt := sq.Update("job_cache").
|
||||
Set("job_state", state).
|
||||
Set("duration", duration).
|
||||
Set("monitoring_status", monitoringStatus).
|
||||
Where("job.id = ?", jobId)
|
||||
Where("job_cache.id = ?", jobId)
|
||||
|
||||
_, err = stmt.RunWith(r.stmtCache).Exec()
|
||||
return err
|
||||
|
||||
@@ -89,6 +89,7 @@ func (r *JobRepository) FindAll(
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jobs := make([]*schema.Job, 0, 10)
|
||||
for rows.Next() {
|
||||
@@ -103,25 +104,31 @@ func (r *JobRepository) FindAll(
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// Get complete joblist only consisting of db ids.
|
||||
// GetJobList returns job IDs for non-running jobs.
|
||||
// This is useful to process large job counts and intended to be used
|
||||
// together with FindById to process jobs one by one
|
||||
func (r *JobRepository) GetJobList() ([]int64, error) {
|
||||
// together with FindById to process jobs one by one.
|
||||
// Use limit and offset for pagination. Use limit=0 to get all results (not recommended for large datasets).
|
||||
func (r *JobRepository) GetJobList(limit int, offset int) ([]int64, error) {
|
||||
query := sq.Select("id").From("job").
|
||||
Where("job.job_state != 'running'")
|
||||
|
||||
// Add pagination if limit is specified
|
||||
if limit > 0 {
|
||||
query = query.Limit(uint64(limit)).Offset(uint64(offset))
|
||||
}
|
||||
|
||||
rows, err := query.RunWith(r.stmtCache).Query()
|
||||
if err != nil {
|
||||
cclog.Error("Error while running query")
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
jl := make([]int64, 0, 1000)
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
err := rows.Scan(&id)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
cclog.Warn("Error while scanning rows")
|
||||
return nil, err
|
||||
}
|
||||
@@ -256,6 +263,7 @@ func (r *JobRepository) FindConcurrentJobs(
|
||||
cclog.Errorf("Error while running query: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]*model.JobLink, 0, 10)
|
||||
queryString := fmt.Sprintf("cluster=%s", job.Cluster)
|
||||
@@ -283,6 +291,7 @@ func (r *JobRepository) FindConcurrentJobs(
|
||||
cclog.Errorf("Error while running query: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, jobId, startTime sql.NullInt64
|
||||
|
||||
@@ -43,7 +43,7 @@ func GetNodeRepository() *NodeRepository {
|
||||
driver: db.Driver,
|
||||
|
||||
stmtCache: sq.NewStmtCache(db.DB),
|
||||
cache: lrucache.New(1024 * 1024),
|
||||
cache: lrucache.New(repoConfig.CacheSize),
|
||||
}
|
||||
})
|
||||
return nodeRepoInstance
|
||||
@@ -77,43 +77,6 @@ func (r *NodeRepository) FetchMetadata(hostname string, cluster string) (map[str
|
||||
return MetaData, nil
|
||||
}
|
||||
|
||||
//
|
||||
// func (r *NodeRepository) UpdateMetadata(node *schema.Node, key, val string) (err error) {
|
||||
// cachekey := fmt.Sprintf("metadata:%d", node.ID)
|
||||
// r.cache.Del(cachekey)
|
||||
// if node.MetaData == nil {
|
||||
// if _, err = r.FetchMetadata(node); err != nil {
|
||||
// cclog.Warnf("Error while fetching metadata for node, DB ID '%v'", node.ID)
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if node.MetaData != nil {
|
||||
// cpy := make(map[string]string, len(node.MetaData)+1)
|
||||
// maps.Copy(cpy, node.MetaData)
|
||||
// cpy[key] = val
|
||||
// node.MetaData = cpy
|
||||
// } else {
|
||||
// node.MetaData = map[string]string{key: val}
|
||||
// }
|
||||
//
|
||||
// if node.RawMetaData, err = json.Marshal(node.MetaData); err != nil {
|
||||
// cclog.Warnf("Error while marshaling metadata for node, DB ID '%v'", node.ID)
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// if _, err = sq.Update("node").
|
||||
// Set("meta_data", node.RawMetaData).
|
||||
// Where("node.id = ?", node.ID).
|
||||
// RunWith(r.stmtCache).Exec(); err != nil {
|
||||
// cclog.Warnf("Error while updating metadata for node, DB ID '%v'", node.ID)
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// r.cache.Put(cachekey, node.MetaData, len(node.RawMetaData), 24*time.Hour)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (r *NodeRepository) GetNode(hostname string, cluster string, withMeta bool) (*schema.Node, error) {
|
||||
node := &schema.Node{}
|
||||
var timestamp int
|
||||
|
||||
@@ -115,7 +115,7 @@ func nodeTestSetup(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "version.txt"),
|
||||
fmt.Appendf(nil, "%d", 2), 0o666); err != nil {
|
||||
fmt.Appendf(nil, "%d", 3), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -114,16 +114,6 @@ func (r *JobRepository) buildStatsQuery(
|
||||
return query
|
||||
}
|
||||
|
||||
// func (r *JobRepository) getUserName(ctx context.Context, id string) string {
|
||||
// user := GetUserFromContext(ctx)
|
||||
// name, _ := r.FindColumnValue(user, id, "hpc_user", "name", "username", false)
|
||||
// if name != "" {
|
||||
// return name
|
||||
// } else {
|
||||
// return "-"
|
||||
// }
|
||||
// }
|
||||
|
||||
func (r *JobRepository) getCastType() string {
|
||||
var castType string
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -14,65 +15,32 @@ import (
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
)
|
||||
|
||||
// Add the tag with id `tagId` to the job with the database id `jobId`.
|
||||
// AddTag adds the tag with id `tagId` to the job with the database id `jobId`.
|
||||
// Requires user authentication for security checks.
|
||||
func (r *JobRepository) AddTag(user *schema.User, job int64, tag int64) ([]*schema.Tag, error) {
|
||||
j, err := r.FindByIdWithUser(user, job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while finding job by id")
|
||||
cclog.Warnf("Error finding job %d for user %s: %v", job, user.Username, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := sq.Insert("jobtag").Columns("job_id", "tag_id").Values(job, tag)
|
||||
|
||||
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
s, _, _ := q.ToSql()
|
||||
cclog.Errorf("Error adding tag with %s: %v", s, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags, err := r.GetTags(user, &job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
archiveTags, err := r.getArchiveTags(&job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tags, archive.UpdateTags(j, archiveTags)
|
||||
return r.addJobTag(job, tag, j, func() ([]*schema.Tag, error) {
|
||||
return r.GetTags(user, &job)
|
||||
})
|
||||
}
|
||||
|
||||
// AddTagDirect adds a tag without user security checks.
|
||||
// Use only for internal/admin operations.
|
||||
func (r *JobRepository) AddTagDirect(job int64, tag int64) ([]*schema.Tag, error) {
|
||||
j, err := r.FindByIdDirect(job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while finding job by id")
|
||||
cclog.Warnf("Error finding job %d: %v", job, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := sq.Insert("jobtag").Columns("job_id", "tag_id").Values(job, tag)
|
||||
|
||||
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
s, _, _ := q.ToSql()
|
||||
cclog.Errorf("Error adding tag with %s: %v", s, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags, err := r.GetTagsDirect(&job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
archiveTags, err := r.getArchiveTags(&job)
|
||||
if err != nil {
|
||||
cclog.Warn("Error while getting tags for job")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tags, archive.UpdateTags(j, archiveTags)
|
||||
return r.addJobTag(job, tag, j, func() ([]*schema.Tag, error) {
|
||||
return r.GetTagsDirect(&job)
|
||||
})
|
||||
}
|
||||
|
||||
// Removes a tag from a job by tag id.
|
||||
@@ -260,15 +228,18 @@ func (r *JobRepository) CountTags(user *schema.User) (tags []schema.Tag, counts
|
||||
LeftJoin("jobtag jt ON t.id = jt.tag_id").
|
||||
GroupBy("t.tag_name")
|
||||
|
||||
// Handle Scope Filtering
|
||||
scopeList := "\"global\""
|
||||
// Build scope list for filtering
|
||||
var scopeBuilder strings.Builder
|
||||
scopeBuilder.WriteString(`"global"`)
|
||||
if user != nil {
|
||||
scopeList += ",\"" + user.Username + "\""
|
||||
scopeBuilder.WriteString(`,"`)
|
||||
scopeBuilder.WriteString(user.Username)
|
||||
scopeBuilder.WriteString(`"`)
|
||||
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) {
|
||||
scopeBuilder.WriteString(`,"admin"`)
|
||||
}
|
||||
}
|
||||
if user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) {
|
||||
scopeList += ",\"admin\""
|
||||
}
|
||||
q = q.Where("t.tag_scope IN (" + scopeList + ")")
|
||||
q = q.Where("t.tag_scope IN (" + scopeBuilder.String() + ")")
|
||||
|
||||
// Handle Job Ownership
|
||||
if user != nil && user.HasAnyRole([]schema.Role{schema.RoleAdmin, schema.RoleSupport}) { // ADMIN || SUPPORT: Count all jobs
|
||||
@@ -302,6 +273,41 @@ func (r *JobRepository) CountTags(user *schema.User) (tags []schema.Tag, counts
|
||||
return tags, counts, err
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTagNotFound = errors.New("the tag does not exist")
|
||||
ErrJobNotOwned = errors.New("user is not owner of job")
|
||||
ErrTagNoAccess = errors.New("user not permitted to use that tag")
|
||||
ErrTagPrivateScope = errors.New("tag is private to another user")
|
||||
ErrTagAdminScope = errors.New("tag requires admin privileges")
|
||||
ErrTagsIncompatScopes = errors.New("combining admin and non-admin scoped tags not allowed")
|
||||
)
|
||||
|
||||
// addJobTag is a helper function that inserts a job-tag association and updates the archive.
|
||||
// Returns the updated tag list for the job.
|
||||
func (r *JobRepository) addJobTag(jobId int64, tagId int64, job *schema.Job, getTags func() ([]*schema.Tag, error)) ([]*schema.Tag, error) {
|
||||
q := sq.Insert("jobtag").Columns("job_id", "tag_id").Values(jobId, tagId)
|
||||
|
||||
if _, err := q.RunWith(r.stmtCache).Exec(); err != nil {
|
||||
s, _, _ := q.ToSql()
|
||||
cclog.Errorf("Error adding tag with %s: %v", s, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tags, err := getTags()
|
||||
if err != nil {
|
||||
cclog.Warnf("Error getting tags for job %d: %v", jobId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
archiveTags, err := r.getArchiveTags(&jobId)
|
||||
if err != nil {
|
||||
cclog.Warnf("Error getting archive tags for job %d: %v", jobId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tags, archive.UpdateTags(job, archiveTags)
|
||||
}
|
||||
|
||||
// AddTagOrCreate adds the tag with the specified type and name to the job with the database id `jobId`.
|
||||
// If such a tag does not yet exist, it is created.
|
||||
func (r *JobRepository) AddTagOrCreate(user *schema.User, jobId int64, tagType string, tagName string, tagScope string) (tagId int64, err error) {
|
||||
|
||||
@@ -5,84 +5,96 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// Transaction wraps a database transaction for job-related operations.
|
||||
type Transaction struct {
|
||||
tx *sqlx.Tx
|
||||
stmt *sqlx.NamedStmt
|
||||
tx *sqlx.Tx
|
||||
}
|
||||
|
||||
// TransactionInit begins a new transaction.
|
||||
func (r *JobRepository) TransactionInit() (*Transaction, error) {
|
||||
var err error
|
||||
t := new(Transaction)
|
||||
|
||||
t.tx, err = r.DB.Beginx()
|
||||
tx, err := r.DB.Beginx()
|
||||
if err != nil {
|
||||
cclog.Warn("Error while bundling transactions")
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("beginning transaction: %w", err)
|
||||
}
|
||||
return t, nil
|
||||
return &Transaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (r *JobRepository) TransactionCommit(t *Transaction) error {
|
||||
var err error
|
||||
if t.tx != nil {
|
||||
if err = t.tx.Commit(); err != nil {
|
||||
cclog.Warn("Error while committing transactions")
|
||||
return err
|
||||
}
|
||||
// Commit commits the transaction.
|
||||
// After calling Commit, the transaction should not be used again.
|
||||
func (t *Transaction) Commit() error {
|
||||
if t.tx == nil {
|
||||
return fmt.Errorf("transaction already committed or rolled back")
|
||||
}
|
||||
|
||||
t.tx, err = r.DB.Beginx()
|
||||
err := t.tx.Commit()
|
||||
t.tx = nil // Mark as completed
|
||||
if err != nil {
|
||||
cclog.Warn("Error while bundling transactions")
|
||||
return err
|
||||
return fmt.Errorf("committing transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback rolls back the transaction.
|
||||
// It's safe to call Rollback on an already committed or rolled back transaction.
|
||||
func (t *Transaction) Rollback() error {
|
||||
if t.tx == nil {
|
||||
return nil // Already committed/rolled back
|
||||
}
|
||||
err := t.tx.Rollback()
|
||||
t.tx = nil // Mark as completed
|
||||
if err != nil {
|
||||
return fmt.Errorf("rolling back transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TransactionEnd commits the transaction.
|
||||
// Deprecated: Use Commit() instead.
|
||||
func (r *JobRepository) TransactionEnd(t *Transaction) error {
|
||||
if err := t.tx.Commit(); err != nil {
|
||||
cclog.Warn("Error while committing SQL transactions")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return t.Commit()
|
||||
}
|
||||
|
||||
// TransactionAddNamed executes a named query within the transaction.
|
||||
func (r *JobRepository) TransactionAddNamed(
|
||||
t *Transaction,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (int64, error) {
|
||||
if t.tx == nil {
|
||||
return 0, fmt.Errorf("transaction is nil or already completed")
|
||||
}
|
||||
|
||||
res, err := t.tx.NamedExec(query, args)
|
||||
if err != nil {
|
||||
cclog.Errorf("Named Exec failed: %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("named exec: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
cclog.Errorf("repository initDB(): %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("getting last insert id: %w", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// TransactionAdd executes a query within the transaction.
|
||||
func (r *JobRepository) TransactionAdd(t *Transaction, query string, args ...interface{}) (int64, error) {
|
||||
if t.tx == nil {
|
||||
return 0, fmt.Errorf("transaction is nil or already completed")
|
||||
}
|
||||
|
||||
res, err := t.tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
cclog.Errorf("TransactionAdd(), Exec() Error: %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("exec: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
cclog.Errorf("TransactionAdd(), LastInsertId() Error: %v", err)
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("getting last insert id: %w", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
|
||||
@@ -24,10 +24,14 @@ import (
|
||||
)
|
||||
|
||||
//go:embed jobclasses/*
|
||||
var jobclassFiles embed.FS
|
||||
var jobClassFiles embed.FS
|
||||
|
||||
// Variable defines a named expression that can be computed and reused in rules.
|
||||
// Variables are evaluated before the main rule and their results are added to the environment.
|
||||
type Variable struct {
|
||||
// Name is the variable identifier used in rule expressions
|
||||
Name string `json:"name"`
|
||||
// Expr is the expression to evaluate (must return a numeric value)
|
||||
Expr string `json:"expr"`
|
||||
}
|
||||
|
||||
@@ -36,14 +40,25 @@ type ruleVariable struct {
|
||||
expr *vm.Program
|
||||
}
|
||||
|
||||
// RuleFormat defines the JSON structure for job classification rules.
|
||||
// Each rule specifies requirements, metrics to analyze, variables to compute,
|
||||
// and the final rule expression that determines if the job matches the classification.
|
||||
type RuleFormat struct {
|
||||
// Name is a human-readable description of the rule
|
||||
Name string `json:"name"`
|
||||
// Tag is the classification tag to apply if the rule matches
|
||||
Tag string `json:"tag"`
|
||||
// Parameters are shared values referenced in the rule (e.g., thresholds)
|
||||
Parameters []string `json:"parameters"`
|
||||
// Metrics are the job metrics required for this rule (e.g., "cpu_load", "mem_used")
|
||||
Metrics []string `json:"metrics"`
|
||||
// Requirements are boolean expressions that must be true for the rule to apply
|
||||
Requirements []string `json:"requirements"`
|
||||
// Variables are computed values used in the rule expression
|
||||
Variables []Variable `json:"variables"`
|
||||
// Rule is the boolean expression that determines if the job matches
|
||||
Rule string `json:"rule"`
|
||||
// Hint is a template string that generates a message when the rule matches
|
||||
Hint string `json:"hint"`
|
||||
}
|
||||
|
||||
@@ -56,11 +71,35 @@ type ruleInfo struct {
|
||||
hint *template.Template
|
||||
}
|
||||
|
||||
// JobRepository defines the interface for job database operations needed by the tagger.
|
||||
// This interface allows for easier testing and decoupling from the concrete repository implementation.
|
||||
type JobRepository interface {
|
||||
// HasTag checks if a job already has a specific tag
|
||||
HasTag(jobId int64, tagType string, tagName string) bool
|
||||
// AddTagOrCreateDirect adds a tag to a job or creates it if it doesn't exist
|
||||
AddTagOrCreateDirect(jobId int64, tagType string, tagName string) (tagId int64, err error)
|
||||
// UpdateMetadata updates job metadata with a key-value pair
|
||||
UpdateMetadata(job *schema.Job, key, val string) (err error)
|
||||
}
|
||||
|
||||
// JobClassTagger classifies jobs based on configurable rules that evaluate job metrics and properties.
|
||||
// Rules are loaded from embedded JSON files and can be dynamically reloaded from a watched directory.
|
||||
// When a job matches a rule, it is tagged with the corresponding classification and an optional hint message.
|
||||
type JobClassTagger struct {
|
||||
rules map[string]ruleInfo
|
||||
parameters map[string]any
|
||||
tagType string
|
||||
cfgPath string
|
||||
// rules maps classification tags to their compiled rule information
|
||||
rules map[string]ruleInfo
|
||||
// parameters are shared values (e.g., thresholds) used across multiple rules
|
||||
parameters map[string]any
|
||||
// tagType is the type of tag ("jobClass")
|
||||
tagType string
|
||||
// cfgPath is the path to watch for configuration changes
|
||||
cfgPath string
|
||||
// repo provides access to job database operations
|
||||
repo JobRepository
|
||||
// getStatistics retrieves job statistics for analysis
|
||||
getStatistics func(job *schema.Job) (map[string]schema.JobStatistics, error)
|
||||
// getMetricConfig retrieves metric configuration (limits) for a cluster
|
||||
getMetricConfig func(cluster, subCluster string) map[string]*schema.Metric
|
||||
}
|
||||
|
||||
func (t *JobClassTagger) prepareRule(b []byte, fns string) {
|
||||
@@ -127,10 +166,14 @@ func (t *JobClassTagger) prepareRule(b []byte, fns string) {
|
||||
t.rules[rule.Tag] = ri
|
||||
}
|
||||
|
||||
// EventMatch checks if a filesystem event should trigger configuration reload.
|
||||
// It returns true if the event path contains "jobclasses".
|
||||
func (t *JobClassTagger) EventMatch(s string) bool {
|
||||
return strings.Contains(s, "jobclasses")
|
||||
}
|
||||
|
||||
// EventCallback is triggered when the configuration directory changes.
|
||||
// It reloads parameters and all rule files from the watched directory.
|
||||
// FIXME: Only process the file that caused the event
|
||||
func (t *JobClassTagger) EventCallback() {
|
||||
files, err := os.ReadDir(t.cfgPath)
|
||||
@@ -170,7 +213,7 @@ func (t *JobClassTagger) EventCallback() {
|
||||
|
||||
func (t *JobClassTagger) initParameters() error {
|
||||
cclog.Info("Initialize parameters")
|
||||
b, err := jobclassFiles.ReadFile("jobclasses/parameters.json")
|
||||
b, err := jobClassFiles.ReadFile("jobclasses/parameters.json")
|
||||
if err != nil {
|
||||
cclog.Warnf("prepareRule() > open file error: %v", err)
|
||||
return err
|
||||
@@ -184,6 +227,10 @@ func (t *JobClassTagger) initParameters() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register initializes the JobClassTagger by loading parameters and classification rules.
|
||||
// It loads embedded configuration files and sets up a file watch on ./var/tagger/jobclasses
|
||||
// if it exists, allowing for dynamic configuration updates without restarting the application.
|
||||
// Returns an error if the embedded configuration files cannot be read or parsed.
|
||||
func (t *JobClassTagger) Register() error {
|
||||
t.cfgPath = "./var/tagger/jobclasses"
|
||||
t.tagType = "jobClass"
|
||||
@@ -194,18 +241,18 @@ func (t *JobClassTagger) Register() error {
|
||||
return err
|
||||
}
|
||||
|
||||
files, err := jobclassFiles.ReadDir("jobclasses")
|
||||
files, err := jobClassFiles.ReadDir("jobclasses")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading app folder: %#v", err)
|
||||
}
|
||||
t.rules = make(map[string]ruleInfo, 0)
|
||||
t.rules = make(map[string]ruleInfo)
|
||||
for _, fn := range files {
|
||||
fns := fn.Name()
|
||||
if fns != "parameters.json" {
|
||||
filename := fmt.Sprintf("jobclasses/%s", fns)
|
||||
cclog.Infof("Process: %s", fns)
|
||||
|
||||
b, err := jobclassFiles.ReadFile(filename)
|
||||
b, err := jobClassFiles.ReadFile(filename)
|
||||
if err != nil {
|
||||
cclog.Warnf("prepareRule() > open file error: %v", err)
|
||||
return err
|
||||
@@ -220,13 +267,30 @@ func (t *JobClassTagger) Register() error {
|
||||
util.AddListener(t.cfgPath, t)
|
||||
}
|
||||
|
||||
t.repo = repository.GetJobRepository()
|
||||
t.getStatistics = archive.GetStatistics
|
||||
t.getMetricConfig = archive.GetMetricConfigSubCluster
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Match evaluates all classification rules against a job and applies matching tags.
|
||||
// It retrieves job statistics and metric configurations, then tests each rule's requirements
|
||||
// and main expression. For each matching rule, it:
|
||||
// - Applies the classification tag to the job
|
||||
// - Generates and stores a hint message based on the rule's template
|
||||
//
|
||||
// The function constructs an evaluation environment containing:
|
||||
// - Job properties (duration, cores, nodes, state, etc.)
|
||||
// - Metric statistics (min, max, avg) and their configured limits
|
||||
// - Shared parameters defined in parameters.json
|
||||
// - Computed variables from the rule definition
|
||||
//
|
||||
// Rules are evaluated in arbitrary order. If multiple rules match, only the first
|
||||
// encountered match is applied (FIXME: this should handle multiple matches).
|
||||
func (t *JobClassTagger) Match(job *schema.Job) {
|
||||
r := repository.GetJobRepository()
|
||||
jobstats, err := archive.GetStatistics(job)
|
||||
metricsList := archive.GetMetricConfigSubCluster(job.Cluster, job.SubCluster)
|
||||
jobStats, err := t.getStatistics(job)
|
||||
metricsList := t.getMetricConfig(job.Cluster, job.SubCluster)
|
||||
cclog.Infof("Enter match rule with %d rules for job %d", len(t.rules), job.JobID)
|
||||
if err != nil {
|
||||
cclog.Errorf("job classification failed for job %d: %#v", job.JobID, err)
|
||||
@@ -251,7 +315,7 @@ func (t *JobClassTagger) Match(job *schema.Job) {
|
||||
|
||||
// add metrics to env
|
||||
for _, m := range ri.metrics {
|
||||
stats, ok := jobstats[m]
|
||||
stats, ok := jobStats[m]
|
||||
if !ok {
|
||||
cclog.Errorf("job classification failed for job %d: missing metric '%s'", job.JobID, m)
|
||||
return
|
||||
@@ -302,8 +366,11 @@ func (t *JobClassTagger) Match(job *schema.Job) {
|
||||
if match.(bool) {
|
||||
cclog.Info("Rule matches!")
|
||||
id := *job.ID
|
||||
if !r.HasTag(id, t.tagType, tag) {
|
||||
r.AddTagOrCreateDirect(id, t.tagType, tag)
|
||||
if !t.repo.HasTag(id, t.tagType, tag) {
|
||||
_, err := t.repo.AddTagOrCreateDirect(id, t.tagType, tag)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// process hint template
|
||||
@@ -314,7 +381,11 @@ func (t *JobClassTagger) Match(job *schema.Job) {
|
||||
}
|
||||
|
||||
// FIXME: Handle case where multiple tags apply
|
||||
r.UpdateMetadata(job, "message", msg.String())
|
||||
// FIXME: Handle case where multiple tags apply
|
||||
err = t.repo.UpdateMetadata(job, "message", msg.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
cclog.Info("Rule does not match!")
|
||||
}
|
||||
|
||||
162
internal/tagger/classifyJob_test.go
Normal file
162
internal/tagger/classifyJob_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package tagger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockJobRepository is a mock implementation of the JobRepository interface
|
||||
type MockJobRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockJobRepository) HasTag(jobId int64, tagType string, tagName string) bool {
|
||||
args := m.Called(jobId, tagType, tagName)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockJobRepository) AddTagOrCreateDirect(jobId int64, tagType string, tagName string) (tagId int64, err error) {
|
||||
args := m.Called(jobId, tagType, tagName)
|
||||
return args.Get(0).(int64), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockJobRepository) UpdateMetadata(job *schema.Job, key, val string) (err error) {
|
||||
args := m.Called(job, key, val)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestPrepareRule(t *testing.T) {
|
||||
tagger := &JobClassTagger{
|
||||
rules: make(map[string]ruleInfo),
|
||||
parameters: make(map[string]any),
|
||||
}
|
||||
|
||||
// Valid rule JSON
|
||||
validRule := []byte(`{
|
||||
"name": "Test Rule",
|
||||
"tag": "test_tag",
|
||||
"parameters": [],
|
||||
"metrics": ["flops_any"],
|
||||
"requirements": ["job.numNodes > 1"],
|
||||
"variables": [{"name": "avg_flops", "expr": "flops_any.avg"}],
|
||||
"rule": "avg_flops > 100",
|
||||
"hint": "High FLOPS"
|
||||
}`)
|
||||
|
||||
tagger.prepareRule(validRule, "test_rule.json")
|
||||
|
||||
assert.Contains(t, tagger.rules, "test_tag")
|
||||
rule := tagger.rules["test_tag"]
|
||||
assert.Equal(t, 1, len(rule.metrics))
|
||||
assert.Equal(t, 1, len(rule.requirements))
|
||||
assert.Equal(t, 1, len(rule.variables))
|
||||
assert.NotNil(t, rule.rule)
|
||||
assert.NotNil(t, rule.hint)
|
||||
}
|
||||
|
||||
func TestClassifyJobMatch(t *testing.T) {
|
||||
mockRepo := new(MockJobRepository)
|
||||
tagger := &JobClassTagger{
|
||||
rules: make(map[string]ruleInfo),
|
||||
parameters: make(map[string]any),
|
||||
tagType: "jobClass",
|
||||
repo: mockRepo,
|
||||
getStatistics: func(job *schema.Job) (map[string]schema.JobStatistics, error) {
|
||||
return map[string]schema.JobStatistics{
|
||||
"flops_any": {Min: 0, Max: 200, Avg: 150},
|
||||
}, nil
|
||||
},
|
||||
getMetricConfig: func(cluster, subCluster string) map[string]*schema.Metric {
|
||||
return map[string]*schema.Metric{
|
||||
"flops_any": {Peak: 1000, Normal: 100, Caution: 50, Alert: 10},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Add a rule manually or via prepareRule
|
||||
validRule := []byte(`{
|
||||
"name": "Test Rule",
|
||||
"tag": "high_flops",
|
||||
"parameters": [],
|
||||
"metrics": ["flops_any"],
|
||||
"requirements": [],
|
||||
"variables": [{"name": "avg_flops", "expr": "flops_any.avg"}],
|
||||
"rule": "avg_flops > 100",
|
||||
"hint": "High FLOPS: {{.avg_flops}}"
|
||||
}`)
|
||||
tagger.prepareRule(validRule, "test_rule.json")
|
||||
|
||||
jobID := int64(123)
|
||||
job := &schema.Job{
|
||||
ID: &jobID,
|
||||
JobID: 123,
|
||||
Cluster: "test_cluster",
|
||||
SubCluster: "test_subcluster",
|
||||
NumNodes: 2,
|
||||
NumHWThreads: 4,
|
||||
State: schema.JobStateCompleted,
|
||||
}
|
||||
|
||||
// Expectation: Rule matches
|
||||
// 1. Check if tag exists (return false)
|
||||
mockRepo.On("HasTag", jobID, "jobClass", "high_flops").Return(false)
|
||||
// 2. Add tag
|
||||
mockRepo.On("AddTagOrCreateDirect", jobID, "jobClass", "high_flops").Return(int64(1), nil)
|
||||
// 3. Update metadata
|
||||
mockRepo.On("UpdateMetadata", job, "message", mock.Anything).Return(nil)
|
||||
|
||||
tagger.Match(job)
|
||||
|
||||
mockRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMatch_NoMatch(t *testing.T) {
|
||||
mockRepo := new(MockJobRepository)
|
||||
tagger := &JobClassTagger{
|
||||
rules: make(map[string]ruleInfo),
|
||||
parameters: make(map[string]any),
|
||||
tagType: "jobClass",
|
||||
repo: mockRepo,
|
||||
getStatistics: func(job *schema.Job) (map[string]schema.JobStatistics, error) {
|
||||
return map[string]schema.JobStatistics{
|
||||
"flops_any": {Min: 0, Max: 50, Avg: 20}, // Avg 20 < 100
|
||||
}, nil
|
||||
},
|
||||
getMetricConfig: func(cluster, subCluster string) map[string]*schema.Metric {
|
||||
return map[string]*schema.Metric{
|
||||
"flops_any": {Peak: 1000, Normal: 100, Caution: 50, Alert: 10},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
validRule := []byte(`{
|
||||
"name": "Test Rule",
|
||||
"tag": "high_flops",
|
||||
"parameters": [],
|
||||
"metrics": ["flops_any"],
|
||||
"requirements": [],
|
||||
"variables": [{"name": "avg_flops", "expr": "flops_any.avg"}],
|
||||
"rule": "avg_flops > 100",
|
||||
"hint": "High FLOPS"
|
||||
}`)
|
||||
tagger.prepareRule(validRule, "test_rule.json")
|
||||
|
||||
jobID := int64(123)
|
||||
job := &schema.Job{
|
||||
ID: &jobID,
|
||||
JobID: 123,
|
||||
Cluster: "test_cluster",
|
||||
SubCluster: "test_subcluster",
|
||||
NumNodes: 2,
|
||||
NumHWThreads: 4,
|
||||
State: schema.JobStateCompleted,
|
||||
}
|
||||
|
||||
// Expectation: Rule does NOT match, so no repo calls
|
||||
tagger.Match(job)
|
||||
|
||||
mockRepo.AssertExpectations(t)
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tagger
|
||||
|
||||
import (
|
||||
@@ -28,9 +29,16 @@ type appInfo struct {
|
||||
strings []string
|
||||
}
|
||||
|
||||
// AppTagger detects applications by matching patterns in job scripts.
|
||||
// It loads application patterns from embedded files and can dynamically reload
|
||||
// configuration from a watched directory. When a job script matches a pattern,
|
||||
// the corresponding application tag is automatically applied.
|
||||
type AppTagger struct {
|
||||
// apps maps application tags to their matching patterns
|
||||
apps map[string]appInfo
|
||||
// tagType is the type of tag ("app")
|
||||
tagType string
|
||||
// cfgPath is the path to watch for configuration changes
|
||||
cfgPath string
|
||||
}
|
||||
|
||||
@@ -45,10 +53,14 @@ func (t *AppTagger) scanApp(f fs.File, fns string) {
|
||||
t.apps[ai.tag] = ai
|
||||
}
|
||||
|
||||
// EventMatch checks if a filesystem event should trigger configuration reload.
|
||||
// It returns true if the event path contains "apps".
|
||||
func (t *AppTagger) EventMatch(s string) bool {
|
||||
return strings.Contains(s, "apps")
|
||||
}
|
||||
|
||||
// EventCallback is triggered when the configuration directory changes.
|
||||
// It reloads all application pattern files from the watched directory.
|
||||
// FIXME: Only process the file that caused the event
|
||||
func (t *AppTagger) EventCallback() {
|
||||
files, err := os.ReadDir(t.cfgPath)
|
||||
@@ -67,6 +79,10 @@ func (t *AppTagger) EventCallback() {
|
||||
}
|
||||
}
|
||||
|
||||
// Register initializes the AppTagger by loading application patterns from embedded files.
|
||||
// It also sets up a file watch on ./var/tagger/apps if it exists, allowing for
|
||||
// dynamic configuration updates without restarting the application.
|
||||
// Returns an error if the embedded application files cannot be read.
|
||||
func (t *AppTagger) Register() error {
|
||||
t.cfgPath = "./var/tagger/apps"
|
||||
t.tagType = "app"
|
||||
@@ -96,6 +112,11 @@ func (t *AppTagger) Register() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Match attempts to detect the application used by a job by analyzing its job script.
|
||||
// It fetches the job metadata, extracts the job script, and matches it against
|
||||
// all configured application patterns using regular expressions.
|
||||
// If a match is found, the corresponding application tag is added to the job.
|
||||
// Only the first matching application is tagged.
|
||||
func (t *AppTagger) Match(job *schema.Job) {
|
||||
r := repository.GetJobRepository()
|
||||
metadata, err := r.FetchMetadata(job)
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tagger provides automatic job tagging functionality for cc-backend.
|
||||
// It supports detecting applications and classifying jobs based on configurable rules.
|
||||
// Tags are automatically applied when jobs start or stop, or can be applied retroactively
|
||||
// to existing jobs using RunTaggers.
|
||||
package tagger
|
||||
|
||||
import (
|
||||
@@ -12,8 +17,15 @@ import (
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
)
|
||||
|
||||
// Tagger is the interface that must be implemented by all tagging components.
|
||||
// Taggers can be registered at job start or stop events to automatically apply tags.
|
||||
type Tagger interface {
|
||||
// Register initializes the tagger and loads any required configuration.
|
||||
// It should be called once before the tagger is used.
|
||||
Register() error
|
||||
|
||||
// Match evaluates the tagger's rules against a job and applies appropriate tags.
|
||||
// It is called for each job that needs to be evaluated.
|
||||
Match(job *schema.Job)
|
||||
}
|
||||
|
||||
@@ -22,8 +34,12 @@ var (
|
||||
jobTagger *JobTagger
|
||||
)
|
||||
|
||||
// JobTagger coordinates multiple taggers that run at different job lifecycle events.
|
||||
// It maintains separate lists of taggers that run when jobs start and when they stop.
|
||||
type JobTagger struct {
|
||||
// startTaggers are applied when a job starts (e.g., application detection)
|
||||
startTaggers []Tagger
|
||||
// stopTaggers are applied when a job completes (e.g., job classification)
|
||||
stopTaggers []Tagger
|
||||
}
|
||||
|
||||
@@ -42,6 +58,9 @@ func newTagger() {
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the job tagger system and registers it with the job repository.
|
||||
// This function is safe to call multiple times; initialization only occurs once.
|
||||
// It should be called during application startup.
|
||||
func Init() {
|
||||
initOnce.Do(func() {
|
||||
newTagger()
|
||||
@@ -49,22 +68,30 @@ func Init() {
|
||||
})
|
||||
}
|
||||
|
||||
// JobStartCallback is called when a job starts.
|
||||
// It runs all registered start taggers (e.g., application detection) on the job.
|
||||
func (jt *JobTagger) JobStartCallback(job *schema.Job) {
|
||||
for _, tagger := range jt.startTaggers {
|
||||
tagger.Match(job)
|
||||
}
|
||||
}
|
||||
|
||||
// JobStopCallback is called when a job completes.
|
||||
// It runs all registered stop taggers (e.g., job classification) on the job.
|
||||
func (jt *JobTagger) JobStopCallback(job *schema.Job) {
|
||||
for _, tagger := range jt.stopTaggers {
|
||||
tagger.Match(job)
|
||||
}
|
||||
}
|
||||
|
||||
// RunTaggers applies all configured taggers to all existing jobs in the repository.
|
||||
// This is useful for retroactively applying tags to jobs that were created before
|
||||
// the tagger system was initialized or when new tagging rules are added.
|
||||
// It fetches all jobs and runs both start and stop taggers on each one.
|
||||
func RunTaggers() error {
|
||||
newTagger()
|
||||
r := repository.GetJobRepository()
|
||||
jl, err := r.GetJobList()
|
||||
jl, err := r.GetJobList(0, 0) // 0 limit means get all jobs (no pagination)
|
||||
if err != nil {
|
||||
cclog.Errorf("Error while getting job list %s", err)
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user