Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

martian/header: optimized ViaModifier implementation #974

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions internal/martian/header/via_modifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,42 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"

"github.com/saucelabs/forwarder/internal/martian"
)

var whitespace = regexp.MustCompile("[\t ]+")
const (
h20Prefix = "2.0"
h11Prefix = "1.1"
h10Prefix = "1.0"
)

// ViaModifier is a header modifier that checks for proxy redirect loops.
type ViaModifier struct {
requestedBy string
boundary string
tag string
}

// NewViaModifier returns a new Via modifier.
func NewViaModifier(requestedBy string) *ViaModifier {
return NewViaModifierWithBoundary(requestedBy, randomBoundary())
}

// randomBoundary generates a 10 character string to ensure that Martians that
// are chained together with the same requestedBy value do not collide. This func
// panics if io.Readfull fails.
func randomBoundary() string {
var buf [10]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return hex.EncodeToString(buf[:])
}

func NewViaModifierWithBoundary(requestedBy, boundary string) *ViaModifier {
return &ViaModifier{
requestedBy: requestedBy,
boundary: randomBoundary(),
tag: requestedBy + "-" + boundary,
}
}

Expand All @@ -49,59 +66,51 @@ func NewViaModifier(requestedBy string) *ViaModifier {
//
// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9
func (m *ViaModifier) ModifyRequest(req *http.Request) error {
via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary)
via := req.Header.Get("Via")

var sb strings.Builder
sb.Grow(m.nextLen(via))

if v := req.Header.Get("Via"); v != "" {
if m.hasLoop(v) {
if via != "" {
if strings.Contains(via, m.tag) {
req.Close = true
return martian.ErrorStatus{
Err: fmt.Errorf("via: detected request loop, header contains %s", via),
Status: 400,
}
}

via = fmt.Sprintf("%s, %s", v, via)
sb.WriteString(via)
sb.WriteString(", ")
}

switch req.ProtoMajor*10 + req.ProtoMinor {
case 20:
sb.WriteString(h20Prefix)
case 11:
sb.WriteString(h11Prefix)
case 10:
sb.WriteString(h10Prefix)
default:
fmt.Fprintf(&sb, "%d.%d", req.ProtoMajor, req.ProtoMinor)
}

req.Header.Set("Via", via)
sb.WriteByte(' ')
sb.WriteString(m.tag)

req.Header.Set("Via", sb.String())

return nil
}

// hasLoop parses via and attempts to match requestedBy against the contained
// pseudonyms/host:port pairs.
func (m *ViaModifier) hasLoop(via string) bool {
for _, v := range strings.Split(via, ",") {
parts := whitespace.Split(strings.TrimSpace(v), 3)

// No pseudonym or host:port, assume there is no loop.
if len(parts) < 2 {
continue
}
func (m *ViaModifier) nextLen(via string) int {
l := 0

if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] {
return true
}
if via != "" {
l += len(via) + 2
}

return false
}
l += len(h11Prefix) + 1 + len(m.tag)

// SetBoundary sets the boundary string (random 10 character by default) used to
// disabiguate Martians that are chained together with identical requestedBy values.
// This should only be used for testing.
func (m *ViaModifier) SetBoundary(boundary string) {
m.boundary = boundary
}

// randomBoundary generates a 10 character string to ensure that Martians that
// are chained together with the same requestedBy value do not collide. This func
// panics if io.Readfull fails.
func randomBoundary() string {
var buf [10]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return hex.EncodeToString(buf[:])
return l
}
21 changes: 20 additions & 1 deletion internal/martian/header/via_modifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestViaModifier(t *testing.T) {
t.Errorf("req.Header.Get(%q): got %q, want prefixed with %q", "Via", got, want)
}

m.SetBoundary("boundary")
m = NewViaModifierWithBoundary("martian", "boundary")
req.Header.Set("Via", "1.0\talpha\t(martian), 1.1 martian-boundary, 1.1 beta")
if err := m.ModifyRequest(req); err == nil {
t.Fatal("ModifyRequest(): got nil, want request loop error")
Expand All @@ -51,3 +51,22 @@ func TestViaModifier(t *testing.T) {
t.Fatalf("req.Close: got %v, want true", req.Close)
}
}

func BenchmarkViaModifier(b *testing.B) {
const via = "1.0\talpha\t(martian), 1.1 martian-boundary, 1.1 beta"

req, err := http.NewRequest(http.MethodGet, "/", http.NoBody)
if err != nil {
b.Fatalf("http.NewRequest(): got %v, want no error", err)
}

m := NewViaModifier("martian")

b.ResetTimer()
for i := 0; i < b.N; i++ {
req.Header.Set("Via", via)
if err := m.ModifyRequest(req); err != nil {
b.Fatalf("ModifyRequest(): got %v, want no error", err)
}
}
}
Loading