From 60073a18b318ae7364e4386c656ee42b32a711df Mon Sep 17 00:00:00 2001 From: Daniel Fireman Date: Thu, 28 Sep 2017 10:22:23 -0300 Subject: [PATCH] Adding headers flag. --- cmd/replay/replay.go | 43 +++++++++++++++++++++++++++++++- cmd/replay/replay_test.go | 52 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 cmd/replay/replay_test.go diff --git a/cmd/replay/replay.go b/cmd/replay/replay.go index 160b6ab..28b4c2f 100644 --- a/cmd/replay/replay.go +++ b/cmd/replay/replay.go @@ -2,6 +2,7 @@ package replay import ( "bufio" + "bytes" "context" "encoding/json" "fmt" @@ -24,6 +25,33 @@ import ( "github.com/spf13/cobra" ) +type headersFlag struct{ http.Header } + +func (h headersFlag) String() string { + buf := &bytes.Buffer{} + if err := h.Write(buf); err != nil { + return "" + } + return buf.String() +} + +func (h headersFlag) Set(value string) error { + parts := strings.SplitN(value, ":", 2) + if len(parts) < 2 { + return fmt.Errorf("[header] wrong format:'%s'", value) + } + key, val := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if key == "" || val == "" { + return fmt.Errorf("[header] wrong format:'%s'", value) + } + h.Header.Add(key, val) + return nil +} + +func (h headersFlag) Type() string { + return "headers" +} + var ( host string resultsPath string @@ -34,6 +62,9 @@ var ( numClients int isPaused int32 continueOn400 bool + // Adding Content-Type:application/json as default. + // https://www.elastic.co/blog/strict-content-type-checking-for-elasticsearch-rest-requests + headers = headersFlag{http.Header{"Content-Type": []string{"application/json"}}} ) func init() { @@ -45,6 +76,7 @@ func init() { RootCmd.Flags().BoolVar(&debug, "debug", false, "Dump requests and responses.") RootCmd.Flags().IntVarP(&numClients, "num_clients", "c", 10, "Number of active clients making requests.") RootCmd.Flags().BoolVar(&continueOn400, "continue_on_400s", false, "Whether the loadtest should continue if it receives a 400 response.") + RootCmd.Flags().VarP(&headers, "headers", "H", "Custom HTTP headers. You can specify as many as needed by repeating the flag. \"Content-Type: application/json\" is added by default.") } var ( @@ -201,7 +233,7 @@ func (r *runner) Run() error { r.clients <- client }() - req, err := http.NewRequest("GET", entry.URL, strings.NewReader(entry.Source)) + req, err := newRequest(entry.URL, entry.Source) if err != nil { // TODO(danielfireman): Make this more elegant. Leveraging cobra error messages. fmt.Printf("Error creating request: %q\n", err) @@ -326,3 +358,12 @@ func (r *runner) Run() error { } return nil } + +func newRequest(url, source string) (*http.Request, error) { + req, err := http.NewRequest("GET", url, strings.NewReader(source)) + if err != nil { + return nil, err + } + req.Header = headers.Header + return req, nil +} diff --git a/cmd/replay/replay_test.go b/cmd/replay/replay_test.go new file mode 100644 index 0000000..60db189 --- /dev/null +++ b/cmd/replay/replay_test.go @@ -0,0 +1,52 @@ +package replay + +import ( + "io/ioutil" + "net/http" + "testing" +) + +func TestNewRequest(t *testing.T) { + t.Run("ValidRequest", func(t *testing.T) { + req, err := newRequest("url", "source") + if err != nil { + t.Fatalf("error got:%q want:nil", err) + } + if req.URL.String() != "url" { + t.Fatalf("got:%s want:url", req.URL.String()) + } + b, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatalf("error got:%q want:nil", err) + } + if string(b) != "source" { + t.Fatalf("got:%s want:source", b) + } + }) + + t.Run("InvalidRequest", func(t *testing.T) { + _, err := newRequest("%zzzzz", "source") + if err == nil { + t.Fatalf("error got:nil want:error") + } + }) +} + +func TestHeadersFlag_Set(t *testing.T) { + h := headersFlag{http.Header{}} + t.Run("ValidHeader", func(t *testing.T) { + if err := h.Set("MyHeader:Foo"); err != nil { + t.Fatalf("error got:%q want:nil", err) + } + }) + t.Run("NoValue", func(t *testing.T) { + if err := h.Set("MyHeader"); err == nil { + t.Fatalf("error got:nil want:error") + } + }) + t.Run("KeyValueAsWhitespaces", func(t *testing.T) { + if err := h.Set(" : "); err == nil { + t.Fatalf("error got:nil want:error") + } + }) +}