Skip to content

Commit 7c8355e

Browse files
committed
add client check for code and refresh token lookup
1 parent 32c68b5 commit 7c8355e

4 files changed

+12
-12
lines changed

authorization_code_grant_type.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (gt AuthorizationCodeGrantType) TokenHandler(c *Client, ew *EncoderResponse
6969
return
7070
}
7171

72-
auth, err := gt.persistence.GetAuthorizationByCode(code)
72+
auth, err := gt.persistence.GetAuthorizationByCode(c, code)
7373
if err != nil {
7474
log.Println("couldn't find authorization for code:", err)
7575
ew.Encode(ErrInvalidGrant)

backend.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ type PersistenceBackend interface {
3333
//*
3434
// Authorization persistence
3535
//*
36-
GetAuthorizationByCode(code string) (*Authorization, error)
36+
GetAuthorizationByCode(c *Client, code string) (*Authorization, error)
37+
GetAuthorizationByRefreshToken(c *Client, refreshToken string) (*Authorization, error)
3738
GetAuthorizationByAccessToken(accessToken string) (*Authorization, error)
38-
GetAuthorizationByRefreshToken(refreshToken string) (*Authorization, error)
3939
SaveAuthorization(a *Authorization) error
4040

4141
//*

in_memory_persistence.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -61,33 +61,33 @@ func (b *InMemoryPersistence) SaveAuthorization(a *Authorization) error {
6161
}
6262

6363
// GetAuthorizationByCode takes a code and look it up
64-
func (b *InMemoryPersistence) GetAuthorizationByCode(code string) (*Authorization, error) {
64+
func (b *InMemoryPersistence) GetAuthorizationByCode(c *Client, code string) (*Authorization, error) {
6565
for _, a := range b.authorizations {
66-
if a.Code == code {
66+
if a.Client == c && a.Code == code {
6767
return a, nil
6868
}
6969
}
7070

7171
return nil, ErrNotFound
7272
}
7373

74-
// GetAuthorizationByAccessToken takes an access token and returns the authorization
74+
// GetAuthorizationByRefreshToken takes an access token and returns the authorization
7575
// it represents, if exists.
76-
func (b *InMemoryPersistence) GetAuthorizationByAccessToken(accessToken string) (*Authorization, error) {
76+
func (b *InMemoryPersistence) GetAuthorizationByRefreshToken(c *Client, refreshToken string) (*Authorization, error) {
7777
for _, a := range b.authorizations {
78-
if a.AccessToken == accessToken {
78+
if a.Client == c && a.RefreshToken == refreshToken {
7979
return a, nil
8080
}
8181
}
8282

8383
return nil, ErrNotFound
8484
}
8585

86-
// GetAuthorizationByRefreshToken takes an access token and returns the authorization
86+
// GetAuthorizationByAccessToken takes an access token and returns the authorization
8787
// it represents, if exists.
88-
func (b *InMemoryPersistence) GetAuthorizationByRefreshToken(refreshToken string) (*Authorization, error) {
88+
func (b *InMemoryPersistence) GetAuthorizationByAccessToken(accessToken string) (*Authorization, error) {
8989
for _, a := range b.authorizations {
90-
if a.RefreshToken == refreshToken {
90+
if a.AccessToken == accessToken {
9191
return a, nil
9292
}
9393
}

refresh_token_grant_type.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (gt RefreshTokenGrantType) TokenHandler(c *Client, ew *EncoderResponseWrite
5252
return
5353
}
5454

55-
auth, err := gt.persistence.GetAuthorizationByRefreshToken(refreshToken)
55+
auth, err := gt.persistence.GetAuthorizationByRefreshToken(c, refreshToken)
5656
if err != nil {
5757
log.Println("invalid refresh token:", refreshToken)
5858
ew.Encode(ErrInvalidGrant)

0 commit comments

Comments
 (0)