Expand Up
@@ -2,6 +2,11 @@ package repository
import (
"fmt"
"math"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/snyk/driftctl/enumeration/remote/cache"
"github.com/aws/aws-sdk-go/aws"
Expand Down
Expand Up
@@ -30,21 +35,48 @@ type apigatewayRepository struct {
cache cache.Cache
}
const MaxRetries = 5
func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewayRepository {
return &apigatewayRepository{
apigateway.New(session),
c,
}
}
func retryOnFailure(callback func() error, message string) error {
retries := 0
retry := true
var err error
for retry && retries < MaxRetries {
sleepTime := time.Duration(math.Pow(2, float64(retries))) * 2 * time.Second
logrus.Warn(message, "Attempt number ", retries+1, "/", MaxRetries, ". Retrying after sleeping for ", sleepTime, "...")
time.Sleep(sleepTime)
logrus.Debug("Awake! Attempting to make API call again.")
err = callback()
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
retry = true
} else {
retry = false
}
retries++
}
return err
}
func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) {
cacheKey := "apigatewayListAllRestApis"
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
logrus.Debug("Getting all rest APIs from cache")
return v.([]*apigateway.RestApi), nil
}
logrus.Debug("Making a call to get rest APIs not found in cache")
var restApis []*apigateway.RestApi
input := apigateway.GetRestApisInput{}
err := r.client.GetRestApisPages(&input,
Expand All
@@ -53,6 +85,20 @@ func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error)
return !lastPage
},
)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetRestApisPages(&input,
func(resp *apigateway.GetRestApisOutput, lastPage bool) bool {
restApis = append(restApis, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetRestApisPages!")
}
if err != nil {
return nil, err
}
Expand All
@@ -67,6 +113,16 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) {
}
account, err := r.client.GetAccount(&apigateway.GetAccountInput{})
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
input := apigateway.GetAccountInput{}
account, err = r.client.GetAccount(&input)
return err
}, "Error caught during GetAccount!")
}
if err != nil {
return nil, err
}
Expand All
@@ -77,6 +133,7 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) {
func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
if v := r.cache.Get("apigatewayListAllApiKeys"); v != nil {
logrus.Debug("Getting api keys from cache")
return v.([]*apigateway.ApiKey), nil
}
Expand All
@@ -88,6 +145,20 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
return !lastPage
},
)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetApiKeysPages(&input,
func(resp *apigateway.GetApiKeysOutput, lastPage bool) bool {
apiKeys = append(apiKeys, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetApiKeysPages!")
}
if err != nil {
return nil, err
}
Expand All
@@ -99,13 +170,24 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
logrus.Debug("Getting api authorizers from cache ", apiId)
return v.([]*apigateway.Authorizer), nil
}
logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId)
input := &apigateway.GetAuthorizersInput{
RestApiId: &apiId,
}
resources, err := r.client.GetAuthorizers(input)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId)
resources, err = r.client.GetAuthorizers(input)
return err
}, "Error caught during GetAuthorizers with input "+apiId+"!")
}
if err != nil {
return nil, err
}
Expand All
@@ -119,14 +201,26 @@ func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
logrus.Debug("Getting api stages from cache ", apiId)
return v.([]*apigateway.Stage), nil
}
logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId)
input := &apigateway.GetStagesInput{
RestApiId: &apiId,
}
resources, err := r.client.GetStages(input)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId)
resources, err = r.client.GetStages(input)
return err
}, "Error caught during GetStages with input "+apiId+"!")
}
if err != nil {
logrus.Error("error in api stage")
return nil, err
}
Expand All
@@ -139,9 +233,11 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
logrus.Debug("Getting api resource from cache ", apiId)
return v.([]*apigateway.Resource), nil
}
logrus.Debug("Making a call to API for specific resource not found in cache ", apiId)
var resources []*apigateway.Resource
input := &apigateway.GetResourcesInput{
RestApiId: &apiId,
Expand All
@@ -151,6 +247,18 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate
resources = append(resources, res.Items...)
return !lastPage
})
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool {
resources = append(resources, res.Items...)
return !lastPage
})
return err
}, "Error caught during GetResourcesPages with input "+apiId+"!")
}
if err != nil {
return nil, err
}
Expand All
@@ -175,6 +283,20 @@ func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, e
return !lastPage
},
)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetDomainNamesPages(&input,
func(resp *apigateway.GetDomainNamesOutput, lastPage bool) bool {
domainNames = append(domainNames, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetDomainNamesPages!")
}
if err != nil {
return nil, err
}
Expand All
@@ -196,6 +318,20 @@ func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOut
return !lastPage
},
)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetVpcLinksPages(&input,
func(resp *apigateway.GetVpcLinksOutput, lastPage bool) bool {
vpcLinks = append(vpcLinks, resp.Items...)
return !lastPage
},
)
return err
}, "Error caught during GetVpcLinksPages!")
}
if err != nil {
return nil, err
}
Expand All
@@ -214,6 +350,15 @@ func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([]
RestApiId: &apiId,
}
resources, err := r.client.GetRequestValidators(input)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
resources, err = r.client.GetRequestValidators(input)
return err
}, "Error caught during GetRequestValidators with input "+apiId+"!")
}
if err != nil {
return nil, err
}
Expand All
@@ -236,6 +381,18 @@ func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName stri
mappings = append(mappings, res.Items...)
return !lastPage
})
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool {
mappings = append(mappings, res.Items...)
return !lastPage
})
return err
}, "Error caught during GetBasePathMappingsPages with input "+domainName+"!")
}
if err != nil {
return nil, err
}
Expand All
@@ -258,6 +415,17 @@ func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway
resources = append(resources, res.Items...)
return !lastPage
})
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
err = r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool {
resources = append(resources, res.Items...)
return !lastPage
})
return err
}, "Error caught during GetModelsPages with input "+apiId+"!")
}
if err != nil {
return nil, err
}
Expand All
@@ -276,6 +444,15 @@ func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]*
RestApiId: &apiId,
}
resources, err := r.client.GetGatewayResponses(input)
if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") {
err = retryOnFailure(func() error {
logrus.Debug("Making a call to get rest APIs not found in cache")
resources, err = r.client.GetGatewayResponses(input)
return err
}, "Error caught during GetGatewayResponses with input "+apiId+"!")
}
if err != nil {
return nil, err
}
Expand Down