diff --git a/internal/martian/header/via_modifier.go b/internal/martian/header/via_modifier.go index 633c3daf..462245cf 100644 --- a/internal/martian/header/via_modifier.go +++ b/internal/martian/header/via_modifier.go @@ -20,25 +20,42 @@ import ( "fmt" "io" "net/http" - "regexp" "strings" "github.com/saucelabs/forwarder/internal/martian" ) -var whitespace = regexp.MustCompile("[\t ]+") +const ( + h20Prefix = "2.0" + h11Prefix = "1.1" + h10Prefix = "1.0" +) // ViaModifier is a header modifier that checks for proxy redirect loops. type ViaModifier struct { - requestedBy string - boundary string + tag string } // NewViaModifier returns a new Via modifier. func NewViaModifier(requestedBy string) *ViaModifier { + return NewViaModifierWithBoundary(requestedBy, randomBoundary()) +} + +// randomBoundary generates a 10 character string to ensure that Martians that +// are chained together with the same requestedBy value do not collide. This func +// panics if io.Readfull fails. +func randomBoundary() string { + var buf [10]byte + _, err := io.ReadFull(rand.Reader, buf[:]) + if err != nil { + panic(err) + } + return hex.EncodeToString(buf[:]) +} + +func NewViaModifierWithBoundary(requestedBy, boundary string) *ViaModifier { return &ViaModifier{ - requestedBy: requestedBy, - boundary: randomBoundary(), + tag: requestedBy + "-" + boundary, } } @@ -49,10 +66,13 @@ func NewViaModifier(requestedBy string) *ViaModifier { // // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9 func (m *ViaModifier) ModifyRequest(req *http.Request) error { - via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary) + via := req.Header.Get("Via") + + var sb strings.Builder + sb.Grow(m.nextLen(via)) - if v := req.Header.Get("Via"); v != "" { - if m.hasLoop(v) { + if via != "" { + if strings.Contains(via, m.tag) { req.Close = true return martian.ErrorStatus{ Err: fmt.Errorf("via: detected request loop, header contains %s", via), @@ -60,48 +80,37 @@ func (m *ViaModifier) ModifyRequest(req *http.Request) error { } } - via = fmt.Sprintf("%s, %s", v, via) + sb.WriteString(via) + sb.WriteString(", ") + } + + switch req.ProtoMajor*10 + req.ProtoMinor { + case 20: + sb.WriteString(h20Prefix) + case 11: + sb.WriteString(h11Prefix) + case 10: + sb.WriteString(h10Prefix) + default: + fmt.Fprintf(&sb, "%d.%d", req.ProtoMajor, req.ProtoMinor) } - req.Header.Set("Via", via) + sb.WriteByte(' ') + sb.WriteString(m.tag) + + req.Header.Set("Via", sb.String()) return nil } -// hasLoop parses via and attempts to match requestedBy against the contained -// pseudonyms/host:port pairs. -func (m *ViaModifier) hasLoop(via string) bool { - for _, v := range strings.Split(via, ",") { - parts := whitespace.Split(strings.TrimSpace(v), 3) - - // No pseudonym or host:port, assume there is no loop. - if len(parts) < 2 { - continue - } +func (m *ViaModifier) nextLen(via string) int { + l := 0 - if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] { - return true - } + if via != "" { + l += len(via) + 2 } - return false -} + l += len(h11Prefix) + 1 + len(m.tag) -// SetBoundary sets the boundary string (random 10 character by default) used to -// disabiguate Martians that are chained together with identical requestedBy values. -// This should only be used for testing. -func (m *ViaModifier) SetBoundary(boundary string) { - m.boundary = boundary -} - -// randomBoundary generates a 10 character string to ensure that Martians that -// are chained together with the same requestedBy value do not collide. This func -// panics if io.Readfull fails. -func randomBoundary() string { - var buf [10]byte - _, err := io.ReadFull(rand.Reader, buf[:]) - if err != nil { - panic(err) - } - return hex.EncodeToString(buf[:]) + return l } diff --git a/internal/martian/header/via_modifier_test.go b/internal/martian/header/via_modifier_test.go index b8e62ca2..6377030e 100644 --- a/internal/martian/header/via_modifier_test.go +++ b/internal/martian/header/via_modifier_test.go @@ -42,7 +42,7 @@ func TestViaModifier(t *testing.T) { t.Errorf("req.Header.Get(%q): got %q, want prefixed with %q", "Via", got, want) } - m.SetBoundary("boundary") + m = NewViaModifierWithBoundary("martian", "boundary") req.Header.Set("Via", "1.0\talpha\t(martian), 1.1 martian-boundary, 1.1 beta") if err := m.ModifyRequest(req); err == nil { t.Fatal("ModifyRequest(): got nil, want request loop error") @@ -51,3 +51,22 @@ func TestViaModifier(t *testing.T) { t.Fatalf("req.Close: got %v, want true", req.Close) } } + +func BenchmarkViaModifier(b *testing.B) { + const via = "1.0\talpha\t(martian), 1.1 martian-boundary, 1.1 beta" + + req, err := http.NewRequest(http.MethodGet, "/", http.NoBody) + if err != nil { + b.Fatalf("http.NewRequest(): got %v, want no error", err) + } + + m := NewViaModifier("martian") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req.Header.Set("Via", via) + if err := m.ModifyRequest(req); err != nil { + b.Fatalf("ModifyRequest(): got %v, want no error", err) + } + } +}