diff --git a/client.go b/client.go index f8a6b9c..40c4209 100644 --- a/client.go +++ b/client.go @@ -81,6 +81,24 @@ func ClientWithFormat(format Format) ClientOption { } } +// ClientWithProtocolVersion will result in the given protocol version being used. +// +// The default is to use the protocol version returned by the plugin. +func ClientWithProtocolVersion(version int) ClientOption { + return func(clientOptions *clientOptions) { + clientOptions.protocolVersion = version + } +} + +// ClientWithSpec will result in the given Spec being used. +// +// The default is to use the Spec returned by the plugin. +func ClientWithSpec(spec Spec) ClientOption { + return func(clientOptions *clientOptions) { + clientOptions.spec = spec + } +} + // CallOption is an option for an individual client call. type CallOption func(*callOptions) @@ -91,9 +109,9 @@ type client struct { stderr io.Writer format Format - spec Spec - specErr error - lock sync.RWMutex + protocolVersion int + spec Spec + lock sync.RWMutex } func newClient( @@ -111,9 +129,11 @@ func newClient( clientOptions.format = FormatBinary } return &client{ - runner: runner, - stderr: clientOptions.stderr, - format: clientOptions.format, + runner: runner, + stderr: clientOptions.stderr, + format: clientOptions.format, + protocolVersion: clientOptions.protocolVersion, + spec: clientOptions.spec, } } @@ -125,22 +145,26 @@ func newClient( // be desirable for situations where clients are long-lived, for example in services. func (c *client) Spec(ctx context.Context) (Spec, error) { // Difficult to use sync.OnceValues since we want to use the context for cancellation - // when passing to the runner. It's awkward if the client constructor took a conteext. + // when passing to the runner. It's awkward if the client constructor took a context. c.lock.RLock() - if c.spec != nil || c.specErr != nil { + if c.spec != nil { c.lock.RUnlock() - return c.spec, c.specErr + return c.spec, nil } c.lock.RUnlock() c.lock.Lock() defer c.lock.Unlock() - if c.spec != nil || c.specErr != nil { - return c.spec, c.specErr + if c.spec != nil { + return c.spec, nil } - c.spec, c.specErr = c.getSpecUncached(ctx) - return c.spec, c.specErr + spec, err := c.getSpecUncached(ctx) + if err != nil { + return nil, err + } + c.spec = spec + return spec, nil } func (c *client) Call( @@ -217,7 +241,7 @@ func (c *client) getSpecUncached(ctx context.Context) (Spec, error) { } func (c *client) checkProtocolVersion(ctx context.Context) error { - version, err := c.getProtocolVersionUncached(ctx) + version, err := c.getProtocolVersion(ctx) if err != nil { return err } @@ -227,6 +251,18 @@ func (c *client) checkProtocolVersion(ctx context.Context) error { return nil } +func (c *client) getProtocolVersion(ctx context.Context) (int, error) { + if c.protocolVersion != 0 { + return c.protocolVersion, nil + } + protocolVersion, err := c.getProtocolVersionUncached(ctx) + if err != nil { + return 0, err + } + c.protocolVersion = protocolVersion + return protocolVersion, nil +} + func (c *client) getProtocolVersionUncached(ctx context.Context) (int, error) { stdout := bytes.NewBuffer(nil) if err := c.runner.Run( @@ -251,8 +287,10 @@ func (c *client) getProtocolVersionUncached(ctx context.Context) (int, error) { } type clientOptions struct { - stderr io.Writer - format Format + stderr io.Writer + format Format + protocolVersion int + spec Spec } func newClientOptions() *clientOptions {