Skip to content

Commit e88d85b

Browse files
authored
adding context.Context to loghook to allow tracing id extraction for custom loghooks (#40)
Co-authored-by: mdreikorn <[email protected]>
1 parent 68a33a0 commit e88d85b

File tree

2 files changed

+74
-16
lines changed

2 files changed

+74
-16
lines changed

pester.go

+32-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package pester
44

55
import (
66
"bytes"
7+
"context"
78
"errors"
89
"fmt"
910
"io"
@@ -36,11 +37,12 @@ type Client struct {
3637
Timeout time.Duration
3738

3839
// pester specific
39-
Concurrency int
40-
MaxRetries int
41-
Backoff BackoffStrategy
42-
KeepLog bool
43-
LogHook LogHook
40+
Concurrency int
41+
MaxRetries int
42+
Backoff BackoffStrategy
43+
KeepLog bool
44+
LogHook LogHook
45+
ContextLogHook ContextLogHook
4446

4547
SuccessReqNum int
4648
SuccessRetryNum int
@@ -115,6 +117,9 @@ func NewExtendedClient(hc *http.Client) *Client {
115117
// however, if KeepLog is set to true.
116118
type LogHook func(e ErrEntry)
117119

120+
// ContextLogHook does the same as LogHook but with passed Context
121+
type ContextLogHook func(ctx context.Context, e ErrEntry)
122+
118123
// BackoffStrategy is used to determine how long a retry request should wait until attempted
119124
type BackoffStrategy func(retry int) time.Duration
120125

@@ -286,16 +291,23 @@ func (c *Client) pester(p params) (*http.Response, error) {
286291
return
287292
}
288293

289-
c.log(ErrEntry{
290-
Time: time.Now(),
291-
Method: p.method,
292-
Verb: p.verb,
293-
URL: p.url,
294-
Request: n,
295-
Retry: i + 1, // would remove, but would break backward compatibility
296-
Attempt: i,
297-
Err: err,
298-
})
294+
loggingContext := context.Background()
295+
if p.req != nil {
296+
loggingContext = p.req.Context()
297+
}
298+
299+
c.log(
300+
loggingContext,
301+
ErrEntry{
302+
Time: time.Now(),
303+
Method: p.method,
304+
Verb: p.verb,
305+
URL: p.url,
306+
Request: n,
307+
Retry: i + 1, // would remove, but would break backward compatibility
308+
Attempt: i,
309+
Err: err,
310+
})
299311

300312
// if it is the last iteration, grab the result (which is an error at this point)
301313
if i == AttemptLimit {
@@ -387,11 +399,15 @@ func (c *Client) EmbedHTTPClient(hc *http.Client) {
387399
c.hc = hc
388400
}
389401

390-
func (c *Client) log(e ErrEntry) {
402+
func (c *Client) log(ctx context.Context, e ErrEntry) {
391403
if c.KeepLog {
392404
c.Lock()
393405
defer c.Unlock()
394406
c.ErrLog = append(c.ErrLog, e)
407+
} else if c.ContextLogHook != nil {
408+
// NOTE: There is a possibility that Log Printing hook slows it down.
409+
// but the consumer can always do the Job in a go-routine.
410+
c.ContextLogHook(ctx, e)
395411
} else if c.LogHook != nil {
396412
// NOTE: There is a possibility that Log Printing hook slows it down.
397413
// but the consumer can always do the Job in a go-routine.

pester_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,48 @@ func TestCustomLogHook(t *testing.T) {
412412
}
413413
}
414414

415+
func TestCustomContextLogHook(t *testing.T) {
416+
t.Parallel()
417+
418+
expectedRetries := 5
419+
errorLines := []ErrEntry{}
420+
testContextKey := "testContextKey"
421+
testContextValue := "testContextValue"
422+
ctx := context.WithValue(context.Background(), testContextKey, testContextValue)
423+
424+
c := New()
425+
c.MaxRetries = expectedRetries
426+
c.Backoff = func(_ int) time.Duration {
427+
return 10 * time.Microsecond
428+
}
429+
430+
c.ContextLogHook = func(ctx context.Context, e ErrEntry) {
431+
if testContextValue != ctx.Value(testContextKey) {
432+
t.Fatalf("Value %s not found under key %s in context", testContextValue, testContextKey)
433+
}
434+
errorLines = append(errorLines, e)
435+
}
436+
437+
nonExistantURL := "http://localhost:9000/foo"
438+
httpRequest, err := http.NewRequest(http.MethodGet, nonExistantURL, nil)
439+
httpRequest = httpRequest.WithContext(ctx)
440+
441+
if err != nil {
442+
t.Fatal("unexpected error on request creation")
443+
}
444+
445+
_, err = c.Do(httpRequest)
446+
if err == nil {
447+
t.Fatal("expected to get an error")
448+
}
449+
c.Wait()
450+
451+
// in the event of an error, let's see what the logs were
452+
if expectedRetries != len(errorLines) {
453+
t.Errorf("Expected %d lines to be emitted. Got %d", expectedRetries, len(errorLines))
454+
}
455+
}
456+
415457
func TestDefaultLogHook(t *testing.T) {
416458
t.Parallel()
417459

0 commit comments

Comments
 (0)