From 1e0d9d7b327e515a68e62be3969f367f7ca87496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Tue, 2 Apr 2024 15:59:48 +0200 Subject: [PATCH] Add an optional Init func to the server, fix a data race That takes a config struct from the client. The old way of configuring the server was to pass env vars (which still works), but this was at best very cumbersome. This also fixes a data race when both sending raw (e.g. log messages) and other responses. Closes #10 --- README.md | 135 +++++++++++++------------ client.go | 100 +++++++++++++----- client_test.go | 118 ++++++++++++++------- examples/model/model.go | 18 ++++ examples/servers/raw/go.mod | 2 +- examples/servers/readmeexample/go.mod | 13 +++ examples/servers/readmeexample/go.sum | 21 ++++ examples/servers/readmeexample/main.go | 96 ++++++++++++++++++ examples/servers/typed/go.mod | 2 +- examples/servers/typed/main.go | 52 ++++------ message.go | 1 + server.go | 89 ++++++++++++---- 12 files changed, 471 insertions(+), 176 deletions(-) create mode 100644 examples/servers/readmeexample/go.mod create mode 100644 examples/servers/readmeexample/go.sum create mode 100644 examples/servers/readmeexample/main.go diff --git a/README.md b/README.md index 24df633..b70b99d 100644 --- a/README.md +++ b/README.md @@ -7,66 +7,60 @@ This library implements a simple, custom [RPC protocol](https://en.wikipedia.org A strongly typed client may look like this: ```go -package main - -import ( - "fmt" - "log" - "time" - - "github.com/bep/execrpc" - "github.com/bep/execrpc/codecs" - "github.com/bep/execrpc/examples/model" -) - -func main() { - // Define the request, message and receipt types for the RPC call. +// Define the request, message and receipt types for the RPC call. +client, err := execrpc.StartClient( client, err := execrpc.StartClient( - execrpc.ClientOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ - ClientRawOptions: execrpc.ClientRawOptions{ - Version: 1, - Cmd: "go", - Dir: "./examples/servers/typed", - Args: []string{"run", "."}, - Env: nil, - Timeout: 30 * time.Second, - }, - Codec: codecs.JSONCodec{}, + execrpc.ClientOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + ClientRawOptions: execrpc.ClientRawOptions{ + Version: 1, + Cmd: "go", + Dir: "./examples/servers/typed", + Args: []string{"run", "."}, + Env: env, + Timeout: 30 * time.Second, }, - ) - if err != nil { - log.Fatal(err) + Config: model.ExampleConfig{}, + Codec: codec, + }, +) + +if err != nil { + log.Fatal(err) +} + +// Consume standalone messages (e.g. log messages) in its own goroutine. +go func() { + for msg := range client.MessagesRaw() { + fmt.Println("got message", string(msg.Body)) } +}() - // Consume standalone messages (e.g. log messages) in its own goroutine. - go func() { - for msg := range client.MessagesRaw() { - fmt.Println("got message", string(msg.Body)) - } - }() +// Execute the request. +result := client.Execute(model.ExampleRequest{Text: "world"}) - // Execute the request. - result := client.Execute(model.ExampleRequest{Text: "world"}) +// Check for errors. +if err := result.Err(); err != nil { + log.Fatal(err) +} - // Check for errors. - if err := result.Err(); err != nil { - log.Fatal(err) - } +// Consume the messages. +for m := range result.Messages() { + fmt.Println(m) +} - // Consume the messages. - for m := range result.Messages() { - fmt.Println(m) - } +// Wait for the receipt. +receipt := <-result.Receipt() - // Wait for the receipt. - receipt := <-result.Receipt() +// Check again for errors. +if err := result.Err(); err != nil { + log.Fatal(err) +} - // Check again for errors. - if err := result.Err(); err != nil { - log.Fatal(err) - } +fmt.Println(receipt.Text) - fmt.Println(receipt.Text) +// Close the client. +if err := client.Close(); err != nil { + log.Fatal(err) } ``` @@ -75,20 +69,32 @@ To get the best performance you should keep the client open as long as its neede And the server side of the above: ```go + func main() { - getHasher := func() hash.Hash { - return fnv.New64a() - } + log.SetFlags(0) + log.SetPrefix("readme-example: ") + + var clientConfig model.ExampleConfig server, err := execrpc.NewServer( - execrpc.ServerOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ - // Optional function to get a hasher for the ETag. - GetHasher: getHasher, + execrpc.ServerOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + // Optional function to provide a hasher for the ETag. + GetHasher: func() hash.Hash { + return fnv.New64a() + }, // Allows you to delay message delivery, and drop // them after reading the receipt (e.g. the ETag matches the ETag seen by client). DelayDelivery: false, + // Optional function to initialize the server + // with the client configuration. + // This will be called once on server start. + Init: func(cfg model.ExampleConfig) error { + clientConfig = cfg + return clientConfig.Init() + }, + // Handle the incoming call. Handle: func(c *execrpc.Call[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) { // Raw messages are passed directly to the client, @@ -99,7 +105,7 @@ func main() { Version: 32, Status: 150, }, - Body: []byte("a log message"), + Body: []byte("log message"), }, ) @@ -124,12 +130,13 @@ func main() { // ETag provided by the framework. // A hash of all message bodies. - fmt.Println("Receipt:", receipt.ETag) + // fmt.Println("Receipt:", receipt.ETag) // Modify if needed. receipt.Size = uint32(123) + receipt.Text = "echoed: " + c.Request.Text - // Close the message stream. + // Close the message stream and send the receipt. // Pass true to drop any queued messages, // this is only relevant if DelayDelivery is enabled. c.Close(false, receipt) @@ -137,14 +144,18 @@ func main() { }, ) if err != nil { - log.Fatal(err) + handleErr(err) } - // Start the server. This will block. if err := server.Start(); err != nil { - log.Fatal(err) + handleErr(err) } } + +func handleErr(err error) { + log.Fatalf("error: failed to start typed echo server: %s", err) +} + ``` ## Generate ETag diff --git a/client.go b/client.go index 8476f1f..bcbeff7 100644 --- a/client.go +++ b/client.go @@ -24,7 +24,7 @@ const ( ) // StartClient starts a client for the given options. -func StartClient[Q, M, R any](opts ClientOptions[Q, M, R]) (*Client[Q, M, R], error) { +func StartClient[C, Q, M, R any](opts ClientOptions[C, Q, M, R]) (*Client[C, Q, M, R], error) { if opts.Codec == nil { return nil, errors.New("opts: Codec is required") } @@ -37,16 +37,23 @@ func StartClient[Q, M, R any](opts ClientOptions[Q, M, R]) (*Client[Q, M, R], er return nil, err } - return &Client[Q, M, R]{ + c := &Client[C, Q, M, R]{ rawClient: rawClient, opts: opts, - }, nil + } + + err = c.init(opts.Config) + if err != nil { + return nil, err + } + + return c, nil } // Client is a strongly typed RPC client. -type Client[Q, M, R any] struct { +type Client[C, Q, M, R any] struct { rawClient *ClientRaw - opts ClientOptions[Q, M, R] + opts ClientOptions[C, Q, M, R] } // Result is the result of a request @@ -85,13 +92,49 @@ func (r Result[M, R]) close() { // MessagesRaw returns the raw messages from the server. // These are not connected to the request-response flow, // typically used for log messages etc. -func (c *Client[Q, M, R]) MessagesRaw() <-chan Message { +func (c *Client[C, Q, M, R]) MessagesRaw() <-chan Message { return c.rawClient.Messages } +// init passes the configuration to the server. +func (c *Client[C, Q, M, R]) init(cfg C) error { + body, err := c.opts.Codec.Encode(cfg) + if err != nil { + return fmt.Errorf("failed to encode config: %w", err) + } + var ( + messagec = make(chan Message, 10) + errc = make(chan error, 1) + ) + + go func() { + err := c.rawClient.Execute( + func(m *Message) { + m.Body = body + m.Header.Status = MessageStatusInitServer + }, + messagec, + ) + if err != nil { + errc <- fmt.Errorf("failed to execute init: %w", err) + } + }() + + select { + case err := <-errc: + return err + case m := <-messagec: + if m.Header.Status != MessageStatusOK { + return fmt.Errorf("failed to init: %s (error code %d)", m.Body, m.Header.Status) + } + } + + return nil +} + // Execute sends the request to the server and returns the result. // You should check Err() both before and after reading from the messages and receipt channels. -func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] { +func (c *Client[C, Q, M, R]) Execute(r Q) Result[M, R] { result := Result[M, R]{ messages: make(chan M, 10), receipt: make(chan R, 1), @@ -112,20 +155,21 @@ func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] { messagesRaw := make(chan Message, 10) go func() { - err := c.rawClient.Execute(body, messagesRaw) + err := c.rawClient.Execute(func(m *Message) { m.Body = body }, messagesRaw) if err != nil { result.errc <- fmt.Errorf("failed to execute: %w", err) } }() for message := range messagesRaw { - if message.Header.Status > MessageStatusContinue && message.Header.Status <= MessageStatusSystemReservedMax { + if message.Header.Status >= MessageStatusErrDecodeFailed && message.Header.Status <= MessageStatusSystemReservedMax { // All of these are currently error situations produced by the server. result.errc <- fmt.Errorf("%s (error code %d)", message.Body, message.Header.Status) return } - if message.Header.Status == MessageStatusContinue { + switch message.Header.Status { + case MessageStatusContinue: var resp M err = c.opts.Codec.Decode(message.Body, &resp) if err != nil { @@ -133,7 +177,9 @@ func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] { return } result.messages <- resp - } else { + case MessageStatusInitServer: + panic("unexpected status") + default: // Receipt. var rec R err = c.opts.Codec.Decode(message.Body, &rec) @@ -152,7 +198,7 @@ func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] { } // Close closes the client. -func (c *Client[Q, M, R]) Close() error { +func (c *Client[C, Q, M, R]) Close() error { return c.rawClient.Close() } @@ -248,10 +294,10 @@ func (c *ClientRaw) Close() error { // Execute sends body to the server and sends any messages to the messages channel. // It's safe to call Execute from multiple goroutines. // The messages channel wil be closed when the call is done. -func (c *ClientRaw) Execute(body []byte, messages chan<- Message) error { +func (c *ClientRaw) Execute(withMessage func(m *Message), messages chan<- Message) error { defer close(messages) - call, err := c.newCall(body, messages) + call, err := c.newCall(withMessage, messages) if err != nil { return err } @@ -276,20 +322,21 @@ func (c *ClientRaw) addErrContext(op string, err error) error { return fmt.Errorf("%s: %s %s", op, err, c.conn.stdErr.String()) } -func (c *ClientRaw) newCall(body []byte, messages chan<- Message) (*call, error) { +func (c *ClientRaw) newCall(withMessage func(m *Message), messages chan<- Message) (*call, error) { c.mu.Lock() c.seq++ id := c.seq + m := Message{ + Header: Header{ + Version: c.version, + ID: id, + }, + } + withMessage(&m) call := &call{ - Done: make(chan *call, 1), - Request: Message{ - Header: Header{ - Version: c.version, - ID: id, - }, - Body: body, - }, + Done: make(chan *call, 1), + Request: m, Messages: messages, } @@ -384,8 +431,13 @@ func (c *ClientRaw) send(call *call) error { } // ClientOptions are options for the client. -type ClientOptions[Q, M, R any] struct { +type ClientOptions[C, Q, M, R any] struct { ClientRawOptions + + // The configuration to pass to the server. + Config C + + // The codec to use. Codec codecs.Codec } diff --git a/client_test.go b/client_test.go index dffcd99..0b9de81 100644 --- a/client_test.go +++ b/client_test.go @@ -12,7 +12,7 @@ import ( "golang.org/x/sync/errgroup" ) -func TestExecRaw(t *testing.T) { +func TestRaw(t *testing.T) { c := qt.New(t) newClient := func(c *qt.C) *execrpc.ClientRaw { @@ -35,7 +35,7 @@ func TestExecRaw(t *testing.T) { messages := make(chan execrpc.Message) var g errgroup.Group g.Go(func() error { - return client.Execute([]byte("hello"), messages) + return client.Execute(func(m *execrpc.Message) { m.Body = []byte("hello") }, messages) }) var i int for msg := range messages { @@ -49,7 +49,7 @@ func TestExecRaw(t *testing.T) { }) } -func TestExecStartFailed(t *testing.T) { +func TestStartFailed(t *testing.T) { c := qt.New(t) client, err := execrpc.StartClientRaw( execrpc.ClientRawOptions{ @@ -64,18 +64,19 @@ func TestExecStartFailed(t *testing.T) { c.Assert(client.Close(), qt.IsNil) } -func newTestClient(t testing.TB, codec codecs.Codec, env ...string) *execrpc.Client[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt] { +func newTestClientForServer(t testing.TB, server string, codec codecs.Codec, cfg model.ExampleConfig, env ...string) *execrpc.Client[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt] { client, err := execrpc.StartClient( - execrpc.ClientOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + execrpc.ClientOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ ClientRawOptions: execrpc.ClientRawOptions{ Version: 1, Cmd: "go", - Dir: "./examples/servers/typed", + Dir: "./examples/servers/" + server, Args: []string{"run", "."}, Env: env, Timeout: 30 * time.Second, }, - Codec: codec, + Config: cfg, + Codec: codec, }, ) if err != nil { @@ -84,17 +85,23 @@ func newTestClient(t testing.TB, codec codecs.Codec, env ...string) *execrpc.Cli t.Cleanup(func() { if err := client.Close(); err != nil { - t.Fatal(err) + if err != execrpc.ErrShutdown { + t.Fatal(err) + } } }) return client } -func TestExecTyped(t *testing.T) { +func newTestClient(t testing.TB, codec codecs.Codec, cfg model.ExampleConfig, env ...string) *execrpc.Client[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt] { + return newTestClientForServer(t, "typed", codec, cfg, env...) +} + +func TestTyped(t *testing.T) { c := qt.New(t) - runBasicTestForClient := func(c *qt.C, client *execrpc.Client[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) execrpc.Result[model.ExampleMessage, model.ExampleReceipt] { + runBasicTestForClient := func(c *qt.C, client *execrpc.Client[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) execrpc.Result[model.ExampleMessage, model.ExampleReceipt] { result := client.Execute(model.ExampleRequest{Text: "world"}) c.Assert(result.Err(), qt.IsNil) return result @@ -112,7 +119,7 @@ func TestExecTyped(t *testing.T) { } c.Run("One message", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}) + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}) result := runBasicTestForClient(c, client) assertMessages(c, result, 1) receipt := <-result.Receipt() @@ -122,7 +129,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("100 messages", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NUM_MESSAGES=100") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{NumMessages: 100}) result := runBasicTestForClient(c, client) assertMessages(c, result, 100) receipt := <-result.Receipt() @@ -132,7 +139,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("1234 messages", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NUM_MESSAGES=1234") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{NumMessages: 1234}, "EXECRPC_NUM_MESSAGES=1234") result := runBasicTestForClient(c, client) assertMessages(c, result, 1234) receipt := <-result.Receipt() @@ -143,7 +150,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("Delay delivery", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_DELAY_DELIVERY=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}, "EXECRPC_DELAY_DELIVERY=true") result := runBasicTestForClient(c, client) assertMessages(c, result, 1) receipt := <-result.Receipt() @@ -152,7 +159,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("Delay delivery, drop messages", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_DELAY_DELIVERY=true", "EXECRPC_DROP_MESSAGES=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{DropMessages: true}, "EXECRPC_DELAY_DELIVERY=true") result := runBasicTestForClient(c, client) assertMessages(c, result, 0) receipt := <-result.Receipt() @@ -164,7 +171,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("No Close", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_CLOSE=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{NoClose: true}) result := runBasicTestForClient(c, client) assertMessages(c, result, 1) receipt := <-result.Receipt() @@ -173,7 +180,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("Receipt", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}) + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}) result := runBasicTestForClient(c, client) assertMessages(c, result, 1) receipt := <-result.Receipt() @@ -186,7 +193,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("No hasher", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_HASHER=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}, "EXECRPC_NO_HASHER=true") result := runBasicTestForClient(c, client) assertMessages(c, result, 1) receipt := <-result.Receipt() @@ -194,7 +201,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("No reading Receipt", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_READING_RECEIPT=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{NoReadingReceipt: true}) result := runBasicTestForClient(c, client) assertMessages(c, result, 1) receipt := <-result.Receipt() @@ -203,7 +210,7 @@ func TestExecTyped(t *testing.T) { }) c.Run("No reading Receipt, no Close", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_READING_RECEIPT=true", "EXECRPC_NO_CLOSE=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{NoClose: true, NoReadingReceipt: true}) result := runBasicTestForClient(c, client) assertMessages(c, result, 1) receipt := <-result.Receipt() @@ -215,16 +222,16 @@ func TestExecTyped(t *testing.T) { var logMessages []execrpc.Message client, err := execrpc.StartClient( - execrpc.ClientOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + execrpc.ClientOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ ClientRawOptions: execrpc.ClientRawOptions{ Version: 1, Cmd: "go", Dir: "./examples/servers/typed", Args: []string{"run", "."}, - Env: []string{"EXECRPC_SEND_TWO_LOG_MESSAGES=true"}, Timeout: 30 * time.Second, }, - Codec: codecs.JSONCodec{}, + Config: model.ExampleConfig{SendLogMessage: true}, + Codec: codecs.JSONCodec{}, }, ) if err != nil { @@ -249,13 +256,13 @@ func TestExecTyped(t *testing.T) { }) c.Run("TOML", func(c *qt.C) { - client := newTestClient(c, codecs.TOMLCodec{}) + client := newTestClient(c, codecs.TOMLCodec{}, model.ExampleConfig{}) result := runBasicTestForClient(c, client) assertMessages(c, result, 1) }) c.Run("Error in receipt", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_CALL_SHOULD_FAIL=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{CallShouldFail: true}) result := client.Execute(model.ExampleRequest{Text: "hello"}) c.Assert(result.Err(), qt.IsNil) receipt := <-result.Receipt() @@ -266,28 +273,60 @@ func TestExecTyped(t *testing.T) { // The "stdout print tests" are just to make sure that the server behaves and does not hang. c.Run("Print to stdout outside server before", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") runBasicTestForClient(c, client) }) c.Run("Print to stdout inside server", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_INSIDE_SERVER=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}, "EXECRPC_PRINT_INSIDE_SERVER=true") runBasicTestForClient(c, client) }) c.Run("Print to stdout outside server before", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") runBasicTestForClient(c, client) }) c.Run("Print to stdout inside after", func(c *qt.C) { - client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_OUTSIDE_SERVER_AFTER=true") + client := newTestClient(c, codecs.JSONCodec{}, model.ExampleConfig{}, "EXECRPC_PRINT_OUTSIDE_SERVER_AFTER=true") runBasicTestForClient(c, client) }) } -func TestExecTypedConcurrent(t *testing.T) { - client := newTestClient(t, codecs.JSONCodec{}) +// Make sure that the README example compiles and runs. +func TestReadmeExample(t *testing.T) { + c := qt.New(t) + + client := newTestClientForServer(c, "readmeexample", codecs.JSONCodec{}, model.ExampleConfig{}) + var wg errgroup.Group + wg.Go(func() error { + for msg := range client.MessagesRaw() { + s := string(msg.Body) + msg := fmt.Sprintf("got message: %s id: %d status: %d header size: %d actual size: %d", s, msg.Header.ID, msg.Header.Status, msg.Header.Size, len(s)) + if s != "log message" { + return fmt.Errorf("unexpected message: %s", msg) + } + } + return nil + }) + result := client.Execute(model.ExampleRequest{Text: "world"}) + c.Assert(result.Err(), qt.IsNil) + var hellos []string + for m := range result.Messages() { + hellos = append(hellos, m.Hello) + } + c.Assert(hellos, qt.DeepEquals, []string{"Hello 1!", "Hello 2!", "Hello 3!"}) + receipt := <-result.Receipt() + c.Assert(receipt.LastModified, qt.Not(qt.Equals), int64(0)) + c.Assert(receipt.ETag, qt.Equals, "44af821b6bd180d0") + c.Assert(receipt.Text, qt.Equals, "echoed: world") + c.Assert(result.Err(), qt.IsNil) + c.Assert(client.Close(), qt.IsNil) + c.Assert(wg.Wait(), qt.IsNil) +} + +func TestTypedConcurrent(t *testing.T) { + client := newTestClient(t, codecs.JSONCodec{}, model.ExampleConfig{}) var g errgroup.Group for i := 0; i < 100; i++ { @@ -324,9 +363,9 @@ func TestExecTypedConcurrent(t *testing.T) { func BenchmarkClient(b *testing.B) { const word = "World" - runBenchmark := func(name string, codec codecs.Codec, env ...string) { + runBenchmark := func(name string, codec codecs.Codec, cfg model.ExampleConfig, env ...string) { b.Run(name, func(b *testing.B) { - client := newTestClient(b, codec, env...) + client := newTestClient(b, codec, cfg, env...) b.RunParallel(func(pb *testing.PB) { for pb.Next() { result := client.Execute(model.ExampleRequest{Text: word}) @@ -344,11 +383,12 @@ func BenchmarkClient(b *testing.B) { }) } - runBenchmarksForCodec := func(codec codecs.Codec) { - runBenchmark("1 message "+codec.Name(), codec) - runBenchmark("100 messages "+codec.Name(), codec, "EXECRPC_NUM_MESSAGES=100") + runBenchmarksForCodec := func(codec codecs.Codec, cfg model.ExampleConfig) { + runBenchmark("1 message "+codec.Name(), codec, cfg) + cfg.NumMessages = 100 + runBenchmark("100 messages "+codec.Name(), codec, cfg) } - runBenchmarksForCodec(codecs.JSONCodec{}) - runBenchmark("100 messages JSON, no hasher ", codecs.JSONCodec{}, "EXECRPC_NUM_MESSAGES=100", "EXECRPC_NO_HASHER=true") - runBenchmarksForCodec(codecs.TOMLCodec{}) + runBenchmarksForCodec(codecs.JSONCodec{}, model.ExampleConfig{}) + runBenchmark("100 messages JSON, no hasher ", codecs.JSONCodec{}, model.ExampleConfig{NumMessages: 100}, "EXECRPC_NO_HASHER=true") + runBenchmarksForCodec(codecs.TOMLCodec{}, model.ExampleConfig{}) } diff --git a/examples/model/model.go b/examples/model/model.go index 7a7085e..81e64c8 100644 --- a/examples/model/model.go +++ b/examples/model/model.go @@ -2,6 +2,24 @@ package model import "github.com/bep/execrpc" +type ExampleConfig struct { + // Used in tests. + CallShouldFail bool `json:"callShouldFail"` + SendLogMessage bool `json:"sendLogMessage"` + NoClose bool `json:"noClose"` + NoReadingReceipt bool `json:"noReadingReceipt"` + DropMessages bool `json:"dropMessages"` + NumMessages int `json:"numMessages"` +} + +func (cfg *ExampleConfig) Init() error { + if cfg.NumMessages < 1 { + cfg.NumMessages = 1 + } + + return nil +} + // ExampleRequest is just a simple example request. type ExampleRequest struct { Text string `json:"text"` diff --git a/examples/servers/raw/go.mod b/examples/servers/raw/go.mod index 0c37f8a..2592804 100644 --- a/examples/servers/raw/go.mod +++ b/examples/servers/raw/go.mod @@ -1,6 +1,6 @@ module github.com/bep/execrpc/examples/servers/raw -go 1.19 +go 1.21 require github.com/bep/execrpc v0.3.0 diff --git a/examples/servers/readmeexample/go.mod b/examples/servers/readmeexample/go.mod new file mode 100644 index 0000000..afc9f9c --- /dev/null +++ b/examples/servers/readmeexample/go.mod @@ -0,0 +1,13 @@ +module github.com/bep/execrpc/examples/servers/readmeexample + +go 1.21 + +require github.com/bep/execrpc v0.3.0 + +require ( + github.com/bep/helpers v0.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.0.2 // indirect + golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect +) + +replace github.com/bep/execrpc => ../../.. diff --git a/examples/servers/readmeexample/go.sum b/examples/servers/readmeexample/go.sum new file mode 100644 index 0000000..fc21b62 --- /dev/null +++ b/examples/servers/readmeexample/go.sum @@ -0,0 +1,21 @@ +github.com/bep/helpers v0.1.0 h1:HFLG+W6axHackmKMk0houEnz9G2aiBrDMZyOvL9J0WM= +github.com/bep/helpers v0.1.0/go.mod h1:/QpHdmcPagDw7+RjkLFCvnlUc8lQ5kg4KDrEkb2Yyco= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= +github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/servers/readmeexample/main.go b/examples/servers/readmeexample/main.go new file mode 100644 index 0000000..a315ebf --- /dev/null +++ b/examples/servers/readmeexample/main.go @@ -0,0 +1,96 @@ +package main + +import ( + "hash" + "hash/fnv" + "log" + + "github.com/bep/execrpc" + "github.com/bep/execrpc/examples/model" +) + +func main() { + log.SetFlags(0) + log.SetPrefix("readme-example: ") + + var clientConfig model.ExampleConfig + + server, err := execrpc.NewServer( + execrpc.ServerOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + // Optional function to provide a hasher for the ETag. + GetHasher: func() hash.Hash { + return fnv.New64a() + }, + + // Allows you to delay message delivery, and drop + // them after reading the receipt (e.g. the ETag matches the ETag seen by client). + DelayDelivery: false, + + // Optional function to initialize the server + // with the client configuration. + // This will be called once on server start. + Init: func(cfg model.ExampleConfig) error { + clientConfig = cfg + return clientConfig.Init() + }, + + // Handle the incoming call. + Handle: func(c *execrpc.Call[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) { + // Raw messages are passed directly to the client, + // typically used for log messages. + c.SendRaw( + execrpc.Message{ + Header: execrpc.Header{ + Version: 32, + Status: 150, + }, + Body: []byte("log message"), + }, + ) + + // Enqueue one or more messages. + c.Enqueue( + model.ExampleMessage{ + Hello: "Hello 1!", + }, + model.ExampleMessage{ + Hello: "Hello 2!", + }, + ) + + c.Enqueue( + model.ExampleMessage{ + Hello: "Hello 3!", + }, + ) + + // Wait for the framework generated receipt. + receipt := <-c.Receipt() + + // ETag provided by the framework. + // A hash of all message bodies. + // fmt.Println("Receipt:", receipt.ETag) + + // Modify if needed. + receipt.Size = uint32(123) + receipt.Text = "echoed: " + c.Request.Text + + // Close the message stream and send the receipt. + // Pass true to drop any queued messages, + // this is only relevant if DelayDelivery is enabled. + c.Close(false, receipt) + }, + }, + ) + if err != nil { + handleErr(err) + } + + if err := server.Start(); err != nil { + handleErr(err) + } +} + +func handleErr(err error) { + log.Fatalf("error: failed to start typed echo server: %s", err) +} diff --git a/examples/servers/typed/go.mod b/examples/servers/typed/go.mod index 7076eff..dd3de8a 100644 --- a/examples/servers/typed/go.mod +++ b/examples/servers/typed/go.mod @@ -1,6 +1,6 @@ module github.com/bep/execrpc/examples/servers/typed -go 1.19 +go 1.21 require github.com/bep/execrpc v0.3.0 diff --git a/examples/servers/typed/main.go b/examples/servers/typed/main.go index 813e3b8..b243861 100644 --- a/examples/servers/typed/main.go +++ b/examples/servers/typed/main.go @@ -18,27 +18,13 @@ func main() { // Some test flags from the client. var ( + delayDelivery = os.Getenv("EXECRPC_DELAY_DELIVERY") != "" + noHasher = os.Getenv("EXECRPC_NO_HASHER") != "" printOutsideServerBefore = os.Getenv("EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE") != "" printOutsideServerAfter = os.Getenv("EXECRPC_PRINT_OUTSIDE_SERVER_AFTER") != "" printInsideServer = os.Getenv("EXECRPC_PRINT_INSIDE_SERVER") != "" - callShouldFail = os.Getenv("EXECRPC_CALL_SHOULD_FAIL") != "" - sendLogMessage = os.Getenv("EXECRPC_SEND_TWO_LOG_MESSAGES") != "" - noClose = os.Getenv("EXECRPC_NO_CLOSE") != "" - noReadingReceipt = os.Getenv("EXECRPC_NO_READING_RECEIPT") != "" - numMessagesStr = os.Getenv("EXECRPC_NUM_MESSAGES") - numMessages = 1 - delayDelivery = os.Getenv("EXECRPC_DELAY_DELIVERY") != "" - dropMessages = os.Getenv("EXECRPC_DROP_MESSAGES") != "" - noHasher = os.Getenv("EXECRPC_NO_HASHER") != "" ) - if numMessagesStr != "" { - numMessages, _ = strconv.Atoi(numMessagesStr) - if numMessages < 1 { - numMessages = 1 - } - } - if printOutsideServerBefore { fmt.Println("Printing outside server before") } @@ -51,16 +37,22 @@ func main() { } } + var clientConfig model.ExampleConfig + server, err := execrpc.NewServer( - execrpc.ServerOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + execrpc.ServerOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ GetHasher: getHasher, DelayDelivery: delayDelivery, - Handle: func(c *execrpc.Call[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) { + Init: func(cfg model.ExampleConfig) error { + clientConfig = cfg + return clientConfig.Init() + }, + Handle: func(call *execrpc.Call[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) { if printInsideServer { fmt.Println("Printing inside server") } - if callShouldFail { - c.Close( + if clientConfig.CallShouldFail { + call.Close( false, model.ExampleReceipt{ Error: &model.Error{Msg: "failed to echo"}, @@ -69,8 +61,8 @@ func main() { return } - if sendLogMessage { - c.SendRaw( + if clientConfig.SendLogMessage { + call.SendRaw( execrpc.Message{ Header: execrpc.Header{ Version: 32, @@ -88,23 +80,23 @@ func main() { ) } - for i := 0; i < numMessages; i++ { - c.Enqueue( + for i := 0; i < clientConfig.NumMessages; i++ { + call.Enqueue( model.ExampleMessage{ - Hello: strconv.Itoa(i) + ": Hello " + c.Request.Text + "!", + Hello: strconv.Itoa(i) + ": Hello " + call.Request.Text + "!", }, ) } - if !noClose { + if !clientConfig.NoClose { var receipt model.ExampleReceipt - if !noReadingReceipt { - receipt = <-c.Receipt() - receipt.Text = "echoed: " + c.Request.Text + if !clientConfig.NoReadingReceipt { + receipt = <-call.Receipt() + receipt.Text = "echoed: " + call.Request.Text receipt.Size = uint32(123) } - c.Close(dropMessages, receipt) + call.Close(clientConfig.DropMessages, receipt) } }, }, diff --git a/message.go b/message.go index 50e88f1..c561aaa 100644 --- a/message.go +++ b/message.go @@ -31,6 +31,7 @@ func (m *Message) Write(w io.Writer) error { // Header is the header of a message. // ID and Size are set by the system. +// Status may be set by the system. type Header struct { ID uint32 Version uint16 diff --git a/server.go b/server.go index 2700115..8dc1626 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "hash" "io" "os" + "sync" "time" "github.com/bep/execrpc/codecs" @@ -19,10 +20,15 @@ const ( // MessageStatusContinue is the status code for a message that should continue the conversation. MessageStatusContinue + // MessageStatusInitServer is the status code for a message used to initialize/configure the server. + MessageStatusInitServer + // MessageStatusErrDecodeFailed is the status code for a message that failed to decode. MessageStatusErrDecodeFailed // MessageStatusErrEncodeFailed is the status code for a message that failed to encode. MessageStatusErrEncodeFailed + // MessageStatusErrInitServerFailed is the status code for a message that failed to initialize the server. + MessageStatusErrInitServerFailed // MessageStatusSystemReservedMax is the maximum value for a system reserved status code. MessageStatusSystemReservedMax = 99 @@ -36,14 +42,14 @@ func NewServerRaw(opts ServerRawOptions) (*ServerRaw, error) { s := &ServerRaw{ call: opts.Call, } - s.dispatcher = messageDispatcher{ + s.dispatcher = &messageDispatcher{ s: s, } return s, nil } // NewServer creates a new Server. using the given options. -func NewServer[Q, M, R comparable](opts ServerOptions[Q, M, R]) (*Server[Q, M, R], error) { +func NewServer[C, Q, M, R any](opts ServerOptions[C, Q, M, R]) (*Server[C, Q, M, R], error) { if opts.Handle == nil { return nil, fmt.Errorf("opts: Handle function is required") } @@ -63,14 +69,39 @@ func NewServer[Q, M, R comparable](opts ServerOptions[Q, M, R]) (*Server[Q, M, R ) callRaw := func(message Message, d Dispatcher) error { + if message.Header.Status == MessageStatusInitServer { + if opts.Init == nil { + m := createErrorMessage(fmt.Errorf("opts: Init function is required"), message.Header, MessageStatusErrInitServerFailed) + d.SendMessage(m) + return nil + } + + var cfg C + err := opts.Codec.Decode(message.Body, &cfg) + if err != nil { + m := createErrorMessage(err, message.Header, MessageStatusErrDecodeFailed) + d.SendMessage(m) + return nil + } + + if err := opts.Init(cfg); err != nil { + m := createErrorMessage(err, message.Header, MessageStatusErrInitServerFailed) + d.SendMessage(m) + return nil + } + + // OK. + var receipt Message + receipt.Header = message.Header + receipt.Header.Status = MessageStatusOK + d.SendMessage(receipt) + return nil + } + var q Q err := opts.Codec.Decode(message.Body, &q) if err != nil { - m := Message{ - Header: message.Header, - Body: []byte(fmt.Sprintf("failed to decode request: %s. Check that client and server uses the same codec.", err)), - } - m.Header.Status = MessageStatusErrDecodeFailed + m := createErrorMessage(err, message.Header, MessageStatusErrDecodeFailed) d.SendMessage(m) return nil } @@ -135,6 +166,9 @@ func NewServer[Q, M, R comparable](opts ServerOptions[Q, M, R]) (*Server[Q, M, R b, err := opts.Codec.Encode(m) h := message.Header h.Status = MessageStatusContinue + if h.ID == 0 { + panic("message ID must not be 0 for request/response messages") + } m := createMessage(b, err, h, MessageStatusErrEncodeFailed) if opts.DelayDelivery { messageBuff = append(messageBuff, m) @@ -168,7 +202,7 @@ func NewServer[Q, M, R comparable](opts ServerOptions[Q, M, R]) (*Server[Q, M, R return nil, err } - s := &Server[Q, M, R]{ + s := &Server[C, Q, M, R]{ messagesRaw: messagesRaw, ServerRaw: rawServer, } @@ -202,11 +236,7 @@ func setReceiptValuesIfNotSet(size uint32, checksum string, r any) { func createMessage(b []byte, err error, h Header, failureStatus uint16) Message { var m Message if err != nil { - m = Message{ - Header: h, - Body: []byte(fmt.Sprintf("failed create message: %s. Check that client and server uses the same codec.", err)), - } - m.Header.Status = failureStatus + return createErrorMessage(err, h, failureStatus) } else { m = Message{ Header: h, @@ -216,8 +246,26 @@ func createMessage(b []byte, err error, h Header, failureStatus uint16) Message return m } +func createErrorMessage(err error, h Header, failureStatus uint16) Message { + var additionalMsg string + if failureStatus == MessageStatusErrDecodeFailed || failureStatus == MessageStatusErrEncodeFailed { + additionalMsg = " Check that client and server uses the same codec." + } + m := Message{ + Header: h, + Body: []byte(fmt.Sprintf("failed create message (error code %d): %s.%s", failureStatus, err, additionalMsg)), + } + m.Header.Status = failureStatus + return m +} + // ServerOptions is the options for a server. -type ServerOptions[Q, M, R any] struct { +type ServerOptions[C, Q, M, R any] struct { + // Init is the function that will be called when the server is started. + // It can be used to initialize the server with the given configuration. + // If an error is returned, the server will stop. + Init func(C) error + // Handle is the function that will be called when a request is received. Handle func(*Call[Q, M, R]) @@ -237,12 +285,12 @@ type ServerOptions[Q, M, R any] struct { } // Server is a stringly typed server for requests of type Q and responses of tye R. -type Server[Q, M, R any] struct { +type Server[C, Q, M, R any] struct { messagesRaw chan Message *ServerRaw } -func (s *Server[Q, M, R]) Start() error { +func (s *Server[C, Q, M, R]) Start() error { err := s.ServerRaw.Start() // Close the standalone message channel. @@ -259,7 +307,7 @@ func (s *Server[Q, M, R]) Start() error { // See Server for a generic, typed version. type ServerRaw struct { call func(Message, Dispatcher) error - dispatcher messageDispatcher + dispatcher *messageDispatcher started bool onStop func() @@ -375,7 +423,8 @@ type ServerRawOptions struct { } type messageDispatcher struct { - s *ServerRaw + mu sync.Mutex + s *ServerRaw } // Call is the request/response exchange between the client and server. @@ -436,7 +485,9 @@ type Dispatcher interface { SendMessage(...Message) } -func (s messageDispatcher) SendMessage(ms ...Message) { +func (s *messageDispatcher) SendMessage(ms ...Message) { + s.mu.Lock() + defer s.mu.Unlock() for _, m := range ms { m.Header.Size = uint32(len(m.Body)) if err := m.Write(s.s.out); err != nil {