From b9d7189cb91630b14975be9f6c550c9e8ca440b6 Mon Sep 17 00:00:00 2001 From: Michael Reimsbach Date: Wed, 7 Aug 2024 13:40:13 +0200 Subject: [PATCH] feat(issue): add issueType count (#119) Co-authored-by: David Rochow --- .../api/graphql/graph/baseResolver/issue.go | 36 +++- internal/api/graphql/graph/generated.go | 195 +++++++++++++++++- .../api/graphql/graph/model/models_gen.go | 10 +- .../issue/directRelations.graphql | 3 + .../api/graphql/graph/schema/issue.graphqls | 3 + internal/app/interface.go | 2 +- internal/app/issue.go | 34 +-- internal/app/issue_test.go | 34 ++- internal/database/interface.go | 1 + internal/database/mariadb/entity.go | 13 ++ internal/database/mariadb/issue.go | 48 +++++ internal/database/mariadb/issue_test.go | 35 ++++ internal/e2e/issue_query_test.go | 16 +- internal/entity/common.go | 3 +- internal/entity/issue.go | 31 ++- internal/mocks/mock_Database.go | 58 ++++++ internal/mocks/mock_Heureka.go | 118 +++++++++++ 17 files changed, 594 insertions(+), 46 deletions(-) diff --git a/internal/api/graphql/graph/baseResolver/issue.go b/internal/api/graphql/graph/baseResolver/issue.go index d5e2c870..4073bb39 100644 --- a/internal/api/graphql/graph/baseResolver/issue.go +++ b/internal/api/graphql/graph/baseResolver/issue.go @@ -14,6 +14,14 @@ import ( "k8s.io/utils/pointer" ) +func GetIssueListOptions(requestedFields []string) *entity.IssueListOptions { + listOptions := GetListOptions(requestedFields) + return &entity.IssueListOptions{ + ListOptions: *listOptions, + ShowIssueTypeCounts: lo.Contains(requestedFields, "issueTypeCounts"), + } +} + func SingleIssueBaseResolver(app app.Heureka, ctx context.Context, parent *model.NodeParent) (*model.Issue, error) { requestedFields := GetPreloads(ctx) logrus.WithFields(logrus.Fields{ @@ -29,7 +37,7 @@ func SingleIssueBaseResolver(app app.Heureka, ctx context.Context, parent *model Id: parent.ChildIds, } - opt := &entity.ListOptions{} + opt := &entity.IssueListOptions{} issues, err := app.ListIssues(f, opt) @@ -102,7 +110,7 @@ func IssueBaseResolver(app app.Heureka, ctx context.Context, filter *model.Issue IssueMatchTargetRemediationDate: nil, //@todo Implement } - opt := GetListOptions(requestedFields) + opt := GetIssueListOptions(requestedFields) issues, err := app.ListIssues(f, opt) @@ -121,17 +129,29 @@ func IssueBaseResolver(app app.Heureka, ctx context.Context, filter *model.Issue edges = append(edges, &edge) } - tc := 0 + totalCount := 0 if issues.TotalCount != nil { - tc = int(*issues.TotalCount) + totalCount = int(*issues.TotalCount) + } + + vulnerabilityCount := 0 + policiyViolationCount := 0 + securityEventCount := 0 + + if issues.VulnerabilityCount != nil && issues.PolicyViolationCount != nil && issues.SecurityEventCount != nil { + vulnerabilityCount = int(*issues.VulnerabilityCount) + policiyViolationCount = int(*issues.PolicyViolationCount) + securityEventCount = int(*issues.SecurityEventCount) } connection := model.IssueConnection{ - TotalCount: tc, - Edges: edges, - PageInfo: model.NewPageInfo(issues.PageInfo), + TotalCount: totalCount, + VulnerabilityCount: vulnerabilityCount, + PolicyViolationCount: policiyViolationCount, + SecurityEventCount: securityEventCount, + Edges: edges, + PageInfo: model.NewPageInfo(issues.PageInfo), } return &connection, nil - } diff --git a/internal/api/graphql/graph/generated.go b/internal/api/graphql/graph/generated.go index 0e0dfc25..bcfa65fe 100644 --- a/internal/api/graphql/graph/generated.go +++ b/internal/api/graphql/graph/generated.go @@ -229,9 +229,12 @@ type ComplexityRoot struct { } IssueConnection struct { - Edges func(childComplexity int) int - PageInfo func(childComplexity int) int - TotalCount func(childComplexity int) int + Edges func(childComplexity int) int + PageInfo func(childComplexity int) int + PolicyViolationCount func(childComplexity int) int + SecurityEventCount func(childComplexity int) int + TotalCount func(childComplexity int) int + VulnerabilityCount func(childComplexity int) int } IssueEdge struct { @@ -1461,6 +1464,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.IssueConnection.PageInfo(childComplexity), true + case "IssueConnection.policyViolationCount": + if e.complexity.IssueConnection.PolicyViolationCount == nil { + break + } + + return e.complexity.IssueConnection.PolicyViolationCount(childComplexity), true + + case "IssueConnection.securityEventCount": + if e.complexity.IssueConnection.SecurityEventCount == nil { + break + } + + return e.complexity.IssueConnection.SecurityEventCount(childComplexity), true + case "IssueConnection.totalCount": if e.complexity.IssueConnection.TotalCount == nil { break @@ -1468,6 +1485,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.IssueConnection.TotalCount(childComplexity), true + case "IssueConnection.vulnerabilityCount": + if e.complexity.IssueConnection.VulnerabilityCount == nil { + break + } + + return e.complexity.IssueConnection.VulnerabilityCount(childComplexity), true + case "IssueEdge.cursor": if e.complexity.IssueEdge.Cursor == nil { break @@ -5980,6 +6004,12 @@ func (ec *executionContext) fieldContext_Activity_issues(ctx context.Context, fi switch field.Name { case "totalCount": return ec.fieldContext_IssueConnection_totalCount(ctx, field) + case "vulnerabilityCount": + return ec.fieldContext_IssueConnection_vulnerabilityCount(ctx, field) + case "policyViolationCount": + return ec.fieldContext_IssueConnection_policyViolationCount(ctx, field) + case "securityEventCount": + return ec.fieldContext_IssueConnection_securityEventCount(ctx, field) case "edges": return ec.fieldContext_IssueConnection_edges(ctx, field) case "pageInfo": @@ -9054,6 +9084,12 @@ func (ec *executionContext) fieldContext_ComponentVersion_issues(ctx context.Con switch field.Name { case "totalCount": return ec.fieldContext_IssueConnection_totalCount(ctx, field) + case "vulnerabilityCount": + return ec.fieldContext_IssueConnection_vulnerabilityCount(ctx, field) + case "policyViolationCount": + return ec.fieldContext_IssueConnection_policyViolationCount(ctx, field) + case "securityEventCount": + return ec.fieldContext_IssueConnection_securityEventCount(ctx, field) case "edges": return ec.fieldContext_IssueConnection_edges(ctx, field) case "pageInfo": @@ -10646,6 +10682,138 @@ func (ec *executionContext) fieldContext_IssueConnection_totalCount(_ context.Co return fc, nil } +func (ec *executionContext) _IssueConnection_vulnerabilityCount(ctx context.Context, field graphql.CollectedField, obj *model.IssueConnection) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_IssueConnection_vulnerabilityCount(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.VulnerabilityCount, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(int) + fc.Result = res + return ec.marshalNInt2int(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_IssueConnection_vulnerabilityCount(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "IssueConnection", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Int does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _IssueConnection_policyViolationCount(ctx context.Context, field graphql.CollectedField, obj *model.IssueConnection) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_IssueConnection_policyViolationCount(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.PolicyViolationCount, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(int) + fc.Result = res + return ec.marshalNInt2int(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_IssueConnection_policyViolationCount(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "IssueConnection", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Int does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _IssueConnection_securityEventCount(ctx context.Context, field graphql.CollectedField, obj *model.IssueConnection) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_IssueConnection_securityEventCount(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.SecurityEventCount, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(int) + fc.Result = res + return ec.marshalNInt2int(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_IssueConnection_securityEventCount(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "IssueConnection", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Int does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _IssueConnection_edges(ctx context.Context, field graphql.CollectedField, obj *model.IssueConnection) (ret graphql.Marshaler) { fc, err := ec.fieldContext_IssueConnection_edges(ctx, field) if err != nil { @@ -18431,6 +18599,12 @@ func (ec *executionContext) fieldContext_Query_Issues(ctx context.Context, field switch field.Name { case "totalCount": return ec.fieldContext_IssueConnection_totalCount(ctx, field) + case "vulnerabilityCount": + return ec.fieldContext_IssueConnection_vulnerabilityCount(ctx, field) + case "policyViolationCount": + return ec.fieldContext_IssueConnection_policyViolationCount(ctx, field) + case "securityEventCount": + return ec.fieldContext_IssueConnection_securityEventCount(ctx, field) case "edges": return ec.fieldContext_IssueConnection_edges(ctx, field) case "pageInfo": @@ -25789,6 +25963,21 @@ func (ec *executionContext) _IssueConnection(ctx context.Context, sel ast.Select if out.Values[i] == graphql.Null { out.Invalids++ } + case "vulnerabilityCount": + out.Values[i] = ec._IssueConnection_vulnerabilityCount(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "policyViolationCount": + out.Values[i] = ec._IssueConnection_policyViolationCount(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "securityEventCount": + out.Values[i] = ec._IssueConnection_securityEventCount(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } case "edges": out.Values[i] = ec._IssueConnection_edges(ctx, field, obj) if out.Values[i] == graphql.Null { diff --git a/internal/api/graphql/graph/model/models_gen.go b/internal/api/graphql/graph/model/models_gen.go index 816b0852..b9fc7d37 100644 --- a/internal/api/graphql/graph/model/models_gen.go +++ b/internal/api/graphql/graph/model/models_gen.go @@ -3,6 +3,7 @@ // Code generated by github.com/99designs/gqlgen, DO NOT EDIT. + package model import ( @@ -309,9 +310,12 @@ func (Issue) IsNode() {} func (this Issue) GetID() string { return this.ID } type IssueConnection struct { - TotalCount int `json:"totalCount"` - Edges []*IssueEdge `json:"edges"` - PageInfo *PageInfo `json:"pageInfo,omitempty"` + TotalCount int `json:"totalCount"` + VulnerabilityCount int `json:"vulnerabilityCount"` + PolicyViolationCount int `json:"policyViolationCount"` + SecurityEventCount int `json:"securityEventCount"` + Edges []*IssueEdge `json:"edges"` + PageInfo *PageInfo `json:"pageInfo,omitempty"` } func (IssueConnection) IsConnection() {} diff --git a/internal/api/graphql/graph/queryCollection/issue/directRelations.graphql b/internal/api/graphql/graph/queryCollection/issue/directRelations.graphql index 9207121c..31c599d7 100644 --- a/internal/api/graphql/graph/queryCollection/issue/directRelations.graphql +++ b/internal/api/graphql/graph/queryCollection/issue/directRelations.graphql @@ -8,6 +8,9 @@ query ($filter: IssueFilter, $first: Int, $after: String) { after: $after ) { totalCount + vulnerabilityCount + policyViolationCount + securityEventCount edges { node { id diff --git a/internal/api/graphql/graph/schema/issue.graphqls b/internal/api/graphql/graph/schema/issue.graphqls index e5364fb5..5332baba 100644 --- a/internal/api/graphql/graph/schema/issue.graphqls +++ b/internal/api/graphql/graph/schema/issue.graphqls @@ -34,6 +34,9 @@ input IssueInput { type IssueConnection implements Connection { totalCount: Int! + vulnerabilityCount: Int! + policyViolationCount: Int! + securityEventCount: Int! edges: [IssueEdge]! pageInfo: PageInfo } diff --git a/internal/app/interface.go b/internal/app/interface.go index 82961c27..31033760 100644 --- a/internal/app/interface.go +++ b/internal/app/interface.go @@ -8,7 +8,7 @@ import ( ) type Heureka interface { - ListIssues(*entity.IssueFilter, *entity.ListOptions) (*entity.List[entity.IssueResult], error) + ListIssues(*entity.IssueFilter, *entity.IssueListOptions) (*entity.IssueList, error) CreateIssue(*entity.Issue) (*entity.Issue, error) UpdateIssue(*entity.Issue) (*entity.Issue, error) DeleteIssue(int64) error diff --git a/internal/app/issue.go b/internal/app/issue.go index 9d200e9e..21bc0ad2 100644 --- a/internal/app/issue.go +++ b/internal/app/issue.go @@ -54,7 +54,7 @@ func (h *HeurekaApp) GetIssue(id int64) (*entity.Issue, error) { "id": id, }) - issues, err := h.ListIssues(&entity.IssueFilter{Id: []*int64{&id}}, &entity.ListOptions{}) + issues, err := h.ListIssues(&entity.IssueFilter{Id: []*int64{&id}}, &entity.IssueListOptions{}) if err != nil { l.Error(err) @@ -68,11 +68,13 @@ func (h *HeurekaApp) GetIssue(id int64) (*entity.Issue, error) { return issues.Elements[0].Issue, nil } -func (h *HeurekaApp) ListIssues(filter *entity.IssueFilter, options *entity.ListOptions) (*entity.List[entity.IssueResult], error) { - var count int64 +func (h *HeurekaApp) ListIssues(filter *entity.IssueFilter, options *entity.IssueListOptions) (*entity.IssueList, error) { var pageInfo *entity.PageInfo var res []entity.IssueResult var err error + issueList := entity.IssueList{ + List: &entity.List[entity.IssueResult]{}, + } l := logrus.WithFields(logrus.Fields{ "event": "app.ListIssues", @@ -95,6 +97,8 @@ func (h *HeurekaApp) ListIssues(filter *entity.IssueFilter, options *entity.List } } + issueList.Elements = res + if options.ShowPageInfo { if len(res) > 0 { ids, err := h.database.GetAllIssueIds(filter) @@ -103,21 +107,23 @@ func (h *HeurekaApp) ListIssues(filter *entity.IssueFilter, options *entity.List return nil, heurekaError("Error while getting all Ids") } pageInfo = getPageInfo(res, ids, *filter.First, *filter.After) - count = int64(len(ids)) + issueList.PageInfo = pageInfo } - } else if options.ShowTotalCount { - count, err = h.database.CountIssues(filter) + } + if options.ShowPageInfo || options.ShowTotalCount || options.ShowIssueTypeCounts { + counts, err := h.database.CountIssueTypes(filter) if err != nil { l.Error(err) - return nil, heurekaError("Error while total count of issues") + return nil, heurekaError("Error while count of issues") } + tc := counts.TotalIssueCount() + issueList.PolicyViolationCount = &counts.PolicyViolationCount + issueList.SecurityEventCount = &counts.SecurityEventCount + issueList.VulnerabilityCount = &counts.VulnerabilityCount + issueList.TotalCount = &tc } - return &entity.List[entity.IssueResult]{ - TotalCount: &count, - PageInfo: pageInfo, - Elements: res, - }, nil + return &issueList, nil } func (h *HeurekaApp) CreateIssue(issue *entity.Issue) (*entity.Issue, error) { @@ -131,7 +137,7 @@ func (h *HeurekaApp) CreateIssue(issue *entity.Issue) (*entity.Issue, error) { "filter": f, }) - issues, err := h.ListIssues(f, &entity.ListOptions{}) + issues, err := h.ListIssues(f, &entity.IssueListOptions{}) if err != nil { l.Error(err) @@ -165,7 +171,7 @@ func (h *HeurekaApp) UpdateIssue(issue *entity.Issue) (*entity.Issue, error) { return nil, heurekaError("Internal error while updating issue.") } - issueResult, err := h.ListIssues(&entity.IssueFilter{Id: []*int64{&issue.Id}}, &entity.ListOptions{}) + issueResult, err := h.ListIssues(&entity.IssueFilter{Id: []*int64{&issue.Id}}, &entity.IssueListOptions{}) if err != nil { l.Error(err) diff --git a/internal/app/issue_test.go b/internal/app/issue_test.go index 783c03a8..28b19cc5 100644 --- a/internal/app/issue_test.go +++ b/internal/app/issue_test.go @@ -31,18 +31,37 @@ func getIssueFilter() *entity.IssueFilter { } } +func getIssueListOptions() *entity.IssueListOptions { + listOptions := getListOptions() + return &entity.IssueListOptions{ + ListOptions: *listOptions, + ShowIssueTypeCounts: false, + } +} + +func getIssueTypeCounts() *entity.IssueTypeCounts { + return &entity.IssueTypeCounts{ + VulnerabilityCount: 1000, + PolicyViolationCount: 300, + SecurityEventCount: 37, + } +} + var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { var ( - db *mocks.MockDatabase - heureka app.Heureka - filter *entity.IssueFilter - options *entity.ListOptions + db *mocks.MockDatabase + heureka app.Heureka + filter *entity.IssueFilter + options *entity.IssueListOptions + issueTypeCounts *entity.IssueTypeCounts ) BeforeEach(func() { db = mocks.NewMockDatabase(GinkgoT()) - options = getListOptions() + options = getIssueListOptions() filter = getIssueFilter() + issueTypeCounts = getIssueTypeCounts() + }) When("the list option does include the totalCount", func() { @@ -50,7 +69,7 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { BeforeEach(func() { options.ShowTotalCount = true db.On("GetIssues", filter).Return([]entity.Issue{}, nil) - db.On("CountIssues", filter).Return(int64(1337), nil) + db.On("CountIssueTypes", filter).Return(issueTypeCounts, nil) }) It("shows the total count in the results", func() { @@ -64,6 +83,7 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { When("the list option does include the PageInfo", func() { BeforeEach(func() { options.ShowPageInfo = true + db.On("CountIssueTypes", filter).Return(issueTypeCounts, nil) }) DescribeTable("pagination information is correct", func(pageSize int, dbElements int, resElements int, hasNextPage bool) { filter.First = &pageSize @@ -286,7 +306,7 @@ var _ = Describe("When deleting Issue", Label("app", "DeleteIssue"), func() { Expect(err).To(BeNil(), "no error should be thrown") filter.Id = []*int64{&id} - issues, err := heureka.ListIssues(filter, &entity.ListOptions{}) + issues, err := heureka.ListIssues(filter, &entity.IssueListOptions{}) Expect(err).To(BeNil(), "no error should be thrown") Expect(issues.Elements).To(BeEmpty(), "no error should be thrown") }) diff --git a/internal/database/interface.go b/internal/database/interface.go index e5d35c7f..fc309b04 100644 --- a/internal/database/interface.go +++ b/internal/database/interface.go @@ -9,6 +9,7 @@ type Database interface { GetIssues(*entity.IssueFilter) ([]entity.Issue, error) GetIssuesWithAggregations(*entity.IssueFilter) ([]entity.IssueWithAggregations, error) CountIssues(*entity.IssueFilter) (int64, error) + CountIssueTypes(*entity.IssueFilter) (*entity.IssueTypeCounts, error) GetAllIssueIds(*entity.IssueFilter) ([]int64, error) CreateIssue(*entity.Issue) (*entity.Issue, error) UpdateIssue(*entity.Issue) error diff --git a/internal/database/mariadb/entity.go b/internal/database/mariadb/entity.go index fcecacac..4b27cae8 100644 --- a/internal/database/mariadb/entity.go +++ b/internal/database/mariadb/entity.go @@ -52,6 +52,7 @@ func GetUserTypeValue(v sql.NullInt64) entity.UserType { type DatabaseRow interface { IssueRow | + IssueCountRow | GetIssuesByRow | IssueMatchRow | IssueAggregationsRow | @@ -157,6 +158,18 @@ func (ibr *GetIssuesByRow) AsIssue() entity.Issue { } } +type IssueCountRow struct { + Count sql.NullInt64 `db:"issue_count"` + Type sql.NullString `db:"issue_type"` +} + +func (icr *IssueCountRow) AsIssueCount() entity.IssueCount { + return entity.IssueCount{ + Count: GetInt64Value(icr.Count), + Type: entity.NewIssueType(GetStringValue(icr.Type)), + } +} + func (ir *IssueRow) FromIssue(i *entity.Issue) { ir.Id = sql.NullInt64{Int64: i.Id, Valid: true} ir.PrimaryName = sql.NullString{String: i.PrimaryName, Valid: true} diff --git a/internal/database/mariadb/issue.go b/internal/database/mariadb/issue.go index f7e4b12c..4bc45330 100644 --- a/internal/database/mariadb/issue.go +++ b/internal/database/mariadb/issue.go @@ -240,6 +240,54 @@ func (s *SqlDatabase) CountIssues(filter *entity.IssueFilter) (int64, error) { return performCountScan(stmt, filterParameters, l) } +func (s *SqlDatabase) CountIssueTypes(filter *entity.IssueFilter) (*entity.IssueTypeCounts, error) { + l := logrus.WithFields(logrus.Fields{ + "event": "database.CountIssueTypes", + }) + + baseQuery := ` + SELECT I.issue_type, COUNT(distinct I.issue_id) as issue_count FROM Issue I + %s + %s + GROUP BY I.issue_type + ` + + stmt, filterParameters, err := s.buildIssueStatement(baseQuery, filter, []string{}, false, l) + + if err != nil { + return nil, err + } + + defer stmt.Close() + + counts, err := performListScan( + stmt, + filterParameters, + l, + func(l []entity.IssueCount, e IssueCountRow) []entity.IssueCount { + return append(l, e.AsIssueCount()) + }, + ) + + if err != nil { + return nil, err + } + + var issueTypeCounts entity.IssueTypeCounts + for _, count := range counts { + switch count.Type { + case entity.IssueTypeVulnerability: + issueTypeCounts.VulnerabilityCount = count.Count + case entity.IssueTypePolicyViolation: + issueTypeCounts.PolicyViolationCount = count.Count + case entity.IssueTypeSecurityEvent: + issueTypeCounts.SecurityEventCount = count.Count + } + } + + return &issueTypeCounts, nil +} + func (s *SqlDatabase) GetAllIssueIds(filter *entity.IssueFilter) ([]int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.GetIssueIds", diff --git a/internal/database/mariadb/issue_test.go b/internal/database/mariadb/issue_test.go index debc39e8..e7ead789 100644 --- a/internal/database/mariadb/issue_test.go +++ b/internal/database/mariadb/issue_test.go @@ -450,6 +450,41 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { }) }) }) + Context("and counting issue types", func() { + var seedCollection *test.SeedCollection + BeforeEach(func() { + seedCollection = seeder.SeedDbWithNFakeData(20) + }) + It("returns the correct count for each issue type", func() { + vulnerabilityCount := 0 + policyViolationCount := 0 + securityEventCount := 0 + + for _, issue := range seedCollection.IssueRows { + switch issue.Type.String { + case entity.IssueTypeVulnerability.String(): + vulnerabilityCount++ + case entity.IssueTypePolicyViolation.String(): + policyViolationCount++ + case entity.IssueTypeSecurityEvent.String(): + securityEventCount++ + } + } + + issueTypeCounts, err := db.CountIssueTypes(nil) + + By("throwing no error", func() { + Expect(err).To(BeNil()) + }) + + By("returning the correct counts", func() { + Expect(issueTypeCounts.VulnerabilityCount).To(BeEquivalentTo(vulnerabilityCount)) + Expect(issueTypeCounts.PolicyViolationCount).To(BeEquivalentTo(policyViolationCount)) + Expect(issueTypeCounts.SecurityEventCount).To(BeEquivalentTo(securityEventCount)) + }) + + }) + }) }) When("Insert Issue", Label("InsertIssue"), func() { Context("and we have 10 Issues in the database", func() { diff --git a/internal/e2e/issue_query_test.go b/internal/e2e/issue_query_test.go index 8e073b91..62d8363c 100644 --- a/internal/e2e/issue_query_test.go +++ b/internal/e2e/issue_query_test.go @@ -8,20 +8,20 @@ import ( "fmt" "os" - "github.wdf.sap.corp/cc/heureka/internal/entity" - testentity "github.wdf.sap.corp/cc/heureka/internal/entity/test" - "github.wdf.sap.corp/cc/heureka/internal/util" - util2 "github.wdf.sap.corp/cc/heureka/pkg/util" - - "github.com/machinebox/graphql" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + + "github.com/machinebox/graphql" "github.com/samber/lo" "github.com/sirupsen/logrus" "github.wdf.sap.corp/cc/heureka/internal/api/graphql/graph/model" "github.wdf.sap.corp/cc/heureka/internal/database/mariadb" "github.wdf.sap.corp/cc/heureka/internal/database/mariadb/test" + "github.wdf.sap.corp/cc/heureka/internal/entity" + testentity "github.wdf.sap.corp/cc/heureka/internal/entity/test" "github.wdf.sap.corp/cc/heureka/internal/server" + "github.wdf.sap.corp/cc/heureka/internal/util" + util2 "github.wdf.sap.corp/cc/heureka/pkg/util" ) var _ = Describe("Getting Issues via API", Label("e2e", "Issues"), func() { @@ -182,12 +182,12 @@ var _ = Describe("Getting Issues via API", Label("e2e", "Issues"), func() { }) Expect(issueMatchFound).To(BeTrue(), "attached IssueMatch is correct") - _, componentINstanceFound := lo.Find(seedCollection.ComponentInstanceRows, func(row mariadb.ComponentInstanceRow) bool { + _, componentInstanceFound := lo.Find(seedCollection.ComponentInstanceRows, func(row mariadb.ComponentInstanceRow) bool { return fmt.Sprintf("%d", row.Id.Int64) == im.Node.ComponentInstance.ID && row.CCRN.String == *im.Node.ComponentInstance.Ccrn && int(row.Count.Int16) == *im.Node.ComponentInstance.Count }) - Expect(componentINstanceFound).To(BeTrue(), "attached Component instance is correct") + Expect(componentInstanceFound).To(BeTrue(), "attached Component instance is correct") } } }) diff --git a/internal/entity/common.go b/internal/entity/common.go index a561ccfb..0569efc4 100644 --- a/internal/entity/common.go +++ b/internal/entity/common.go @@ -41,7 +41,8 @@ type HeurekaEntity interface { Issue | IssueMatch | IssueMatchChange | - HeurekaFilter + HeurekaFilter | + IssueCount } type HeurekaFilter interface { diff --git a/internal/entity/issue.go b/internal/entity/issue.go index 5e5a9182..b29970a3 100644 --- a/internal/entity/issue.go +++ b/internal/entity/issue.go @@ -3,7 +3,9 @@ package entity -import "time" +import ( + "time" +) type IssueWithAggregations struct { IssueAggregations @@ -85,3 +87,30 @@ type Issue struct { DeletedAt time.Time `json:"deleted_at,omitempty"` UpdatedAt time.Time `json:"updated_lsat"` } + +type IssueCount struct { + Count int64 `json:"count"` + Type IssueType `json:"type"` +} + +type IssueTypeCounts struct { + VulnerabilityCount int64 `json:"vulnerability_count"` + PolicyViolationCount int64 `json:"policy_violation_count"` + SecurityEventCount int64 `json:"security_event_count"` +} + +func (itc *IssueTypeCounts) TotalIssueCount() int64 { + return itc.VulnerabilityCount + itc.PolicyViolationCount + itc.SecurityEventCount +} + +type IssueList struct { + *List[IssueResult] + VulnerabilityCount *int64 + PolicyViolationCount *int64 + SecurityEventCount *int64 +} + +type IssueListOptions struct { + ListOptions + ShowIssueTypeCounts bool +} diff --git a/internal/mocks/mock_Database.go b/internal/mocks/mock_Database.go index d61728b8..1ebb54ad 100644 --- a/internal/mocks/mock_Database.go +++ b/internal/mocks/mock_Database.go @@ -893,6 +893,64 @@ func (_c *MockDatabase_CountIssueRepositories_Call) RunAndReturn(run func(*entit return _c } +// CountIssueTypes provides a mock function with given fields: _a0 +func (_m *MockDatabase) CountIssueTypes(_a0 *entity.IssueFilter) (*entity.IssueTypeCounts, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for CountIssueTypes") + } + + var r0 *entity.IssueTypeCounts + var r1 error + if rf, ok := ret.Get(0).(func(*entity.IssueFilter) (*entity.IssueTypeCounts, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(*entity.IssueFilter) *entity.IssueTypeCounts); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*entity.IssueTypeCounts) + } + } + + if rf, ok := ret.Get(1).(func(*entity.IssueFilter) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDatabase_CountIssueTypes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CountIssueTypes' +type MockDatabase_CountIssueTypes_Call struct { + *mock.Call +} + +// CountIssueTypes is a helper method to define mock.On call +// - _a0 *entity.IssueFilter +func (_e *MockDatabase_Expecter) CountIssueTypes(_a0 interface{}) *MockDatabase_CountIssueTypes_Call { + return &MockDatabase_CountIssueTypes_Call{Call: _e.mock.On("CountIssueTypes", _a0)} +} + +func (_c *MockDatabase_CountIssueTypes_Call) Run(run func(_a0 *entity.IssueFilter)) *MockDatabase_CountIssueTypes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*entity.IssueFilter)) + }) + return _c +} + +func (_c *MockDatabase_CountIssueTypes_Call) Return(_a0 *entity.IssueTypeCounts, _a1 error) *MockDatabase_CountIssueTypes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDatabase_CountIssueTypes_Call) RunAndReturn(run func(*entity.IssueFilter) (*entity.IssueTypeCounts, error)) *MockDatabase_CountIssueTypes_Call { + _c.Call.Return(run) + return _c +} + // CountIssueVariants provides a mock function with given fields: _a0 func (_m *MockDatabase) CountIssueVariants(_a0 *entity.IssueVariantFilter) (int64, error) { ret := _m.Called(_a0) diff --git a/internal/mocks/mock_Heureka.go b/internal/mocks/mock_Heureka.go index ba70f0e0..b149f09f 100644 --- a/internal/mocks/mock_Heureka.go +++ b/internal/mocks/mock_Heureka.go @@ -437,6 +437,65 @@ func (_c *MockHeureka_AddServiceToSupportGroup_Call) RunAndReturn(run func(int64 return _c } +// AddUserToSupportGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockHeureka) AddUserToSupportGroup(_a0 int64, _a1 int64) (*entity.SupportGroup, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for AddUserToSupportGroup") + } + + var r0 *entity.SupportGroup + var r1 error + if rf, ok := ret.Get(0).(func(int64, int64) (*entity.SupportGroup, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(int64, int64) *entity.SupportGroup); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*entity.SupportGroup) + } + } + + if rf, ok := ret.Get(1).(func(int64, int64) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHeureka_AddUserToSupportGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddUserToSupportGroup' +type MockHeureka_AddUserToSupportGroup_Call struct { + *mock.Call +} + +// AddUserToSupportGroup is a helper method to define mock.On call +// - _a0 int64 +// - _a1 int64 +func (_e *MockHeureka_Expecter) AddUserToSupportGroup(_a0 interface{}, _a1 interface{}) *MockHeureka_AddUserToSupportGroup_Call { + return &MockHeureka_AddUserToSupportGroup_Call{Call: _e.mock.On("AddUserToSupportGroup", _a0, _a1)} +} + +func (_c *MockHeureka_AddUserToSupportGroup_Call) Run(run func(_a0 int64, _a1 int64)) *MockHeureka_AddUserToSupportGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int64)) + }) + return _c +} + +func (_c *MockHeureka_AddUserToSupportGroup_Call) Return(_a0 *entity.SupportGroup, _a1 error) *MockHeureka_AddUserToSupportGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockHeureka_AddUserToSupportGroup_Call) RunAndReturn(run func(int64, int64) (*entity.SupportGroup, error)) *MockHeureka_AddUserToSupportGroup_Call { + _c.Call.Return(run) + return _c +} + // CreateActivity provides a mock function with given fields: _a0 func (_m *MockHeureka) CreateActivity(_a0 *entity.Activity) (*entity.Activity, error) { ret := _m.Called(_a0) @@ -3318,6 +3377,65 @@ func (_c *MockHeureka_RemoveServiceFromSupportGroup_Call) RunAndReturn(run func( return _c } +// RemoveUserFromSupportGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockHeureka) RemoveUserFromSupportGroup(_a0 int64, _a1 int64) (*entity.SupportGroup, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for RemoveUserFromSupportGroup") + } + + var r0 *entity.SupportGroup + var r1 error + if rf, ok := ret.Get(0).(func(int64, int64) (*entity.SupportGroup, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(int64, int64) *entity.SupportGroup); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*entity.SupportGroup) + } + } + + if rf, ok := ret.Get(1).(func(int64, int64) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHeureka_RemoveUserFromSupportGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveUserFromSupportGroup' +type MockHeureka_RemoveUserFromSupportGroup_Call struct { + *mock.Call +} + +// RemoveUserFromSupportGroup is a helper method to define mock.On call +// - _a0 int64 +// - _a1 int64 +func (_e *MockHeureka_Expecter) RemoveUserFromSupportGroup(_a0 interface{}, _a1 interface{}) *MockHeureka_RemoveUserFromSupportGroup_Call { + return &MockHeureka_RemoveUserFromSupportGroup_Call{Call: _e.mock.On("RemoveUserFromSupportGroup", _a0, _a1)} +} + +func (_c *MockHeureka_RemoveUserFromSupportGroup_Call) Run(run func(_a0 int64, _a1 int64)) *MockHeureka_RemoveUserFromSupportGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int64)) + }) + return _c +} + +func (_c *MockHeureka_RemoveUserFromSupportGroup_Call) Return(_a0 *entity.SupportGroup, _a1 error) *MockHeureka_RemoveUserFromSupportGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockHeureka_RemoveUserFromSupportGroup_Call) RunAndReturn(run func(int64, int64) (*entity.SupportGroup, error)) *MockHeureka_RemoveUserFromSupportGroup_Call { + _c.Call.Return(run) + return _c +} + // Shutdown provides a mock function with given fields: func (_m *MockHeureka) Shutdown() error { ret := _m.Called()