Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added exponential backoff of calls to api gateway api #1645

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions enumeration/remote/aws/repository/api_gateway_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ package repository

import (
"fmt"
"math"
"strings"
"time"

"github.com/sirupsen/logrus"
"github.com/snyk/driftctl/enumeration/remote/cache"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -30,21 +35,48 @@ type apigatewayRepository struct {
cache cache.Cache
}

const MaxRetries = 5

func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewayRepository {
return &apigatewayRepository{
apigateway.New(session),
c,
}
}

func retryOnFailure(callback func() error, message string) error {
retries := 0
retry := true

var err error
for retry && retries < MaxRetries {
sleepTime := time.Duration(math.Pow(2, float64(retries))) * 2 * time.Second
logrus.Warn(message, "Attempt number ", retries+1, "/", MaxRetries, ". Retrying after sleeping for ", sleepTime, "...")
time.Sleep(sleepTime)
logrus.Debug("Awake! Attempting to make API call again.")

err = callback()
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
retry = true
} else {
retry = false
}

retries++
}
return err
}

func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) {
cacheKey := "apigatewayListAllRestApis"
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
logrus.Debug("Getting all rest APIs from cache")
return v.([]*apigateway.RestApi), nil
}

logrus.Debug("Making a call to get rest APIs not found in cache")
var restApis []*apigateway.RestApi
input := apigateway.GetRestApisInput{}
err := r.client.GetRestApisPages(&input,
Expand All @@ -53,6 +85,20 @@ func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error)
return !lastPage
},
)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetRestApisPages(&input,
func(resp *apigateway.GetRestApisOutput, lastPage bool) bool {
restApis = append(restApis, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetRestApisPages!")
}

if err != nil {
return nil, err
}
Expand All @@ -67,6 +113,16 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) {
}

account, err := r.client.GetAccount(&apigateway.GetAccountInput{})

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
input := apigateway.GetAccountInput{}
account, err = r.client.GetAccount(&input)
return err
}, "Error caught during GetAccount!")
}

if err != nil {
return nil, err
}
Expand All @@ -77,6 +133,7 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) {

func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
if v := r.cache.Get("apigatewayListAllApiKeys"); v != nil {
logrus.Debug("Getting api keys from cache")
return v.([]*apigateway.ApiKey), nil
}

Expand All @@ -88,6 +145,20 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
return !lastPage
},
)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetApiKeysPages(&input,
func(resp *apigateway.GetApiKeysOutput, lastPage bool) bool {
apiKeys = append(apiKeys, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetApiKeysPages!")
}

if err != nil {
return nil, err
}
Expand All @@ -99,13 +170,24 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
logrus.Debug("Getting api authorizers from cache ", apiId)
return v.([]*apigateway.Authorizer), nil
}

logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId)
input := &apigateway.GetAuthorizersInput{
RestApiId: &apiId,
}
resources, err := r.client.GetAuthorizers(input)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId)
resources, err = r.client.GetAuthorizers(input)
return err
}, "Error caught during GetAuthorizers with input "+apiId+"!")
}

if err != nil {
return nil, err
}
Expand All @@ -119,14 +201,26 @@ func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
logrus.Debug("Getting api stages from cache ", apiId)
return v.([]*apigateway.Stage), nil
}

logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId)
input := &apigateway.GetStagesInput{
RestApiId: &apiId,
}
resources, err := r.client.GetStages(input)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId)
resources, err = r.client.GetStages(input)
return err
}, "Error caught during GetStages with input "+apiId+"!")
}

if err != nil {
logrus.Error("error in api stage")
return nil, err
}

Expand All @@ -139,9 +233,11 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
logrus.Debug("Getting api resource from cache ", apiId)
return v.([]*apigateway.Resource), nil
}

logrus.Debug("Making a call to API for specific resource not found in cache ", apiId)
var resources []*apigateway.Resource
input := &apigateway.GetResourcesInput{
RestApiId: &apiId,
Expand All @@ -151,6 +247,18 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate
resources = append(resources, res.Items...)
return !lastPage
})

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool {
resources = append(resources, res.Items...)
return !lastPage
})
return err
}, "Error caught during GetResourcesPages with input "+apiId+"!")
}

if err != nil {
return nil, err
}
Expand All @@ -175,6 +283,20 @@ func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, e
return !lastPage
},
)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetDomainNamesPages(&input,
func(resp *apigateway.GetDomainNamesOutput, lastPage bool) bool {
domainNames = append(domainNames, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetDomainNamesPages!")
}

if err != nil {
return nil, err
}
Expand All @@ -196,6 +318,20 @@ func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOut
return !lastPage
},
)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetVpcLinksPages(&input,
func(resp *apigateway.GetVpcLinksOutput, lastPage bool) bool {
vpcLinks = append(vpcLinks, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetVpcLinksPages!")
}

if err != nil {
return nil, err
}
Expand All @@ -214,6 +350,15 @@ func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([]
RestApiId: &apiId,
}
resources, err := r.client.GetRequestValidators(input)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
resources, err = r.client.GetRequestValidators(input)
return err
}, "Error caught during GetRequestValidators with input "+apiId+"!")
}

if err != nil {
return nil, err
}
Expand All @@ -236,6 +381,18 @@ func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName stri
mappings = append(mappings, res.Items...)
return !lastPage
})

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool {
mappings = append(mappings, res.Items...)
return !lastPage
})
return err
}, "Error caught during GetBasePathMappingsPages with input "+domainName+"!")
}

if err != nil {
return nil, err
}
Expand All @@ -258,6 +415,17 @@ func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway
resources = append(resources, res.Items...)
return !lastPage
})

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool {
resources = append(resources, res.Items...)
return !lastPage
})
return err
}, "Error caught during GetModelsPages with input "+apiId+"!")
}
if err != nil {
return nil, err
}
Expand All @@ -276,6 +444,15 @@ func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]*
RestApiId: &apiId,
}
resources, err := r.client.GetGatewayResponses(input)

if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
resources, err = r.client.GetGatewayResponses(input)
return err
}, "Error caught during GetGatewayResponses with input "+apiId+"!")
}

if err != nil {
return nil, err
}
Expand Down