From b85832646fea2f53441adf9802bc9160e1d5e8b1 Mon Sep 17 00:00:00 2001 From: Lou Knauer Date: Wed, 16 Feb 2022 09:32:38 +0100 Subject: [PATCH] Show correct count of jobs/tags --- graph/resolver.go | 1 + repository/job.go | 10 ++++++---- repository/job_test.go | 2 +- server.go | 12 +++++++++++- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/graph/resolver.go b/graph/resolver.go index 5f125fe..d4af150 100644 --- a/graph/resolver.go +++ b/graph/resolver.go @@ -102,6 +102,7 @@ func (r *Resolver) queryJobs(ctx context.Context, filters []*model.JobFilter, pa for _, f := range filters { query = buildWhereClause(f, query) } + query = securityCheck(ctx, query) var count int if err := query.RunWith(r.DB).Scan(&count); err != nil { return nil, 0, err diff --git a/repository/job.go b/repository/job.go index d38a133..8e61025 100644 --- a/repository/job.go +++ b/repository/job.go @@ -154,7 +154,7 @@ func (r *JobRepository) CreateTag(tagType string, tagName string) (tagId int64, return res.LastInsertId() } -func (r *JobRepository) GetTags() (tags []schema.Tag, counts map[string]int, err error) { +func (r *JobRepository) GetTags(user *string) (tags []schema.Tag, counts map[string]int, err error) { tags = make([]schema.Tag, 0, 100) xrows, err := r.DB.Queryx("SELECT * FROM tag") if err != nil { @@ -173,11 +173,13 @@ func (r *JobRepository) GetTags() (tags []schema.Tag, counts map[string]int, err From("tag t"). LeftJoin("jobtag jt ON t.id = jt.tag_id"). GroupBy("t.tag_name") + if user != nil { + q = q.Where("jt.job_id IN (SELECT id FROM job WHERE job.user = ?)", *user) + } - qs, _, _ := q.ToSql() - rows, err := r.DB.Query(qs) + rows, err := q.RunWith(r.DB).Query() if err != nil { - fmt.Println(err) + return nil, nil, err } counts = make(map[string]int) diff --git a/repository/job_test.go b/repository/job_test.go index 5ea43ca..8a43617 100644 --- a/repository/job_test.go +++ b/repository/job_test.go @@ -57,7 +57,7 @@ func TestFindById(t *testing.T) { func TestGetTags(t *testing.T) { r := setup(t) - tags, counts, err := r.GetTags() + tags, counts, err := r.GetTags(nil) if err != nil { t.Fatal(err) } diff --git a/server.go b/server.go index af581f3..da4fe92 100644 --- a/server.go +++ b/server.go @@ -154,8 +154,18 @@ func setupAnalysisRoute(i InfoType, r *http.Request) InfoType { } func setupTaglistRoute(i InfoType, r *http.Request) InfoType { - tags, counts, _ := jobRepo.GetTags() + var username *string = nil + if user := auth.GetUser(r.Context()); user != nil && !user.HasRole(auth.RoleAdmin) { + username = &user.Username + } + + tags, counts, err := jobRepo.GetTags(username) tagMap := make(map[string][]map[string]interface{}) + if err != nil { + log.Errorf("GetTags failed: %s", err.Error()) + i["tagmap"] = tagMap + return i + } for _, tag := range tags { tagItem := map[string]interface{}{