diff --git a/.changelog/3a4c3951c2504554a64b14ce2dddf6ef.json b/.changelog/3a4c3951c2504554a64b14ce2dddf6ef.json new file mode 100644 index 00000000000..a17adbcd3fc --- /dev/null +++ b/.changelog/3a4c3951c2504554a64b14ce2dddf6ef.json @@ -0,0 +1,9 @@ +{ + "id": "3a4c3951-c250-4554-a64b-14ce2dddf6ef", + "type": "feature", + "description": "Support account ID retrieval in IMDS credentials provider, and support new IMDS profile name config:\n\n1. environment: `AWS_EC2_INSTANCE_PROFILE_NAME`\n2. shared config: `ec2_instance_profile_name`", + "modules": [ + "config", + "credentials" + ] +} \ No newline at end of file diff --git a/config/config_source_test.go b/config/config_source_test.go index b79bdb7a515..ac56a1d581b 100644 --- a/config/config_source_test.go +++ b/config/config_source_test.go @@ -128,10 +128,10 @@ func (f imdsForwarder) Do(r *http.Request) (*http.Response, error) { header.Set(ttlHeader, r.Header.Get(ttlHeader)) return &http.Response{StatusCode: 200, Header: header, Body: io.NopCloser(strings.NewReader("validToken"))}, nil } - if r.URL.Path == "/latest/meta-data/iam/security-credentials/" { + if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/" { return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("RoleName"))}, nil } - if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" { + if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/RoleName" { return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(ecsResponse))}, nil } return f.innerClient.Do(r) diff --git a/config/config_test.go b/config/config_test.go index 84aff3d7795..e3a2360aa19 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "reflect" + "strings" "testing" "github.com/aws/aws-sdk-go-v2/aws" @@ -148,6 +149,17 @@ func TestLoadDefaultConfig(t *testing.T) { } } +func TestLoadDefaultConfig_EmptyEC2InstanceProfileName(t *testing.T) { + t.Setenv(awsEc2InstanceProfileNameEnv, "") + _, err := LoadDefaultConfig(context.TODO()) + if err == nil { + t.Fatal("expect error, got none") + } + if expect, actual := "env AWS_EC2_INSTANCE_PROFILE_NAME cannot be empty", err.Error(); !strings.Contains(actual, expect) { + t.Fatalf("expect error %s, got %s", expect, actual) + } +} + func BenchmarkLoadProfile1(b *testing.B) { benchConfigLoad(b, 1) } diff --git a/config/env_config.go b/config/env_config.go index 9db507e38ea..f72a46c15a0 100644 --- a/config/env_config.go +++ b/config/env_config.go @@ -60,6 +60,8 @@ const ( awsEc2MetadataDisabledEnv = "AWS_EC2_METADATA_DISABLED" awsEc2MetadataV1DisabledEnv = "AWS_EC2_METADATA_V1_DISABLED" + awsEc2InstanceProfileNameEnv = "AWS_EC2_INSTANCE_PROFILE_NAME" + awsS3DisableMultiRegionAccessPointsEnv = "AWS_S3_DISABLE_MULTIREGION_ACCESS_POINTS" awsUseDualStackEndpointEnv = "AWS_USE_DUALSTACK_ENDPOINT" @@ -304,6 +306,9 @@ type EnvConfig struct { // Indicates whether response checksum should be validated ResponseChecksumValidation aws.ResponseChecksumValidation + + // Profile name used for fetching IMDS credentials. + EC2InstanceProfileName string } // loadEnvConfig reads configuration values from the OS's environment variables. @@ -347,6 +352,12 @@ func NewEnvConfig() (EnvConfig, error) { cfg.AppID = os.Getenv(awsSdkUaAppIDEnv) + ec2InstanceProfileName, ok := os.LookupEnv(awsEc2InstanceProfileNameEnv) + if ok && ec2InstanceProfileName == "" { + return cfg, fmt.Errorf("env %s cannot be empty", awsEc2InstanceProfileNameEnv) + } + cfg.EC2InstanceProfileName = ec2InstanceProfileName + if err := setBoolPtrFromEnvVal(&cfg.DisableRequestCompression, []string{awsDisableRequestCompressionEnv}); err != nil { return cfg, err } @@ -916,3 +927,11 @@ func (c EnvConfig) GetS3DisableExpressAuth() (value, ok bool) { return *c.S3DisableExpressAuth, true } + +func (c EnvConfig) getEC2InstanceProfileName() (string, bool, error) { + if len(c.EC2InstanceProfileName) == 0 { + return "", false, nil + } + + return c.EC2InstanceProfileName, true, nil +} diff --git a/config/env_config_test.go b/config/env_config_test.go index 870c46509bc..6fcda8f214d 100644 --- a/config/env_config_test.go +++ b/config/env_config_test.go @@ -568,6 +568,14 @@ func TestNewEnvConfig(t *testing.T) { Config: EnvConfig{}, WantErr: true, }, + 54: { + Env: map[string]string{ + "AWS_EC2_INSTANCE_PROFILE_NAME": "ProfileName", + }, + Config: EnvConfig{ + EC2InstanceProfileName: "ProfileName", + }, + }, } for i, c := range cases { diff --git a/config/provider.go b/config/provider.go index a8ff40d846b..baa36f53501 100644 --- a/config/provider.go +++ b/config/provider.go @@ -445,6 +445,22 @@ func getEC2RoleCredentialProviderOptions(ctx context.Context, configs configs) ( return } +type ec2InstanceProfileNameProvider interface { + getEC2InstanceProfileName() (string, bool, error) +} + +func getEC2InstanceProfileName(ctx context.Context, configs configs) (v string, found bool, err error) { + for _, config := range configs { + if p, ok := config.(ec2InstanceProfileNameProvider); ok { + v, found, err = p.getEC2InstanceProfileName() + if err != nil || found { + break + } + } + } + return +} + // defaultRegionProvider is an interface for retrieving a default region if a region was not resolved from other sources type defaultRegionProvider interface { getDefaultRegion(ctx context.Context) (string, bool, error) diff --git a/config/resolve_credentials.go b/config/resolve_credentials.go index b00259df03a..1f6392b2c3b 100644 --- a/config/resolve_credentials.go +++ b/config/resolve_credentials.go @@ -189,7 +189,7 @@ func resolveCredsFromProfile(ctx context.Context, cfg *aws.Config, envConfig *En default: ctx = addCredentialSource(ctx, aws.CredentialSourceIMDS) - err = resolveEC2RoleCredentials(ctx, cfg, configs) + err = resolveEC2RoleCredentials(ctx, cfg, envConfig, sharedConfig, configs) } if err != nil { return ctx, err @@ -379,7 +379,7 @@ func resolveCredsFromSource(ctx context.Context, cfg *aws.Config, envConfig *Env switch sharedCfg.CredentialSource { case credSourceEc2Metadata: ctx = addCredentialSource(ctx, aws.CredentialSourceIMDS) - return ctx, resolveEC2RoleCredentials(ctx, cfg, configs) + return ctx, resolveEC2RoleCredentials(ctx, cfg, envConfig, sharedCfg, configs) case credSourceEnvironment: ctx = addCredentialSource(ctx, aws.CredentialSourceHTTP) @@ -402,8 +402,21 @@ func resolveCredsFromSource(ctx context.Context, cfg *aws.Config, envConfig *Env return ctx, nil } -func resolveEC2RoleCredentials(ctx context.Context, cfg *aws.Config, configs configs) error { - optFns := make([]func(*ec2rolecreds.Options), 0, 2) +func resolveEC2RoleCredentials(ctx context.Context, cfg *aws.Config, envCfg *EnvConfig, sharedCfg *SharedConfig, configs configs) error { + optFns := make([]func(*ec2rolecreds.Options), 0, 3) + + var profile string + if sharedCfg != nil && sharedCfg.EC2InstanceProfileName != "" { + profile = sharedCfg.EC2InstanceProfileName + } + if envCfg != nil && envCfg.EC2InstanceProfileName != "" { + profile = envCfg.EC2InstanceProfileName + } + if profile != "" { + optFns = append(optFns, func(o *ec2rolecreds.Options) { + o.ProfileName = profile // caller options will override + }) + } optFn, found, err := getEC2RoleCredentialProviderOptions(ctx, configs) if err != nil { diff --git a/config/resolve_credentials_test.go b/config/resolve_credentials_test.go index 839445065a2..6828091e452 100644 --- a/config/resolve_credentials_test.go +++ b/config/resolve_credentials_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/internal/awstesting" "github.com/aws/aws-sdk-go-v2/service/sso" @@ -82,9 +83,15 @@ func setupCredentialsEndpoints() (aws.EndpointResolverWithOptions, func()) { ec2MetadataServer := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" { + if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/RoleName" { w.Write([]byte(ec2MetadataResponse)) - } else if r.URL.Path == "/latest/meta-data/iam/security-credentials/" { + } else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/LoadOptions" { + w.Write([]byte(ec2MetadataResponseLoadOptions)) + } else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/EnvCfg" { + w.Write([]byte(ec2MetadataResponseEnvCfg)) + } else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/SharedCfg" { + w.Write([]byte(ec2MetadataResponseSharedCfg)) + } else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/" { w.Write([]byte("RoleName")) } else if r.URL.Path == "/latest/api/token" { header := w.Header() @@ -750,6 +757,103 @@ func TestResolveCredentialsEcsContainer(t *testing.T) { } +func TestResolveCredentialsEC2RoleCreds(t *testing.T) { + testCases := map[string]struct { + expectedAccessKey string + expectedSecretKey string + envVar map[string]string + configFile string + configProfile string + loadOptions func(*LoadOptions) error + }{ + "no config whatsoever": { + expectedAccessKey: "ec2-access-key", + expectedSecretKey: "ec2-secret-key", + envVar: map[string]string{}, + configFile: "", + }, + "env cfg": { + expectedAccessKey: "ec2-access-key-envcfg", + expectedSecretKey: "ec2-secret-key-envcfg", + envVar: map[string]string{ + "AWS_EC2_INSTANCE_PROFILE_NAME": "EnvCfg", + }, + configFile: "", + }, + "shared cfg": { + expectedAccessKey: "ec2-access-key-sharedcfg", + expectedSecretKey: "ec2-secret-key-sharedcfg", + envVar: map[string]string{}, + configFile: filepath.Join("testdata", "config_source_shared"), + configProfile: "ec2metadata-profilename", + }, + "loadopts + env cfg + shared cfg": { + expectedAccessKey: "ec2-access-key-loadopts", + expectedSecretKey: "ec2-secret-key-loadopts", + envVar: map[string]string{ + "AWS_EC2_INSTANCE_PROFILE_NAME": "EnvCfg", + }, + configFile: filepath.Join("testdata", "config_source_shared"), + configProfile: "ec2metadata-profilename", + loadOptions: WithEC2RoleCredentialOptions(func(o *ec2rolecreds.Options) { + o.ProfileName = "LoadOptions" + }), + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + endpointResolver, cleanupFn := setupCredentialsEndpoints() + defer cleanupFn() + + // setupCredentialsEndpoints sets this above and then we hold onto + // it for this test + ec2MetadataURL := os.Getenv("AWS_EC2_METADATA_SERVICE_ENDPOINT") + + restoreEnv := awstesting.StashEnv() + defer awstesting.PopEnv(restoreEnv) + + os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", ec2MetadataURL) + for k, v := range tc.envVar { + os.Setenv(k, v) + } + var sharedConfigFiles []string + if tc.configFile != "" { + sharedConfigFiles = append(sharedConfigFiles, tc.configFile) + } + opts := []func(*LoadOptions) error{ + WithEndpointResolverWithOptions(endpointResolver), + WithRetryer(func() aws.Retryer { return aws.NopRetryer{} }), + WithSharedConfigFiles(sharedConfigFiles), + WithSharedCredentialsFiles([]string{}), + } + if len(tc.configProfile) != 0 { + opts = append(opts, WithSharedConfigProfile(tc.configProfile)) + } + + if tc.loadOptions != nil { + opts = append(opts, tc.loadOptions) + } + + cfg, err := LoadDefaultConfig(context.TODO(), opts...) + if err != nil { + t.Fatalf("could not load config: %s", err) + } + actual, err := cfg.Credentials.Retrieve(context.TODO()) + if err != nil { + t.Fatalf("could not retrieve credentials: %s", err) + } + if actual.AccessKeyID != tc.expectedAccessKey { + t.Errorf("expected access key to be %s, got %s", tc.expectedAccessKey, actual.AccessKeyID) + } + if actual.SecretAccessKey != tc.expectedSecretKey { + t.Errorf("expected secret key to be %s, got %s", tc.expectedSecretKey, actual.SecretAccessKey) + } + }) + } + +} + type stubErrorClient struct { err error } diff --git a/config/shared_config.go b/config/shared_config.go index 00b071fe6f1..4783d700ec8 100644 --- a/config/shared_config.go +++ b/config/shared_config.go @@ -82,6 +82,8 @@ const ( ec2MetadataV1DisabledKey = "ec2_metadata_v1_disabled" + ec2InstanceProfileNameKey = "ec2_instance_profile_name" + // Use DualStack Endpoint Resolution useDualStackEndpoint = "use_dualstack_endpoint" @@ -357,6 +359,9 @@ type SharedConfig struct { // ResponseChecksumValidation indicates if the response checksum should be validated ResponseChecksumValidation aws.ResponseChecksumValidation + + // Profile name used for fetching IMDS credentials. + EC2InstanceProfileName string } func (c SharedConfig) getDefaultsMode(ctx context.Context) (value aws.DefaultsMode, ok bool, err error) { @@ -877,6 +882,7 @@ func mergeSections(dst *ini.Sections, src ini.Sections) error { ec2MetadataServiceEndpointModeKey, ec2MetadataServiceEndpointKey, ec2MetadataV1DisabledKey, + ec2InstanceProfileNameKey, useDualStackEndpoint, useFIPSEndpointKey, defaultsModeKey, @@ -1110,6 +1116,8 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er updateString(&c.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey) updateBoolPtr(&c.EC2IMDSv1Disabled, section, ec2MetadataV1DisabledKey) + updateString(&c.EC2InstanceProfileName, section, ec2InstanceProfileNameKey) + updateUseDualStackEndpoint(&c.UseDualStackEndpoint, section, useDualStackEndpoint) updateUseFIPSEndpoint(&c.UseFIPSEndpoint, section, useFIPSEndpointKey) @@ -1678,3 +1686,11 @@ func updateUseFIPSEndpoint(dst *aws.FIPSEndpointState, section ini.Section, key return } + +func (c SharedConfig) getEC2InstanceProfileName() (string, bool, error) { + if len(c.EC2InstanceProfileName) == 0 { + return "", false, nil + } + + return c.EC2InstanceProfileName, true, nil +} diff --git a/config/shared_config_test.go b/config/shared_config_test.go index 8d71818c8d4..c480e47ff90 100644 --- a/config/shared_config_test.go +++ b/config/shared_config_test.go @@ -806,6 +806,15 @@ func TestNewSharedConfig(t *testing.T) { }, Err: fmt.Errorf("invalid value for shared config profile field, response_checksum_validation=blabla, must be when_supported/when_required"), }, + + "profile with ec2 instance profile name": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "ec2_instance_profile_name", + Expected: SharedConfig{ + Profile: "ec2_instance_profile_name", + EC2InstanceProfileName: "ProfileName", + }, + }, } for name, c := range cases { diff --git a/config/shared_test.go b/config/shared_test.go index b0ee123a161..29e67a20b23 100644 --- a/config/shared_test.go +++ b/config/shared_test.go @@ -27,6 +27,36 @@ const ec2MetadataResponse = `{ "LastUpdated": "2009-11-23T00:00:00Z" }` +const ec2MetadataResponseLoadOptions = `{ + "Code": "Success", + "Type": "AWS-HMAC", + "AccessKeyId": "ec2-access-key-loadopts", + "SecretAccessKey": "ec2-secret-key-loadopts", + "Token": "token", + "Expiration": "2100-01-01T00:00:00Z", + "LastUpdated": "2009-11-23T00:00:00Z" +}` + +const ec2MetadataResponseEnvCfg = `{ + "Code": "Success", + "Type": "AWS-HMAC", + "AccessKeyId": "ec2-access-key-envcfg", + "SecretAccessKey": "ec2-secret-key-envcfg", + "Token": "token", + "Expiration": "2100-01-01T00:00:00Z", + "LastUpdated": "2009-11-23T00:00:00Z" +}` + +const ec2MetadataResponseSharedCfg = `{ + "Code": "Success", + "Type": "AWS-HMAC", + "AccessKeyId": "ec2-access-key-sharedcfg", + "SecretAccessKey": "ec2-secret-key-sharedcfg", + "Token": "token", + "Expiration": "2100-01-01T00:00:00Z", + "LastUpdated": "2009-11-23T00:00:00Z" +}` + const assumeRoleRespMsg = ` diff --git a/config/testdata/config_source_shared b/config/testdata/config_source_shared index 625b14ecdb8..15816ba0056 100644 --- a/config/testdata/config_source_shared +++ b/config/testdata/config_source_shared @@ -101,3 +101,6 @@ role_arn = webident_arn [profile webident-partial] web_identity_token_file = ./testdata/wit.txt + +[profile ec2metadata-profilename] +ec2_instance_profile_name = SharedCfg diff --git a/config/testdata/shared_config b/config/testdata/shared_config index c7159d52bff..197b72e3a68 100644 --- a/config/testdata/shared_config +++ b/config/testdata/shared_config @@ -347,3 +347,5 @@ response_checksum_validation = when_required [profile response_checksum_validation_error] response_checksum_validation = blabla +[profile ec2_instance_profile_name] +ec2_instance_profile_name = ProfileName diff --git a/credentials/ec2rolecreds/provider.go b/credentials/ec2rolecreds/provider.go index a95e6c8bdd6..35254700e9c 100644 --- a/credentials/ec2rolecreds/provider.go +++ b/credentials/ec2rolecreds/provider.go @@ -4,10 +4,13 @@ import ( "bufio" "context" "encoding/json" + "errors" "fmt" "math" "path" "strings" + "sync" + "sync/atomic" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -17,6 +20,7 @@ import ( "github.com/aws/smithy-go" "github.com/aws/smithy-go/logging" "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" ) // ProviderName provides a name of EC2Role provider @@ -38,6 +42,23 @@ type GetMetadataAPIClient interface { // }) type Provider struct { options Options + + isLegacyPath atomic.Bool + + mu sync.Mutex + cachedProfile string +} + +func (p *Provider) getCachedProfile() string { + p.mu.Lock() + defer p.mu.Unlock() + return p.cachedProfile +} + +func (p *Provider) setCachedProfile(v string) { + p.mu.Lock() + defer p.mu.Unlock() + p.cachedProfile = v } // Options is a list of user settable options for setting the behavior of the Provider. @@ -48,6 +69,12 @@ type Options struct { // If nil, the provider will default to the EC2 IMDS client. Client GetMetadataAPIClient + // Explicit EC2 instance profile name to use when fetching credentials. + // + // If unset, the provider will make an extra initial IMDS call to determine + // what profile to use. + ProfileName string + // The chain of providers that was used to create this provider // These values are for reporting purposes and are not meant to be set up directly CredentialSources []aws.CredentialSource @@ -74,18 +101,12 @@ func New(optFns ...func(*Options)) *Provider { // Retrieve retrieves credentials from the EC2 service. Error will be returned // if the request fails, or unable to extract the desired credentials. func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { - credsList, err := requestCredList(ctx, p.options.Client) + profileName, err := p.resolveProfile(ctx) if err != nil { return aws.Credentials{Source: ProviderName}, err } - if len(credsList) == 0 { - return aws.Credentials{Source: ProviderName}, - fmt.Errorf("unexpected empty EC2 IMDS role list") - } - credsName := credsList[0] - - roleCreds, err := requestCred(ctx, p.options.Client, credsName) + roleCreds, err := p.requestCred(ctx, profileName) if err != nil { return aws.Credentials{Source: ProviderName}, err } @@ -94,6 +115,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { AccessKeyID: roleCreds.AccessKeyID, SecretAccessKey: roleCreds.SecretAccessKey, SessionToken: roleCreds.Token, + AccountID: roleCreds.AccountID, Source: ProviderName, CanExpire: true, @@ -109,6 +131,59 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { return creds, nil } +func (p *Provider) resolveProfile(ctx context.Context) (string, error) { + if p.options.ProfileName != "" { + return p.options.ProfileName, nil + } + + if cached := p.getCachedProfile(); cached != "" { + return cached, nil + } + + credsList, err := p.requestCredList(ctx) + if err != nil { + return "", err + } + if len(credsList) == 0 { + return "", errors.New("unexpected empty EC2 IMDS role list") + } + + p.setCachedProfile(credsList[0]) + return credsList[0], nil +} + +// Indirects the underlying imds.GetMetadata to handle fallback to the "legacy" +// credentials metadata path. The profile MAY be empty. +func (p *Provider) getMetadata(ctx context.Context, profile string) (*imds.GetMetadataOutput, error) { + isLegacy := p.isLegacyPath.Load() + // we only need to fallback when + // 1. we haven't already + // 2. this request IS NOT using a cached profile - it's either to + // retrieve a profile, or retrieval with an explicit profile from options + canFallback := !isLegacy && (profile == "" || profile == p.options.ProfileName) + + ppath := credsPath + if isLegacy { + ppath = legacyCredsPath + } + + if profile != "" { // path.Join will strip the trailing slash, which we don't want + ppath = path.Join(ppath, profile) + } + out, err := p.options.Client.GetMetadata(ctx, &imds.GetMetadataInput{ + Path: ppath, + }) + if err != nil && is404(err) && canFallback { + p.isLegacyPath.Store(true) + return p.getMetadata(ctx, profile) + } + if err != nil { + return nil, err + } + + return out, nil +} + // HandleFailToRefresh will extend the credentials Expires time if it it is // expired. If the credentials will not expire within the minimum time, they // will be returned. @@ -166,21 +241,23 @@ type ec2RoleCredRespBody struct { AccessKeyID string SecretAccessKey string Token string + AccountID string // Error state Code string Message string } -const iamSecurityCredsPath = "/iam/security-credentials/" +const ( + legacyCredsPath = "/iam/security-credentials/" + credsPath = "/iam/security-credentials-extended/" +) // requestCredList requests a list of credentials from the EC2 service. If // there are no credentials, or there is an error making or receiving the // request -func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) { - resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ - Path: iamSecurityCredsPath, - }) +func (p *Provider) requestCredList(ctx context.Context) ([]string, error) { + resp, err := p.getMetadata(ctx, "") if err != nil { return nil, fmt.Errorf("no EC2 IMDS role found, %w", err) } @@ -203,10 +280,18 @@ func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string // // If the credentials cannot be found, or there is an error reading the response // and error will be returned. -func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) { - resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ - Path: path.Join(iamSecurityCredsPath, credsName), - }) +func (p *Provider) requestCred(ctx context.Context, credsName string) (ec2RoleCredRespBody, error) { + resp, err := p.getMetadata(ctx, credsName) + if err != nil && is404(err) && p.getCachedProfile() != "" { + // 404 on a cached profile means it isn't stable, so reset it and try again + p.setCachedProfile("") + credsName, err = p.resolveProfile(ctx) + if err != nil { + return ec2RoleCredRespBody{}, err + } + + resp, err = p.getMetadata(ctx, credsName) + } if err != nil { return ec2RoleCredRespBody{}, fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w", @@ -239,3 +324,11 @@ func (p *Provider) ProviderSources() []aws.CredentialSource { } // If no source has been set, assume this is used directly which means just call to assume role return p.options.CredentialSources } + +func is404(err error) bool { + var terr *smithyhttp.ResponseError + if errors.As(err, &terr) { + return terr.HTTPStatusCode() == 404 + } + return false +} diff --git a/credentials/ec2rolecreds/provider_sep_test.go b/credentials/ec2rolecreds/provider_sep_test.go new file mode 100644 index 00000000000..d7d15115e1c --- /dev/null +++ b/credentials/ec2rolecreds/provider_sep_test.go @@ -0,0 +1,784 @@ +package ec2rolecreds + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +type sepTestCase struct { + Summary string + Config struct { + EC2InstanceProfileName string + EnvVars map[string]string + } + Expectations []struct { + Get string + Response struct { + Status int + Body any + } + } + Outcomes []struct { + Result string + AccountID string + } +} + +type sepMockIMDS struct { + testCase sepTestCase + callIndex int +} + +func (m *sepMockIMDS) GetMetadata(ctx context.Context, in *imds.GetMetadataInput, opts ...func(*imds.Options)) (*imds.GetMetadataOutput, error) { + callIndex := m.callIndex + next := m.testCase.Expectations[callIndex] + m.callIndex++ + + expectPath := strings.TrimPrefix(next.Get, "/latest/meta-data") // the real IMDS client injects this + if strings.HasSuffix(expectPath, "security-credentials") || + strings.HasSuffix(expectPath, "security-credentials-extended") { + expectPath += "/" // we've always had the trailing / on these + } + if expectPath != in.Path { + return nil, fmt.Errorf("unexpected path in call %d: expect %s, got %s", callIndex, expectPath, in.Path) + } + + if next.Response.Status != 200 { + return nil, mockResponseError(next.Response.Status) + } + + switch v := next.Response.Body.(type) { + case string: + return &imds.GetMetadataOutput{ + Content: io.NopCloser(strings.NewReader(v)), + }, nil + case map[string]any: + j, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("unhandled response marshal failure in sep test case: %v", v) + } + return &imds.GetMetadataOutput{ + Content: io.NopCloser(bytes.NewReader(j)), + }, nil + default: + return nil, fmt.Errorf("unhandled body type in sep test case: %T", next.Response.Body) + } +} + +func mockResponseError(status int) error { + return &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: status, + Body: http.NoBody, + }, + }, + } +} + +const sepTestCaseJSON = `[ + { + "summary": "Test IMDS credentials provider with env vars { AWS_EC2_METADATA_DISABLED=true } returns no credentials", + "config": { + "ec2InstanceProfileName": null, + "envVars": { + "AWS_EC2_METADATA_DISABLED": "true" + } + }, + "expectations": [], + "outcomes": [ + { + "result": "no credentials" + } + ] + }, + { + "summary": "Test IMDS credentials provider returns valid credentials with account ID", + "config": { + "ec2InstanceProfileName": null + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 200, + "body": "my-profile-0001" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0001", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-12T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-12T21:53:17.832308Z", + "UnexpectedElement1": { + "Name": "ignore-me-1" + }, + "AccountId": "123456789101" + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0001", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-12T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-12T21:53:17.832308Z", + "UnexpectedElement1": { + "Name": "ignore-me-1" + }, + "AccountId": "123456789101" + } + } + } + ], + "outcomes": [ + { + "result": "credentials", + "accountId": "123456789101" + }, + { + "result": "credentials", + "accountId": "123456789101" + } + ] + }, + { + "summary": "Test IMDS credentials provider with a given profile name returns valid credentials with account ID", + "config": { + "ec2InstanceProfileName": "my-profile-0002" + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0002", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-13T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-13T21:53:17.832308Z", + "UnexpectedElement2": { + "Name": "ignore-me-2" + }, + "AccountId": "234567891011" + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0002", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-13T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-13T21:53:17.832308Z", + "UnexpectedElement2": { + "Name": "ignore-me-2" + }, + "AccountId": "234567891011" + } + } + } + ], + "outcomes": [ + { + "result": "credentials", + "accountId": "234567891011" + }, + { + "result": "credentials", + "accountId": "234567891011" + } + ] + }, + { + "summary": "Test IMDS credentials provider when profile is unstable returns valid credentials with account ID", + "config": { + "ec2InstanceProfileName": null + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 200, + "body": "my-profile-0003" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0003", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-14T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-14T21:53:17.832308Z", + "UnexpectedElement3": { + "Name": "ignore-me-3" + }, + "AccountId": "345678910112" + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0003", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 200, + "body": "my-profile-0003-b" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0003-b", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-14T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-14T21:53:17.832308Z", + "UnexpectedElement3": { + "Name": "ignore-me-3" + }, + "AccountId": "314253647589" + } + } + } + ], + "outcomes": [ + { + "result": "credentials", + "accountId": "345678910112" + }, + { + "result": "credentials", + "accountId": "314253647589" + } + ] + }, + { + "summary": "Test IMDS credentials provider with a given profile name when profile is invalid throws an error", + "config": { + "ec2InstanceProfileName": "my-profile-0004" + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0004", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0004", + "response": { + "status": 404 + } + } + ], + "outcomes": [ + { + "result": "invalid profile" + } + ] + }, + { + "summary": "Test IMDS credentials provider when account ID is unavailable returns valid credentials", + "config": { + "ec2InstanceProfileName": null + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 200, + "body": "my-profile-0005" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0005", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-16T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-16T21:53:17.832308Z", + "UnexpectedElement5": { + "Name": "ignore-me-5" + } + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0005", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-16T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-16T21:53:17.832308Z", + "UnexpectedElement5": { + "Name": "ignore-me-5" + } + } + } + } + ], + "outcomes": [ + { + "result": "credentials" + }, + { + "result": "credentials" + } + ] + }, + { + "summary": "Test IMDS credentials provider with a given profile name when account ID is unavailable returns valid credentials", + "config": { + "ec2InstanceProfileName": "my-profile-0006" + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0006", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-17T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-17T21:53:17.832308Z", + "UnexpectedElement6": { + "Name": "ignore-me-6" + } + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0006", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-17T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-17T21:53:17.832308Z", + "UnexpectedElement6": { + "Name": "ignore-me-6" + } + } + } + } + ], + "outcomes": [ + { + "result": "credentials" + }, + { + "result": "credentials" + } + ] + }, + { + "summary": "Test IMDS credentials provider when account ID is unavailable when profile is unstable returns valid credentials", + "config": { + "ec2InstanceProfileName": null + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 200, + "body": "my-profile-0007" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0007", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-18T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-18T21:53:17.832308Z", + "UnexpectedElement7": { + "Name": "ignore-me-7" + } + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0007", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 200, + "body": "my-profile-0007-b" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0007-b", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-18T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-18T21:53:17.832308Z", + "UnexpectedElement7": { + "Name": "ignore-me-7" + } + } + } + } + ], + "outcomes": [ + { + "result": "credentials" + }, + { + "result": "credentials" + } + ] + }, + { + "summary": "Test IMDS credentials provider with a given profile name when account ID is unavailable when profile is invalid throws an error", + "config": { + "ec2InstanceProfileName": "my-profile-0008" + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0008", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0008", + "response": { + "status": 404 + } + } + ], + "outcomes": [ + { + "result": "invalid profile" + } + ] + }, + { + "summary": "Test IMDS credentials provider against legacy API returns valid credentials", + "config": { + "ec2InstanceProfileName": null + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials", + "response": { + "status": 200, + "body": "my-profile-0009" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0009", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-20T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-20T21:53:17.832308Z" + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0009", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-20T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-20T21:53:17.832308Z" + } + } + } + ], + "outcomes": [ + { + "result": "credentials" + }, + { + "result": "credentials" + } + ] + }, + { + "summary": "Test IMDS credentials provider with a given profile name against legacy API returns valid credentials", + "config": { + "ec2InstanceProfileName": "my-profile-0010" + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0010", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0010", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-21T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-21T21:53:17.832308Z" + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0010", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-21T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-21T21:53:17.832308Z" + } + } + } + ], + "outcomes": [ + { + "result": "credentials" + }, + { + "result": "credentials" + } + ] + }, + { + "summary": "Test IMDS credentials provider against legacy API when profile is unstable returns valid credentials", + "config": { + "ec2InstanceProfileName": null + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials", + "response": { + "status": 200, + "body": "my-profile-0011" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0011", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-22T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-22T21:53:17.832308Z" + } + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0011", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials", + "response": { + "status": 200, + "body": "my-profile-0011-b" + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0011-b", + "response": { + "status": 200, + "body": { + "Code": "Success", + "LastUpdated": "2025-03-22T20:53:17.832308Z", + "Type": "AWS-HMAC", + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "Token": "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKw...(truncated)", + "Expiration": "2025-03-22T21:53:17.832308Z" + } + } + } + ], + "outcomes": [ + { + "result": "credentials" + }, + { + "result": "credentials" + } + ] + }, + { + "summary": "Test IMDS credentials provider with a given profile name against legacy API when profile is invalid throws an error", + "config": { + "ec2InstanceProfileName": "my-profile-0012" + }, + "expectations": [ + { + "get": "/latest/meta-data/iam/security-credentials-extended/my-profile-0012", + "response": { + "status": 404 + } + }, + { + "get": "/latest/meta-data/iam/security-credentials/my-profile-0012", + "response": { + "status": 404 + } + } + ], + "outcomes": [ + { + "result": "invalid profile" + } + ] + } +]` + +var skipSEPTestCases = map[string]string{ + "Test IMDS credentials provider with env vars { AWS_EC2_METADATA_DISABLED=true } returns no credentials": "environment variables are not considered in these unit tests", +} + +func TestProvider_SEPTestCases(t *testing.T) { + var testCases []sepTestCase + if err := json.Unmarshal([]byte(sepTestCaseJSON), &testCases); err != nil { + t.Fatal(err) + } + + for _, tt := range testCases { + t.Run(tt.Summary, func(t *testing.T) { + if reason, ok := skipSEPTestCases[tt.Summary]; ok { + t.Skip(reason) + } + + mockIMDS := &sepMockIMDS{testCase: tt} + provider := New(func(o *Options) { + o.ProfileName = tt.Config.EC2InstanceProfileName + o.Client = mockIMDS + }) + + for _, expect := range tt.Outcomes { + creds, err := provider.Retrieve(context.Background()) + switch expect.Result { + case "credentials": + if creds.AccessKeyID == "" { + t.Errorf("expected credentials, got none: %v", err) + } + if expect.AccountID != creds.AccountID { + t.Errorf("expected account id %q, got %q", expect.AccountID, creds.AccountID) + } + case "no credentials": + fallthrough + case "invalid profile": + if err == nil { + t.Error("expected error, got none") + } + } + } + }) + } +} diff --git a/credentials/ec2rolecreds/provider_test.go b/credentials/ec2rolecreds/provider_test.go index e04ef1beac6..ccab63adf7d 100644 --- a/credentials/ec2rolecreds/provider_test.go +++ b/credentials/ec2rolecreds/provider_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "net/http" "reflect" "strings" "testing" @@ -19,6 +20,7 @@ import ( "github.com/aws/smithy-go" "github.com/aws/smithy-go/logging" "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" ) const credsRespTmpl = `{ @@ -31,34 +33,76 @@ const credsRespTmpl = `{ "LastUpdated" : "2009-11-23T00:00:00Z" }` +const credsRespTmplWithAccountID = `{ + "Code": "Success", + "Type": "AWS-HMAC", + "AccessKeyId" : "accessKey", + "SecretAccessKey" : "secret", + "AccountId" : "accountId", + "Token" : "token", + "Expiration" : "%s", + "LastUpdated" : "2009-11-23T00:00:00Z" +}` + const credsFailRespTmpl = `{ "Code": "ErrorCode", "Message": "ErrorMsg", "LastUpdated": "2009-11-23T00:00:00Z" }` +var err404 = &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 404, + Body: http.NoBody, + }, + }, +} + type mockClient struct { - t *testing.T roleName string failAssume bool expireOn string + + calls []string + isLegacy bool + fail404 bool + returnAccountID bool } -func (c mockClient) GetMetadata( +func (c *mockClient) GetMetadata( ctx context.Context, params *imds.GetMetadataInput, optFns ...func(*imds.Options), ) ( *imds.GetMetadataOutput, error, ) { + c.calls = append(c.calls, params.Path) switch params.Path { - case iamSecurityCredsPath: + case credsPath: + if c.isLegacy { + return nil, err404 + } + fallthrough + case legacyCredsPath: + if c.fail404 { + return nil, err404 + } return &imds.GetMetadataOutput{ Content: ioutil.NopCloser(strings.NewReader(c.roleName)), }, nil - case iamSecurityCredsPath + c.roleName: + case credsPath + c.roleName: + if c.isLegacy { + return nil, err404 + } + fallthrough + case legacyCredsPath + c.roleName: var w strings.Builder - if c.failAssume { + if c.fail404 { + return nil, err404 + } else if c.failAssume { fmt.Fprintf(&w, credsFailRespTmpl) + } else if c.returnAccountID { + fmt.Fprintf(&w, credsRespTmplWithAccountID, c.expireOn) } else { fmt.Fprintf(&w, credsRespTmpl, c.expireOn) } @@ -70,6 +114,20 @@ func (c mockClient) GetMetadata( } } +func (c *mockClient) expectCalls(t *testing.T, calls ...string) { + t.Helper() + + if len(calls) != len(c.calls) { + t.Fatalf("expected %d calls, got %d", len(calls), len(c.calls)) + } + + for i, expect := range calls { + if expect != c.calls[i] { + t.Errorf("expect call to %s, got %s", expect, c.calls[i]) + } + } +} + var ( _ aws.AdjustExpiresByCredentialsCacheStrategy = (*Provider)(nil) _ aws.HandleFailRefreshCredentialsCacheStrategy = (*Provider)(nil) @@ -80,13 +138,92 @@ func TestProvider(t *testing.T) { defer func() { sdk.NowTime = orig }() p := New(func(options *Options) { - options.Client = mockClient{ + options.Client = &mockClient{ roleName: "RoleName", failAssume: false, expireOn: "2014-12-16T01:51:37Z", } }) + creds, err := p.Retrieve(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := "accessKey", creds.AccessKeyID; e != a { + t.Errorf("Expect access key ID to match") + } + if e, a := "secret", creds.SecretAccessKey; e != a { + t.Errorf("Expect secret access key to match") + } + if e, a := "token", creds.SessionToken; e != a { + t.Errorf("Expect session token to match") + } + if e, a := "", creds.AccountID; e != a { + t.Errorf("Expect account ID to match") + } + + sdk.NowTime = func() time.Time { + return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC) + } + + if creds.Expired() { + t.Errorf("Expect not expired") + } +} + +func TestProvider_AccountID(t *testing.T) { + orig := sdk.NowTime + defer func() { sdk.NowTime = orig }() + + p := New(func(options *Options) { + options.Client = &mockClient{ + roleName: "RoleName", + failAssume: false, + expireOn: "2014-12-16T01:51:37Z", + returnAccountID: true, + } + }) + + creds, err := p.Retrieve(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := "accessKey", creds.AccessKeyID; e != a { + t.Errorf("Expect access key ID to match") + } + if e, a := "secret", creds.SecretAccessKey; e != a { + t.Errorf("Expect secret access key to match") + } + if e, a := "token", creds.SessionToken; e != a { + t.Errorf("Expect session token to match") + } + if e, a := "accountId", creds.AccountID; e != a { + t.Errorf("Expect account ID to match") + } + + sdk.NowTime = func() time.Time { + return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC) + } + + if creds.Expired() { + t.Errorf("Expect not expired") + } +} + +func TestProvider_LegacyPath(t *testing.T) { + orig := sdk.NowTime + defer func() { sdk.NowTime = orig }() + + m := &mockClient{ + roleName: "RoleName", + failAssume: false, + expireOn: "2014-12-16T01:51:37Z", + isLegacy: true, + } + p := New(func(options *Options) { + options.Client = m + }) + creds, err := p.Retrieve(context.Background()) if err != nil { t.Fatalf("expect no error, got %v", err) @@ -108,11 +245,112 @@ func TestProvider(t *testing.T) { if creds.Expired() { t.Errorf("Expect not expired") } + + m.expectCalls(t, + "/iam/security-credentials-extended/", + "/iam/security-credentials/", + "/iam/security-credentials/RoleName", + ) +} + +func TestProvider_LegacyPath_ProfileOverride(t *testing.T) { + orig := sdk.NowTime + defer func() { sdk.NowTime = orig }() + + m := &mockClient{ + roleName: "RoleName", + failAssume: false, + expireOn: "2014-12-16T01:51:37Z", + isLegacy: true, + } + + p := New(func(options *Options) { + options.ProfileName = "RoleName" + options.Client = m + }) + + creds, err := p.Retrieve(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := "accessKey", creds.AccessKeyID; e != a { + t.Errorf("Expect access key ID to match") + } + if e, a := "secret", creds.SecretAccessKey; e != a { + t.Errorf("Expect secret access key to match") + } + if e, a := "token", creds.SessionToken; e != a { + t.Errorf("Expect session token to match") + } + + sdk.NowTime = func() time.Time { + return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC) + } + + if creds.Expired() { + t.Errorf("Expect not expired") + } + + m.expectCalls(t, + "/iam/security-credentials-extended/RoleName", + "/iam/security-credentials/RoleName", + ) +} + +func TestProvider_LegacyPath_Still404(t *testing.T) { + orig := sdk.NowTime + defer func() { sdk.NowTime = orig }() + + m := &mockClient{ + roleName: "RoleName", + expireOn: "2014-12-16T01:51:37Z", + isLegacy: true, + fail404: true, + } + p := New(func(options *Options) { + options.Client = m + }) + + _, err := p.Retrieve(context.Background()) + if err == nil { + t.Fatal("expect error, got none") + } + + m.expectCalls(t, + "/iam/security-credentials-extended/", + "/iam/security-credentials/", + ) +} + +func TestProvider_LegacyPath_ProfileOverride_Still404(t *testing.T) { + orig := sdk.NowTime + defer func() { sdk.NowTime = orig }() + + m := &mockClient{ + roleName: "RoleName", + expireOn: "2014-12-16T01:51:37Z", + isLegacy: true, + fail404: true, + } + p := New(func(options *Options) { + options.ProfileName = "RoleName" + options.Client = m + }) + + _, err := p.Retrieve(context.Background()) + if err == nil { + t.Fatal("expect error, got none") + } + + m.expectCalls(t, + "/iam/security-credentials-extended/RoleName", + "/iam/security-credentials/RoleName", + ) } func TestProvider_FailAssume(t *testing.T) { p := New(func(options *Options) { - options.Client = mockClient{ + options.Client = &mockClient{ roleName: "RoleName", failAssume: true, expireOn: "2014-12-16T01:51:37Z", @@ -156,7 +394,7 @@ func TestProvider_IsExpired(t *testing.T) { defer func() { sdk.NowTime = orig }() p := New(func(options *Options) { - options.Client = mockClient{ + options.Client = &mockClient{ roleName: "RoleName", failAssume: false, expireOn: "2014-12-16T01:51:37Z",