diff --git a/client.go b/client.go index 6f293278..b6024afc 100644 --- a/client.go +++ b/client.go @@ -101,7 +101,7 @@ type Client struct { // forcefully killed. processKilled bool - hostSocketDir string + unixSocketCfg UnixSocketConfig } // NegotiatedVersion returns the protocol version negotiated with the server. @@ -240,6 +240,28 @@ type ClientConfig struct { // SkipHostEnv allows plugins to run without inheriting the parent process' // environment variables. SkipHostEnv bool + + // UnixSocketConfig configures additional options for any Unix sockets + // that are created. Not normally required. Not supported on Windows. + UnixSocketConfig *UnixSocketConfig +} + +type UnixSocketConfig struct { + // If set, go-plugin will change the owner of any Unix sockets created to + // this group, and set them as group-writable. Can be a name or gid. The + // client process must be a member of this group or chown will fail. + Group string + + // The directory to create Unix sockets in. Internally managed by go-plugin + // and deleted when the plugin is killed. + directory string +} + +func unixSocketConfigFromEnv() UnixSocketConfig { + return UnixSocketConfig{ + Group: os.Getenv(EnvUnixSocketGroup), + directory: os.Getenv(EnvUnixSocketDir), + } } // ReattachConfig is used to configure a client to reattach to an @@ -445,7 +467,7 @@ func (c *Client) Kill() { c.l.Lock() runner := c.runner addr := c.address - hostSocketDir := c.hostSocketDir + hostSocketDir := c.unixSocketCfg.directory c.l.Unlock() // If there is no runner or ID, there is nothing to kill. @@ -629,15 +651,33 @@ func (c *Client) Start() (addr net.Addr, err error) { } } + if c.config.UnixSocketConfig != nil { + c.unixSocketCfg.Group = c.config.UnixSocketConfig.Group + } + + if c.unixSocketCfg.Group != "" { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketGroup, c.unixSocketCfg.Group)) + } + var runner runner.Runner switch { case c.config.RunnerFunc != nil: - c.hostSocketDir, err = os.MkdirTemp("", "") + c.unixSocketCfg.directory, err = os.MkdirTemp("", "plugin-dir") if err != nil { return nil, err } - c.logger.Trace("created temporary directory for unix sockets", "dir", c.hostSocketDir) - runner, err = c.config.RunnerFunc(c.logger, cmd, c.hostSocketDir) + // os.MkdirTemp creates folders with 0o700, so if we have a group + // configured we need to make it group-writable. + if c.unixSocketCfg.Group != "" { + err = setGroupWritable(c.unixSocketCfg.directory, c.unixSocketCfg.Group, 0o770) + if err != nil { + return nil, err + } + } + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketDir, c.unixSocketCfg.directory)) + c.logger.Trace("created temporary directory for unix sockets", "dir", c.unixSocketCfg.directory) + + runner, err = c.config.RunnerFunc(c.logger, cmd, c.unixSocketCfg.directory) if err != nil { return nil, err } diff --git a/client_unix_test.go b/client_unix_test.go new file mode 100644 index 00000000..6c1f16a3 --- /dev/null +++ b/client_unix_test.go @@ -0,0 +1,97 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build !windows +// +build !windows + +package plugin + +import ( + "fmt" + "os" + "os/exec" + "os/user" + "runtime" + "syscall" + "testing" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin/internal/cmdrunner" + "github.com/hashicorp/go-plugin/runner" +) + +func TestSetGroup(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("go-plugin doesn't support unix sockets on Windows") + } + + group, err := user.LookupGroupId(fmt.Sprintf("%d", os.Getgid())) + if err != nil { + t.Fatal(err) + } + for name, tc := range map[string]struct { + group string + }{ + "as integer": {fmt.Sprintf("%d", os.Getgid())}, + "as name": {group.Name}, + } { + t.Run(name, func(t *testing.T) { + process := helperProcess("mock") + c := NewClient(&ClientConfig{ + HandshakeConfig: testHandshake, + Plugins: testPluginMap, + UnixSocketConfig: &UnixSocketConfig{ + Group: tc.group, + }, + RunnerFunc: func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) { + // Run tests inside the RunnerFunc to ensure we don't race + // with the code that deletes tmpDir when the client fails + // to start properly. + + // Test that it creates a directory with the proper owners and permissions. + info, err := os.Lstat(tmpDir) + if err != nil { + t.Fatal(err) + } + if info.Mode()&os.ModePerm != 0o770 { + t.Fatal(info.Mode()) + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + t.Fatal() + } + if stat.Gid != uint32(os.Getgid()) { + t.Fatalf("Expected %d, but got %d", os.Getgid(), stat.Gid) + } + + // Check the correct environment variables were set to forward + // Unix socket config onto the plugin. + var foundUnixSocketDir, foundUnixSocketGroup bool + for _, env := range cmd.Env { + if env == fmt.Sprintf("%s=%s", EnvUnixSocketDir, tmpDir) { + foundUnixSocketDir = true + } + if env == fmt.Sprintf("%s=%s", EnvUnixSocketGroup, tc.group) { + foundUnixSocketGroup = true + } + } + if !foundUnixSocketDir { + t.Errorf("Did not find correct %s env in %v", EnvUnixSocketDir, cmd.Env) + } + if !foundUnixSocketGroup { + t.Errorf("Did not find correct %s env in %v", EnvUnixSocketGroup, cmd.Env) + } + + process.Env = append(process.Env, cmd.Env...) + return cmdrunner.NewCmdRunner(l, process) + }, + }) + defer c.Kill() + + _, err := c.Start() + if err != nil { + t.Fatalf("err should be nil, got %s", err) + } + }) + } +} diff --git a/grpc_broker.go b/grpc_broker.go index 91eee6e6..b86561a0 100644 --- a/grpc_broker.go +++ b/grpc_broker.go @@ -268,7 +268,7 @@ type GRPCBroker struct { doneCh chan struct{} o sync.Once - socketDir string + unixSocketCfg UnixSocketConfig addrTranslator runner.AddrTranslator sync.Mutex @@ -279,14 +279,14 @@ type gRPCBrokerPending struct { doneCh chan struct{} } -func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator runner.AddrTranslator) *GRPCBroker { +func newGRPCBroker(s streamer, tls *tls.Config, unixSocketCfg UnixSocketConfig, addrTranslator runner.AddrTranslator) *GRPCBroker { return &GRPCBroker{ streamer: s, streams: make(map[uint32]*gRPCBrokerPending), tls: tls, doneCh: make(chan struct{}), - socketDir: socketDir, + unixSocketCfg: unixSocketCfg, addrTranslator: addrTranslator, } } @@ -295,7 +295,7 @@ func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator // // This should not be called multiple times with the same ID at one time. func (b *GRPCBroker) Accept(id uint32) (net.Listener, error) { - listener, err := serverListener(b.socketDir) + listener, err := serverListener(b.unixSocketCfg) if err != nil { return nil, err } diff --git a/grpc_client.go b/grpc_client.go index f11dd0da..583e4250 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -63,7 +63,7 @@ func newGRPCClient(doneCtx context.Context, c *Client) (*GRPCClient, error) { // Start the broker. brokerGRPCClient := newGRPCBrokerClient(conn) - broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.hostSocketDir, c.runner) + broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.unixSocketCfg, c.runner) go broker.Run() go brokerGRPCClient.StartStream() diff --git a/grpc_server.go b/grpc_server.go index 303d650a..369f958a 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -84,7 +84,7 @@ func (s *GRPCServer) Init() error { // Register the broker service brokerServer := newGRPCBrokerServer() plugin.RegisterGRPCBrokerServer(s.server, brokerServer) - s.broker = newGRPCBroker(brokerServer, s.TLS, "", nil) + s.broker = newGRPCBroker(brokerServer, s.TLS, unixSocketConfigFromEnv(), nil) go s.broker.Run() // Register the controller diff --git a/server.go b/server.go index 4e9a22c0..4b0f2b76 100644 --- a/server.go +++ b/server.go @@ -273,7 +273,7 @@ func Serve(opts *ServeConfig) { } // Register a listener so we can accept a connection - listener, err := serverListener(os.Getenv(EnvUnixSocketDir)) + listener, err := serverListener(unixSocketConfigFromEnv()) if err != nil { logger.Error("plugin init error", "error", err) return @@ -496,12 +496,12 @@ func Serve(opts *ServeConfig) { } } -func serverListener(dir string) (net.Listener, error) { +func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) { if runtime.GOOS == "windows" { return serverListener_tcp() } - return serverListener_unix(dir) + return serverListener_unix(unixSocketCfg) } func serverListener_tcp() (net.Listener, error) { @@ -546,8 +546,8 @@ func serverListener_tcp() (net.Listener, error) { return nil, errors.New("Couldn't bind plugin TCP listener") } -func serverListener_unix(dir string) (net.Listener, error) { - tf, err := os.CreateTemp(dir, "plugin") +func serverListener_unix(unixSocketCfg UnixSocketConfig) (net.Listener, error) { + tf, err := os.CreateTemp(unixSocketCfg.directory, "plugin") if err != nil { return nil, err } @@ -569,25 +569,8 @@ func serverListener_unix(dir string) (net.Listener, error) { // By default, unix sockets are only writable by the owner. Set up a custom // group owner and group write permissions if configured. - if groupString := os.Getenv(EnvUnixSocketGroup); groupString != "" { - groupID, err := strconv.Atoi(groupString) - if err != nil { - group, err := user.LookupGroup(groupString) - if err != nil { - return nil, fmt.Errorf("failed to find group ID from %s=%s environment variable: %w", EnvUnixSocketGroup, groupString, err) - } - groupID, err = strconv.Atoi(group.Gid) - if err != nil { - return nil, fmt.Errorf("failed to parse %q group's Gid as an integer: %w", groupString, err) - } - } - - err = os.Chown(path, -1, groupID) - if err != nil { - return nil, err - } - - err = os.Chmod(path, 0o660) + if unixSocketCfg.Group != "" { + err = setGroupWritable(path, unixSocketCfg.Group, 0o660) if err != nil { return nil, err } @@ -601,6 +584,32 @@ func serverListener_unix(dir string) (net.Listener, error) { }, nil } +func setGroupWritable(path, groupString string, mode os.FileMode) error { + groupID, err := strconv.Atoi(groupString) + if err != nil { + group, err := user.LookupGroup(groupString) + if err != nil { + return fmt.Errorf("failed to find gid from %q: %w", groupString, err) + } + groupID, err = strconv.Atoi(group.Gid) + if err != nil { + return fmt.Errorf("failed to parse %q group's gid as an integer: %w", groupString, err) + } + } + + err = os.Chown(path, -1, groupID) + if err != nil { + return err + } + + err = os.Chmod(path, mode) + if err != nil { + return err + } + + return nil +} + // rmListener is an implementation of net.Listener that forwards most // calls to the listener but also removes a file as part of the close. We // use this to cleanup the unix domain socket on close. diff --git a/server_unix_test.go b/server_unix_test.go index 1de10a1f..14416729 100644 --- a/server_unix_test.go +++ b/server_unix_test.go @@ -25,15 +25,13 @@ func TestUnixSocketGroupPermissions(t *testing.T) { t.Fatal(err) } for name, tc := range map[string]struct { - gid string + group string }{ "as integer": {fmt.Sprintf("%d", os.Getgid())}, "as name": {group.Name}, } { t.Run(name, func(t *testing.T) { - t.Setenv(EnvUnixSocketGroup, tc.gid) - - ln, err := serverListener_unix("") + ln, err := serverListener_unix(UnixSocketConfig{Group: tc.group}) if err != nil { t.Fatal(err) } diff --git a/testing.go b/testing.go index 27e05f01..ae48b7a3 100644 --- a/testing.go +++ b/testing.go @@ -166,7 +166,7 @@ func TestPluginGRPCConn(t testing.T, ps map[string]Plugin) (*GRPCClient, *GRPCSe } brokerGRPCClient := newGRPCBrokerClient(conn) - broker := newGRPCBroker(brokerGRPCClient, nil, "", nil) + broker := newGRPCBroker(brokerGRPCClient, nil, UnixSocketConfig{}, nil) go broker.Run() go brokerGRPCClient.StartStream()