Skip to content

Commit

Permalink
Merge pull request #212 from cerberauth/client-improvements
Browse files Browse the repository at this point in the history
Add setter and getter for default client
  • Loading branch information
emmanuelgautier authored Oct 27, 2024
2 parents 39cb8e3 + 76f0e55 commit fb9a55d
Show file tree
Hide file tree
Showing 29 changed files with 158 additions and 75 deletions.
1 change: 1 addition & 0 deletions api/curl.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (h *Handler) ScanURL(ctx *gin.Context) {
opts.Header = ctx.Request.Header
opts.Cookies = ctx.Request.Cookies()
client := request.NewClient(opts)

s, err := scenario.NewURLScan(form.Method, form.URL, form.Data, client, &scan.ScanOptions{
IncludeScans: form.Opts.Scans,
ExcludeScans: form.Opts.ExcludeScans,
Expand Down
1 change: 1 addition & 0 deletions api/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func (h *Handler) ScanGraphQL(ctx *gin.Context) {
opts.Header = ctx.Request.Header
opts.Cookies = ctx.Request.Cookies()
client := request.NewClient(opts)

s, err := scenario.NewGraphQLScan(form.Endpoint, client, &scan.ScanOptions{
IncludeScans: form.Opts.Scans,
ExcludeScans: form.Opts.ExcludeScans,
Expand Down
2 changes: 2 additions & 0 deletions cmd/scan/curl.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"

internalCmd "github.com/cerberauth/vulnapi/internal/cmd"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/scan"
"github.com/cerberauth/vulnapi/scenario"
"github.com/cerberauth/x/analyticsx"
Expand Down Expand Up @@ -40,6 +41,7 @@ func NewCURLScanCmd() (scanCmd *cobra.Command) {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
}
request.SetDefaultClient(client)

s, err := scenario.NewURLScan(curlMethod, curlUrl, curlData, client, &scan.ScanOptions{
IncludeScans: internalCmd.GetIncludeScans(),
Expand Down
2 changes: 2 additions & 0 deletions cmd/scan/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"

internalCmd "github.com/cerberauth/vulnapi/internal/cmd"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/scan"
"github.com/cerberauth/vulnapi/scenario"
"github.com/cerberauth/x/analyticsx"
Expand Down Expand Up @@ -32,6 +33,7 @@ func NewGraphQLScanCmd() (scanCmd *cobra.Command) {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
}
request.SetDefaultClient(client)

s, err := scenario.NewGraphQLScan(graphqlEndpoint, client, &scan.ScanOptions{
IncludeScans: internalCmd.GetIncludeScans(),
Expand Down
2 changes: 2 additions & 0 deletions cmd/scan/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/cerberauth/vulnapi/internal/auth"
internalCmd "github.com/cerberauth/vulnapi/internal/cmd"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/openapi"
"github.com/cerberauth/vulnapi/scan"
"github.com/cerberauth/vulnapi/scenario"
Expand Down Expand Up @@ -74,6 +75,7 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
}
request.SetDefaultClient(client)

s, err := scenario.NewOpenAPIScan(openapi, securitySchemesValues, client, &scan.ScanOptions{
IncludeScans: internalCmd.GetIncludeScans(),
Expand Down
16 changes: 14 additions & 2 deletions internal/request/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ import (

var rl = ratelimit.New(10)

var DefaultClient = NewClient(NewClientOptions{})
var defaultClient *Client = nil

func GetDefaultClient() *Client {
if defaultClient == nil {
defaultClient = NewClient(NewClientOptions{})
}

return defaultClient
}

func SetDefaultClient(client *Client) {
defaultClient = client
}

type Client struct {
*http.Client
Expand Down Expand Up @@ -53,7 +65,7 @@ func NewClient(opts NewClientOptions) *Client {

return &Client{
&http.Client{
Timeout: 10 * time.Second,
Timeout: opts.Timeout,

Transport: &http.Transport{
Proxy: proxy,
Expand Down
63 changes: 63 additions & 0 deletions internal/request/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package request

import (
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestNewClient_DefaultOptions(t *testing.T) {
client := NewClient(NewClientOptions{})

assert.NotNil(t, client)
assert.Equal(t, 10*time.Second, client.Timeout)
assert.Equal(t, 100, client.Transport.(*http.Transport).MaxIdleConns)
assert.Equal(t, 100, client.Transport.(*http.Transport).MaxIdleConnsPerHost)
assert.Empty(t, client.Header)
assert.Empty(t, client.Cookies)
}

func TestNewClient_CustomOptions(t *testing.T) {
header := http.Header{"Custom-Header": []string{"value"}}
cookies := []*http.Cookie{{Name: "test", Value: "cookie"}}

client := NewClient(NewClientOptions{
Timeout: 5 * time.Second,
Header: header,
Cookies: cookies,
})

assert.NotNil(t, client)
assert.Equal(t, 5*time.Second, client.Timeout)
assert.Equal(t, header, client.Header)
assert.Equal(t, cookies, client.Cookies)
}

func TestGetClient(t *testing.T) {
client := GetDefaultClient()
assert.NotNil(t, client)
}

func TestSetClient(t *testing.T) {
newClient := NewClient(NewClientOptions{})
SetDefaultClient(newClient)
assert.Equal(t, newClient, GetDefaultClient())
}

func TestClient_WithHeader(t *testing.T) {
client := NewClient(NewClientOptions{})
header := http.Header{"Custom-Header": []string{"value"}}
client = client.WithHeader(header)

assert.Equal(t, header, client.Header)
}

func TestClient_WithCookies(t *testing.T) {
client := NewClient(NewClientOptions{})
cookies := []*http.Cookie{{Name: "test", Value: "cookie"}}
client = client.WithCookies(cookies)

assert.Equal(t, cookies, client.Cookies)
}
2 changes: 1 addition & 1 deletion internal/request/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Operation struct {

func NewOperation(method string, operationUrl string, body *bytes.Buffer, client *Client) (*Operation, error) {
if client == nil {
client = DefaultClient
client = GetDefaultClient()
}

parsedUrl, err := url.Parse(operationUrl)
Expand Down
2 changes: 1 addition & 1 deletion internal/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type Request struct {

func NewRequest(method string, reqUrl string, body io.Reader, client *Client) (*Request, error) {
if client == nil {
client = DefaultClient
client = GetDefaultClient()
}

req, err := http.NewRequest(method, reqUrl, body)
Expand Down
14 changes: 7 additions & 7 deletions internal/request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestWithSecurityScheme(t *testing.T) {
}

func TestDo(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -93,7 +93,7 @@ func TestDo(t *testing.T) {
}

func TestDoWithHeaders(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down Expand Up @@ -152,7 +152,7 @@ func TestDoWithClientHeaders(t *testing.T) {
}

func TestDoWithBody(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -178,7 +178,7 @@ func TestDoWithBody(t *testing.T) {
}

func TestDoWithSecuritySchemeHeaders(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -204,7 +204,7 @@ func TestDoWithSecuritySchemeHeaders(t *testing.T) {
}

func TestDoWithHeadersSecuritySchemeHeaders(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down Expand Up @@ -236,7 +236,7 @@ func TestDoWithHeadersSecuritySchemeHeaders(t *testing.T) {
}

func TestDoWithCookiesSecuritySchemeHeaders(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down Expand Up @@ -269,7 +269,7 @@ func TestDoWithCookiesSecuritySchemeHeaders(t *testing.T) {
}

func TestDoWithCookies(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestAuthenticationByPassScanHandler_Skipped_WhenNoAuthSecurityScheme(t *tes
}

func TestAuthenticationByPassScanHandler_Failed_WhenAuthIsByPassed(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -39,7 +39,7 @@ func TestAuthenticationByPassScanHandler_Failed_WhenAuthIsByPassed(t *testing.T)
}

func TestAuthenticationByPassScanHandler_Passed_WhenAuthIsNotByPassed(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down
8 changes: 4 additions & 4 deletions scan/broken_authentication/jwt/alg_none/alg_none_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestAlgNoneJwtScanHandler_WithoutSecurityScheme(t *testing.T) {
}

func TestAlgNoneJwtScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -41,7 +41,7 @@ func TestAlgNoneJwtScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testin
}

func TestAlgNoneJwtScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -58,7 +58,7 @@ func TestAlgNoneJwtScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) {
}

func TestAlgNoneJwtScanHandler_Failed_WhenOKResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -78,7 +78,7 @@ func TestAlgNoneJwtScanHandler_Failed_WhenOKResponse(t *testing.T) {
}

func TestAlgNoneJwtScanHandler_Failed_WhenOKResponseAndAlgNone(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestBlankSecretScanHandler_WithoutSecurityScheme(t *testing.T) {
}

func TestBlankSecretScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -39,7 +39,7 @@ func TestBlankSecretScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testi
}

func TestBlankSecretScanHandler_Passed_WhenNoJWTAndOKResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -55,7 +55,7 @@ func TestBlankSecretScanHandler_Passed_WhenNoJWTAndOKResponse(t *testing.T) {
}

func TestBlankSecretScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -72,7 +72,7 @@ func TestBlankSecretScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) {
}

func TestBlankSecretScanHandler_Failed_WhenOKResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestNotVerifiedScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testi
}

func TestNotVerifiedScanHandler_Failed_WhenUnauthorizedThenOK(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -55,7 +55,7 @@ func TestNotVerifiedScanHandler_Failed_WhenUnauthorizedThenOK(t *testing.T) {
}

func TestNotVerifiedScanHandler_Skipped_WhenOKFirstRequest(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -77,7 +77,7 @@ func TestNotVerifiedScanHandler_Skipped_WhenOKFirstRequest(t *testing.T) {
}

func TestNotVerifiedScanHandler_Failed_WhenUnauthorizedThenUnauthorized(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestNullSignatureScanHandler_WithoutSecurityScheme(t *testing.T) {
}

func TestNullSignatureScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -39,7 +39,7 @@ func TestNullSignatureScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *tes
}

func TestNullSignatureScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand All @@ -56,7 +56,7 @@ func TestNullSignatureScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T)
}

func TestNullSignatureScanHandler_Failed_WhenOKResponse(t *testing.T) {
client := request.DefaultClient
client := request.GetDefaultClient()
httpmock.ActivateNonDefault(client.Client)
defer httpmock.DeactivateAndReset()

Expand Down
Loading

0 comments on commit fb9a55d

Please sign in to comment.