diff --git a/enumeration/remote/aws/repository/api_gateway_repository.go b/enumeration/remote/aws/repository/api_gateway_repository.go index bb69d7dd6..5da06f7bd 100644 --- a/enumeration/remote/aws/repository/api_gateway_repository.go +++ b/enumeration/remote/aws/repository/api_gateway_repository.go @@ -2,6 +2,11 @@ package repository import ( "fmt" + "math" + "strings" + "time" + + "github.com/sirupsen/logrus" "github.com/snyk/driftctl/enumeration/remote/cache" "github.com/aws/aws-sdk-go/aws" @@ -30,6 +35,8 @@ type apigatewayRepository struct { cache cache.Cache } +const MaxRetries = 5 + func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewayRepository { return &apigatewayRepository{ apigateway.New(session), @@ -37,14 +44,39 @@ func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewa } } +func retryOnFailure(callback func() error, message string) error { + retries := 0 + retry := true + + var err error + for retry && retries < MaxRetries { + sleepTime := time.Duration(math.Pow(2, float64(retries))) * 2 * time.Second + logrus.Warn(message, "Attempt number ", retries+1, "/", MaxRetries, ". Retrying after sleeping for ", sleepTime, "...") + time.Sleep(sleepTime) + logrus.Debug("Awake! Attempting to make API call again.") + + err = callback() + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + retry = true + } else { + retry = false + } + + retries++ + } + return err +} + func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) { cacheKey := "apigatewayListAllRestApis" v := r.cache.GetAndLock(cacheKey) defer r.cache.Unlock(cacheKey) if v != nil { + logrus.Debug("Getting all rest APIs from cache") return v.([]*apigateway.RestApi), nil } + logrus.Debug("Making a call to get rest APIs not found in cache") var restApis []*apigateway.RestApi input := apigateway.GetRestApisInput{} err := r.client.GetRestApisPages(&input, @@ -53,6 +85,20 @@ func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) return !lastPage }, ) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetRestApisPages(&input, + func(resp *apigateway.GetRestApisOutput, lastPage bool) bool { + restApis = append(restApis, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetRestApisPages!") + } + if err != nil { return nil, err } @@ -67,6 +113,16 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) { } account, err := r.client.GetAccount(&apigateway.GetAccountInput{}) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + input := apigateway.GetAccountInput{} + account, err = r.client.GetAccount(&input) + return err + }, "Error caught during GetAccount!") + } + if err != nil { return nil, err } @@ -77,6 +133,7 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) { func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { if v := r.cache.Get("apigatewayListAllApiKeys"); v != nil { + logrus.Debug("Getting api keys from cache") return v.([]*apigateway.ApiKey), nil } @@ -88,6 +145,20 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { return !lastPage }, ) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetApiKeysPages(&input, + func(resp *apigateway.GetApiKeysOutput, lastPage bool) bool { + apiKeys = append(apiKeys, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetApiKeysPages!") + } + if err != nil { return nil, err } @@ -99,13 +170,24 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) { cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId) if v := r.cache.Get(cacheKey); v != nil { + logrus.Debug("Getting api authorizers from cache ", apiId) return v.([]*apigateway.Authorizer), nil } + logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId) input := &apigateway.GetAuthorizersInput{ RestApiId: &apiId, } resources, err := r.client.GetAuthorizers(input) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId) + resources, err = r.client.GetAuthorizers(input) + return err + }, "Error caught during GetAuthorizers with input "+apiId+"!") + } + if err != nil { return nil, err } @@ -119,14 +201,26 @@ func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway v := r.cache.GetAndLock(cacheKey) defer r.cache.Unlock(cacheKey) if v != nil { + logrus.Debug("Getting api stages from cache ", apiId) return v.([]*apigateway.Stage), nil } + logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId) input := &apigateway.GetStagesInput{ RestApiId: &apiId, } resources, err := r.client.GetStages(input) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId) + resources, err = r.client.GetStages(input) + return err + }, "Error caught during GetStages with input "+apiId+"!") + } + if err != nil { + logrus.Error("error in api stage") return nil, err } @@ -139,9 +233,11 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate v := r.cache.GetAndLock(cacheKey) defer r.cache.Unlock(cacheKey) if v != nil { + logrus.Debug("Getting api resource from cache ", apiId) return v.([]*apigateway.Resource), nil } + logrus.Debug("Making a call to API for specific resource not found in cache ", apiId) var resources []*apigateway.Resource input := &apigateway.GetResourcesInput{ RestApiId: &apiId, @@ -151,6 +247,18 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate resources = append(resources, res.Items...) return !lastPage }) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool { + resources = append(resources, res.Items...) + return !lastPage + }) + return err + }, "Error caught during GetResourcesPages with input "+apiId+"!") + } + if err != nil { return nil, err } @@ -175,6 +283,20 @@ func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, e return !lastPage }, ) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetDomainNamesPages(&input, + func(resp *apigateway.GetDomainNamesOutput, lastPage bool) bool { + domainNames = append(domainNames, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetDomainNamesPages!") + } + if err != nil { return nil, err } @@ -196,6 +318,20 @@ func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOut return !lastPage }, ) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetVpcLinksPages(&input, + func(resp *apigateway.GetVpcLinksOutput, lastPage bool) bool { + vpcLinks = append(vpcLinks, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetVpcLinksPages!") + } + if err != nil { return nil, err } @@ -214,6 +350,15 @@ func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([] RestApiId: &apiId, } resources, err := r.client.GetRequestValidators(input) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + resources, err = r.client.GetRequestValidators(input) + return err + }, "Error caught during GetRequestValidators with input "+apiId+"!") + } + if err != nil { return nil, err } @@ -236,6 +381,18 @@ func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName stri mappings = append(mappings, res.Items...) return !lastPage }) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool { + mappings = append(mappings, res.Items...) + return !lastPage + }) + return err + }, "Error caught during GetBasePathMappingsPages with input "+domainName+"!") + } + if err != nil { return nil, err } @@ -258,6 +415,17 @@ func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway resources = append(resources, res.Items...) return !lastPage }) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool { + resources = append(resources, res.Items...) + return !lastPage + }) + return err + }, "Error caught during GetModelsPages with input "+apiId+"!") + } if err != nil { return nil, err } @@ -276,6 +444,15 @@ func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]* RestApiId: &apiId, } resources, err := r.client.GetGatewayResponses(input) + + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + resources, err = r.client.GetGatewayResponses(input) + return err + }, "Error caught during GetGatewayResponses with input "+apiId+"!") + } + if err != nil { return nil, err }