diff --git a/internal/repository/tags.go b/internal/repository/tags.go index 588a98d..587a4e1 100644 --- a/internal/repository/tags.go +++ b/internal/repository/tags.go @@ -7,9 +7,10 @@ package repository import ( "strings" + "github.com/ClusterCockpit/cc-backend/internal/auth" "github.com/ClusterCockpit/cc-backend/pkg/archive" - "github.com/ClusterCockpit/cc-backend/pkg/schema" "github.com/ClusterCockpit/cc-backend/pkg/log" + "github.com/ClusterCockpit/cc-backend/pkg/schema" sq "github.com/Masterminds/squirrel" ) @@ -67,7 +68,7 @@ func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, return res.LastInsertId() } -func (r *JobRepository) CountTags(user *string, projects *[]string) (tags []schema.Tag, counts map[string]int, err error) { +func (r *JobRepository) CountTags(user *auth.User) (tags []schema.Tag, counts map[string]int, err error) { tags = make([]schema.Tag, 0, 100) xrows, err := r.DB.Queryx("SELECT * FROM tag") if err != nil { @@ -87,11 +88,11 @@ func (r *JobRepository) CountTags(user *string, projects *[]string) (tags []sche LeftJoin("jobtag jt ON t.id = jt.tag_id"). GroupBy("t.tag_name") - if user != nil && len(*projects) == 0 { // USER: Only count own jobs - q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user) - } else if user != nil && len(*projects) != 0 { // MANAGER: Count own jobs plus project's jobs + if user.HasRole(auth.RoleUser) { // USER: Only count own jobs + q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", user.Username) + } else if user.HasRole(auth.RoleManager) { // MANAGER: Count own jobs plus project's jobs // Build ("project1", "project2", ...) list of variable length directly in SQL string - q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ? OR job.project IN (\""+strings.Join(*projects, "\",\"")+"\"))", *user) + q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ? OR job.project IN (\""+strings.Join(user.Projects, "\",\"")+"\"))", user.Username) } // else: ADMIN || SUPPORT: Count all jobs rows, err := q.RunWith(r.stmtCache).Query() diff --git a/internal/routerConfig/routes.go b/internal/routerConfig/routes.go index 009f348..4f4e47d 100644 --- a/internal/routerConfig/routes.go +++ b/internal/routerConfig/routes.go @@ -142,20 +142,10 @@ func setupAnalysisRoute(i InfoType, r *http.Request) InfoType { } func setupTaglistRoute(i InfoType, r *http.Request) InfoType { - var username *string = nil - var projects *[]string - jobRepo := repository.GetJobRepository() user := auth.GetUser(r.Context()) - if user != nil && user.HasNotRoles([]string{auth.RoleAdmin, auth.RoleSupport, auth.RoleManager}) { - username = &user.Username - } else if user != nil && user.HasRole(auth.RoleManager) { - username = &user.Username - projects = &user.Projects - } // ADMINS && SUPPORT w/o additional conditions - - tags, counts, err := jobRepo.CountTags(username, projects) + tags, counts, err := jobRepo.CountTags(user) tagMap := make(map[string][]map[string]interface{}) if err != nil { log.Warnf("GetTags failed: %s", err.Error())