Skip to content

Commit

Permalink
[wip] session: implement Call and CallStreaming
Browse files Browse the repository at this point in the history
  • Loading branch information
enr0n committed Aug 21, 2024
1 parent eac4c79 commit bf348ce
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 9 deletions.
63 changes: 62 additions & 1 deletion vici/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"context"
"errors"
"fmt"
"iter"
"net"
"sync"
)
Expand Down Expand Up @@ -194,20 +195,80 @@ func withTestConn(conn net.Conn) SessionOption {
// the command fails, the response Message is returned along with the error returned by
// Message.Err.
func (s *Session) CommandRequest(cmd string, msg *Message) (*Message, error) {
return s.Call(context.Background(), cmd, msg)
}

// Call
func (s *Session) Call(ctx context.Context, cmd string, in *Message) (*Message, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.cc == nil {
return nil, errors.New("session closed")
}

resp, err := s.request(context.Background(), cmd, msg)
resp, err := s.request(ctx, cmd, in)
if err != nil {
return nil, err
}

return resp, resp.Err()
}

// CallStreaming
func (s *Session) CallStreaming(ctx context.Context, cmd string, event string, in *Message) (seq iter.Seq2[*Message, error], err error) {
s.mu.Lock()
defer func() {
if err != nil {
s.mu.Unlock()
}
}()

if s.cc == nil {
return nil, errors.New("session closed")
}

if err := s.eventRegister(ctx, event); err != nil {
return nil, err
}
defer func() {
if err != nil {
// nolint
s.eventUnregister(ctx, event)
}
}()

if err := s.cc.packetWrite(ctx, newPacket(pktCmdRequest, cmd, in)); err != nil {
return nil, err
}

return func(yield func(*Message, error) bool) {
defer s.mu.Unlock()
// nolint
defer s.eventUnregister(ctx, event)

for {
p, err := s.cc.packetRead(ctx)
if err != nil {
yield(nil, err)
return
}

switch p.ptype {
case pktEvent:
if !yield(p.msg, p.msg.Err()) {
return
}
case pktCmdResponse:
yield(p.msg, p.msg.Err())
return // End of event stream
default:
yield(nil, fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype))
return
}
}
}, nil
}

// StreamedCommandRequest sends a streamed command request to the server. StreamedCommandRequest
// behaves like CommandRequest, but accepts an event argument, which specifies the event type
// to stream while the command request is active. The complete stream of messages received from
Expand Down
17 changes: 9 additions & 8 deletions vici/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package vici

import (
"context"
"flag"
"fmt"
"net"
Expand Down Expand Up @@ -84,7 +85,7 @@ func TestCommandRequestAfterClose(t *testing.T) {
// only meant to test the package API, and the specific commands used are out
// of convenience; any command that satisfies the need of the test could be used.
//
// For example, TestStreamedCommandRequest uses the 'list-authorities' command, but
// For example, TestCallStreaming uses the 'list-authorities' command, but
// any event-streaming vici command could be used.
//
// These tests are only run when the -integration flag is set to true.
Expand Down Expand Up @@ -119,10 +120,10 @@ func TestCommandRequest(t *testing.T) {
}
}

// TestStreamedCommandRequest tests StreamedCommandRequest by calling the
// TestCallStreaming tests CallStreaming by calling the
// 'list-authorities' command. Likely, there will be no authorities returned,
// but make sure any Messages that are streamed have non-nil err.
func TestStreamedCommandRequest(t *testing.T) {
func TestCallStreaming(t *testing.T) {
maybeSkipIntegrationTest(t)

s, err := NewSession()
Expand All @@ -131,14 +132,14 @@ func TestStreamedCommandRequest(t *testing.T) {
}
defer s.Close()

ms, err := s.StreamedCommandRequest("list-authorities", "list-authority", nil)
resp, err := s.CallStreaming(context.Background(), "list-authorities", "list-authority", nil)
if err != nil {
t.Fatalf("Failed to list authorities: %v", err)
t.Fatalf("Failed to make streaming call: %v", err)
}

for i, m := range ms {
if m.Err() != nil {
t.Fatalf("Got error in message #%d: %v", i+1, m.Err())
for _, err := range resp {

Check failure on line 140 in vici/session_test.go

View workflow job for this annotation

GitHub Actions / test (1.23.x, ubuntu-latest)

cannot range over resp (variable of type iter.Seq2[*Message, error]): requires go1.23 or later (-lang was set to go1.20; check go.mod)

Check failure on line 140 in vici/session_test.go

View workflow job for this annotation

GitHub Actions / lint (1.23.x, ubuntu-latest)

cannot range over resp (variable of type iter.Seq2[*Message, error]): requires go1.23 or later (-lang was set to go1.20; check go.mod) (typecheck)
if err != nil {
t.Fatalf("Got error from CallStreaming: %v", err)
}
}
}
Expand Down

0 comments on commit bf348ce

Please sign in to comment.