diff --git a/.gitignore b/.gitignore index 7fe5237b6..c96cf2d0c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,11 +4,8 @@ *.swo /coverage.out* /tests/output/ -/nvidia-container-runtime -/nvidia-container-runtime.* -/nvidia-container-runtime-hook -/nvidia-container-toolkit -/nvidia-ctk +/nvidia-* /shared-* /release-* /bin +/toolkit-test diff --git a/internal/discover/compat_libs.go b/internal/discover/compat_libs.go index 027ca2ed2..7e7f9ff4f 100644 --- a/internal/discover/compat_libs.go +++ b/internal/discover/compat_libs.go @@ -9,16 +9,12 @@ import ( // NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook. // This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version. -func NewCUDACompatHookDiscoverer(logger logger.Interface, nvidiaCDIHookPath string, driver *root.Driver) Discover { +func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover { _, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver) var args []string if !strings.Contains(cudaVersionPattern, "*") { args = append(args, "--host-driver-version="+cudaVersionPattern) } - return CreateNvidiaCDIHook( - nvidiaCDIHookPath, - "enable-cuda-compat", - args..., - ) + return hookCreator.Create("enable-cuda-compat", args...) } diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index e80dd0be2..4665ce295 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -36,21 +36,21 @@ import ( // TODO: The logic for creating DRM devices should be consolidated between this // and the logic for generating CDI specs for a single device. This is only used // when applying OCI spec modifications to an incoming spec in "legacy" mode. -func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCDIHookPath string) (Discover, error) { +func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) { drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot) if err != nil { return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err) } - drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCDIHookPath) + drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, hookCreator) discover := Merge(drmDeviceNodes, drmByPathSymlinks) return discover, nil } // NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan. -func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) { - libraries := newGraphicsLibrariesDiscoverer(logger, driver, nvidiaCDIHookPath) +func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) { + libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator) configs := NewMounts( logger, @@ -95,13 +95,13 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc type graphicsDriverLibraries struct { Discover - logger logger.Interface - nvidiaCDIHookPath string + logger logger.Interface + hookCreator HookCreator } var _ Discover = (*graphicsDriverLibraries)(nil) -func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) Discover { +func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover { cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver) libraries := NewMounts( @@ -140,9 +140,9 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver ) return &graphicsDriverLibraries{ - Discover: Merge(libraries, xorgLibraries), - logger: logger, - nvidiaCDIHookPath: nvidiaCDIHookPath, + Discover: Merge(libraries, xorgLibraries), + logger: logger, + hookCreator: hookCreator, } } @@ -203,9 +203,9 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) { return nil, nil } - hooks := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links) + hook := d.hookCreator.Create("create-symlinks", links...) - return hooks.Hooks() + return hook.Hooks() } // isDriverLibrary checks whether the specified filename is a specific driver library. @@ -275,19 +275,19 @@ func buildXOrgSearchPaths(libRoot string) []string { type drmDevicesByPath struct { None - logger logger.Interface - nvidiaCDIHookPath string - devRoot string - devicesFrom Discover + logger logger.Interface + hookCreator HookCreator + devRoot string + devicesFrom Discover } // newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer -func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCDIHookPath string) Discover { +func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover { d := drmDevicesByPath{ - logger: logger, - nvidiaCDIHookPath: nvidiaCDIHookPath, - devRoot: devRoot, - devicesFrom: devices, + logger: logger, + hookCreator: hookCreator, + devRoot: devRoot, + devicesFrom: devices, } return &d @@ -315,13 +315,9 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) { args = append(args, "--link", l) } - hook := CreateNvidiaCDIHook( - d.nvidiaCDIHookPath, - "create-symlinks", - args..., - ) + hook := d.hookCreator.Create("create-symlinks", args...) - return []Hook{hook}, nil + return hook.Hooks() } // getSpecificLinkArgs returns the required specific links that need to be created diff --git a/internal/discover/graphics_test.go b/internal/discover/graphics_test.go index a515c9390..3aea93cb3 100644 --- a/internal/discover/graphics_test.go +++ b/internal/discover/graphics_test.go @@ -25,6 +25,7 @@ import ( func TestGraphicsLibrariesDiscoverer(t *testing.T) { logger, _ := testlog.NewNullLogger() + hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook") testCases := []struct { description string @@ -136,9 +137,9 @@ func TestGraphicsLibrariesDiscoverer(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { d := &graphicsDriverLibraries{ - Discover: tc.libraries, - logger: logger, - nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook", + Discover: tc.libraries, + logger: logger, + hookCreator: hookCreator, } devices, err := d.Devices() diff --git a/internal/discover/hooks.go b/internal/discover/hooks.go index 4259ccf86..0f239bfd6 100644 --- a/internal/discover/hooks.go +++ b/internal/discover/hooks.go @@ -25,54 +25,66 @@ import ( var _ Discover = (*Hook)(nil) // Devices returns an empty list of devices for a Hook discoverer. -func (h Hook) Devices() ([]Device, error) { +func (h *Hook) Devices() ([]Device, error) { return nil, nil } // Mounts returns an empty list of mounts for a Hook discoverer. -func (h Hook) Mounts() ([]Mount, error) { +func (h *Hook) Mounts() ([]Mount, error) { return nil, nil } // Hooks allows the Hook type to also implement the Discoverer interface. // It returns a single hook -func (h Hook) Hooks() ([]Hook, error) { - return []Hook{h}, nil +func (h *Hook) Hooks() ([]Hook, error) { + if h == nil { + return nil, nil + } + + return []Hook{*h}, nil } -// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target. -func CreateCreateSymlinkHook(nvidiaCDIHookPath string, links []string) Discover { - if len(links) == 0 { - return None{} - } +// Option is a function that configures the nvcdilib +type Option func(*CDIHook) - var args []string - for _, link := range links { - args = append(args, "--link", link) - } - return CreateNvidiaCDIHook( - nvidiaCDIHookPath, - "create-symlinks", - args..., - ) +type CDIHook struct { + nvidiaCDIHookPath string } -// CreateNvidiaCDIHook creates a hook which invokes the NVIDIA Container CLI hook subcommand. -func CreateNvidiaCDIHook(nvidiaCDIHookPath string, hookName string, additionalArgs ...string) Hook { - return cdiHook(nvidiaCDIHookPath).Create(hookName, additionalArgs...) +type HookCreator interface { + Create(string, ...string) *Hook } -type cdiHook string +func NewHookCreator(nvidiaCDIHookPath string) HookCreator { + CDIHook := &CDIHook{ + nvidiaCDIHookPath: nvidiaCDIHookPath, + } -func (c cdiHook) Create(name string, args ...string) Hook { - return Hook{ + return CDIHook +} + +func (c CDIHook) Create(name string, args ...string) *Hook { + if name == "create-symlinks" { + if len(args) == 0 { + return nil + } + + links := []string{} + for _, arg := range args { + links = append(links, "--link", arg) + } + args = links + } + + return &Hook{ Lifecycle: cdi.CreateContainerHook, - Path: string(c), + Path: c.nvidiaCDIHookPath, Args: append(c.requiredArgs(name), args...), } } -func (c cdiHook) requiredArgs(name string) []string { - base := filepath.Base(string(c)) + +func (c CDIHook) requiredArgs(name string) []string { + base := filepath.Base(c.nvidiaCDIHookPath) if base == "nvidia-ctk" { return []string{base, "hook", name} } diff --git a/internal/discover/ldconfig.go b/internal/discover/ldconfig.go index b81b9be59..3fab927af 100644 --- a/internal/discover/ldconfig.go +++ b/internal/discover/ldconfig.go @@ -25,12 +25,12 @@ import ( ) // NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified -func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHookPath, ldconfigPath string) (Discover, error) { +func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) { d := ldconfig{ - logger: logger, - nvidiaCDIHookPath: nvidiaCDIHookPath, - ldconfigPath: ldconfigPath, - mountsFrom: mounts, + logger: logger, + hookCreator: hookCreator, + ldconfigPath: ldconfigPath, + mountsFrom: mounts, } return &d, nil @@ -38,10 +38,10 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHoo type ldconfig struct { None - logger logger.Interface - nvidiaCDIHookPath string - ldconfigPath string - mountsFrom Discover + logger logger.Interface + hookCreator HookCreator + ldconfigPath string + mountsFrom Discover } // Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths. @@ -50,16 +50,18 @@ func (d ldconfig) Hooks() ([]Hook, error) { if err != nil { return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err) } - h := CreateLDCacheUpdateHook( - d.nvidiaCDIHookPath, + + h := createLDCacheUpdateHook( + d.hookCreator, d.ldconfigPath, getLibraryPaths(mounts), ) - return []Hook{h}, nil + + return h.Hooks() } -// CreateLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache -func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []string) Hook { +// createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache +func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook { var args []string if ldconfig != "" { @@ -70,13 +72,7 @@ func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []str args = append(args, "--folder", f) } - hook := CreateNvidiaCDIHook( - executable, - "update-ldcache", - args..., - ) - - return hook + return hookCreator.Create("update-ldcache", args...) } // getLibraryPaths extracts the library dirs from the specified mounts diff --git a/internal/discover/ldconfig_test.go b/internal/discover/ldconfig_test.go index 0b214c77b..ddbda4cc0 100644 --- a/internal/discover/ldconfig_test.go +++ b/internal/discover/ldconfig_test.go @@ -31,6 +31,7 @@ const ( func TestLDCacheUpdateHook(t *testing.T) { logger, _ := testlog.NewNullLogger() + hookCreator := NewHookCreator(testNvidiaCDIHookPath) testCases := []struct { description string @@ -97,7 +98,7 @@ func TestLDCacheUpdateHook(t *testing.T) { Lifecycle: "createContainer", } - d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCDIHookPath, tc.ldconfigPath) + d, err := NewLDCacheUpdateHook(logger, mountMock, hookCreator, tc.ldconfigPath) require.NoError(t, err) hooks, err := d.Hooks() diff --git a/internal/discover/symlinks.go b/internal/discover/symlinks.go index b7637aa26..a9cd811ad 100644 --- a/internal/discover/symlinks.go +++ b/internal/discover/symlinks.go @@ -23,20 +23,20 @@ import ( type additionalSymlinks struct { Discover - version string - nvidiaCDIHookPath string + version string + hookCreator HookCreator } // WithDriverDotSoSymlinks decorates the provided discoverer. // A hook is added that checks for specific driver symlinks that need to be created. -func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover { +func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover { if version == "" { version = "*.*" } return &additionalSymlinks{ - Discover: mounts, - nvidiaCDIHookPath: nvidiaCDIHookPath, - version: version, + Discover: mounts, + hookCreator: hookCreator, + version: version, } } @@ -73,8 +73,12 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) { return hooks, nil } - hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook) - return append(hooks, hook), nil + createSymlinkHooks, err := d.hookCreator.Create("create-symlinks", links...).Hooks() + if err != nil { + return nil, fmt.Errorf("failed to create symlink hook: %v", err) + } + + return append(hooks, createSymlinkHooks...), nil } // getLinksForMount maps the path to created links if any. diff --git a/internal/discover/symlinks_test.go b/internal/discover/symlinks_test.go index 7653b847b..2a6c98129 100644 --- a/internal/discover/symlinks_test.go +++ b/internal/discover/symlinks_test.go @@ -306,12 +306,13 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) { }, } + hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook") for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { d := WithDriverDotSoSymlinks( tc.discover, tc.version, - "/path/to/nvidia-cdi-hook", + hookCreator, ) devices, err := d.Devices() diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index 584391aa3..a0239df8e 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -36,7 +36,7 @@ import ( // NVIDIA_GDRCOPY=enabled // // If not devices are selected, no changes are made. -func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) { +func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil @@ -81,7 +81,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image // If the feature flag has explicitly been toggled, we don't make any modification. if !cfg.Features.DisableCUDACompatLibHook.IsEnabled() { - cudaCompatDiscoverer, err := getCudaCompatModeDiscoverer(logger, cfg, driver) + cudaCompatDiscoverer, err := getCudaCompatModeDiscoverer(logger, cfg, driver, hookCreator) if err != nil { return nil, fmt.Errorf("failed to construct CUDA Compat discoverer: %w", err) } @@ -91,13 +91,13 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image return NewModifierFromDiscoverer(logger, discover.Merge(discoverers...)) } -func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver) (discover.Discover, error) { +func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator discover.HookCreator) (discover.Discover, error) { // For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook. if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook { return nil, nil } - compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, cfg.NVIDIACTKConfig.Path, driver) + compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, driver) // For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook. if cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" { return compatLibHookDiscoverer, nil @@ -108,7 +108,7 @@ func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, dr ldcacheUpdateHookDiscoverer, err := discover.NewLDCacheUpdateHook( logger, discover.None{}, - cfg.NVIDIACTKConfig.Path, + hookCreator, "", ) if err != nil { diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 31aa6ef3b..6e602d7a9 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -29,18 +29,16 @@ import ( // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. -func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver) (oci.SpecModifier, error) { +func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { if required, reason := requiresGraphicsModifier(containerImage); !required { logger.Infof("No graphics modifier required: %v", reason) return nil, nil } - nvidiaCDIHookPath := cfg.NVIDIACTKConfig.Path - mounts, err := discover.NewGraphicsMountsDiscoverer( logger, driver, - nvidiaCDIHookPath, + hookCreator, ) if err != nil { return nil, fmt.Errorf("failed to create mounts discoverer: %v", err) @@ -52,7 +50,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI logger, containerImage.DevicesFromEnvvars(image.EnvVarNvidiaVisibleDevices), devRoot, - nvidiaCDIHookPath, + hookCreator, ) if err != nil { return nil, fmt.Errorf("failed to construct discoverer: %v", err) diff --git a/internal/platform-support/dgpu/by-path-hooks.go b/internal/platform-support/dgpu/by-path-hooks.go index cd38f5e72..b78720a24 100644 --- a/internal/platform-support/dgpu/by-path-hooks.go +++ b/internal/platform-support/dgpu/by-path-hooks.go @@ -27,11 +27,11 @@ import ( // byPathHookDiscoverer discovers the entities required for injecting by-path DRM device links type byPathHookDiscoverer struct { - logger logger.Interface - devRoot string - nvidiaCDIHookPath string - pciBusID string - deviceNodes discover.Discover + logger logger.Interface + devRoot string + hookCreator discover.HookCreator + pciBusID string + deviceNodes discover.Discover } var _ discover.Discover = (*byPathHookDiscoverer)(nil) @@ -53,18 +53,9 @@ func (d *byPathHookDiscoverer) Hooks() ([]discover.Hook, error) { return nil, nil } - var args []string - for _, l := range links { - args = append(args, "--link", l) - } - - hook := discover.CreateNvidiaCDIHook( - d.nvidiaCDIHookPath, - "create-symlinks", - args..., - ) + hook := d.hookCreator.Create("create-symlinks", links...) - return []discover.Hook{hook}, nil + return hook.Hooks() } // Mounts returns an empty slice for a full GPU diff --git a/internal/platform-support/dgpu/nvml.go b/internal/platform-support/dgpu/nvml.go index f24f4d552..2ad36a24e 100644 --- a/internal/platform-support/dgpu/nvml.go +++ b/internal/platform-support/dgpu/nvml.go @@ -58,11 +58,11 @@ func (o *options) newNvmlDGPUDiscoverer(d requiredInfo) (discover.Discover, erro ) byPathHooks := &byPathHookDiscoverer{ - logger: o.logger, - devRoot: o.devRoot, - nvidiaCDIHookPath: o.nvidiaCDIHookPath, - pciBusID: pciBusID, - deviceNodes: deviceNodes, + logger: o.logger, + devRoot: o.devRoot, + hookCreator: o.hookCreator, + pciBusID: pciBusID, + deviceNodes: deviceNodes, } dd := discover.Merge( diff --git a/internal/platform-support/dgpu/nvsandboxutils.go b/internal/platform-support/dgpu/nvsandboxutils.go index ebeea7c84..f8925e4a2 100644 --- a/internal/platform-support/dgpu/nvsandboxutils.go +++ b/internal/platform-support/dgpu/nvsandboxutils.go @@ -28,12 +28,12 @@ import ( ) type nvsandboxutilsDGPU struct { - lib nvsandboxutils.Interface - uuid string - devRoot string - isMig bool - nvidiaCDIHookPath string - deviceLinks []string + lib nvsandboxutils.Interface + uuid string + devRoot string + isMig bool + hookCreator discover.HookCreator + deviceLinks []string } var _ discover.Discover = (*nvsandboxutilsDGPU)(nil) @@ -53,11 +53,11 @@ func (o *options) newNvsandboxutilsDGPUDiscoverer(d UUIDer) (discover.Discover, } nvd := nvsandboxutilsDGPU{ - lib: o.nvsandboxutilslib, - uuid: uuid, - devRoot: strings.TrimSuffix(filepath.Clean(o.devRoot), "/dev"), - isMig: o.isMigDevice, - nvidiaCDIHookPath: o.nvidiaCDIHookPath, + lib: o.nvsandboxutilslib, + uuid: uuid, + devRoot: strings.TrimSuffix(filepath.Clean(o.devRoot), "/dev"), + isMig: o.isMigDevice, + hookCreator: o.hookCreator, } return &nvd, nil @@ -112,18 +112,9 @@ func (d *nvsandboxutilsDGPU) Hooks() ([]discover.Hook, error) { return nil, nil } - var args []string - for _, l := range d.deviceLinks { - args = append(args, "--link", l) - } - - hook := discover.CreateNvidiaCDIHook( - d.nvidiaCDIHookPath, - "create-symlinks", - args..., - ) + hook := d.hookCreator.Create("create-symlinks", d.deviceLinks...) - return []discover.Hook{hook}, nil + return hook.Hooks() } func (d *nvsandboxutilsDGPU) Mounts() ([]discover.Mount, error) { diff --git a/internal/platform-support/dgpu/options.go b/internal/platform-support/dgpu/options.go index 2fd1c01bd..6b2d62ce7 100644 --- a/internal/platform-support/dgpu/options.go +++ b/internal/platform-support/dgpu/options.go @@ -17,15 +17,16 @@ package dgpu import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" ) type options struct { - logger logger.Interface - devRoot string - nvidiaCDIHookPath string + logger logger.Interface + devRoot string + hookCreator discover.HookCreator isMigDevice bool // migCaps stores the MIG capabilities for the system. @@ -52,10 +53,10 @@ func WithLogger(logger logger.Interface) Option { } } -// WithNVIDIACDIHookPath sets the path to the NVIDIA Container Toolkit CLI path for the library -func WithNVIDIACDIHookPath(path string) Option { +// WithHookCreator sets the hook creator for the library +func WithHookCreator(hookCreator discover.HookCreator) Option { return func(l *options) { - l.nvidiaCDIHookPath = path + l.hookCreator = hookCreator } } diff --git a/internal/platform-support/tegra/csv.go b/internal/platform-support/tegra/csv.go index 9af38d715..ca760ec5c 100644 --- a/internal/platform-support/tegra/csv.go +++ b/internal/platform-support/tegra/csv.go @@ -59,7 +59,7 @@ func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) { targetsByType[csv.MountSpecLib], ), "", - o.nvidiaCDIHookPath, + o.hookCreator, ) // We process the explicitly requested symlinks. diff --git a/internal/platform-support/tegra/csv_test.go b/internal/platform-support/tegra/csv_test.go index dca09bb57..129bf00ca 100644 --- a/internal/platform-support/tegra/csv_test.go +++ b/internal/platform-support/tegra/csv_test.go @@ -181,13 +181,14 @@ func TestDiscovererFromCSVFiles(t *testing.T) { }, } + hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook") for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { defer setGetTargetsFromCSVFiles(tc.moutSpecs)() o := tegraOptions{ logger: logger, - nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook", + hookCreator: hookCreator, csvFiles: []string{"dummy"}, ignorePatterns: tc.ignorePatterns, symlinkLocator: tc.symlinkLocator, diff --git a/internal/platform-support/tegra/symlinks.go b/internal/platform-support/tegra/symlinks.go index cc677638e..822d482fd 100644 --- a/internal/platform-support/tegra/symlinks.go +++ b/internal/platform-support/tegra/symlinks.go @@ -26,9 +26,9 @@ import ( type symlinkHook struct { discover.None - logger logger.Interface - nvidiaCDIHookPath string - targets []string + logger logger.Interface + hookCreator discover.HookCreator + targets []string // The following can be overridden for testing symlinkChainLocator lookup.Locator @@ -39,7 +39,7 @@ type symlinkHook struct { func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover { return symlinkHook{ logger: o.logger, - nvidiaCDIHookPath: o.nvidiaCDIHookPath, + hookCreator: o.hookCreator, targets: targets, symlinkChainLocator: o.symlinkChainLocator, resolveSymlink: o.resolveSymlink, @@ -48,10 +48,7 @@ func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover // Hooks returns a hook to create the symlinks from the required CSV files func (d symlinkHook) Hooks() ([]discover.Hook, error) { - return discover.CreateCreateSymlinkHook( - d.nvidiaCDIHookPath, - d.getCSVFileSymlinks(), - ).Hooks() + return d.hookCreator.Create("create-symlinks", d.getCSVFileSymlinks()...).Hooks() } // getSymlinkCandidates returns a list of symlinks that are candidates for being created. diff --git a/internal/platform-support/tegra/tegra.go b/internal/platform-support/tegra/tegra.go index 1031fc726..6ad774b4e 100644 --- a/internal/platform-support/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -30,7 +30,7 @@ type tegraOptions struct { csvFiles []string driverRoot string devRoot string - nvidiaCDIHookPath string + hookCreator discover.HookCreator ldconfigPath string librarySearchPaths []string ignorePatterns ignoreMountSpecPatterns @@ -80,7 +80,7 @@ func New(opts ...Option) (discover.Discover, error) { return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) } - ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.nvidiaCDIHookPath, o.ldconfigPath) + ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.hookCreator, o.ldconfigPath) if err != nil { return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err) } @@ -133,10 +133,10 @@ func WithCSVFiles(csvFiles []string) Option { } } -// WithNVIDIACDIHookPath sets the path to the nvidia-cdi-hook binary. -func WithNVIDIACDIHookPath(nvidiaCDIHookPath string) Option { +// WithHookCreator sets the hook creator for the discoverer. +func WithHookCreator(hookCreator discover.HookCreator) Option { return func(o *tegraOptions) { - o.nvidiaCDIHookPath = nvidiaCDIHookPath + o.hookCreator = hookCreator } } diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index e88213dc3..c1a82ac93 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -21,6 +21,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" @@ -74,6 +75,8 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return nil, err } + hookCreator := discover.NewHookCreator(cfg.NVIDIACTKConfig.Path) + mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) // We update the mode here so that we can continue passing just the config to other functions. cfg.NVIDIAContainerRuntimeConfig.Mode = mode @@ -90,13 +93,13 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp case "nvidia-hook-remover": modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger)) case "graphics": - graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver) + graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver, hookCreator) if err != nil { return nil, err } modifiers = append(modifiers, graphicsModifier) case "feature-gated": - featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image, driver) + featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image, driver, hookCreator) if err != nil { return nil, err } diff --git a/pkg/nvcdi/common-nvml.go b/pkg/nvcdi/common-nvml.go index 6e9661cb7..fbb5f01d1 100644 --- a/pkg/nvcdi/common-nvml.go +++ b/pkg/nvcdi/common-nvml.go @@ -36,7 +36,7 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) { }, ) - graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.nvidiaCDIHookPath) + graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.hookCreator) if err != nil { l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err) } diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index f49f1129b..3fbc0e947 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -102,17 +102,17 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover driverDotSoSymlinksDiscoverer := discover.WithDriverDotSoSymlinks( libraries, version, - l.nvidiaCDIHookPath, + l.hookCreator, ) discoverers = append(discoverers, driverDotSoSymlinksDiscoverer) if l.HookIsSupported(HookEnableCudaCompat) { // TODO: The following should use the version directly. - cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.nvidiaCDIHookPath, l.driver) + cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver) discoverers = append(discoverers, cudaCompatLibHookDiscoverer) } - updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.nvidiaCDIHookPath, l.ldconfigPath) + updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath) discoverers = append(discoverers, updateLDCache) d := discover.Merge(discoverers...) diff --git a/pkg/nvcdi/driver-wsl.go b/pkg/nvcdi/driver-wsl.go index d184d7779..97e267c5b 100644 --- a/pkg/nvcdi/driver-wsl.go +++ b/pkg/nvcdi/driver-wsl.go @@ -39,7 +39,7 @@ var requiredDriverStoreFiles = []string{ } // newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers. -func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCDIHookPath, ldconfigPath string) (discover.Discover, error) { +func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string) (discover.Discover, error) { err := dxcore.Init() if err != nil { return nil, fmt.Errorf("failed to initialize dxcore: %v", err) @@ -56,11 +56,11 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCD } logger.Infof("Using WSL driver store paths: %v", driverStorePaths) - return newWSLDriverStoreDiscoverer(logger, driverRoot, nvidiaCDIHookPath, ldconfigPath, driverStorePaths) + return newWSLDriverStoreDiscoverer(logger, driverRoot, hookCreator, ldconfigPath, driverStorePaths) } // newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter. -func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvidiaCDIHookPath string, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) { +func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) { var searchPaths []string seen := make(map[string]bool) for _, path := range driverStorePaths { @@ -88,12 +88,12 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi ) symlinkHook := nvidiaSMISimlinkHook{ - logger: logger, - mountsFrom: libraries, - nvidiaCDIHookPath: nvidiaCDIHookPath, + logger: logger, + mountsFrom: libraries, + hookCreator: hookCreator, } - ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath) + ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, hookCreator, ldconfigPath) d := discover.Merge( libraries, @@ -106,9 +106,9 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi type nvidiaSMISimlinkHook struct { discover.None - logger logger.Interface - mountsFrom discover.Discover - nvidiaCDIHookPath string + logger logger.Interface + mountsFrom discover.Discover + hookCreator discover.HookCreator } // Hooks returns a hook that creates a symlink to nvidia-smi in the driver store. @@ -135,7 +135,7 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) { } link := "/usr/bin/nvidia-smi" links := []string{fmt.Sprintf("%s::%s", target, link)} - symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCDIHookPath, links) + symlinkHook := m.hookCreator.Create("create-symlinks", links...) return symlinkHook.Hooks() } diff --git a/pkg/nvcdi/driver-wsl_test.go b/pkg/nvcdi/driver-wsl_test.go index b9aac1a10..27247cc66 100644 --- a/pkg/nvcdi/driver-wsl_test.go +++ b/pkg/nvcdi/driver-wsl_test.go @@ -29,6 +29,7 @@ import ( func TestNvidiaSMISymlinkHook(t *testing.T) { logger, _ := testlog.NewNullLogger() + hookCreator := discover.NewHookCreator("nvidia-cdi-hook") errMounts := errors.New("mounts error") @@ -143,9 +144,9 @@ func TestNvidiaSMISymlinkHook(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { m := nvidiaSMISimlinkHook{ - logger: logger, - mountsFrom: tc.mounts, - nvidiaCDIHookPath: "nvidia-cdi-hook", + logger: logger, + mountsFrom: tc.mounts, + hookCreator: hookCreator, } devices, err := m.Devices() diff --git a/pkg/nvcdi/full-gpu-nvml.go b/pkg/nvcdi/full-gpu-nvml.go index 003515ca8..1b293ed8a 100644 --- a/pkg/nvcdi/full-gpu-nvml.go +++ b/pkg/nvcdi/full-gpu-nvml.go @@ -71,7 +71,7 @@ func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, erro deviceNodes, err := dgpu.NewForDevice(d, dgpu.WithDevRoot(l.devRoot), dgpu.WithLogger(l.logger), - dgpu.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath), + dgpu.WithHookCreator(l.hookCreator), dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib), ) if err != nil { @@ -81,7 +81,7 @@ func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, erro deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer( l.logger, l.devRoot, - l.nvidiaCDIHookPath, + l.hookCreator, deviceNodes, ) diff --git a/pkg/nvcdi/lib-csv.go b/pkg/nvcdi/lib-csv.go index 75ad00a4b..4d59941a2 100644 --- a/pkg/nvcdi/lib-csv.go +++ b/pkg/nvcdi/lib-csv.go @@ -44,7 +44,7 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) { tegra.WithLogger(l.logger), tegra.WithDriverRoot(l.driverRoot), tegra.WithDevRoot(l.devRoot), - tegra.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath), + tegra.WithHookCreator(l.hookCreator), tegra.WithLdconfigPath(l.ldconfigPath), tegra.WithCSVFiles(l.csvFiles), tegra.WithLibrarySearchPaths(l.librarySearchPaths...), diff --git a/pkg/nvcdi/lib-wsl.go b/pkg/nvcdi/lib-wsl.go index dd0e8db0c..82be607e0 100644 --- a/pkg/nvcdi/lib-wsl.go +++ b/pkg/nvcdi/lib-wsl.go @@ -54,7 +54,7 @@ func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { // GetCommonEdits generates a CDI specification that can be used for ANY devices func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) { - driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCDIHookPath, l.ldconfigPath) + driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.hookCreator, l.ldconfigPath) if err != nil { return nil, fmt.Errorf("failed to create discoverer for WSL driver: %v", err) } diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 8e7653b44..97a391682 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -23,6 +23,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvml/pkg/nvml" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" @@ -56,6 +57,7 @@ type nvcdilib struct { mergedDeviceOptions []transform.MergedDeviceOption disabledHooks disabledHooks + hookCreator discover.HookCreator } // New creates a new nvcdi library @@ -79,6 +81,9 @@ func New(opts ...Option) (Interface, error) { if l.nvidiaCDIHookPath == "" { l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" } + // create hookCreator + l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath) + if l.driverRoot == "" { l.driverRoot = "/" } diff --git a/pkg/nvcdi/management.go b/pkg/nvcdi/management.go index f0fa900e4..0d2f98703 100644 --- a/pkg/nvcdi/management.go +++ b/pkg/nvcdi/management.go @@ -138,7 +138,7 @@ func (m *managementlib) newManagementDeviceDiscoverer() (discover.Discover, erro deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer( m.logger, m.devRoot, - m.nvidiaCDIHookPath, + m.hookCreator, deviceNodes, ) diff --git a/pkg/nvcdi/mig-device-nvml.go b/pkg/nvcdi/mig-device-nvml.go index 5c1a504c2..729ade5ec 100644 --- a/pkg/nvcdi/mig-device-nvml.go +++ b/pkg/nvcdi/mig-device-nvml.go @@ -54,7 +54,7 @@ func (l *nvmllib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) deviceNodes, err := dgpu.NewForMigDevice(parent, mig, dgpu.WithDevRoot(l.devRoot), dgpu.WithLogger(l.logger), - dgpu.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath), + dgpu.WithHookCreator(l.hookCreator), dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib), ) if err != nil { diff --git a/pkg/nvcdi/workarounds-device-folder-permissions.go b/pkg/nvcdi/workarounds-device-folder-permissions.go index 511eb1fce..71967ac49 100644 --- a/pkg/nvcdi/workarounds-device-folder-permissions.go +++ b/pkg/nvcdi/workarounds-device-folder-permissions.go @@ -25,10 +25,10 @@ import ( ) type deviceFolderPermissions struct { - logger logger.Interface - devRoot string - nvidiaCDIHookPath string - devices discover.Discover + logger logger.Interface + devRoot string + devices discover.Discover + hookCreator discover.HookCreator } var _ discover.Discover = (*deviceFolderPermissions)(nil) @@ -39,12 +39,12 @@ var _ discover.Discover = (*deviceFolderPermissions)(nil) // The nested devices that are applicable to the NVIDIA GPU devices are: // - DRM devices at /dev/dri/* // - NVIDIA Caps devices at /dev/nvidia-caps/* -func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, nvidiaCDIHookPath string, devices discover.Discover) discover.Discover { +func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator discover.HookCreator, devices discover.Discover) discover.Discover { d := &deviceFolderPermissions{ - logger: logger, - devRoot: devRoot, - nvidiaCDIHookPath: nvidiaCDIHookPath, - devices: devices, + logger: logger, + devRoot: devRoot, + hookCreator: hookCreator, + devices: devices, } return d @@ -70,13 +70,9 @@ func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) { args = append(args, "--path", folder) } - hook := discover.CreateNvidiaCDIHook( - d.nvidiaCDIHookPath, - "chmod", - args..., - ) + hook := d.hookCreator.Create("chmod", args...) - return []discover.Hook{hook}, nil + return []discover.Hook{*hook}, nil } func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) {