diff --git a/pkg/assembler/backends/ent/backend/search.go b/pkg/assembler/backends/ent/backend/search.go index bbe2d0aeed..cc46e8e0d6 100644 --- a/pkg/assembler/backends/ent/backend/search.go +++ b/pkg/assembler/backends/ent/backend/search.go @@ -199,6 +199,11 @@ func (b *EntBackend) FindPackagesThatNeedScanning(ctx context.Context, queryType func (b *EntBackend) QueryPackagesListForScan(ctx context.Context, pkgIDs []string, after *string, first *int) (*model.PackageConnection, error) { var afterCursor *entgql.Cursor[uuid.UUID] + // if empty pkgIDs slice is passed in return nothing + if len(pkgIDs) == 0 { + return nil, nil + } + if after != nil { globalID := fromGlobalID(*after) if globalID.nodeType != packageversion.Table { @@ -244,14 +249,17 @@ func (b *EntBackend) QueryPackagesListForScan(ctx context.Context, pkgIDs []stri shortenedQueryList = append(shortenedQueryList, convertedID) } } - var queryErr error - pkgConn, queryErr = b.client.PackageVersion.Query(). - Where(packageversion.IDIn(shortenedQueryList...)). - WithName(func(q *ent.PackageNameQuery) {}). - Paginate(ctx, afterCursor, first, nil, nil) - if queryErr != nil { - return nil, fmt.Errorf("failed package query based on package IDs that need scanning with error: %w", queryErr) + if len(shortenedQueryList) > 0 { + var queryErr error + pkgConn, queryErr = b.client.PackageVersion.Query(). + Where(packageversion.IDIn(shortenedQueryList...)). + WithName(func(q *ent.PackageNameQuery) {}). + Paginate(ctx, afterCursor, first, nil, nil) + + if queryErr != nil { + return nil, fmt.Errorf("failed package query based on package IDs that need scanning with error: %w", queryErr) + } } // if not found return nil @@ -295,6 +303,11 @@ func constructPkgConn(pkgConn *ent.PackageVersionConnection, totalCount int, has func (b *EntBackend) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []string) ([]*model.CertifyVuln, error) { + // if empty pkgIDs slice is passed in return nothing + if len(pkgIDs) == 0 { + return nil, nil + } + // static ID for noVuln that is generated from type = novuln and vulnid = "" // this is generated via: vulnIDs := helpers.GetKey[*model.VulnerabilityInputSpec, helpers.VulnIds](&model.VulnerabilityInputSpec{Type: NoVuln, VulnerabilityID: ""}, helpers.VulnServerKey) @@ -343,25 +356,32 @@ func (b *EntBackend) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []st )) } - certVulnConn, err := b.client.CertifyVuln.Query(). - Where(certifyvuln.Or(predicates...)). - WithVulnerability(func(query *ent.VulnerabilityIDQuery) {}). - WithPackage(func(q *ent.PackageVersionQuery) { - q.WithName(func(q *ent.PackageNameQuery) {}) - }).All(ctx) - - if err != nil { - return nil, fmt.Errorf("failed certifyVuln query based on package IDs with error: %w", err) - } var collectedCertVuln []*model.CertifyVuln - for _, entCertVuln := range certVulnConn { - collectedCertVuln = append(collectedCertVuln, toModelCertifyVulnerability(entCertVuln)) + if len(predicates) > 0 { + certVulnConn, err := b.client.CertifyVuln.Query(). + Where(certifyvuln.Or(predicates...)). + WithVulnerability(func(query *ent.VulnerabilityIDQuery) {}). + WithPackage(func(q *ent.PackageVersionQuery) { + q.WithName(func(q *ent.PackageNameQuery) {}) + }).All(ctx) + + if err != nil { + return nil, fmt.Errorf("failed certifyVuln query based on package IDs with error: %w", err) + } + for _, entCertVuln := range certVulnConn { + collectedCertVuln = append(collectedCertVuln, toModelCertifyVulnerability(entCertVuln)) + } } return collectedCertVuln, nil } func (b *EntBackend) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []string) ([]*model.CertifyLegal, error) { + // if empty pkgIDs slice is passed in return nothing + if len(pkgIDs) == 0 { + return nil, nil + } + var queryList []uuid.UUID for _, id := range pkgIDs { @@ -408,26 +428,36 @@ func (b *EntBackend) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []s )) } - certLegalConn, err := b.client.CertifyLegal.Query(). - Where(certifylegal.Or(predicates...)). - WithPackage(func(q *ent.PackageVersionQuery) { - q.WithName(func(q *ent.PackageNameQuery) {}) - }). - WithDeclaredLicenses(). - WithDiscoveredLicenses().All(ctx) + var collectedCertLegal []*model.CertifyLegal - if err != nil { - return nil, fmt.Errorf("failed certifyLegal query based on package IDs with error: %w", err) - } + if len(predicates) > 0 { + certLegalConn, err := b.client.CertifyLegal.Query(). + Where(certifylegal.Or(predicates...)). + WithPackage(func(q *ent.PackageVersionQuery) { + q.WithName(func(q *ent.PackageNameQuery) {}) + }). + WithDeclaredLicenses(). + WithDiscoveredLicenses().All(ctx) - var collectedCertLegal []*model.CertifyLegal - for _, entCertLegal := range certLegalConn { - collectedCertLegal = append(collectedCertLegal, toModelCertifyLegal(entCertLegal)) + if err != nil { + return nil, fmt.Errorf("failed certifyLegal query based on package IDs with error: %w", err) + } + + for _, entCertLegal := range certLegalConn { + collectedCertLegal = append(collectedCertLegal, toModelCertifyLegal(entCertLegal)) + } } + return collectedCertLegal, nil } func (b *EntBackend) BatchQuerySubjectPkgDependency(ctx context.Context, pkgIDs []string) ([]*model.IsDependency, error) { + + // if empty pkgIDs slice is passed in return nothing + if len(pkgIDs) == 0 { + return nil, nil + } + var queryList []uuid.UUID for _, id := range pkgIDs { @@ -456,6 +486,12 @@ func (b *EntBackend) BatchQuerySubjectPkgDependency(ctx context.Context, pkgIDs } func (b *EntBackend) BatchQueryDepPkgDependency(ctx context.Context, pkgIDs []string) ([]*model.IsDependency, error) { + + // if empty pkgIDs slice is passed in return nothing + if len(pkgIDs) == 0 { + return nil, nil + } + var queryList []uuid.UUID for _, id := range pkgIDs {