Skip to content

Commit

Permalink
check if batch query is empty, otherwise skip (#2252)
Browse files Browse the repository at this point in the history
Signed-off-by: pxp928 <[email protected]>
  • Loading branch information
pxp928 authored Nov 4, 2024
1 parent fad3dd5 commit a5fe089
Showing 1 changed file with 68 additions and 32 deletions.
100 changes: 68 additions & 32 deletions pkg/assembler/backends/ent/backend/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ func (b *EntBackend) FindPackagesThatNeedScanning(ctx context.Context, queryType
func (b *EntBackend) QueryPackagesListForScan(ctx context.Context, pkgIDs []string, after *string, first *int) (*model.PackageConnection, error) {
var afterCursor *entgql.Cursor[uuid.UUID]

// if empty pkgIDs slice is passed in return nothing
if len(pkgIDs) == 0 {
return nil, nil
}

if after != nil {
globalID := fromGlobalID(*after)
if globalID.nodeType != packageversion.Table {
Expand Down Expand Up @@ -244,14 +249,17 @@ func (b *EntBackend) QueryPackagesListForScan(ctx context.Context, pkgIDs []stri
shortenedQueryList = append(shortenedQueryList, convertedID)
}
}
var queryErr error
pkgConn, queryErr = b.client.PackageVersion.Query().
Where(packageversion.IDIn(shortenedQueryList...)).
WithName(func(q *ent.PackageNameQuery) {}).
Paginate(ctx, afterCursor, first, nil, nil)

if queryErr != nil {
return nil, fmt.Errorf("failed package query based on package IDs that need scanning with error: %w", queryErr)
if len(shortenedQueryList) > 0 {
var queryErr error
pkgConn, queryErr = b.client.PackageVersion.Query().
Where(packageversion.IDIn(shortenedQueryList...)).
WithName(func(q *ent.PackageNameQuery) {}).
Paginate(ctx, afterCursor, first, nil, nil)

if queryErr != nil {
return nil, fmt.Errorf("failed package query based on package IDs that need scanning with error: %w", queryErr)
}
}

// if not found return nil
Expand Down Expand Up @@ -295,6 +303,11 @@ func constructPkgConn(pkgConn *ent.PackageVersionConnection, totalCount int, has

func (b *EntBackend) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []string) ([]*model.CertifyVuln, error) {

// if empty pkgIDs slice is passed in return nothing
if len(pkgIDs) == 0 {
return nil, nil
}

// static ID for noVuln that is generated from type = novuln and vulnid = ""
// this is generated via:
vulnIDs := helpers.GetKey[*model.VulnerabilityInputSpec, helpers.VulnIds](&model.VulnerabilityInputSpec{Type: NoVuln, VulnerabilityID: ""}, helpers.VulnServerKey)
Expand Down Expand Up @@ -343,25 +356,32 @@ func (b *EntBackend) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []st
))
}

certVulnConn, err := b.client.CertifyVuln.Query().
Where(certifyvuln.Or(predicates...)).
WithVulnerability(func(query *ent.VulnerabilityIDQuery) {}).
WithPackage(func(q *ent.PackageVersionQuery) {
q.WithName(func(q *ent.PackageNameQuery) {})
}).All(ctx)

if err != nil {
return nil, fmt.Errorf("failed certifyVuln query based on package IDs with error: %w", err)
}
var collectedCertVuln []*model.CertifyVuln
for _, entCertVuln := range certVulnConn {
collectedCertVuln = append(collectedCertVuln, toModelCertifyVulnerability(entCertVuln))
if len(predicates) > 0 {
certVulnConn, err := b.client.CertifyVuln.Query().
Where(certifyvuln.Or(predicates...)).
WithVulnerability(func(query *ent.VulnerabilityIDQuery) {}).
WithPackage(func(q *ent.PackageVersionQuery) {
q.WithName(func(q *ent.PackageNameQuery) {})
}).All(ctx)

if err != nil {
return nil, fmt.Errorf("failed certifyVuln query based on package IDs with error: %w", err)
}
for _, entCertVuln := range certVulnConn {
collectedCertVuln = append(collectedCertVuln, toModelCertifyVulnerability(entCertVuln))
}
}
return collectedCertVuln, nil
}

func (b *EntBackend) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []string) ([]*model.CertifyLegal, error) {

// if empty pkgIDs slice is passed in return nothing
if len(pkgIDs) == 0 {
return nil, nil
}

var queryList []uuid.UUID

for _, id := range pkgIDs {
Expand Down Expand Up @@ -408,26 +428,36 @@ func (b *EntBackend) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []s
))
}

certLegalConn, err := b.client.CertifyLegal.Query().
Where(certifylegal.Or(predicates...)).
WithPackage(func(q *ent.PackageVersionQuery) {
q.WithName(func(q *ent.PackageNameQuery) {})
}).
WithDeclaredLicenses().
WithDiscoveredLicenses().All(ctx)
var collectedCertLegal []*model.CertifyLegal

if err != nil {
return nil, fmt.Errorf("failed certifyLegal query based on package IDs with error: %w", err)
}
if len(predicates) > 0 {
certLegalConn, err := b.client.CertifyLegal.Query().
Where(certifylegal.Or(predicates...)).
WithPackage(func(q *ent.PackageVersionQuery) {
q.WithName(func(q *ent.PackageNameQuery) {})
}).
WithDeclaredLicenses().
WithDiscoveredLicenses().All(ctx)

var collectedCertLegal []*model.CertifyLegal
for _, entCertLegal := range certLegalConn {
collectedCertLegal = append(collectedCertLegal, toModelCertifyLegal(entCertLegal))
if err != nil {
return nil, fmt.Errorf("failed certifyLegal query based on package IDs with error: %w", err)
}

for _, entCertLegal := range certLegalConn {
collectedCertLegal = append(collectedCertLegal, toModelCertifyLegal(entCertLegal))
}
}

return collectedCertLegal, nil
}

func (b *EntBackend) BatchQuerySubjectPkgDependency(ctx context.Context, pkgIDs []string) ([]*model.IsDependency, error) {

// if empty pkgIDs slice is passed in return nothing
if len(pkgIDs) == 0 {
return nil, nil
}

var queryList []uuid.UUID

for _, id := range pkgIDs {
Expand Down Expand Up @@ -456,6 +486,12 @@ func (b *EntBackend) BatchQuerySubjectPkgDependency(ctx context.Context, pkgIDs
}

func (b *EntBackend) BatchQueryDepPkgDependency(ctx context.Context, pkgIDs []string) ([]*model.IsDependency, error) {

// if empty pkgIDs slice is passed in return nothing
if len(pkgIDs) == 0 {
return nil, nil
}

var queryList []uuid.UUID

for _, id := range pkgIDs {
Expand Down

0 comments on commit a5fe089

Please sign in to comment.