From bf348cebcf1ad085e4d335ff0f980d2dea7f4953 Mon Sep 17 00:00:00 2001 From: Nick Rosbrook Date: Tue, 20 Aug 2024 17:22:29 -0400 Subject: [PATCH] [wip] session: implement Call and CallStreaming --- vici/session.go | 63 +++++++++++++++++++++++++++++++++++++++++++- vici/session_test.go | 17 ++++++------ 2 files changed, 71 insertions(+), 9 deletions(-) diff --git a/vici/session.go b/vici/session.go index b8c0dfe..3dadf1d 100644 --- a/vici/session.go +++ b/vici/session.go @@ -24,6 +24,7 @@ import ( "context" "errors" "fmt" + "iter" "net" "sync" ) @@ -194,13 +195,18 @@ func withTestConn(conn net.Conn) SessionOption { // the command fails, the response Message is returned along with the error returned by // Message.Err. func (s *Session) CommandRequest(cmd string, msg *Message) (*Message, error) { + return s.Call(context.Background(), cmd, msg) +} + +// Call +func (s *Session) Call(ctx context.Context, cmd string, in *Message) (*Message, error) { s.mu.Lock() defer s.mu.Unlock() if s.cc == nil { return nil, errors.New("session closed") } - resp, err := s.request(context.Background(), cmd, msg) + resp, err := s.request(ctx, cmd, in) if err != nil { return nil, err } @@ -208,6 +214,61 @@ func (s *Session) CommandRequest(cmd string, msg *Message) (*Message, error) { return resp, resp.Err() } +// CallStreaming +func (s *Session) CallStreaming(ctx context.Context, cmd string, event string, in *Message) (seq iter.Seq2[*Message, error], err error) { + s.mu.Lock() + defer func() { + if err != nil { + s.mu.Unlock() + } + }() + + if s.cc == nil { + return nil, errors.New("session closed") + } + + if err := s.eventRegister(ctx, event); err != nil { + return nil, err + } + defer func() { + if err != nil { + // nolint + s.eventUnregister(ctx, event) + } + }() + + if err := s.cc.packetWrite(ctx, newPacket(pktCmdRequest, cmd, in)); err != nil { + return nil, err + } + + return func(yield func(*Message, error) bool) { + defer s.mu.Unlock() + // nolint + defer s.eventUnregister(ctx, event) + + for { + p, err := s.cc.packetRead(ctx) + if err != nil { + yield(nil, err) + return + } + + switch p.ptype { + case pktEvent: + if !yield(p.msg, p.msg.Err()) { + return + } + case pktCmdResponse: + yield(p.msg, p.msg.Err()) + return // End of event stream + default: + yield(nil, fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype)) + return + } + } + }, nil +} + // StreamedCommandRequest sends a streamed command request to the server. StreamedCommandRequest // behaves like CommandRequest, but accepts an event argument, which specifies the event type // to stream while the command request is active. The complete stream of messages received from diff --git a/vici/session_test.go b/vici/session_test.go index ea00929..fdcf86e 100644 --- a/vici/session_test.go +++ b/vici/session_test.go @@ -21,6 +21,7 @@ package vici import ( + "context" "flag" "fmt" "net" @@ -84,7 +85,7 @@ func TestCommandRequestAfterClose(t *testing.T) { // only meant to test the package API, and the specific commands used are out // of convenience; any command that satisfies the need of the test could be used. // -// For example, TestStreamedCommandRequest uses the 'list-authorities' command, but +// For example, TestCallStreaming uses the 'list-authorities' command, but // any event-streaming vici command could be used. // // These tests are only run when the -integration flag is set to true. @@ -119,10 +120,10 @@ func TestCommandRequest(t *testing.T) { } } -// TestStreamedCommandRequest tests StreamedCommandRequest by calling the +// TestCallStreaming tests CallStreaming by calling the // 'list-authorities' command. Likely, there will be no authorities returned, // but make sure any Messages that are streamed have non-nil err. -func TestStreamedCommandRequest(t *testing.T) { +func TestCallStreaming(t *testing.T) { maybeSkipIntegrationTest(t) s, err := NewSession() @@ -131,14 +132,14 @@ func TestStreamedCommandRequest(t *testing.T) { } defer s.Close() - ms, err := s.StreamedCommandRequest("list-authorities", "list-authority", nil) + resp, err := s.CallStreaming(context.Background(), "list-authorities", "list-authority", nil) if err != nil { - t.Fatalf("Failed to list authorities: %v", err) + t.Fatalf("Failed to make streaming call: %v", err) } - for i, m := range ms { - if m.Err() != nil { - t.Fatalf("Got error in message #%d: %v", i+1, m.Err()) + for _, err := range resp { + if err != nil { + t.Fatalf("Got error from CallStreaming: %v", err) } } }