diff --git a/internal/martian/header/via_modifier.go b/internal/martian/header/via_modifier.go index 633c3daf..12784652 100644 --- a/internal/martian/header/via_modifier.go +++ b/internal/martian/header/via_modifier.go @@ -20,26 +20,32 @@ import ( "fmt" "io" "net/http" - "regexp" "strings" "github.com/saucelabs/forwarder/internal/martian" ) -var whitespace = regexp.MustCompile("[\t ]+") - // ViaModifier is a header modifier that checks for proxy redirect loops. type ViaModifier struct { requestedBy string boundary string + len int } // NewViaModifier returns a new Via modifier. func NewViaModifier(requestedBy string) *ViaModifier { - return &ViaModifier{ + m := &ViaModifier{ requestedBy: requestedBy, boundary: randomBoundary(), } + + var sb strings.Builder + if _, err := m.writeTo(&sb, &http.Request{ProtoMajor: 1, ProtoMinor: 1}); err != nil { + panic(err) // fatal programming error + } + m.len = sb.Len() + + return m } // ModifyRequest sets the Via header and provides loop-detection. If Via is @@ -49,10 +55,18 @@ func NewViaModifier(requestedBy string) *ViaModifier { // // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9 func (m *ViaModifier) ModifyRequest(req *http.Request) error { - via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary) + via := req.Header.Get("Via") + + l := m.len + if via != "" { + l += len(via) + 2 + } + + var sb strings.Builder + sb.Grow(l) - if v := req.Header.Get("Via"); v != "" { - if m.hasLoop(v) { + if via != "" { + if strings.Contains(via, m.requestedBy+"-"+m.boundary) { req.Close = true return martian.ErrorStatus{ Err: fmt.Errorf("via: detected request loop, header contains %s", via), @@ -60,31 +74,17 @@ func (m *ViaModifier) ModifyRequest(req *http.Request) error { } } - via = fmt.Sprintf("%s, %s", v, via) + sb.WriteString(via) + sb.WriteString(", ") } - req.Header.Set("Via", via) - - return nil -} - -// hasLoop parses via and attempts to match requestedBy against the contained -// pseudonyms/host:port pairs. -func (m *ViaModifier) hasLoop(via string) bool { - for _, v := range strings.Split(via, ",") { - parts := whitespace.Split(strings.TrimSpace(v), 3) - - // No pseudonym or host:port, assume there is no loop. - if len(parts) < 2 { - continue - } - - if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] { - return true - } + if _, err := m.writeTo(&sb, req); err != nil { + return err } - return false + req.Header.Set("Via", sb.String()) + + return nil } // SetBoundary sets the boundary string (random 10 character by default) used to @@ -105,3 +105,7 @@ func randomBoundary() string { } return hex.EncodeToString(buf[:]) } + +func (m *ViaModifier) writeTo(w io.Writer, req *http.Request) (int, error) { + return fmt.Fprintf(w, "%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary) +}