diff --git a/apitest.go b/apitest.go index 80d32a4..096352e 100644 --- a/apitest.go +++ b/apitest.go @@ -137,6 +137,12 @@ func (a *APITest) Handler(handler http.Handler) *APITest { return a } +// HandlerFunc defines the http handler that is invoked when the test is run +func (a *APITest) HandlerFunc(handlerFunc http.HandlerFunc) *APITest { + a.handler = handlerFunc + return a +} + // Mocks is a builder method for setting the mocks func (a *APITest) Mocks(mocks ...*Mock) *APITest { var m []*Mock @@ -689,10 +695,7 @@ func (r *Response) runTest() *http.Response { a.assertResponse(res) a.assertHeaders(res) a.assertCookies(res) - err := a.assertFunc(res, req) - if err != nil { - a.t.Fatal(err.Error()) - } + a.assertFunc(res, req) return copyHttpResponse(res) } @@ -705,16 +708,15 @@ func (a *APITest) assertMocks() { } } -func (a *APITest) assertFunc(res *http.Response, req *http.Request) error { +func (a *APITest) assertFunc(res *http.Response, req *http.Request) { if len(a.response.assert) > 0 { for _, assertFn := range a.response.assert { err := assertFn(copyHttpResponse(res), copyHttpRequest(req)) if err != nil { - return err + a.verifier.Equal(a.t, nil, err) } } } - return nil } func (a *APITest) doRequest() (*http.Response, *http.Request) { diff --git a/apitest_test.go b/apitest_test.go index 476f95d..3875662 100644 --- a/apitest_test.go +++ b/apitest_test.go @@ -2,8 +2,6 @@ package apitest_test import ( "fmt" - "github.com/steinfletcher/apitest" - "github.com/steinfletcher/apitest/mocks" "io/ioutil" "net/http" "net/http/cookiejar" @@ -11,6 +9,9 @@ import ( "testing" "time" + "github.com/steinfletcher/apitest" + "github.com/steinfletcher/apitest/mocks" + "github.com/stretchr/testify/assert" ) @@ -546,19 +547,17 @@ func TestApiTest_EndReturnsTheResult(t *testing.T) { type resBody struct { B string `json:"b"` } - handler := http.NewServeMux() - handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusCreated) - w.Header().Set("Content-Type", "application/json") - _, err := w.Write([]byte(`{"a": 12345, "b": "hi"}`)) - if err != nil { - panic(err) - } - }) var r resBody apitest.New(). - Handler(handler). + HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{"a": 12345, "b": "hi"}`)) + if err != nil { + panic(err) + } + }). Get("/hello"). Expect(t). Body(`{ diff --git a/examples/go.sum b/examples/go.sum index b283228..0e80249 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -159,6 +159,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/ugorji/go v1.1.2 h1:JON3E2/GPW2iDNGoSAusl1KDf5TRQ8k8q7Tp097pZGs= github.com/ugorji/go v1.1.2/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= github.com/ugorji/go/codec v0.0.0-20190126102652-8fd0f8d918c8 h1:X8lhf4a2HZiqw4DKNWz9aFZdssVV69au98QlhPXrEp8=