|
4 | 4 | "context"
|
5 | 5 | "errors"
|
6 | 6 | "fmt"
|
| 7 | + "io/ioutil" |
7 | 8 | "log"
|
8 | 9 | "net"
|
9 | 10 | "net/http"
|
@@ -743,6 +744,171 @@ func TestRetriesNotAttemptedIfContextIsCancelled(t *testing.T) {
|
743 | 744 | }
|
744 | 745 | }
|
745 | 746 |
|
| 747 | +type roundTripperFunc func(r *http.Request) (*http.Response, error) |
| 748 | + |
| 749 | +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { |
| 750 | + return f(r) |
| 751 | +} |
| 752 | + |
| 753 | +func TestRetriesContextCancelledDuringWait(t *testing.T) { |
| 754 | + t.Parallel() |
| 755 | + // in order for this test to work we need to be able to reliably put the client in a |
| 756 | + // waiting state. To achieve this, we create a client that will fail fast |
| 757 | + // via a custom RoundTripper that always fails and pair it with a custom BackoffStrategy |
| 758 | + // that waits for a long time. This results in a client that should spend |
| 759 | + // almost all of its time waiting. |
| 760 | + |
| 761 | + ctx, cancel := context.WithCancel(context.Background()) |
| 762 | + |
| 763 | + c := NewExtendedClient(&http.Client{ |
| 764 | + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { |
| 765 | + return nil, fmt.Errorf("always fail") |
| 766 | + }), |
| 767 | + Timeout: 5 * time.Second, |
| 768 | + }) |
| 769 | + c.MaxRetries = 2 |
| 770 | + c.Backoff = func(retry int) time.Duration { |
| 771 | + return 5 * time.Second |
| 772 | + } |
| 773 | + // req details don't really matter, round-tripper will fail it anyway |
| 774 | + req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost", nil) |
| 775 | + if err != nil { |
| 776 | + t.Fatalf("unable to create request %v", err) |
| 777 | + } |
| 778 | + |
| 779 | + // we want to perform the call in a goroutine so we can explicitly check for indefinite |
| 780 | + // blocking behaviour. Since you cannot use t.Fatal/t.Error/etc. in a goroutine, we |
| 781 | + // create a channel to communicate back to our main goroutine what happened |
| 782 | + errReturn := make(chan error) |
| 783 | + go func() { |
| 784 | + // perform call in goroutine to check for indefinite blocks |
| 785 | + _, err := c.Do(req) |
| 786 | + errReturn <- err |
| 787 | + }() |
| 788 | + |
| 789 | + // wait a hundred ms to let the client fail and get into a waiting state |
| 790 | + <-time.After(100 * time.Millisecond) |
| 791 | + // cancel our context |
| 792 | + cancel() |
| 793 | + |
| 794 | + // if all has gone well, we should have aborted our wait period and the |
| 795 | + // err channel should contain a Context-cancellation error |
| 796 | + |
| 797 | + select { |
| 798 | + case recdErr := <-errReturn: |
| 799 | + if recdErr == nil { |
| 800 | + t.Fatal("nil error returned from Do(req) routine") |
| 801 | + } |
| 802 | + // check that it is the right error message |
| 803 | + if context.Canceled != recdErr { |
| 804 | + t.Fatalf("unexpected error returned: %v", recdErr) |
| 805 | + } |
| 806 | + case <-time.After(time.Second): |
| 807 | + // give it a second, then treat this as failing to return |
| 808 | + t.Fatal("failed to receive error return") |
| 809 | + } |
| 810 | +} |
| 811 | + |
| 812 | +func TestRetriesWithBodies_Do(t *testing.T) { |
| 813 | + t.Parallel() |
| 814 | + |
| 815 | + const testContent = "TestRetriesWithBodies_Do" |
| 816 | + // using a channel to route these errors back into this goroutine |
| 817 | + // it is important that this channel have enough capacity to hold all |
| 818 | + // of the errors that will be generated by the test so that we do not |
| 819 | + // deadlock. Therefore, MaxAttempts must be the same size as the channel capacity |
| 820 | + // and each execution must only put at most one error on the channel. |
| 821 | + serverReqErrCh := make(chan error, 4) |
| 822 | + port, closeFn, err := middlewareServer( |
| 823 | + contentVerificationMiddleware(serverReqErrCh, testContent), |
| 824 | + always500RequestMiddleware(), |
| 825 | + ) |
| 826 | + if err != nil { |
| 827 | + t.Fatal("unable to start timeout server", err) |
| 828 | + } |
| 829 | + defer closeFn() |
| 830 | + |
| 831 | + <-time.After(2 * time.Second) |
| 832 | + |
| 833 | + iseUrl := fmt.Sprintf("http://localhost:%d", port) |
| 834 | + |
| 835 | + req, err := http.NewRequest("POST", iseUrl, strings.NewReader(testContent)) |
| 836 | + if err != nil { |
| 837 | + t.Fatalf("unable to create request %v", err) |
| 838 | + } |
| 839 | + |
| 840 | + c := New() |
| 841 | + c.MaxRetries = cap(serverReqErrCh) |
| 842 | + c.KeepLog = true |
| 843 | + c.Backoff = func(retry int) time.Duration { |
| 844 | + // backoff isn't important for this test |
| 845 | + return 0 |
| 846 | + } |
| 847 | + |
| 848 | + resp, err := c.Do(req) |
| 849 | + if err != nil { |
| 850 | + t.Errorf("unexpected error: %v", err) |
| 851 | + } |
| 852 | + if resp == nil { |
| 853 | + t.Error("response was unexpectedly nil") |
| 854 | + } else if resp.StatusCode != http.StatusInternalServerError { |
| 855 | + t.Errorf("unexpected response StatusCode: %v", resp.StatusCode) |
| 856 | + } |
| 857 | + // we're done making requests, so close the return channel and drain it |
| 858 | + close(serverReqErrCh) |
| 859 | + for v := range serverReqErrCh { |
| 860 | + if v != nil { |
| 861 | + t.Errorf("unexpected error occurred when server processed request: %v", v) |
| 862 | + } |
| 863 | + } |
| 864 | +} |
| 865 | + |
| 866 | +func TestRetriesWithBodies_POST(t *testing.T) { |
| 867 | + t.Parallel() |
| 868 | + |
| 869 | + const testContent = "TestRetriesWithBodies_POST" |
| 870 | + // using a channel to route these errors back into this goroutine |
| 871 | + // it is important that this channel have enough capacity to hold all |
| 872 | + // of the errors that will be generated by the test so that we do not |
| 873 | + // deadlock. Therefore, MaxAttempts must be the same size as the channel capacity |
| 874 | + // and each execution must only put at most one error on the channel. |
| 875 | + serverReqErrCh := make(chan error, 4) |
| 876 | + port, closeFn, err := middlewareServer( |
| 877 | + contentVerificationMiddleware(serverReqErrCh, testContent), |
| 878 | + always500RequestMiddleware(), |
| 879 | + ) |
| 880 | + if err != nil { |
| 881 | + t.Fatal("unable to start timeout server", err) |
| 882 | + } |
| 883 | + defer closeFn() |
| 884 | + |
| 885 | + c := New() |
| 886 | + c.MaxRetries = cap(serverReqErrCh) |
| 887 | + c.KeepLog = true |
| 888 | + c.Backoff = func(retry int) time.Duration { |
| 889 | + // backoff isn't important for this test |
| 890 | + return 0 |
| 891 | + } |
| 892 | + |
| 893 | + iseUrl := fmt.Sprintf("http://localhost:%d", port) |
| 894 | + resp, err := c.Post(iseUrl, "text/plain", strings.NewReader(testContent)) |
| 895 | + if err != nil { |
| 896 | + t.Errorf("unexpected error: %v", err) |
| 897 | + } |
| 898 | + if resp == nil { |
| 899 | + t.Error("response was unexpectedly nil") |
| 900 | + } else if resp.StatusCode != http.StatusInternalServerError { |
| 901 | + t.Errorf("unexpected response StatusCode: %v", resp.StatusCode) |
| 902 | + } |
| 903 | + // we're done making requests, so close the return channel and drain it |
| 904 | + close(serverReqErrCh) |
| 905 | + for v := range serverReqErrCh { |
| 906 | + if v != nil { |
| 907 | + t.Errorf("unexpected error occurred when server processed request: %v", v) |
| 908 | + } |
| 909 | + } |
| 910 | +} |
| 911 | + |
746 | 912 | func withinEpsilon(got, want int64, epslion float64) bool {
|
747 | 913 | if want <= int64(epslion*float64(got)) || want >= int64(epslion*float64(got)) {
|
748 | 914 | return false
|
@@ -880,3 +1046,61 @@ func serverWith400() (int, error) {
|
880 | 1046 |
|
881 | 1047 | return port, nil
|
882 | 1048 | }
|
| 1049 | + |
| 1050 | +func contentVerificationMiddleware(errorCh chan<- error, expectedContent string) http.Handler { |
| 1051 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 1052 | + content, err := ioutil.ReadAll(r.Body) |
| 1053 | + defer r.Body.Close() |
| 1054 | + if err != nil { |
| 1055 | + errorCh <- err |
| 1056 | + } else if string(content) != expectedContent { |
| 1057 | + errorCh <- fmt.Errorf( |
| 1058 | + "unexpected body content: expected \"%v\", got \"%v\"", |
| 1059 | + expectedContent, |
| 1060 | + string(content), |
| 1061 | + ) |
| 1062 | + } |
| 1063 | + }) |
| 1064 | +} |
| 1065 | + |
| 1066 | +func always500RequestMiddleware() http.Handler { |
| 1067 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 1068 | + w.WriteHeader(http.StatusInternalServerError) |
| 1069 | + w.Write([]byte("500 Internal Server Error")) |
| 1070 | + }) |
| 1071 | +} |
| 1072 | + |
| 1073 | +// middlewareServer stands up a server that accepts varags of middleware that conforms to the |
| 1074 | +// http.Handler interface |
| 1075 | +func middlewareServer(requestMiddleware ...http.Handler) (int, func(), error) { |
| 1076 | + mux := http.NewServeMux() |
| 1077 | + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { |
| 1078 | + for _, v := range requestMiddleware { |
| 1079 | + v.ServeHTTP(w, r) |
| 1080 | + } |
| 1081 | + }) |
| 1082 | + l, err := net.Listen("tcp", ":0") |
| 1083 | + if err != nil { |
| 1084 | + return -1, nil, fmt.Errorf("unable to secure listener %v", err) |
| 1085 | + } |
| 1086 | + server := &http.Server{ |
| 1087 | + Handler: mux, |
| 1088 | + } |
| 1089 | + go func() { |
| 1090 | + if err := server.Serve(l); err != nil && err != http.ErrServerClosed { |
| 1091 | + log.Fatalf("middleware-server error %v", err) |
| 1092 | + } |
| 1093 | + }() |
| 1094 | + |
| 1095 | + var port int |
| 1096 | + _, sport, err := net.SplitHostPort(l.Addr().String()) |
| 1097 | + if err == nil { |
| 1098 | + port, err = strconv.Atoi(sport) |
| 1099 | + } |
| 1100 | + |
| 1101 | + if err != nil { |
| 1102 | + return -1, nil, fmt.Errorf("unable to determine port %v", err) |
| 1103 | + } |
| 1104 | + |
| 1105 | + return port, func() { server.Close() }, nil |
| 1106 | +} |
0 commit comments