Skip to content

Commit

Permalink
martian/header: optimized ViaModifier implementation
Browse files Browse the repository at this point in the history
- Remove regex
- Remove splitting
- Allocate single chunk of memory
- Precompute chunk size, never grow

Fixes #924
  • Loading branch information
mmatczuk committed Dec 4, 2024
1 parent 4f94833 commit c2ef68e
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions internal/martian/header/via_modifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,32 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"

"github.com/saucelabs/forwarder/internal/martian"
)

var whitespace = regexp.MustCompile("[\t ]+")

// ViaModifier is a header modifier that checks for proxy redirect loops.
type ViaModifier struct {
requestedBy string
boundary string
len int
}

// NewViaModifier returns a new Via modifier.
func NewViaModifier(requestedBy string) *ViaModifier {
return &ViaModifier{
m := &ViaModifier{
requestedBy: requestedBy,
boundary: randomBoundary(),
}

var sb strings.Builder
if _, err := m.writeTo(&sb, &http.Request{ProtoMajor: 1, ProtoMinor: 1}); err != nil {
panic(err) // fatal programming error
}
m.len = sb.Len()

return m
}

// ModifyRequest sets the Via header and provides loop-detection. If Via is
Expand All @@ -49,42 +55,36 @@ func NewViaModifier(requestedBy string) *ViaModifier {
//
// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9
func (m *ViaModifier) ModifyRequest(req *http.Request) error {
via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary)
via := req.Header.Get("Via")

l := m.len
if via != "" {
l += len(via) + 2
}

var sb strings.Builder
sb.Grow(l)

if v := req.Header.Get("Via"); v != "" {
if m.hasLoop(v) {
if via != "" {
if strings.Contains(via, m.requestedBy+"-"+m.boundary) {
req.Close = true
return martian.ErrorStatus{
Err: fmt.Errorf("via: detected request loop, header contains %s", via),
Status: 400,
}
}

via = fmt.Sprintf("%s, %s", v, via)
sb.WriteString(via)
sb.WriteString(", ")
}

req.Header.Set("Via", via)

return nil
}

// hasLoop parses via and attempts to match requestedBy against the contained
// pseudonyms/host:port pairs.
func (m *ViaModifier) hasLoop(via string) bool {
for _, v := range strings.Split(via, ",") {
parts := whitespace.Split(strings.TrimSpace(v), 3)

// No pseudonym or host:port, assume there is no loop.
if len(parts) < 2 {
continue
}

if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] {
return true
}
if _, err := m.writeTo(&sb, req); err != nil {
return err
}

return false
req.Header.Set("Via", sb.String())

return nil
}

// SetBoundary sets the boundary string (random 10 character by default) used to
Expand All @@ -105,3 +105,7 @@ func randomBoundary() string {
}
return hex.EncodeToString(buf[:])
}

func (m *ViaModifier) writeTo(w io.Writer, req *http.Request) (int, error) {
return fmt.Fprintf(w, "%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary)
}

0 comments on commit c2ef68e

Please sign in to comment.