Skip to content

Commit

Permalink
feature: net.Client before and after hooks (#3140)
Browse files Browse the repository at this point in the history
  • Loading branch information
szuecs authored Jul 17, 2024
1 parent b4389e5 commit 7fde930
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 18 deletions.
27 changes: 22 additions & 5 deletions net/httpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ type Options struct {

// Log is used for error logging
Log logging.Logger

// BeforeSend is a hook function that runs just before executing RoundTrip(*http.Request)
BeforeSend func(*http.Request)
// AfterResponse is a hook function that runs just after executing RoundTrip(*http.Request)
AfterResponse func(*http.Response, error)
}

// Transport wraps an http.Transport and adds support for tracing and
Expand All @@ -219,6 +224,8 @@ type Transport struct {
spanName string
componentName string
bearerToken string
beforeSend func(*http.Request)
afterResponse func(*http.Response, error)
}

// NewTransport creates a wrapped http.Transport, with regular DNS
Expand Down Expand Up @@ -275,10 +282,12 @@ func NewTransport(options Options) *Transport {
}

t := &Transport{
once: sync.Once{},
quit: make(chan struct{}),
tr: htransport,
tracer: options.Tracer,
once: sync.Once{},
quit: make(chan struct{}),
tr: htransport,
tracer: options.Tracer,
beforeSend: options.BeforeSend,
afterResponse: options.AfterResponse,
}

if t.tracer != nil {
Expand Down Expand Up @@ -342,6 +351,8 @@ func (t *Transport) shallowCopy() *Transport {
spanName: t.spanName,
componentName: t.componentName,
bearerToken: t.bearerToken,
beforeSend: t.beforeSend,
afterResponse: t.afterResponse,
}
}

Expand Down Expand Up @@ -369,7 +380,13 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+t.bearerToken)
}
if t.beforeSend != nil {
t.beforeSend(req)
}
rsp, err := t.tr.RoundTrip(req)
if t.afterResponse != nil {
t.afterResponse(rsp, err)
}
if span != nil {
span.LogKV("http_do", "stop")
if rsp != nil {
Expand All @@ -388,10 +405,10 @@ func (t *Transport) injectSpan(req *http.Request) (*http.Request, opentracing.Sp
string(ext.HTTPUrl): req.URL.String(),
}}
if parentSpan := opentracing.SpanFromContext(req.Context()); parentSpan != nil {
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), parentSpan))
spanOpts = append(spanOpts, opentracing.ChildOf(parentSpan.Context()))
}
span := t.tracer.StartSpan(t.spanName, spanOpts...)
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))

_ = t.tracer.Inject(span.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))

Expand Down
96 changes: 96 additions & 0 deletions net/httpclient_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,24 @@ import (

"github.com/lightstep/lightstep-tracer-go"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/opentracing/opentracing-go/mocktracer"
"github.com/sirupsen/logrus"
"github.com/zalando/skipper/net"
"github.com/zalando/skipper/secrets"
)

func waitForSpanViaMockTracer(mockTracer *mocktracer.MockTracer) {
for i := 0; i < 20; i++ {
if n := len(mockTracer.FinishedSpans()); n > 0 {
logrus.Printf("found %d spans", n)
return
}
time.Sleep(100 * time.Millisecond)
}
logrus.Println("no span found")
}

func ExampleTransport() {
tracer := lightstep.NewTracer(lightstep.Options{})

Expand Down Expand Up @@ -220,6 +233,9 @@ func ExampleClient_customTracer() {

cli.Get("http://" + srv.Listener.Addr().String() + "/")

// wait for the span to be finished
waitForSpanViaMockTracer(mockTracer)

fmt.Printf("customtag: %s", mockTracer.FinishedSpans()[0].Tags()["customtag"])

// Output:
Expand Down Expand Up @@ -325,3 +341,83 @@ func ExampleClient_hostSecret() {
time.Sleep(1 * time.Second)
}
}

func ExampleClient_withBeforeSendHook() {
mockTracer := mocktracer.New()
peerService := "my-peer-service"
cli := net.NewClient(net.Options{
Tracer: &customTracer{mockTracer},
OpentracingComponentTag: "testclient",
OpentracingSpanName: "clientSpan",
IdleConnTimeout: 2 * time.Second,
BeforeSend: func(req *http.Request) {
req.Header.Set("X-Foo", "qux")
if span := opentracing.SpanFromContext(req.Context()); span != nil {
logrus.Println("BeforeSend: found span")
span.SetTag(string(ext.PeerService), peerService)
} else {
logrus.Println("BeforeSend: no span found")
}
},
})
defer cli.Close()

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("X-Foo: %s\n", r.Header.Get("X-Foo"))
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

cli.Get("http://" + srv.Listener.Addr().String() + "/")

// wait for the span to be finished
waitForSpanViaMockTracer(mockTracer)

fmt.Printf("request tag %q set to %q", string(ext.PeerService), mockTracer.FinishedSpans()[0].Tags()[string(ext.PeerService)])

// Output:
// X-Foo: qux
// request tag "peer.service" set to "my-peer-service"
}

func ExampleClient_withAfterResponseHook() {
mockTracer := mocktracer.New()
cli := net.NewClient(net.Options{
Tracer: &customTracer{mockTracer},
OpentracingComponentTag: "testclient",
OpentracingSpanName: "clientSpan",
BearerTokenRefreshInterval: 10 * time.Second,
BearerTokenFile: "/tmp/foo.token",
IdleConnTimeout: 2 * time.Second,
AfterResponse: func(rsp *http.Response, err error) {
if span := opentracing.SpanFromContext(rsp.Request.Context()); span != nil {
span.SetTag("status.code", rsp.StatusCode)
if err != nil {
span.SetTag("error", err.Error())
}
}
rsp.StatusCode = 255
},
})
defer cli.Close()

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

rsp, err := cli.Get("http://" + srv.Listener.Addr().String() + "/")
if err != nil {
log.Fatalf("Failed to get: %v", err)
}

// wait for the span to be finished
waitForSpanViaMockTracer(mockTracer)

fmt.Printf("response code: %d\n", rsp.StatusCode)
fmt.Printf("span status.code: %d", mockTracer.FinishedSpans()[0].Tags()["status.code"])

// Output:
// response code: 255
// span status.code: 200
}
128 changes: 115 additions & 13 deletions net/httpclient_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package net

import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -10,6 +11,9 @@ import (
"time"

"github.com/AlexanderYastrebov/noleak"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/opentracing/opentracing-go/mocktracer"
"github.com/zalando/skipper/secrets"
"github.com/zalando/skipper/tracing/tracers/basic"
)
Expand Down Expand Up @@ -204,19 +208,23 @@ func TestClient(t *testing.T) {
}

func TestTransport(t *testing.T) {
mtracer := mocktracer.New()
tracer, err := basic.InitTracer(nil)
if err != nil {
t.Fatalf("Failed to get a tracer: %v", err)
}
defer tracer.Close()

for _, tt := range []struct {
name string
options Options
spanName string
bearerToken string
req *http.Request
wantErr bool
name string
options Options
spanName string
bearerToken string
req *http.Request
wantErr bool
checkRequestOnServer func(*http.Request) error
checkRequest func(*http.Request) error
checkResponse func(*http.Response) error
}{
{
name: "All defaults, with request should have a response",
Expand Down Expand Up @@ -248,6 +256,82 @@ func TestTransport(t *testing.T) {
req: httptest.NewRequest("GET", "http://example.com/", nil),
wantErr: false,
},
{
name: "With hooks, should have request header and respose changed",
options: Options{
BeforeSend: func(req *http.Request) {
if req != nil {
req.Header.Set("X-Foo", "bar")
}
},
AfterResponse: func(rsp *http.Response, err error) {
if rsp != nil {
rsp.StatusCode = 255
}
},
},
req: httptest.NewRequest("GET", "http://example.com/", nil),
wantErr: false,
checkRequestOnServer: func(req *http.Request) error {
if v := req.Header.Get("X-Foo"); v != "bar" {
return fmt.Errorf(`failed to patch request want "X-Foo": "bar", but got: %s`, v)
}
return nil
},
checkResponse: func(rsp *http.Response) error {
if rsp.StatusCode != 255 {
return fmt.Errorf("failed to get status code 255, got: %d", rsp.StatusCode)
}
return nil
},
},
{
name: "With hooks and opentracing, should have request header and response changed",
options: Options{
Tracer: mtracer,
BeforeSend: func(req *http.Request) {
if req != nil {
if span := opentracing.SpanFromContext(req.Context()); span != nil {
span.SetTag(string(ext.PeerService), "my-app")
*req = *req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
return
}
}
},
AfterResponse: func(rsp *http.Response, err error) {
if rsp != nil {
if span := opentracing.SpanFromContext(rsp.Request.Context()); span != nil {
span.SetTag("my.status", 255)
*rsp.Request = *rsp.Request.WithContext(opentracing.ContextWithSpan(rsp.Request.Context(), span))
return
}
}
},
},
spanName: "myspan",
req: httptest.NewRequest("GET", "http://example.com/", nil),
wantErr: false,
checkRequest: func(req *http.Request) error {
if span := opentracing.SpanFromContext(req.Context()); span != nil {
peerService := mtracer.FinishedSpans()[0].Tags()[string(ext.PeerService)]
if peerService != "my-app" {
return fmt.Errorf(`failed to get Tag %s value: "my-app", got %q`, ext.PeerService, peerService)
}
return nil
}
return fmt.Errorf("failed get span from request")
},
checkResponse: func(rsp *http.Response) error {
if span := opentracing.SpanFromContext(rsp.Request.Context()); span != nil {
status := mtracer.FinishedSpans()[0].Tags()["my.status"]
if status != 255 {
return fmt.Errorf(`failed to get Tag "my.status" value: "255", got %d`, status)
}
return nil
}
return fmt.Errorf("failed get span from request")
},
},
} {
t.Run(tt.name, func(t *testing.T) {
s := startTestServer(func(r *http.Request) {
Expand All @@ -259,19 +343,26 @@ func TestTransport(t *testing.T) {
return
}

if tt.spanName != "" && tt.options.Tracer != nil {
if r.Header.Get("Ot-Tracer-Sampled") == "" ||
r.Header.Get("Ot-Tracer-Traceid") == "" ||
r.Header.Get("Ot-Tracer-Spanid") == "" {
t.Errorf("One of the OT Tracer headers are missing: %v", r.Header)
if tt.spanName != "" {
if tt.options.Tracer == tracer {
if r.Header.Get("Ot-Tracer-Sampled") == "" ||
r.Header.Get("Ot-Tracer-Traceid") == "" ||
r.Header.Get("Ot-Tracer-Spanid") == "" {
t.Errorf("One of the OT Tracer headers are missing: %v", r.Header)
}
}
}

if tt.bearerToken != "" {
if r.Header.Get("Authorization") != "Bearer "+string(testToken) {
t.Errorf("Failed to have a token, but want to have it, got: %v, want: %v", r.Header.Get("Authorization"), "Bearer "+tt.bearerToken)
}
}

if tt.checkRequestOnServer != nil {
if err := tt.checkRequestOnServer(r); err != nil {
t.Errorf("Failed to check request: %v", err)
}
}
})

defer s.Close()
Expand All @@ -289,11 +380,22 @@ func TestTransport(t *testing.T) {
if tt.req != nil {
tt.req.URL.Host = s.Listener.Addr().String()
}
_, err := rt.RoundTrip(tt.req)
rsp, err := rt.RoundTrip(tt.req)
if (err != nil) != tt.wantErr {
t.Errorf("Transport.RoundTrip() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.checkRequest != nil {
if err := tt.checkRequest(rsp.Request); err != nil {
t.Errorf("Failed to check request: %v", err)
}
}

if tt.checkResponse != nil {
if err := tt.checkResponse(rsp); err != nil {
t.Errorf("Failed to check response: %v", err)
}
}
})
}
}
Expand Down

0 comments on commit 7fde930

Please sign in to comment.