Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code to save the managed instance's region to the stored cred profile #97

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"testing"
"time"

"github.com/aws/amazon-ssm-agent/agent/fileutil"
"github.com/aws/amazon-ssm-agent/agent/managedInstances/sharedCredentials"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/stretchr/testify/assert"
)
Expand All @@ -29,8 +31,17 @@ var (
accessKeyID = "accessKeyID"
secretAccessKey = "secretAccessKey"
sessionToken = "sessionToken"
region = "us-east-1"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is not used

)

func cleanupCredFile() {
if credPath, err := sharedCredentials.Filename(); err == nil {
if credPath != "" && fileutil.Exists(credPath) {
fileutil.DeleteFile(credPath)
}
}
}

func TestRetrieve_ShouldReturnValidToken(t *testing.T) {
updateKeyPair := false
tokenExpirationDate := time.Now().Add(1 * time.Hour)
Expand All @@ -51,6 +62,8 @@ func TestRetrieve_ShouldReturnValidToken(t *testing.T) {
assert.Equal(t, accessKeyID, cred.AccessKeyID)
assert.Equal(t, secretAccessKey, cred.SecretAccessKey)
assert.Equal(t, sessionToken, cred.SessionToken)

cleanupCredFile()
}

func TestRetrieve_ShouldUpdateKeyPair(t *testing.T) {
Expand All @@ -60,6 +73,7 @@ func TestRetrieve_ShouldUpdateKeyPair(t *testing.T) {
publicKey: "publicKey",
privateKey: "privateKey",
keyType: "Rsa",
region: "us-east-1",
}
client := &RsaSignedServiceStub{
roleResponse: ssm.RequestManagedInstanceRoleTokenOutput{
Expand All @@ -76,6 +90,7 @@ func TestRetrieve_ShouldUpdateKeyPair(t *testing.T) {
_, err := testProvider.Retrieve()
assert.NoError(t, err)
assert.True(t, client.updateCalled)
cleanupCredFile()
}

func TestRetrieve_ShouldFailOnError(t *testing.T) {
Expand Down
21 changes: 17 additions & 4 deletions agent/managedInstances/sharedCredentials/shared_Credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"path/filepath"

"github.com/aws/amazon-ssm-agent/agent/fileutil"
"github.com/aws/amazon-ssm-agent/agent/managedInstances/registration"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/go-ini/ini"
)
Expand All @@ -29,12 +30,12 @@ const (
awsAccessKeyID = "aws_access_key_id"
awsSecretAccessKey = "aws_secret_access_key"
awsSessionToken = "aws_session_token"
awsRegion = "region"
)

// filename returns the filename to use to read AWS shared credentials.
//
// Filename returns the filename to use to read AWS shared credentials.
// Will return an error if the user's home directory path cannot be found.
func filename() (string, error) {
func Filename() (string, error) {
if credPath := os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); credPath != "" {
return credPath, nil
}
Expand Down Expand Up @@ -68,7 +69,7 @@ func Store(accessKeyID, secretAccessKey, sessionToken, profile string) error {
profile = defaultProfile
}

credPath, err := filename()
credPath, err := Filename()
if err != nil {
return err
}
Expand Down Expand Up @@ -98,6 +99,18 @@ func Store(accessKeyID, secretAccessKey, sessionToken, profile string) error {

iniProfile.Key(awsSessionToken).SetValue(sessionToken)

// Save the instance's region to the profile so that the FallbackRegionFactory can find it.
// Scripts that use the .NET Cmdlets and aws command line tools will automatically detect
// the AWS Region from the EC2 instance profile, however, this is not the case for on-prem
// servers, since they don't have the EC2 Metadata service. By adding the Region to the
// shared credentials file, the SDK will be able to discover the region automatically.
// This will ensure that scripts that run on on-prem servers will run the same way as
// they would on EC2 instances, without any modification.
region := registration.Region()
if region != "" {
iniProfile.Key(awsRegion).SetValue(region)
}

err = config.SaveTo(credPath)
if err != nil {
return awserr.New("SharedCredentialsStore", "failed to save profile", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
accessKey = "DummyAccessKey"
accessSecretKey = "DummyAccessSecretKey"
token = "DummyToken"
region = "us-east-1"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is not used

profile = "DummyProfile"
testFilePath = "example.ini"
)
Expand Down
4 changes: 2 additions & 2 deletions agent/s3util/s3util.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func (u *AmazonS3Util) S3Upload(log log.T, bucketName string, objectKey string,
func GetBucketRegion(log log.T, bucketName string, httpProvider HttpProvider) (region string) {
instanceRegion, err := getRegion()
if err != nil {
log.Error("Cannot get the current instance region information")
return instanceRegion // Default
log.Error(fmt.Errorf("Cannot get the current instance region information: %v", err))
return "us-east-1" // Default
}
log.Infof("Instance region is %v", instanceRegion)

Expand Down