From 044c1e80f634ec4f17f05913034f901f7a1674ae Mon Sep 17 00:00:00 2001 From: Ryan Brown Date: Fri, 3 May 2019 23:43:49 +0000 Subject: [PATCH] Add basic signature validation as needed for LTI. --- auther.go | 8 +- reference_test.go | 26 ++++ token.go | 7 - token_test.go | 4 +- transport_test.go | 7 +- validator.go | 165 +++++++++++++++++++++ validator_test.go | 361 ++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 561 insertions(+), 17 deletions(-) create mode 100644 validator.go create mode 100644 validator_test.go diff --git a/auther.go b/auther.go index 202543c..681724f 100644 --- a/auther.go +++ b/auther.go @@ -102,13 +102,17 @@ func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, reque // requests with an AccessToken (token credential) according to RFC 5849 3.1. func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) error { oauthParams := a.commonOAuthParams() - oauthParams[oauthTokenParam] = accessToken.Token + var tokenSecret string + if accessToken != nil { + oauthParams[oauthTokenParam] = accessToken.Token + tokenSecret = accessToken.TokenSecret + } params, err := collectParameters(req, oauthParams) if err != nil { return err } signatureBase := signatureBase(req, params) - signature, err := a.signer().Sign(accessToken.TokenSecret, signatureBase) + signature, err := a.signer().Sign(tokenSecret, signatureBase) if err != nil { return err } diff --git a/reference_test.go b/reference_test.go index 3fbfc12..94afcfe 100644 --- a/reference_test.go +++ b/reference_test.go @@ -170,6 +170,32 @@ func TestTwitterRequestAuthHeader(t *testing.T) { assert.Equal(t, expectedVersion, params[oauthVersionParam]) } +func TestNilAuthToken(t *testing.T) { + expectedSignature := PercentEncode("+gxx4CGoDB7afZbRRRpR56orbKU=") + expectedTimestamp := "1318622958" + + auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}, &fixedNoncer{expectedNonce}} + values := url.Values{} + values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!") + + var accessToken *Token + req, err := http.NewRequest("POST", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode())) + assert.Nil(t, err) + req.Header.Set(contentType, formContentType) + err = auther.setRequestAuthHeader(req, accessToken) + // assert that request is signed and has an access token token + assert.Nil(t, err) + params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam)) + assert.NotContains(t, params, oauthTokenParam) + assert.Equal(t, expectedSignature, params[oauthSignatureParam]) + // additional OAuth parameters + assert.Equal(t, expectedTwitterConsumerKey, params[oauthConsumerKeyParam]) + assert.Equal(t, expectedNonce, params[oauthNonceParam]) + assert.Equal(t, expectedSignatureMethod, params[oauthSignatureMethodParam]) + assert.Equal(t, expectedTimestamp, params[oauthTimestampParam]) + assert.Equal(t, expectedVersion, params[oauthVersionParam]) +} + func parseOAuthParamsOrFail(t *testing.T, authHeader string) map[string]string { if !strings.HasPrefix(authHeader, authorizationPrefix) { assert.Fail(t, fmt.Sprintf("Expected Authorization header to start with \"%s\", got \"%s\"", authorizationPrefix, authHeader[:len(authorizationPrefix)+1])) diff --git a/token.go b/token.go index d010d2f..796b907 100644 --- a/token.go +++ b/token.go @@ -1,9 +1,5 @@ package oauth1 -import ( - "errors" -) - // A TokenSource can return a Token. type TokenSource interface { Token() (*Token, error) @@ -36,8 +32,5 @@ type staticTokenSource struct { } func (s staticTokenSource) Token() (*Token, error) { - if s.token == nil { - return nil, errors.New("oauth1: Token is nil") - } return s.token, nil } diff --git a/token_test.go b/token_test.go index 140dc45..7039e2f 100644 --- a/token_test.go +++ b/token_test.go @@ -25,7 +25,5 @@ func TestStaticTokenSourceEmpty(t *testing.T) { ts := StaticTokenSource(nil) tk, err := ts.Token() assert.Nil(t, tk) - if assert.Error(t, err) { - assert.Equal(t, "oauth1: Token is nil", err.Error()) - } + assert.Nil(t, err) } diff --git a/transport_test.go b/transport_test.go index 4a6fef5..a64a3d7 100644 --- a/transport_test.go +++ b/transport_test.go @@ -92,11 +92,8 @@ func TestTransport_emptySource(t *testing.T) { }, } client := &http.Client{Transport: tr} - resp, err := client.Get("http://example.com") - assert.Nil(t, resp) - if assert.Error(t, err) { - assert.Equal(t, "Get http://example.com: oauth1: Token is nil", err.Error()) - } + _, err := client.Get("http://example.com") + assert.NoError(t, err) } func TestTransport_nilAuther(t *testing.T) { diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..32c3f76 --- /dev/null +++ b/validator.go @@ -0,0 +1,165 @@ +package oauth1 + +import ( + "context" + "fmt" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" +) + +type providerRequest struct { + req *http.Request + oauthParams map[string]string + signatureToVerify string + signatureMethod string + timestamp int64 + clientKey string + nonce string +} + +// ClientStorage represents an OAuth 1 provider's database of clients. +type ClientStorage = interface { + // GetSigner returns the signer that should be used to validate the signature for a client. + // To avoid timing attacks, GetSigner should return a Signer and a non-nil error + // if the clientKey is invalid. ValidateRequest will still compute a signature + // so that the runtime of ValidateRequest is about the same regardless of the clientKey's validity. + // The http request is also available for additional validation, e.g. checking for HTTPS. + GetSigner(ctx context.Context, clientKey, signatureMethod string, req *http.Request) (Signer, error) + + // ValidateNonce returns an error if a nonce has been used before. + // + // Per Section 3.3 of the spec: + // The timestamp value MUST be a positive integer. Unless otherwise + // specified by the server's documentation, the timestamp is expressed + // in the number of seconds since January 1, 1970 00:00:00 GMT. + // + // A nonce is a random string, uniquely generated by the client to allow + // the server to verify that a request has never been made before and + // helps prevent replay attacks when requests are made over a non-secure + // channel. The nonce value MUST be unique across all requests with the + // same timestamp, client credentials, and token combinations. + // + // To avoid the need to retain an infinite number of nonce values for + // future checks, servers MAY choose to restrict the time period after + // which a request with an old timestamp is rejected. Note that this + // restriction implies a level of synchronization between the client's + // and server's clocks. + ValidateNonce(ctx context.Context, clientKey, nonce string, timestamp int64, req *http.Request) error +} + +var authorizationHeaderParamPattern = regexp.MustCompile(`^\s*([^=]+)="?(\S*?)"?\s*$`) + +func newProviderRequest(req *http.Request) (*providerRequest, error) { + authParams := make(map[string]string) + authHeader := req.Header.Get(authorizationHeaderParam) + if len(authHeader) > len(authorizationPrefix) { + authHeaderPrefix := strings.ToLower(authHeader[:len(authorizationPrefix)]) + if authHeaderPrefix == strings.ToLower(authorizationPrefix) { + authHeaderSuffix := authHeader[len(authorizationPrefix):] + for _, pair := range strings.Split(authHeaderSuffix, ",") { + if match := authorizationHeaderParamPattern.FindStringSubmatch(pair); match == nil { + return nil, fmt.Errorf("Invalid Authorization header") + } else if value, err := url.PathUnescape(match[2]); err == nil { + authParams[match[1]] = value + } else { + return nil, err + } + } + } + } + allParams, err := collectParameters(req, authParams) + if err != nil { + return nil, err + } + if err = checkMandatoryParams(allParams); err != nil { + return nil, err + } + sig := allParams[oauthSignatureParam] + delete(allParams, oauthSignatureParam) + timestamp, err := strconv.ParseInt(allParams[oauthTimestampParam], 10, 64) + if err != nil { + return nil, fmt.Errorf("unable to parse timestamp: %v", err) + } else if timestamp <= 0 { + return nil, fmt.Errorf("invalid timestamp %v", timestamp) + } + if version, ok := allParams[oauthVersionParam]; ok && version != defaultOauthVersion { + return nil, fmt.Errorf("incorrect oauth version %v", version) + } + preq := &providerRequest{ + req: req, + oauthParams: allParams, + signatureToVerify: sig, + signatureMethod: allParams[oauthSignatureMethodParam], + timestamp: timestamp, + clientKey: allParams[oauthConsumerKeyParam], + nonce: allParams[oauthNonceParam], + } + return preq, nil +} + +func checkMandatoryParams(params map[string]string) error { + var missingParams []string + for _, param := range []string{oauthSignatureParam, oauthConsumerKeyParam, oauthNonceParam, oauthTimestampParam, oauthSignatureMethodParam} { + if _, ok := params[param]; !ok { + missingParams = append(missingParams, param) + } + } + if len(missingParams) > 0 { + return fmt.Errorf("missing required oauth params %v", strings.Join(missingParams, ", ")) + } + if _, hasAccessToken := params[oauthTokenParam]; hasAccessToken { + return fmt.Errorf("token signature validation not implemented") + } + return nil +} + +var errSignatureMismatch = fmt.Errorf("signature mismatch") + +func (r providerRequest) checkSignature(signer Signer) error { + if signer == nil { + return errSignatureMismatch + } + base := signatureBase(r.req, r.oauthParams) + signature, err := signer.Sign("", base) + if err != nil { + return err + } + + // near constant time string comparison to avoid timing attacks + // https://rdist.root.org/2010/01/07/timing-independent-array-comparison/ + sigToVerify := r.signatureToVerify + if len(sigToVerify) != len(signature) { + return errSignatureMismatch + } + result := byte(0) + for i, r := range []byte(signature) { + result |= r ^ sigToVerify[i] + } + if result != 0 { + return errSignatureMismatch + } + return nil +} + +// ValidateSignature checks that req contains a valid OAUTH 1 signature. +// It returns nil if the signature is valid, or an error if the validation fails. +func ValidateSignature(ctx context.Context, req *http.Request, v ClientStorage) error { + preq, err := newProviderRequest(req) + if err != nil { + return err + } + if err = v.ValidateNonce(ctx, preq.clientKey, preq.nonce, preq.timestamp, req); err != nil { + return err + } + signer, invalidClient := v.GetSigner(ctx, preq.clientKey, preq.signatureMethod, req) + + // Check signature even if client is invalid to prevent timing attacks. + invalidSignature := preq.checkSignature(signer) + if invalidClient != nil { + return invalidClient + } + return invalidSignature +} diff --git a/validator_test.go b/validator_test.go new file mode 100644 index 0000000..ecf7f26 --- /dev/null +++ b/validator_test.go @@ -0,0 +1,361 @@ +package oauth1 + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func makeValues(m map[string]string) url.Values { + values := make(url.Values) + for k, v := range m { + values.Set(k, v) + } + return values +} + +func makeRequest(method, url, authHeader string, values url.Values) *http.Request { + req := httptest.NewRequest(method, url, strings.NewReader(values.Encode())) + if authHeader != "" { + req.Header.Set(authorizationHeaderParam, authHeader) + } + req.Header.Set(contentType, formContentType) + return req +} + +func TestNewProviderRequest_ParamSources(t *testing.T) { + expectedParams := map[string]string{ + "oauth_consumer_key": "some_client", + "oauth_signature": "sig", + "oauth_nonce": "a_nonce", + "oauth_timestamp": "1", + "oauth_signature_method": "HMAC-SHA1", + } + expectedValues := makeValues(expectedParams) + url := "https://example.com/oauth1" + queryReq := makeRequest("GET", url+"?"+expectedValues.Encode(), "", nil) + bodyReq := makeRequest("POST", url, "", expectedValues) + headerReq := makeRequest("GET", url, authHeaderValue(expectedParams), nil) + + delete(expectedParams, "oauth_signature") + + for i, req := range []*http.Request{queryReq, bodyReq, headerReq} { + preq, err := newProviderRequest(req) + assert.NoError(t, err, i) + assert.Equal(t, expectedParams, preq.oauthParams, i) + assert.Same(t, req, preq.req, i) + assert.Equal(t, preq.signatureToVerify, "sig", i) + assert.Equal(t, preq.signatureMethod, "HMAC-SHA1", i) + assert.Equal(t, preq.nonce, "a_nonce", i) + assert.Equal(t, preq.timestamp, int64(1), i) + assert.Equal(t, preq.clientKey, "some_client", i) + + } +} + +func TestNewProviderRequest_CombinedParams(t *testing.T) { + req := makeRequest("POST", "https://example.com/oauth?query=1", "OAUTH header=\"1\"", makeValues(map[string]string{ + "body": "1", + "oauth_consumer_key": "some_client", + "oauth_signature": "sig", + "oauth_nonce": "a_nonce", + "oauth_timestamp": "1", + "oauth_signature_method": "HMAC-SHA1", + })) + preq, err := newProviderRequest(req) + assert.Nil(t, err) + assert.Equal(t, map[string]string{ + "header": "1", + "query": "1", + "body": "1", + "oauth_consumer_key": "some_client", + "oauth_nonce": "a_nonce", + "oauth_timestamp": "1", + "oauth_signature_method": "HMAC-SHA1", + }, preq.oauthParams) +} + +func TestNewProviderRequest_InvalidAuthHeader(t *testing.T) { + req := makeRequest("POST", "https://example.com/oauth", "OAUTH thisisn'tvalid", makeValues(map[string]string{ + "oauth_consumer_key": "some_client", + "oauth_signature": "sig", + "oauth_nonce": "a_nonce", + "oauth_timestamp": "1", + "oauth_signature_method": "HMAC-SHA1", + })) + _, err := newProviderRequest(req) + assert.NotNil(t, err) +} + +func TestNewProviderRequest_OtherAuthHeader(t *testing.T) { + req := makeRequest("POST", "https://example.com/oauth", "bearer foobar", makeValues(map[string]string{ + "oauth_consumer_key": "some_client", + "oauth_signature": "sig", + "oauth_nonce": "a_nonce", + "oauth_timestamp": "1", + "oauth_signature_method": "HMAC-SHA1", + })) + _, err := newProviderRequest(req) + assert.Nil(t, err) +} + +func TestNewProviderRequest_MissingParams(t *testing.T) { + for _, param := range []string{"oauth_consumer_key", "oauth_signature", "oauth_nonce", "oauth_timestamp", "oauth_signature_method"} { + params := map[string]string{ + "oauth_consumer_key": "some_client", + "oauth_signature": "sig", + "oauth_nonce": "a_nonce", + "oauth_timestamp": "1", + "oauth_signature_method": "HMAC-SHA1", + } + delete(params, param) + _, err := newProviderRequest(makeRequest("POST", "https://example.com", "", makeValues(params))) + assert.NotNil(t, err, param) + assert.Contains(t, err.Error(), param) + } + +} + +func TestCheckSignature_ValidSignature(t *testing.T) { + config := &Config{ConsumerKey: "consumer_key", ConsumerSecret: "consumer_secret"} + a := newAuther(config) + + req := makeRequest("POST", "https://example.com/foo?q=bar", "", makeValues(map[string]string{"data": "something"})) + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + preq, err := newProviderRequest(req) + assert.NoError(t, err) + assert.NoError(t, preq.checkSignature(a.signer())) +} + +func TestCheckSignature_InvalidSignature(t *testing.T) { + tests := []struct{ orig, modified *http.Request }{ + { + // Change method + makeRequest("POST", "https://example.com/foo?q=bar", "", makeValues(map[string]string{"data": "something"})), + makeRequest("PUT", "https://example.com/foo?q=bar", "", makeValues(map[string]string{"data": "something"})), + }, + { + // Change body params + makeRequest("POST", "https://example.com/foo?q=bar", "", makeValues(map[string]string{"data": "something"})), + makeRequest("POST", "https://example.com/foo?q=bar", "", makeValues(map[string]string{"data": "nothing"})), + }, + { + // Change query params + makeRequest("GET", "https://example.com/foo?q=bar", "", nil), + makeRequest("GET", "https://example.com/foo?q=flowers", "", nil), + }, + } + a := &auther{ + &Config{ConsumerKey: "consumer_key", ConsumerSecret: "secret"}, + &fixedClock{time.Unix(50037133, 0)}, + &fixedNoncer{"some_nonce"}, + } + for i, test := range tests { + assert.NoError(t, a.setRequestAuthHeader(test.orig, nil)) + test.modified.Header[authorizationHeaderParam] = test.orig.Header[authorizationHeaderParam] + preq, err := newProviderRequest(test.modified) + assert.NoError(t, err, i) + err = preq.checkSignature(a.signer()) + if assert.Error(t, err, i) { + assert.Equal(t, errSignatureMismatch, err, i) + } + } +} + +func TestCheckSignature_ReversedSignature(t *testing.T) { + config := &Config{ConsumerKey: "consumer_key", ConsumerSecret: "consumer_secret"} + a := newAuther(config) + + req := makeRequest("GET", "https://example.com/foo?q=bar", "", nil) + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + preq, err := newProviderRequest(req) + assert.NoError(t, err) + + origSig := preq.signatureToVerify + reversedSig := "" + for _, v := range origSig { + reversedSig = string(v) + reversedSig + } + preq.signatureToVerify = reversedSig + + err = preq.checkSignature(a.signer()) + if assert.Error(t, err) { + assert.Equal(t, errSignatureMismatch, err) + } +} + +func TestNewProviderRequest_TimestampParsing(t *testing.T) { + params := map[string]string{ + "oauth_consumer_key": "some_client", + "oauth_signature": "sig", + "oauth_nonce": "a_nonce", + "oauth_timestamp": "-1", + "oauth_signature_method": "HMAC-SHA1", + } + _, err := newProviderRequest(makeRequest("POST", "https://example.com", "", makeValues(params))) + assert.NotNil(t, err) + params["oauth_timestamp"] = "0" + _, err = newProviderRequest(makeRequest("POST", "https://example.com", "", makeValues(params))) + assert.NotNil(t, err) + params["oauth_timestamp"] = "17" + req, err := newProviderRequest(makeRequest("POST", "https://example.com", "", makeValues(params))) + assert.Nil(t, err) + assert.Equal(t, req.timestamp, int64(17)) +} + +type mockStorage struct { + // outputs + Signer Signer + SignerErr error + NonceErr error + + // Saved inputs + SignerContext context.Context + SignerKey string + SignatureMethod string + SignerRequest *http.Request + + NonceContext context.Context + NonceKey string + Nonce string + Timestamp int64 + NonceRequest *http.Request +} + +func (m *mockStorage) GetSigner(ctx context.Context, key, method string, req *http.Request) (Signer, error) { + m.SignerContext = ctx + m.SignerKey = key + m.SignatureMethod = method + m.SignerRequest = req + return m.Signer, m.SignerErr +} + +func (m *mockStorage) ValidateNonce(ctx context.Context, key, nonce string, timestamp int64, req *http.Request) error { + m.NonceContext = ctx + m.NonceKey = key + m.Nonce = nonce + m.Timestamp = timestamp + m.NonceRequest = req + return m.NonceErr +} + +func TestValidateSignature_ClientStorageArgs(t *testing.T) { + req := makeRequest("GET", "https://example.com/foo?q=bar", "", nil) + a := &auther{ + &Config{ConsumerKey: "consumer_key", ConsumerSecret: "secret", Signer: &identitySigner{}}, + &fixedClock{time.Unix(50037133, 0)}, + &fixedNoncer{"some_nonce"}, + } + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + + storage := &mockStorage{Signer: &identitySigner{}} + assert.NoError(t, ValidateSignature(NoContext, req, storage)) + assert.Equal(t, "consumer_key", storage.SignerKey) + assert.Equal(t, "identity", storage.SignatureMethod) + assert.Same(t, req, storage.SignerRequest) + + assert.Same(t, NoContext, storage.NonceContext) + assert.Equal(t, "consumer_key", storage.NonceKey) + assert.Equal(t, "some_nonce", storage.Nonce) + assert.EqualValues(t, 50037133, storage.Timestamp) + assert.Same(t, req, storage.NonceRequest) +} +func TestValidateSignature_BadNonceOrTimestamp(t *testing.T) { + req := makeRequest("GET", "https://example.com/foo?q=bar", "", nil) + a := &auther{ + &Config{ConsumerKey: "consumer_key", ConsumerSecret: "secret", Signer: &identitySigner{}}, + &fixedClock{time.Unix(50037133, 0)}, + &fixedNoncer{"some_nonce"}, + } + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + + storage := &mockStorage{NonceErr: fmt.Errorf("i don't like your nonce")} + err = ValidateSignature(NoContext, req, storage) + if assert.Error(t, err) { + assert.Equal(t, storage.NonceErr, err) + } +} +func TestValidateSignature_BadClientKey(t *testing.T) { + req := makeRequest("GET", "https://example.com/foo?q=bar", "", nil) + a := &auther{ + &Config{ConsumerKey: "consumer_key", ConsumerSecret: "secret", Signer: &identitySigner{}}, + &fixedClock{time.Unix(50037133, 0)}, + &fixedNoncer{"some_nonce"}, + } + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + + storage := &mockStorage{SignerErr: fmt.Errorf("i don't like your key")} + err = ValidateSignature(NoContext, req, storage) + if assert.Error(t, err) { + assert.Equal(t, storage.SignerErr, err) + } +} + +type countingSigner struct { + count int +} + +func (c *countingSigner) Name() string { return "count" } +func (c *countingSigner) Sign(tokenSecret, message string) (string, error) { + c.count++ + return strconv.Itoa(c.count), nil +} + +func TestValidateSignature_SignerCalledOnBadKey(t *testing.T) { + req := makeRequest("GET", "https://example.com/foo?q=bar", "", nil) + a := &auther{ + &Config{ConsumerKey: "consumer_key", ConsumerSecret: "secret", Signer: &identitySigner{}}, + &fixedClock{time.Unix(50037133, 0)}, + &fixedNoncer{"some_nonce"}, + } + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + + signer := &countingSigner{} + storage := &mockStorage{Signer: signer, SignerErr: fmt.Errorf("i don't like your key")} + err = ValidateSignature(NoContext, req, storage) + if assert.Error(t, err) { + assert.Equal(t, storage.SignerErr, err) + } + assert.Equal(t, 1, signer.count) +} + +func TestValidSignature_ValidSignature(t *testing.T) { + config := &Config{ConsumerKey: "consumer_key", ConsumerSecret: "consumer_secret"} + a := newAuther(config) + + req := makeRequest("POST", "https://example.com/foo?q=bar", "", makeValues(map[string]string{"data": "something"})) + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + + storage := &mockStorage{Signer: a.signer()} + assert.NoError(t, ValidateSignature(NoContext, req, storage)) +} + +func TestValidSignature_InvalidSecret(t *testing.T) { + config := &Config{ConsumerKey: "consumer_key", ConsumerSecret: "consumer_secret"} + a := newAuther(config) + + req := makeRequest("POST", "https://example.com/foo?q=bar", "", makeValues(map[string]string{"data": "something"})) + err := a.setRequestAuthHeader(req, nil) + assert.NoError(t, err) + + storage := &mockStorage{Signer: &HMACSigner{ConsumerSecret: "another_secret"}} + err = ValidateSignature(NoContext, req, storage) + if assert.Error(t, err) { + assert.Equal(t, errSignatureMismatch, err) + } +}