diff --git a/internal/api/rest.go b/internal/api/rest.go index b374aa9..077a705 100644 --- a/internal/api/rest.go +++ b/internal/api/rest.go @@ -595,7 +595,7 @@ func (api *RestApi) updateConfiguration(rw http.ResponseWriter, r *http.Request) fmt.Printf("KEY: %#v\nVALUE: %#v\n", key, value) - if err := repository.GetUserCfgRepo().UpdateConfig(key, value, r.Context()); err != nil { + if err := repository.GetUserCfgRepo().UpdateConfig(key, value, auth.GetUser(r.Context())); err != nil { http.Error(rw, err.Error(), http.StatusUnprocessableEntity) return } diff --git a/internal/repository/user.go b/internal/repository/user.go index c0b7f9b..2e036d9 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -5,14 +5,13 @@ package repository import ( - "context" "encoding/json" "log" - "net/http" "sync" "time" "github.com/ClusterCockpit/cc-backend/internal/auth" + "github.com/ClusterCockpit/cc-backend/internal/config" "github.com/ClusterCockpit/cc-backend/pkg/lrucache" "github.com/jmoiron/sqlx" ) @@ -52,9 +51,10 @@ func GetUserCfgRepo() *UserCfgRepo { } userCfgRepoInstance = &UserCfgRepo{ - DB: db.DB, - Lookup: lookupConfigStmt, - cache: lrucache.New(1024), + DB: db.DB, + Lookup: lookupConfigStmt, + uiDefaults: config.Keys.UiDefaults, + cache: lrucache.New(1024), } }) @@ -63,8 +63,7 @@ func GetUserCfgRepo() *UserCfgRepo { // Return the personalised UI config for the currently authenticated // user or return the plain default config. -func (uCfg *UserCfgRepo) GetUIConfig(r *http.Request) (map[string]interface{}, error) { - user := auth.GetUser(r.Context()) +func (uCfg *UserCfgRepo) GetUIConfig(user *auth.User) (map[string]interface{}, error) { if user == nil { uCfg.lock.RLock() copy := make(map[string]interface{}, len(uCfg.uiDefaults)) @@ -116,8 +115,10 @@ func (uCfg *UserCfgRepo) GetUIConfig(r *http.Request) (map[string]interface{}, e // If the context does not have a user, update the global ui configuration // without persisting it! If there is a (authenticated) user, update only his // configuration. -func (uCfg *UserCfgRepo) UpdateConfig(key, value string, ctx context.Context) error { - user := auth.GetUser(ctx) +func (uCfg *UserCfgRepo) UpdateConfig( + key, value string, + user *auth.User) error { + if user == nil { var val interface{} if err := json.Unmarshal([]byte(value), &val); err != nil { diff --git a/internal/repository/user_test.go b/internal/repository/user_test.go new file mode 100644 index 0000000..e43090a --- /dev/null +++ b/internal/repository/user_test.go @@ -0,0 +1,59 @@ +// Copyright (C) 2022 NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package repository + +import ( + "os" + "path/filepath" + "testing" + + "github.com/ClusterCockpit/cc-backend/internal/auth" + "github.com/ClusterCockpit/cc-backend/internal/config" +) + +func init() { + Connect("sqlite3", "../../test/test.db") +} + +func setupUserTest(t *testing.T) *UserCfgRepo { + const testconfig = `{ + "addr": "0.0.0.0:8080", + "archive": { + "kind": "file", + "path": "./var/job-archive" + }, + "clusters": [ + { + "name": "testcluster", + "metricDataRepository": {"kind": "test"} + } + ] +}` + tmpdir := t.TempDir() + cfgFilePath := filepath.Join(tmpdir, "config.json") + if err := os.WriteFile(cfgFilePath, []byte(testconfig), 0666); err != nil { + t.Fatal(err) + } + + config.Init(cfgFilePath) + return GetUserCfgRepo() +} +func TestGetUIConfig(t *testing.T) { + r := setupUserTest(t) + u := auth.User{Username: "jan"} + + cfg, err := r.GetUIConfig(&u) + if err != nil { + t.Fatal("No config") + } + + tmp := cfg["plot_list_selectedMetrics"] + metrics := tmp.([]interface{}) + + str := metrics[2].(string) + if str != "mem_bw" { + t.Errorf("wrong config\ngot: %s \nwant: mem_bw", str) + } +} diff --git a/internal/routerConfig/routes.go b/internal/routerConfig/routes.go index 669aaf6..9a4557a 100644 --- a/internal/routerConfig/routes.go +++ b/internal/routerConfig/routes.go @@ -258,7 +258,7 @@ func SetupRoutes(router *mux.Router) { for _, route := range routes { route := route router.HandleFunc(route.Route, func(rw http.ResponseWriter, r *http.Request) { - conf, err := userCfgRepo.GetUIConfig(r) + conf, err := userCfgRepo.GetUIConfig(auth.GetUser(r.Context())) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return