Skip to content

Commit

Permalink
Implemented support for awsv4 auth for http clients
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrconley committed May 10, 2024
1 parent 88c01d1 commit 002529e
Show file tree
Hide file tree
Showing 10 changed files with 570 additions and 0 deletions.
50 changes: 50 additions & 0 deletions internal/httpclient/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package httpclient

import (
"context"
"github.com/benthosdev/benthos/v4/internal/impl/aws"
"github.com/benthosdev/benthos/v4/internal/impl/aws/config"
"io/fs"
"net/http"

Expand Down Expand Up @@ -35,6 +38,7 @@ func AuthFieldSpecsExpanded() []*service.ConfigField {
oAuth2FieldSpec(),
BasicAuthField(),
jwtFieldSpec(),
awsV4FieldSpec(),
}
}

Expand Down Expand Up @@ -338,3 +342,49 @@ func jwtAuthFromParsed(conf *service.ParsedConfig) (res JWTConfig, err error) {
}
return
}

//------------------------------------------------------------------------------

const (
av4Field = "aws_v4"
av4FieldEnabled = "enabled"
av4FieldService = "service"
)

func awsV4FieldSpec() *service.ConfigField {
awsSessionFields := config.SessionFields()
regionField := awsSessionFields[0]
credentialsField := awsSessionFields[2]

return service.NewObjectField("aws_v4",
service.NewBoolField(av4FieldEnabled).
Description("Whether to use AWS V4 authentication in requests.").
Default(false),
regionField,
credentialsField,
service.NewStringField(av4FieldService).
Description("Optional service name to use for the request").
Default(""),
)
}

func awsV4FromParsed(conf *service.ParsedConfig) (res AWSV4Config, err error) {
res = NewAWSV4Config()
if !conf.Contains(av4Field) {
return
}
conf = conf.Namespace(av4Field)
if res.Enabled, err = conf.FieldBool(av4FieldEnabled); err != nil {
return
}
session, err := aws.GetSession(context.Background(), conf)
if err != nil {
return
}
if res.Service, err = conf.FieldString(av4FieldService); err != nil {
return
}
res.Region = session.Region
res.Creds, err = session.Credentials.Retrieve(context.Background())
return
}
99 changes: 99 additions & 0 deletions internal/httpclient/auth_config.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
package httpclient

import (
"bytes"
"context"
"crypto"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"io"
"io/fs"
"math/rand"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -298,3 +305,95 @@ func (oauth OAuth2Config) Client(ctx context.Context, base *http.Client) *http.C

return conf.Client(context.WithValue(ctx, oauth2.HTTPClient, base))
}

//------------------------------------------------------------------------------

// AWSV4Config holds the configuration parameters for an AWS V4 exchange.
type AWSV4Config struct {
Enabled bool
Region string
Creds aws.Credentials
Service string
}

// NewAWSV4Config returns a new AWSV4Config with default values.
func NewAWSV4Config() AWSV4Config {
return AWSV4Config{
Enabled: false,
Region: "",
Creds: aws.Credentials{},
Service: "",
}
}

func (awsv4 AWSV4Config) Client(ctx context.Context, base *http.Client) *http.Client {
if !awsv4.Enabled {
return base
}

return &http.Client{
Transport: AWSV4Transport{
client: base,
creds: awsv4.Creds,
signer: v4.NewSigner(),
region: awsv4.Region,
service: awsv4.Service,
},
}
}

// AWSV4Transport Transport is a RoundTripper that will sign requests with AWS V4 Signing
type AWSV4Transport struct {
client *http.Client
creds aws.Credentials
signer *v4.Signer
region string
service string
}

// RoundTrip uses the underlying RoundTripper transport, but signs request first with AWS V4 Signing
func (st AWSV4Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if h, ok := req.Header["Authorization"]; ok && len(h) > 0 && strings.HasPrefix(h[0], "AWS4") {
// Received a signed request, just pass it on.
return st.client.Do(req)
}

if strings.Contains(req.URL.RawPath, "%2C") {
// Escaping path
req.URL.RawPath = url.PathEscape(req.URL.RawPath)
}

hash, err := hexEncodedSha256OfRequest(req)
if err != nil {
return nil, err
}
req.Header.Set("X-Amz-Content-Sha256", hash)

if err := st.signer.SignHTTP(req.Context(), st.creds, req, hash, st.service, st.region, time.Now().UTC()); err != nil {
return nil, err
}
return st.client.Do(req)
}

func hexEncodedSha256OfRequest(r *http.Request) (string, error) {
if r.Body == nil {
return "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", nil
}

hasher := sha256.New()

reqBodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}

if err := r.Body.Close(); err != nil {
return "", err
}

r.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes))
hasher.Write(reqBodyBytes)
digest := hasher.Sum(nil)

return hex.EncodeToString(digest), nil
}
27 changes: 27 additions & 0 deletions internal/httpclient/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,30 @@ oauth2:
"third": {"and those"},
}, authConf.EndpointParams)
}

func TestAuthConfigAWSV4Parsing(t *testing.T) {
spec := service.NewConfigSpec().Field(awsV4FieldSpec())

parsedConf, err := spec.ParseYAML(`
aws_v4:
enabled: true
region: eu-west-1
service: test-service
credentials:
id: foo
secret: bar
token: baz
`, service.NewEnvironment())

require.NoError(t, err)

authConf, err := awsV4FromParsed(parsedConf)
require.NoError(t, err)

assert.True(t, authConf.Enabled)
assert.Equal(t, "eu-west-1", authConf.Region)
assert.Equal(t, "foo", authConf.Creds.AccessKeyID)
assert.Equal(t, "bar", authConf.Creds.SecretAccessKey)
assert.Equal(t, "baz", authConf.Creds.SessionToken)
assert.Equal(t, "test-service", authConf.Service)
}
3 changes: 3 additions & 0 deletions internal/httpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ func NewClientFromOldConfig(conf OldConfig, mgr *service.Resources, opts ...Requ

h.client = conf.OAuth2.Client(h.clientCtx, h.client)

// If AWSV4 is enabled we need to create a new client with the signing
h.client = conf.AWSV4.Client(h.clientCtx, h.client)

for _, c := range conf.BackoffOn {
h.backoffOn[c] = struct{}{}
}
Expand Down
37 changes: 37 additions & 0 deletions internal/httpclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,40 @@ tls:
require.NoError(t, err)
assert.Equal(t, "HELLO WORLD", string(mBytes))
}

func TestHTTPClientWithAWSV4Auth(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check the Authorization header
// We only care that it exists and looks generally correct
assert.Regexp(t, "AWS4-HMAC-SHA256 Credential=foo\\/\\d{8}\\/us-west-1\\/test-service\\/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-content-sha256;x-amz-date, Signature=.*", r.Header.Get("Authorization"))
b, err := io.ReadAll(r.Body)
require.NoError(t, err)
_, _ = w.Write(bytes.ToUpper(b))
}))
defer ts.Close()

conf := clientConfig(t, `
url: %v
aws_v4:
enabled: true
region: us-west-1
service: test-service
credentials:
id: foo
secret: bar
`, ts.URL)

h, err := NewClientFromOldConfig(conf, service.MockResources())
require.NoError(t, err)

resBatch, err := h.Send(context.Background(), service.MessageBatch{
service.NewMessage([]byte("hello world")),
})

require.NoError(t, err)
require.Len(t, resBatch, 1)

mBytes, err := resBatch[0].AsBytes()
require.NoError(t, err)
assert.Equal(t, "HELLO WORLD", string(mBytes))
}
4 changes: 4 additions & 0 deletions internal/httpclient/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ func ConfigFromParsed(pConf *service.ParsedConfig) (conf OldConfig, err error) {
if conf.OAuth2, err = oauth2FromParsed(pConf); err != nil {
return
}
if conf.AWSV4, err = awsV4FromParsed(pConf); err != nil {
return
}
return
}

Expand All @@ -181,4 +184,5 @@ type OldConfig struct {
ProxyURL string
Auth AuthConfig
OAuth2 OAuth2Config
AWSV4 AWSV4Config
}
11 changes: 11 additions & 0 deletions internal/httpclient/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ jwt:
oauth2:
enabled: true
client_key: moo
aws_v4:
enabled: true
region: us-west-1
credentials:
id: foo
secret: bar
`,
validator: func(t *testing.T, o *OldConfig) {
sURL, _ := o.URL.Static()
Expand All @@ -111,6 +117,11 @@ oauth2:

assert.True(t, o.OAuth2.Enabled)
assert.Equal(t, "moo", o.OAuth2.ClientKey)

assert.True(t, o.AWSV4.Enabled)
assert.Equal(t, "us-west-1", o.AWSV4.Region)
assert.Equal(t, "foo", o.AWSV4.Creds.AccessKeyID)
assert.Equal(t, "bar", o.AWSV4.Creds.SecretAccessKey)
},
},
}
Expand Down
Loading

0 comments on commit 002529e

Please sign in to comment.