Skip to content

Commit 32a1beb

Browse files
authored
Fix POST retries (#45)
* Add test to verify backoff Context cancellation behaviour * Verify that call will return error when Context-cancelled during backoff * Verify that error is due to Context-cancellation * Fix Context-cancellation during backoff * Context cancellation would result in no result being sent, thus causing the operation to hang in many circumstances * Add tests to verify retry behaviour with POST requests * Add test for POST-requests provided by Do() with a body * Add test for POST-requests provided by Post() with a body * Add middlewareServer to support POST tests * Fix retries for requests with bodies * Provide each concurrent request with its own Request so they do not interfere with each other. Eg reading/closing each others bodies. * Reset body before retrying when requests have bodies * Fix first request with Do and a body failing due to copyBody closing it
1 parent c02ad50 commit 32a1beb

File tree

2 files changed

+292
-20
lines changed

2 files changed

+292
-20
lines changed

pester.go

+68-20
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,16 @@ func (c *Client) copyBody(src io.ReadCloser) ([]byte, error) {
198198
return b, nil
199199
}
200200

201+
// resetBody resets the Body and GetBody fields of an http.Request to new Readers over
202+
// the originalBody. This is used to refresh http.Requests that may have had their
203+
// bodies closed already.
204+
func resetBody(request *http.Request, originalBody []byte) {
205+
request.Body = io.NopCloser(bytes.NewBuffer(originalBody))
206+
request.GetBody = func() (io.ReadCloser, error) {
207+
return io.NopCloser(bytes.NewBuffer(originalBody)), nil
208+
}
209+
}
210+
201211
// pester provides all the logic of retries, concurrency, backoff, and logging
202212
func (c *Client) pester(p params) (*http.Response, error) {
203213
resultCh := make(chan result)
@@ -242,7 +252,6 @@ func (c *Client) pester(p params) (*http.Response, error) {
242252

243253
// if we have a request body, we need to save it for later
244254
var (
245-
request *http.Request
246255
originalBody []byte
247256
err error
248257
)
@@ -252,23 +261,52 @@ func (c *Client) pester(p params) (*http.Response, error) {
252261
} else if p.body != nil {
253262
originalBody, err = c.copyBody(p.body)
254263
}
264+
if err != nil {
265+
return nil, err
266+
}
255267

268+
// check to make sure that we aren't trying to use an unsupported method
256269
switch p.method {
257-
case methodDo:
258-
request = p.req
259-
case methodGet, methodHead:
260-
request, err = http.NewRequest(p.verb, p.url, nil)
261-
case methodPostForm, methodPost:
262-
request, err = http.NewRequest(http.MethodPost, p.url, ioutil.NopCloser(bytes.NewBuffer(originalBody)))
270+
case methodDo, methodGet, methodHead, methodPostForm, methodPost:
263271
default:
264-
err = ErrUnexpectedMethod
265-
}
266-
if err != nil {
267-
return nil, err
272+
return nil, ErrUnexpectedMethod
268273
}
269274

270-
if len(p.bodyType) > 0 {
271-
request.Header.Set(headerKeyContentType, p.bodyType)
275+
// provideRequest returns an HTTP request to be use when retrying.
276+
// if concurrency is 1, it will return the same request that was supplied to the Do() method
277+
// for Do() calls, otherwise it will generate a Clone() of the request each time it is called.
278+
// For non-Do() calls, it creates a new request each time it is called. This re-creation behaviour
279+
// is because requests are not supposed to be used again until the RoundTripper is finished
280+
// with them, which cannot be guaranteed with concurrent callers
281+
// https://pkg.go.dev/net/http#RoundTripper
282+
provideRequest := func() (request *http.Request, err error) {
283+
switch p.method {
284+
case methodDo:
285+
if concurrency > 1 {
286+
request = p.req.Clone(p.req.Context())
287+
} else {
288+
request = p.req
289+
}
290+
if request.Body != nil {
291+
// reset the body since Clone() doesn't do that for us
292+
// and we drained it earlier when performing the Copy
293+
// ex: https://go.dev/play/p/jlc6A-fjaOi
294+
resetBody(request, originalBody)
295+
}
296+
case methodGet, methodHead:
297+
request, err = http.NewRequest(p.verb, p.url, nil)
298+
case methodPostForm, methodPost:
299+
request, err = http.NewRequest(http.MethodPost, p.url, bytes.NewBuffer(originalBody))
300+
}
301+
if err != nil {
302+
return
303+
}
304+
305+
if len(p.bodyType) > 0 {
306+
request.Header.Set(headerKeyContentType, p.bodyType)
307+
}
308+
309+
return
272310
}
273311

274312
AttemptLimit := c.MaxRetries
@@ -279,9 +317,15 @@ func (c *Client) pester(p params) (*http.Response, error) {
279317
for n := 0; n < concurrency; n++ {
280318
c.wg.Add(1)
281319
totalSentRequests.Add(1)
282-
go func(n int, req *http.Request) {
320+
go func(n int) {
283321
defer c.wg.Done()
284322
defer totalSentRequests.Done()
323+
req, err := provideRequest()
324+
// couldn't get a request to use, so don't proceed
325+
if err != nil {
326+
multiplexCh <- result{err: err, req: n}
327+
return
328+
}
285329

286330
for i := 1; i <= AttemptLimit; i++ {
287331
c.wg.Add(1)
@@ -340,15 +384,19 @@ func (c *Client) pester(p params) (*http.Response, error) {
340384
case <-time.After(c.Backoff(i) + 1*time.Microsecond):
341385
// allow context cancellation to cancel during backoff
342386
case <-req.Context().Done():
387+
multiplexCh <- result{resp: resp, err: req.Context().Err()}
343388
return
344389
}
345-
}
346-
}(n, request)
347390

348-
// rehydrate the body (it is drained each read)
349-
if request.Body != nil {
350-
request.Body = ioutil.NopCloser(bytes.NewBuffer(originalBody))
351-
}
391+
// we are about to retry, if we had a Body, we will need to restore it
392+
// to a non-closed one in order to work reliably. If you do not do this,
393+
// there are a number of curious edge cases depending on the type of the
394+
// underlying reader: https://go.dev/play/p/gZLVUe2EXSE
395+
if req.Body != nil {
396+
resetBody(req, originalBody)
397+
}
398+
}
399+
}(n)
352400
}
353401

354402
// spin off the go routine so it can continually listen in on late results and close the response bodies

pester_test.go

+224
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"io/ioutil"
78
"log"
89
"net"
910
"net/http"
@@ -743,6 +744,171 @@ func TestRetriesNotAttemptedIfContextIsCancelled(t *testing.T) {
743744
}
744745
}
745746

747+
type roundTripperFunc func(r *http.Request) (*http.Response, error)
748+
749+
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
750+
return f(r)
751+
}
752+
753+
func TestRetriesContextCancelledDuringWait(t *testing.T) {
754+
t.Parallel()
755+
// in order for this test to work we need to be able to reliably put the client in a
756+
// waiting state. To achieve this, we create a client that will fail fast
757+
// via a custom RoundTripper that always fails and pair it with a custom BackoffStrategy
758+
// that waits for a long time. This results in a client that should spend
759+
// almost all of its time waiting.
760+
761+
ctx, cancel := context.WithCancel(context.Background())
762+
763+
c := NewExtendedClient(&http.Client{
764+
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
765+
return nil, fmt.Errorf("always fail")
766+
}),
767+
Timeout: 5 * time.Second,
768+
})
769+
c.MaxRetries = 2
770+
c.Backoff = func(retry int) time.Duration {
771+
return 5 * time.Second
772+
}
773+
// req details don't really matter, round-tripper will fail it anyway
774+
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost", nil)
775+
if err != nil {
776+
t.Fatalf("unable to create request %v", err)
777+
}
778+
779+
// we want to perform the call in a goroutine so we can explicitly check for indefinite
780+
// blocking behaviour. Since you cannot use t.Fatal/t.Error/etc. in a goroutine, we
781+
// create a channel to communicate back to our main goroutine what happened
782+
errReturn := make(chan error)
783+
go func() {
784+
// perform call in goroutine to check for indefinite blocks
785+
_, err := c.Do(req)
786+
errReturn <- err
787+
}()
788+
789+
// wait a hundred ms to let the client fail and get into a waiting state
790+
<-time.After(100 * time.Millisecond)
791+
// cancel our context
792+
cancel()
793+
794+
// if all has gone well, we should have aborted our wait period and the
795+
// err channel should contain a Context-cancellation error
796+
797+
select {
798+
case recdErr := <-errReturn:
799+
if recdErr == nil {
800+
t.Fatal("nil error returned from Do(req) routine")
801+
}
802+
// check that it is the right error message
803+
if context.Canceled != recdErr {
804+
t.Fatalf("unexpected error returned: %v", recdErr)
805+
}
806+
case <-time.After(time.Second):
807+
// give it a second, then treat this as failing to return
808+
t.Fatal("failed to receive error return")
809+
}
810+
}
811+
812+
func TestRetriesWithBodies_Do(t *testing.T) {
813+
t.Parallel()
814+
815+
const testContent = "TestRetriesWithBodies_Do"
816+
// using a channel to route these errors back into this goroutine
817+
// it is important that this channel have enough capacity to hold all
818+
// of the errors that will be generated by the test so that we do not
819+
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
820+
// and each execution must only put at most one error on the channel.
821+
serverReqErrCh := make(chan error, 4)
822+
port, closeFn, err := middlewareServer(
823+
contentVerificationMiddleware(serverReqErrCh, testContent),
824+
always500RequestMiddleware(),
825+
)
826+
if err != nil {
827+
t.Fatal("unable to start timeout server", err)
828+
}
829+
defer closeFn()
830+
831+
<-time.After(2 * time.Second)
832+
833+
iseUrl := fmt.Sprintf("http://localhost:%d", port)
834+
835+
req, err := http.NewRequest("POST", iseUrl, strings.NewReader(testContent))
836+
if err != nil {
837+
t.Fatalf("unable to create request %v", err)
838+
}
839+
840+
c := New()
841+
c.MaxRetries = cap(serverReqErrCh)
842+
c.KeepLog = true
843+
c.Backoff = func(retry int) time.Duration {
844+
// backoff isn't important for this test
845+
return 0
846+
}
847+
848+
resp, err := c.Do(req)
849+
if err != nil {
850+
t.Errorf("unexpected error: %v", err)
851+
}
852+
if resp == nil {
853+
t.Error("response was unexpectedly nil")
854+
} else if resp.StatusCode != http.StatusInternalServerError {
855+
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
856+
}
857+
// we're done making requests, so close the return channel and drain it
858+
close(serverReqErrCh)
859+
for v := range serverReqErrCh {
860+
if v != nil {
861+
t.Errorf("unexpected error occurred when server processed request: %v", v)
862+
}
863+
}
864+
}
865+
866+
func TestRetriesWithBodies_POST(t *testing.T) {
867+
t.Parallel()
868+
869+
const testContent = "TestRetriesWithBodies_POST"
870+
// using a channel to route these errors back into this goroutine
871+
// it is important that this channel have enough capacity to hold all
872+
// of the errors that will be generated by the test so that we do not
873+
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
874+
// and each execution must only put at most one error on the channel.
875+
serverReqErrCh := make(chan error, 4)
876+
port, closeFn, err := middlewareServer(
877+
contentVerificationMiddleware(serverReqErrCh, testContent),
878+
always500RequestMiddleware(),
879+
)
880+
if err != nil {
881+
t.Fatal("unable to start timeout server", err)
882+
}
883+
defer closeFn()
884+
885+
c := New()
886+
c.MaxRetries = cap(serverReqErrCh)
887+
c.KeepLog = true
888+
c.Backoff = func(retry int) time.Duration {
889+
// backoff isn't important for this test
890+
return 0
891+
}
892+
893+
iseUrl := fmt.Sprintf("http://localhost:%d", port)
894+
resp, err := c.Post(iseUrl, "text/plain", strings.NewReader(testContent))
895+
if err != nil {
896+
t.Errorf("unexpected error: %v", err)
897+
}
898+
if resp == nil {
899+
t.Error("response was unexpectedly nil")
900+
} else if resp.StatusCode != http.StatusInternalServerError {
901+
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
902+
}
903+
// we're done making requests, so close the return channel and drain it
904+
close(serverReqErrCh)
905+
for v := range serverReqErrCh {
906+
if v != nil {
907+
t.Errorf("unexpected error occurred when server processed request: %v", v)
908+
}
909+
}
910+
}
911+
746912
func withinEpsilon(got, want int64, epslion float64) bool {
747913
if want <= int64(epslion*float64(got)) || want >= int64(epslion*float64(got)) {
748914
return false
@@ -880,3 +1046,61 @@ func serverWith400() (int, error) {
8801046

8811047
return port, nil
8821048
}
1049+
1050+
func contentVerificationMiddleware(errorCh chan<- error, expectedContent string) http.Handler {
1051+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1052+
content, err := ioutil.ReadAll(r.Body)
1053+
defer r.Body.Close()
1054+
if err != nil {
1055+
errorCh <- err
1056+
} else if string(content) != expectedContent {
1057+
errorCh <- fmt.Errorf(
1058+
"unexpected body content: expected \"%v\", got \"%v\"",
1059+
expectedContent,
1060+
string(content),
1061+
)
1062+
}
1063+
})
1064+
}
1065+
1066+
func always500RequestMiddleware() http.Handler {
1067+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1068+
w.WriteHeader(http.StatusInternalServerError)
1069+
w.Write([]byte("500 Internal Server Error"))
1070+
})
1071+
}
1072+
1073+
// middlewareServer stands up a server that accepts varags of middleware that conforms to the
1074+
// http.Handler interface
1075+
func middlewareServer(requestMiddleware ...http.Handler) (int, func(), error) {
1076+
mux := http.NewServeMux()
1077+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
1078+
for _, v := range requestMiddleware {
1079+
v.ServeHTTP(w, r)
1080+
}
1081+
})
1082+
l, err := net.Listen("tcp", ":0")
1083+
if err != nil {
1084+
return -1, nil, fmt.Errorf("unable to secure listener %v", err)
1085+
}
1086+
server := &http.Server{
1087+
Handler: mux,
1088+
}
1089+
go func() {
1090+
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
1091+
log.Fatalf("middleware-server error %v", err)
1092+
}
1093+
}()
1094+
1095+
var port int
1096+
_, sport, err := net.SplitHostPort(l.Addr().String())
1097+
if err == nil {
1098+
port, err = strconv.Atoi(sport)
1099+
}
1100+
1101+
if err != nil {
1102+
return -1, nil, fmt.Errorf("unable to determine port %v", err)
1103+
}
1104+
1105+
return port, func() { server.Close() }, nil
1106+
}

0 commit comments

Comments
 (0)