Skip to content

Commit 4a691e1

Browse files
authored
Use proper HTTP client for fetching credentials (#2041)
* Use proper HTTP client for fetching credentials * Allow custom `http.Client` in credential providers.
1 parent 8dc4193 commit 4a691e1

24 files changed

+185
-130
lines changed

api-presigned.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ func (c *Client) PresignedPostPolicy(ctx context.Context, p *PostPolicy) (u *url
140140
}
141141

142142
// Get credentials from the configured credentials provider.
143-
credValues, err := c.credsProvider.Get()
143+
credValues, err := c.credsProvider.GetWithContext(c.CredContext())
144144
if err != nil {
145145
return nil, nil, err
146146
}

api.go

+15-4
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,9 @@ func (c *Client) executeMethod(ctx context.Context, method string, metadata requ
600600
return nil, errors.New(c.endpointURL.String() + " is offline.")
601601
}
602602

603-
var retryable bool // Indicates if request can be retried.
604-
var bodySeeker io.Seeker // Extracted seeker from io.Reader.
605-
var reqRetry = c.maxRetries // Indicates how many times we can retry the request
603+
var retryable bool // Indicates if request can be retried.
604+
var bodySeeker io.Seeker // Extracted seeker from io.Reader.
605+
reqRetry := c.maxRetries // Indicates how many times we can retry the request
606606

607607
if metadata.contentBody != nil {
608608
// Check if body is seekable then it is retryable.
@@ -808,7 +808,7 @@ func (c *Client) newRequest(ctx context.Context, method string, metadata request
808808
}
809809

810810
// Get credentials from the configured credentials provider.
811-
value, err := c.credsProvider.Get()
811+
value, err := c.credsProvider.GetWithContext(c.CredContext())
812812
if err != nil {
813813
return nil, err
814814
}
@@ -1018,3 +1018,14 @@ func (c *Client) isVirtualHostStyleRequest(url url.URL, bucketName string) bool
10181018
// path style requests
10191019
return s3utils.IsVirtualHostSupported(url, bucketName)
10201020
}
1021+
1022+
// CredContext returns the context for fetching credentials
1023+
func (c *Client) CredContext() *credentials.CredContext {
1024+
httpClient := c.httpClient
1025+
if httpClient == nil {
1026+
httpClient = http.DefaultClient
1027+
}
1028+
return &credentials.CredContext{
1029+
Client: httpClient,
1030+
}
1031+
}

bucket-cache.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ func (c *Client) getBucketLocationRequest(ctx context.Context, bucketName string
212212
c.setUserAgent(req)
213213

214214
// Get credentials from the configured credentials provider.
215-
value, err := c.credsProvider.Get()
215+
value, err := c.credsProvider.GetWithContext(c.CredContext())
216216
if err != nil {
217217
return nil, err
218218
}

bucket-cache_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestGetBucketLocationRequest(t *testing.T) {
9797
c.setUserAgent(req)
9898

9999
// Get credentials from the configured credentials provider.
100-
value, err := c.credsProvider.Get()
100+
value, err := c.credsProvider.GetWithContext(c.CredContext())
101101
if err != nil {
102102
return nil, err
103103
}

pkg/credentials/assume_role.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ type AssumeRoleResult struct {
7676
type STSAssumeRole struct {
7777
Expiry
7878

79-
// Required http Client to use when connecting to MinIO STS service.
79+
// Optional http Client to use when connecting to MinIO STS service
80+
// (overrides default client in CredContext)
8081
Client *http.Client
8182

8283
// STS endpoint to fetch STS credentials.
@@ -115,9 +116,6 @@ func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentia
115116
return nil, errors.New("AssumeRole credentials access/secretkey is mandatory")
116117
}
117118
return New(&STSAssumeRole{
118-
Client: &http.Client{
119-
Transport: http.DefaultTransport,
120-
},
121119
STSEndpoint: stsEndpoint,
122120
Options: opts,
123121
}), nil
@@ -224,8 +222,12 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
224222

225223
// Retrieve retrieves credentials from the MinIO service.
226224
// Error will be returned if the request fails.
227-
func (m *STSAssumeRole) Retrieve() (Value, error) {
228-
a, err := getAssumeRoleCredentials(m.Client, m.STSEndpoint, m.Options)
225+
func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) {
226+
client := m.Client
227+
if client == nil {
228+
client = cc.Client
229+
}
230+
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
229231
if err != nil {
230232
return Value{}, err
231233
}

pkg/credentials/chain.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ func NewChainCredentials(providers []Provider) *Credentials {
6060
//
6161
// If a provider is found with credentials, it will be cached and any calls
6262
// to IsExpired() will return the expired state of the cached provider.
63-
func (c *Chain) Retrieve() (Value, error) {
63+
func (c *Chain) Retrieve(cc *CredContext) (Value, error) {
6464
for _, p := range c.Providers {
65-
creds, _ := p.Retrieve()
65+
creds, _ := p.Retrieve(cc)
6666
// Always prioritize non-anonymous providers, if any.
6767
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
6868
continue

pkg/credentials/chain_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ type testCredProvider struct {
2828
err error
2929
}
3030

31-
func (s *testCredProvider) Retrieve() (Value, error) {
31+
func (s *testCredProvider) Retrieve(_ *CredContext) (Value, error) {
3232
s.expired = false
3333
return s.creds, s.err
3434
}
@@ -59,7 +59,7 @@ func TestChainGet(t *testing.T) {
5959
},
6060
}
6161

62-
creds, err := p.Retrieve()
62+
creds, err := p.Retrieve(defaultCredContext)
6363
if err != nil {
6464
t.Fatal(err)
6565
}
@@ -95,7 +95,7 @@ func TestChainIsExpired(t *testing.T) {
9595
t.Fatal("Expected expired to be true before any Retrieve")
9696
}
9797

98-
_, err := p.Retrieve()
98+
_, err := p.Retrieve(defaultCredContext)
9999
if err != nil {
100100
t.Fatal(err)
101101
}
@@ -112,7 +112,7 @@ func TestChainWithNoProvider(t *testing.T) {
112112
if !p.IsExpired() {
113113
t.Fatal("Expected to be expired with no providers")
114114
}
115-
_, err := p.Retrieve()
115+
_, err := p.Retrieve(defaultCredContext)
116116
if err != nil {
117117
if err.Error() != "No valid providers found []" {
118118
t.Error(err)
@@ -136,7 +136,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
136136
t.Fatal("Expected to be expired with no providers")
137137
}
138138

139-
_, err := p.Retrieve()
139+
_, err := p.Retrieve(defaultCredContext)
140140
if err != nil {
141141
if err.Error() != "No valid providers found [FirstError SecondError]" {
142142
t.Error(err)

pkg/credentials/credentials.go

+32-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package credentials
1919

2020
import (
21+
"net/http"
2122
"sync"
2223
"time"
2324
)
@@ -30,6 +31,10 @@ const (
3031
defaultExpiryWindow = 0.8
3132
)
3233

34+
// defaultCredContext is used when the credential context doesn't
35+
// actually matter or the default context is suitable.
36+
var defaultCredContext = &CredContext{Client: http.DefaultClient}
37+
3338
// A Value is the S3 credentials value for individual credential fields.
3439
type Value struct {
3540
// S3 Access key ID
@@ -54,13 +59,21 @@ type Value struct {
5459
type Provider interface {
5560
// Retrieve returns nil if it successfully retrieved the value.
5661
// Error is returned if the value were not obtainable, or empty.
57-
Retrieve() (Value, error)
62+
Retrieve(cc *CredContext) (Value, error)
5863

5964
// IsExpired returns if the credentials are no longer valid, and need
6065
// to be retrieved.
6166
IsExpired() bool
6267
}
6368

69+
// CredContext is passed to the Retrieve function of a provider to provide
70+
// some additional context to retrieve credentials.
71+
type CredContext struct {
72+
// Client specifies the HTTP client that should be used if an HTTP
73+
// request is to be made to fetch the credentials.
74+
Client *http.Client
75+
}
76+
6477
// A Expiry provides shared expiration logic to be used by credentials
6578
// providers to implement expiry functionality.
6679
//
@@ -146,7 +159,24 @@ func New(provider Provider) *Credentials {
146159
//
147160
// If Credentials.Expire() was called the credentials Value will be force
148161
// expired, and the next call to Get() will cause them to be refreshed.
162+
//
163+
// Deprecated: Get() exists for historical compatibility and should not be
164+
// used. To get new credentials use the Credentials.GetWithContext function
165+
// to ensure the proper context (i.e. HTTP client) will be used.
149166
func (c *Credentials) Get() (Value, error) {
167+
return c.GetWithContext(defaultCredContext)
168+
}
169+
170+
// GetWithContext returns the credentials value, or error if the
171+
// credentials Value failed to be retrieved.
172+
//
173+
// Will return the cached credentials Value if it has not expired. If the
174+
// credentials Value has expired the Provider's Retrieve() will be called
175+
// to refresh the credentials.
176+
//
177+
// If Credentials.Expire() was called the credentials Value will be force
178+
// expired, and the next call to Get() will cause them to be refreshed.
179+
func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) {
150180
if c == nil {
151181
return Value{}, nil
152182
}
@@ -155,7 +185,7 @@ func (c *Credentials) Get() (Value, error) {
155185
defer c.Unlock()
156186

157187
if c.isExpired() {
158-
creds, err := c.provider.Retrieve()
188+
creds, err := c.provider.Retrieve(cc)
159189
if err != nil {
160190
return Value{}, err
161191
}

pkg/credentials/credentials_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ type credProvider struct {
2828
err error
2929
}
3030

31-
func (s *credProvider) Retrieve() (Value, error) {
31+
func (s *credProvider) Retrieve(_ *CredContext) (Value, error) {
3232
s.expired = false
3333
return s.creds, s.err
3434
}
@@ -47,7 +47,7 @@ func TestCredentialsGet(t *testing.T) {
4747
expired: true,
4848
})
4949

50-
creds, err := c.Get()
50+
creds, err := c.GetWithContext(defaultCredContext)
5151
if err != nil {
5252
t.Fatal(err)
5353
}
@@ -65,7 +65,7 @@ func TestCredentialsGet(t *testing.T) {
6565
func TestCredentialsGetWithError(t *testing.T) {
6666
c := New(&credProvider{err: errors.New("Custom error")})
6767

68-
_, err := c.Get()
68+
_, err := c.GetWithContext(defaultCredContext)
6969
if err != nil {
7070
if err.Error() != "Custom error" {
7171
t.Errorf("Expected \"Custom error\", got %s", err.Error())

pkg/credentials/env_aws.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func NewEnvAWS() *Credentials {
3838
}
3939

4040
// Retrieve retrieves the keys from the environment.
41-
func (e *EnvAWS) Retrieve() (Value, error) {
41+
func (e *EnvAWS) Retrieve(_ *CredContext) (Value, error) {
4242
e.retrieved = false
4343

4444
id := os.Getenv("AWS_ACCESS_KEY_ID")

pkg/credentials/env_minio.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func NewEnvMinio() *Credentials {
3939
}
4040

4141
// Retrieve retrieves the keys from the environment.
42-
func (e *EnvMinio) Retrieve() (Value, error) {
42+
func (e *EnvMinio) Retrieve(_ *CredContext) (Value, error) {
4343
e.retrieved = false
4444

4545
id := os.Getenv("MINIO_ROOT_USER")

pkg/credentials/env_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
3434
t.Error("Expect creds to be expired before retrieve.")
3535
}
3636

37-
creds, err := e.Retrieve()
37+
creds, err := e.Retrieve(defaultCredContext)
3838
if err != nil {
3939
t.Fatal(err)
4040
}
@@ -63,7 +63,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
6363
SignerType: SignatureV4,
6464
}
6565

66-
creds, err = e.Retrieve()
66+
creds, err = e.Retrieve(defaultCredContext)
6767
if err != nil {
6868
t.Fatal(err)
6969
}
@@ -84,7 +84,7 @@ func TestEnvMinioRetrieve(t *testing.T) {
8484
t.Error("Expect creds to be expired before retrieve.")
8585
}
8686

87-
creds, err := e.Retrieve()
87+
creds, err := e.Retrieve(defaultCredContext)
8888
if err != nil {
8989
t.Fatal(err)
9090
}

pkg/credentials/file_aws_credentials.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func NewFileAWSCredentials(filename, profile string) *Credentials {
7373

7474
// Retrieve reads and extracts the shared credentials from the current
7575
// users home directory.
76-
func (p *FileAWSCredentials) Retrieve() (Value, error) {
76+
func (p *FileAWSCredentials) Retrieve(_ *CredContext) (Value, error) {
7777
if p.Filename == "" {
7878
p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
7979
if p.Filename == "" {

pkg/credentials/file_minio_client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func NewFileMinioClient(filename, alias string) *Credentials {
5858

5959
// Retrieve reads and extracts the shared credentials from the current
6060
// users home directory.
61-
func (p *FileMinioClient) Retrieve() (Value, error) {
61+
func (p *FileMinioClient) Retrieve(_ *CredContext) (Value, error) {
6262
if p.Filename == "" {
6363
if value, ok := os.LookupEnv("MINIO_SHARED_CREDENTIALS_FILE"); ok {
6464
p.Filename = value

0 commit comments

Comments
 (0)