diff --git a/vici/client_conn.go b/vici/client_conn.go index bd6f962..665854d 100644 --- a/vici/client_conn.go +++ b/vici/client_conn.go @@ -28,7 +28,6 @@ import ( "fmt" "io" "net" - "time" ) const ( @@ -44,19 +43,56 @@ var ( ) type clientConn struct { - conn net.Conn + network string + addr string + dialer func(ctx context.Context, network, addr string) (net.Conn, error) + + closed bool + conn net.Conn +} + +func (cc *clientConn) dial(ctx context.Context) error { + if !cc.closed && cc.conn != nil { + return nil + } + + conn, err := cc.dialer(ctx, cc.network, cc.addr) + if err != nil { + return err + } + + cc.conn = conn + cc.closed = false + + return nil +} + +func (cc *clientConn) Close() error { + if cc.closed || cc.conn == nil { + return nil + } + + cc.closed = true + + return cc.conn.Close() } func (cc *clientConn) packetWrite(ctx context.Context, m *Message) error { - if err := cc.conn.SetWriteDeadline(time.Time{}); err != nil { + if err := cc.dial(ctx); err != nil { return err } + rc := cc.asyncPacketWrite(m) select { case <-ctx.Done(): - err := cc.conn.SetWriteDeadline(time.Now()) - return errors.Join(err, ctx.Err()) - case err := <-cc.awaitPacketWrite(m): + // Disconnect on context deadline to avoid data ordering + // problems with subsequent read/writes. Re-establish the + // connection later. + cc.Close() + <-rc + + return ctx.Err() + case err := <-rc: if err != nil { return err } @@ -65,15 +101,21 @@ func (cc *clientConn) packetWrite(ctx context.Context, m *Message) error { } func (cc *clientConn) packetRead(ctx context.Context) (*Message, error) { - if err := cc.conn.SetReadDeadline(time.Time{}); err != nil { + if err := cc.dial(ctx); err != nil { return nil, err } + rc := cc.asyncPacketRead() select { case <-ctx.Done(): - err := cc.conn.SetReadDeadline(time.Now()) - return nil, errors.Join(err, ctx.Err()) - case v := <-cc.awaitPacketRead(): + // Disconnect on context deadline to avoid data ordering + // problems with subsequent read/writes. Re-establish the + // connection later. + cc.Close() + <-rc + + return nil, ctx.Err() + case v := <-rc: switch v.(type) { case error: return nil, v.(error) @@ -86,7 +128,7 @@ func (cc *clientConn) packetRead(ctx context.Context) (*Message, error) { } } -func (cc *clientConn) awaitPacketWrite(m *Message) <-chan error { +func (cc *clientConn) asyncPacketWrite(m *Message) <-chan error { r := make(chan error, 1) buf := bytes.NewBuffer([]byte{}) @@ -117,7 +159,7 @@ func (cc *clientConn) awaitPacketWrite(m *Message) <-chan error { return r } -func (cc *clientConn) awaitPacketRead() <-chan any { +func (cc *clientConn) asyncPacketRead() <-chan any { r := make(chan any, 1) go func() { diff --git a/vici/client_conn_test.go b/vici/client_conn_test.go index b6452da..52d528f 100644 --- a/vici/client_conn_test.go +++ b/vici/client_conn_test.go @@ -127,7 +127,7 @@ func TestPacketRead(t *testing.T) { <-done } -func TestPacketWriteContext(t *testing.T) { +func TestPacketWriteContextCancel(t *testing.T) { client, srvr := net.Pipe() defer client.Close() defer srvr.Close() @@ -143,17 +143,27 @@ func TestPacketWriteContext(t *testing.T) { if !errors.Is(err, context.Canceled) { t.Fatalf("Expected cancel on packet write, but got %v", err) } +} + +func TestPacketWriteContextTimeout(t *testing.T) { + client, srvr := net.Pipe() + defer client.Close() + defer srvr.Close() - ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + cc := &clientConn{ + conn: client, + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - err = cc.packetWrite(ctx, goldNamedPacket) + err := cc.packetWrite(ctx, goldNamedPacket) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("Expected timeout on packet write, but got %v", err) } } -func TestPacketReadContext(t *testing.T) { +func TestPacketReadContextCancel(t *testing.T) { client, srvr := net.Pipe() defer client.Close() defer srvr.Close() @@ -169,11 +179,21 @@ func TestPacketReadContext(t *testing.T) { if !errors.Is(err, context.Canceled) { t.Fatalf("Expected cancel on packet read, but got %v", err) } +} + +func TestPacketReadContextTimeout(t *testing.T) { + client, srvr := net.Pipe() + defer client.Close() + defer srvr.Close() - ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + cc := &clientConn{ + conn: client, + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - _, err = cc.packetRead(ctx) + _, err := cc.packetRead(ctx) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("Expected timeout on packet read, but got %v", err) } diff --git a/vici/events.go b/vici/events.go index 1881cbb..78d4bf5 100644 --- a/vici/events.go +++ b/vici/events.go @@ -81,7 +81,10 @@ func (el *eventListener) Close() error { return err } - el.cc.conn.Close() + if el.cc != nil { + el.cc.Close() + el.cc = nil + } return nil } diff --git a/vici/session.go b/vici/session.go index 013b887..9965855 100644 --- a/vici/session.go +++ b/vici/session.go @@ -86,13 +86,15 @@ func (s *Session) newClientConn() (*clientConn, error) { return &clientConn{conn: s.conn}, nil } - conn, err := s.dialer(context.Background(), s.network, s.addr) - if err != nil { - return nil, err + cc := &clientConn{ + network: s.network, + addr: s.addr, + dialer: s.dialer, + conn: nil, } - cc := &clientConn{ - conn: conn, + if err := cc.dial(context.Background()); err != nil { + return nil, err } return cc, nil @@ -107,7 +109,7 @@ func (s *Session) Close() error { s.mu.Lock() defer s.mu.Unlock() if s.cc != nil { - if err := s.cc.conn.Close(); err != nil { + if err := s.cc.Close(); err != nil { return err }