Skip to content

Commit

Permalink
Merge pull request #221 from cerberauth/openapi-security-schemes
Browse files Browse the repository at this point in the history
Refactor security schemes
  • Loading branch information
emmanuelgautier authored Nov 17, 2024
2 parents c03793a + 8a7e70f commit 841b62c
Show file tree
Hide file tree
Showing 88 changed files with 1,549 additions and 1,449 deletions.
9 changes: 4 additions & 5 deletions api/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"net/http"

"github.com/cerberauth/vulnapi/internal/analytics"
"github.com/cerberauth/vulnapi/internal/auth"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/openapi"
"github.com/cerberauth/vulnapi/scan"
Expand Down Expand Up @@ -33,15 +32,15 @@ func (h *Handler) ScanOpenAPI(ctx *gin.Context) {
traceCtx, span := tracer.Start(ctx, "Scan OpenAPI")
defer span.End()

openapi, err := openapi.LoadFromData(traceCtx, []byte(form.Schema))
doc, err := openapi.LoadFromData(traceCtx, []byte(form.Schema))
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

if err := openapi.Validate(ctx); err != nil {
if err := doc.Validate(ctx); err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
Expand All @@ -59,8 +58,8 @@ func (h *Handler) ScanOpenAPI(ctx *gin.Context) {
values[key] = &value.Value
}
}
securitySchemesValues := auth.NewSecuritySchemeValues(values)
s, err := scenario.NewOpenAPIScan(openapi, securitySchemesValues, client, &scan.ScanOptions{
securitySchemesValues := openapi.NewSecuritySchemeValues(values)
s, err := scenario.NewOpenAPIScan(doc, securitySchemesValues, client, &scan.ScanOptions{
IncludeScans: form.Opts.Scans,
ExcludeScans: form.Opts.ExcludeScans,
})
Expand Down
9 changes: 4 additions & 5 deletions cmd/scan/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"os"

"github.com/cerberauth/vulnapi/internal/analytics"
"github.com/cerberauth/vulnapi/internal/auth"
internalCmd "github.com/cerberauth/vulnapi/internal/cmd"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/openapi"
Expand Down Expand Up @@ -47,14 +46,14 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) {
ctx, span := tracer.Start(cmd.Context(), "Scan OpenAPI")
defer span.End()

openapi, err := openapi.LoadOpenAPI(ctx, openapiUrlOrPath)
doc, err := openapi.LoadOpenAPI(ctx, openapiUrlOrPath)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
log.Fatal(err)
}

if err := openapi.Validate(ctx); err != nil {
if err := doc.Validate(ctx); err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
log.Fatal(err)
Expand All @@ -69,7 +68,7 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) {
for key, value := range securitySchemesValueArg {
values[key] = &value
}
securitySchemesValues := auth.NewSecuritySchemeValues(values).WithDefault(validToken)
securitySchemesValues := openapi.NewSecuritySchemeValues(values).WithDefault(validToken)

client, err := internalCmd.NewHTTPClientFromArgs(internalCmd.GetRateLimit(), internalCmd.GetProxy(), internalCmd.GetHeaders(), internalCmd.GetCookies())
if err != nil {
Expand All @@ -79,7 +78,7 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) {
}
request.SetDefaultClient(client)

s, err := scenario.NewOpenAPIScan(openapi, securitySchemesValues, client, &scan.ScanOptions{
s, err := scenario.NewOpenAPIScan(doc, securitySchemesValues, client, &scan.ScanOptions{
IncludeScans: internalCmd.GetIncludeScans(),
ExcludeScans: internalCmd.GetExcludeScans(),
})
Expand Down
27 changes: 27 additions & 0 deletions internal/auth/api_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package auth

func NewAPIKeySecurityScheme(name string, in SchemeIn, value *string) (*SecurityScheme, error) {
tokenFormat := NoneTokenFormat
securityScheme, err := NewSecurityScheme(name, nil, ApiKey, NoneScheme, &in, &tokenFormat)
if err != nil {
return nil, err
}

if value != nil && *value != "" {
err = securityScheme.SetValidValue(*value)
if err != nil {
return nil, err
}
}

return securityScheme, nil
}

func MustNewAPIKeySecurityScheme(name string, in SchemeIn, value *string) *SecurityScheme {
securityScheme, err := NewAPIKeySecurityScheme(name, in, value)
if err != nil {
panic(err)
}

return securityScheme
}
61 changes: 61 additions & 0 deletions internal/auth/api_key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package auth_test

import (
"testing"

"github.com/cerberauth/vulnapi/internal/auth"
"github.com/stretchr/testify/assert"
)

func TestNewAPIKeySecurityScheme(t *testing.T) {
name := "token"
value := "abc123"
tokenFormat := auth.NoneTokenFormat

securityScheme, err := auth.NewAPIKeySecurityScheme(name, auth.InHeader, &value)

assert.NoError(t, err)
assert.Equal(t, auth.ApiKey, securityScheme.GetType())
assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme())
assert.Equal(t, auth.InHeader, *securityScheme.GetIn())
assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat())
assert.Equal(t, name, securityScheme.GetName())
assert.Equal(t, value, securityScheme.GetValidValue().(string))
assert.Equal(t, nil, securityScheme.GetAttackValue())
}

func TestTestNewAPIKeySecurityScheme_WhenNilValue(t *testing.T) {
name := "token"

securityScheme, err := auth.NewAPIKeySecurityScheme(name, auth.InHeader, nil)

assert.NoError(t, err)
assert.Equal(t, nil, securityScheme.GetValidValue())
assert.Equal(t, nil, securityScheme.GetAttackValue())
}

func TestNewAuthorizationBearerSecurityScheme_WhenInCooke(t *testing.T) {
name := "token"
value := "abc123"

securityScheme, err := auth.NewAPIKeySecurityScheme(name, auth.InQuery, &value)

assert.NoError(t, err)
assert.Equal(t, auth.InQuery, *securityScheme.GetIn())
}

func TestMustNewAPIKeySecurityScheme(t *testing.T) {
name := "token"
value := "abc123"
tokenFormat := auth.NoneTokenFormat

securityScheme := auth.MustNewAPIKeySecurityScheme(name, auth.InHeader, &value)

assert.Equal(t, auth.ApiKey, securityScheme.GetType())
assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme())
assert.Equal(t, auth.InHeader, *securityScheme.GetIn())
assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat())
assert.Equal(t, name, securityScheme.GetName())
assert.Equal(t, value, securityScheme.GetValidValue().(string))
assert.Equal(t, nil, securityScheme.GetAttackValue())
}
99 changes: 27 additions & 72 deletions internal/auth/bearer.go
Original file line number Diff line number Diff line change
@@ -1,86 +1,41 @@
package auth

import (
"fmt"
"net/http"
"github.com/cerberauth/vulnapi/jwt"
)

type BearerSecurityScheme struct {
Type Type `json:"type" yaml:"type"`
Scheme SchemeName `json:"scheme" yaml:"scheme"`
In SchemeIn `json:"in" yaml:"in"`
Name string `json:"name" yaml:"name"`
ValidValue *string `json:"-" yaml:"-"`
AttackValue string `json:"-" yaml:"-"`
}

var _ SecurityScheme = (*BearerSecurityScheme)(nil)

func NewAuthorizationBearerSecurityScheme(name string, value *string) *BearerSecurityScheme {
return &BearerSecurityScheme{
Type: HttpType,
Scheme: BearerScheme,
In: InHeader,
Name: name,
ValidValue: value,
AttackValue: "",
}
}

func (ss *BearerSecurityScheme) GetType() Type {
return ss.Type
}

func (ss *BearerSecurityScheme) GetScheme() SchemeName {
return ss.Scheme
}

func (ss *BearerSecurityScheme) GetIn() *SchemeIn {
return &ss.In
}

func (ss *BearerSecurityScheme) GetName() string {
return ss.Name
}

func (ss *BearerSecurityScheme) GetHeaders() http.Header {
header := http.Header{}
attackValue := ss.GetAttackValue().(string)
if attackValue == "" && ss.HasValidValue() {
attackValue = ss.GetValidValue().(string)
func NewAuthorizationBearerSecurityScheme(name string, value *string) (*SecurityScheme, error) {
in := InHeader
securityScheme, err := NewSecurityScheme(name, nil, HttpType, BearerScheme, &in, nil)
if err != nil {
return nil, err
}

if attackValue != "" {
header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, attackValue))
if value != nil && *value != "" {
err = securityScheme.SetValidValue(*value)
if err != nil {
return nil, err
}

var tokenFormat TokenFormat
if jwt.IsJWT(*value) {
tokenFormat = JWTTokenFormat
} else {
tokenFormat = NoneTokenFormat
}
if err = securityScheme.SetTokenFormat(tokenFormat); err != nil {
return nil, err
}
}

return header
}

func (ss *BearerSecurityScheme) GetCookies() []*http.Cookie {
return []*http.Cookie{}
}

func (ss *BearerSecurityScheme) HasValidValue() bool {
return ss.ValidValue != nil && *ss.ValidValue != ""
return securityScheme, nil
}

func (ss *BearerSecurityScheme) GetValidValue() interface{} {
if !ss.HasValidValue() {
return nil
func MustNewAuthorizationBearerSecurityScheme(name string, value *string) *SecurityScheme {
securityScheme, err := NewAuthorizationBearerSecurityScheme(name, value)
if err != nil {
panic(err)
}

return *ss.ValidValue
}

func (ss *BearerSecurityScheme) GetValidValueWriter() interface{} {
return nil
}

func (ss *BearerSecurityScheme) SetAttackValue(v interface{}) {
ss.AttackValue = v.(string)
}

func (ss *BearerSecurityScheme) GetAttackValue() interface{} {
return ss.AttackValue
return securityScheme
}
Loading

0 comments on commit 841b62c

Please sign in to comment.