diff --git a/README.md b/README.md index f649f2d..8f57f7c 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,20 @@ func TestApi(t *testing.T) { } ``` +#### Provide a multipart/form-data + +```go +func TestApi(t *testing.T) { + apitest.Handler(handler). + Post("/hello"). + MultipartFormData("a", "1", "2"). + MultipartFile("file", "path/to/some.file1", "path/to/some.file2"). + Expect(t). + Status(http.StatusOK). + End() +} +``` + #### Capture the request and response data ```go diff --git a/apitest.go b/apitest.go index 55223fd..56f35d3 100644 --- a/apitest.go +++ b/apitest.go @@ -5,12 +5,16 @@ import ( "encoding/json" "fmt" "hash/fnv" + "io" "io/ioutil" + "mime/multipart" "net/http" "net/http/httptest" "net/http/httputil" "net/textproto" "net/url" + "os" + "path/filepath" "runtime/debug" "sort" "strings" @@ -221,6 +225,8 @@ type Request struct { queryCollection map[string][]string headers map[string][]string formData map[string][]string + multipartBody *bytes.Buffer + multipart *multipart.Writer cookies []*Cookie basicAuth string apiTest *APITest @@ -481,11 +487,71 @@ func (r *Request) BasicAuth(username, password string) *Request { // FormData is a builder method to set the body form data // Also sets the content type of the request to application/x-www-form-urlencoded func (r *Request) FormData(name string, values ...string) *Request { + defer r.checkCombineFormDataWithMultipart() + r.ContentType("application/x-www-form-urlencoded") r.formData[name] = append(r.formData[name], values...) return r } +// MultipartFormData is a builder method to set the field in multipart form data +// Also sets the content type of the request to multipart/form-data +func (r *Request) MultipartFormData(name string, values ...string) *Request { + defer r.checkCombineFormDataWithMultipart() + + r.setMultipartWriter() + + for _, value := range values { + if err := r.multipart.WriteField(name, value); err != nil { + r.apiTest.t.Fatal(err) + } + } + + return r +} + +// MultipartFile is a builder method to set the file in multipart form data +// Also sets the content type of the request to multipart/form-data +func (r *Request) MultipartFile(name string, ff ...string) *Request { + defer r.checkCombineFormDataWithMultipart() + + r.setMultipartWriter() + + for _, f := range ff { + func() { + file, err := os.Open(f) + if err != nil { + r.apiTest.t.Fatal(err) + } + defer file.Close() + + part, err := r.multipart.CreateFormFile(name, filepath.Base(file.Name())) + if err != nil { + r.apiTest.t.Fatal(err) + } + + if _, err = io.Copy(part, file); err != nil { + r.apiTest.t.Fatal(err) + } + }() + } + + return r +} + +func (r *Request) setMultipartWriter() { + if r.multipart == nil { + r.multipartBody = &bytes.Buffer{} + r.multipart = multipart.NewWriter(r.multipartBody) + } +} + +func (r *Request) checkCombineFormDataWithMultipart() { + if r.multipart != nil && len(r.formData) > 0 { + r.apiTest.t.Fatal("FormData (application/x-www-form-urlencoded) and MultiPartFormData(multipart/form-data) cannot be combined") + } +} + // Expect marks the request spec as complete and following code will define the expected response func (r *Request) Expect(t TestingT) *Response { r.apiTest.t = t @@ -896,7 +962,17 @@ func (a *APITest) buildRequest() *http.Request { form.Add(k, value) } } - a.request.body = form.Encode() + a.request.Body(form.Encode()) + } + + if a.request.multipart != nil { + err := a.request.multipart.Close() + if err != nil { + a.request.apiTest.t.Fatal(err) + } + + a.request.Header("Content-Type", a.request.multipart.FormDataContentType()) + a.request.Body(a.request.multipartBody.String()) } req, _ := http.NewRequest(a.request.method, a.request.url, bytes.NewBufferString(a.request.body)) diff --git a/apitest_test.go b/apitest_test.go index 7981a3d..7d627bb 100644 --- a/apitest_test.go +++ b/apitest_test.go @@ -8,6 +8,8 @@ import ( "net/http" "net/http/cookiejar" "net/http/httptest" + "os" + "os/exec" "reflect" "strings" "testing" @@ -1139,6 +1141,115 @@ func TestApiTest_AddsUrlEncodedFormBody(t *testing.T) { End() } +func TestApiTest_AddsMultipartFormData(t *testing.T) { + handler := http.NewServeMux() + handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header["Content-Type"][0], "multipart/form-data") { + w.WriteHeader(http.StatusBadRequest) + return + } + + err := r.ParseMultipartForm(2 << 32) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + expectedPostFormData := map[string][]string{ + "name": {"John"}, + "age": {"99"}, + "children": {"Jack", "Ann"}, + "pets": {"Toby", "Henry", "Alice"}, + } + + for key := range expectedPostFormData { + if !reflect.DeepEqual(expectedPostFormData[key], r.MultipartForm.Value[key]) { + w.WriteHeader(http.StatusBadRequest) + return + } + } + + for _, exp := range []struct { + filename string + data string + }{ + { + filename: "response_body", + data: `{"a": 12345}`, + }, + { + filename: "mock_request_body", + data: `{"bodyKey": "bodyVal"}`, + }, + } { + for _, file := range r.MultipartForm.File[exp.filename] { + assert.Equal(t, exp.filename+".json", file.Filename) + + f, err := file.Open() + if err != nil { + t.Fatal(err) + } + data, err := ioutil.ReadAll(f) + if err != nil { + t.Fatal(err) + } + assert.JSONEq(t, exp.data, string(data)) + } + } + + w.WriteHeader(http.StatusOK) + }) + + apitest.New(). + Handler(handler). + Post("/hello"). + MultipartFormData("name", "John"). + MultipartFormData("age", "99"). + MultipartFormData("children", "Jack"). + MultipartFormData("children", "Ann"). + MultipartFormData("pets", "Toby", "Henry", "Alice"). + MultipartFile("request_body", "testdata/request_body.json", "testdata/request_body.json"). + MultipartFile("mock_request_body", "testdata/mock_request_body.json"). + Expect(t). + Status(http.StatusOK). + End() +} + +func TestApiTest_CombineFormDataWithMultipart(t *testing.T) { + if os.Getenv("RUN_FATAL_TEST") == "FormData" { + apitest.New(). + Post("/hello"). + MultipartFormData("name", "John"). + FormData("name", "John") + return + } + if os.Getenv("RUN_FATAL_TEST") == "File" { + apitest.New(). + Post("/hello"). + MultipartFile("file", "testdata/request_body.json"). + FormData("name", "John") + return + } + + tests := map[string]string{ + "formdata_with_multiple_formdata": "FormData", + "formdata_with_multiple_file": "File", + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + + cmd := exec.Command(os.Args[0], "-test.run=TestApiTest_CombineFormDataWithMultipart") + cmd.Env = append(os.Environ(), "RUN_FATAL_TEST="+tt) + err := cmd.Run() + if e, ok := err.(*exec.ExitError); ok && !e.Success() { + return + } + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + } +} + func TestApiTest_ErrorIfMockInvocationsDoNotMatchTimes(t *testing.T) { getUser := apitest.NewMock(). Get("http://localhost:8080").