diff --git a/integration/gpu_test.go b/integration/gpu_test.go index 47edef4..f3fc1ac 100644 --- a/integration/gpu_test.go +++ b/integration/gpu_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/envbox/integration/integrationtest" @@ -41,8 +42,7 @@ func TestDocker_Nvidia(t *testing.T) { ) // Assert that we can run nvidia-smi in the inner container. - _, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "nvidia-smi") - require.NoError(t, err, "failed to run nvidia-smi in the inner container") + assertInnerNvidiaSMI(ctx, t, ctID) }) t.Run("Redhat", func(t *testing.T) { @@ -52,16 +52,29 @@ func TestDocker_Nvidia(t *testing.T) { // Start the envbox container. ctID := startEnvboxCmd(ctx, t, integrationtest.RedhatImage, "root", - "-v", "/usr/lib/x86_64-linux-gnu:/var/coder/usr/lib64", + "-v", "/usr/lib/x86_64-linux-gnu:/var/coder/usr/lib", "--env", "CODER_ADD_GPU=true", - "--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib64", + "--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib", "--runtime=nvidia", "--gpus=all", ) // Assert that we can run nvidia-smi in the inner container. - _, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "nvidia-smi") - require.NoError(t, err, "failed to run nvidia-smi in the inner container") + assertInnerNvidiaSMI(ctx, t, ctID) + + // Make sure dnf still works. This checks for a regression due to + // gpuExtraRegex matching `libglib.so` in the outer container. + // This had a dependency on `libpcre.so.3` which would cause dnf to fail. + out, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "dnf") + if !assert.NoError(t, err, "failed to run dnf in the inner container") { + t.Logf("dnf output:\n%s", strings.TrimSpace(out)) + } + + // Make sure libglib.so is not present in the inner container. + out, err = execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "ls", "-1", "/usr/lib/x86_64-linux-gnu/libglib*") + // An error is expected here. + assert.Error(t, err, "libglib should not be present in the inner container") + assert.Contains(t, out, "No such file or directory", "libglib should not be present in the inner container") }) t.Run("InnerUsrLibDirOverride", func(t *testing.T) { @@ -79,11 +92,58 @@ func TestDocker_Nvidia(t *testing.T) { "--gpus=all", ) - // Assert that the libraries end up in the expected location in the inner container. - out, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "ls", "-l", "/usr/lib/coder") + // Assert that the libraries end up in the expected location in the inner + // container. + out, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "ls", "-1", "/usr/lib/coder") require.NoError(t, err, "inner usr lib dir override failed") require.Regexp(t, `(?i)(libgl|nvidia|vulkan|cuda)`, out) }) + + t.Run("EmptyHostUsrLibDir", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + emptyUsrLibDir := t.TempDir() + + // Start the envbox container. + ctID := startEnvboxCmd(ctx, t, integrationtest.UbuntuImage, "root", + "-v", emptyUsrLibDir+":/var/coder/usr/lib", + "--env", "CODER_ADD_GPU=true", + "--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib", + "--runtime=nvidia", + "--gpus=all", + ) + + ofs := outerFiles(ctx, t, ctID, "/usr/lib/x86_64-linux-gnu/libnv*") + // Assert invariant: the outer container has the files we expect. + require.NotEmpty(t, ofs, "failed to list outer container files") + // Assert that expected files are available in the inner container. + assertInnerFiles(ctx, t, ctID, "/usr/lib/x86_64-linux-gnu/libnv*", ofs...) + assertInnerNvidiaSMI(ctx, t, ctID) + }) + + t.Run("CUDASample", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Start the envbox container. + ctID := startEnvboxCmd(ctx, t, integrationtest.CUDASampleImage, "root", + "-v", "/usr/lib/x86_64-linux-gnu:/var/coder/usr/lib", + "--env", "CODER_ADD_GPU=true", + "--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib", + "--runtime=nvidia", + "--gpus=all", + ) + + // Assert that we can run nvidia-smi in the inner container. + assertInnerNvidiaSMI(ctx, t, ctID) + + // Assert that /tmp/vectorAdd runs successfully in the inner container. + _, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "/tmp/vectorAdd") + require.NoError(t, err, "failed to run /tmp/vectorAdd in the inner container") + }) } // dockerRuntimes returns the list of container runtimes available on the host. @@ -101,6 +161,49 @@ func dockerRuntimes(t *testing.T) []string { return strings.Split(raw, "\n") } +// outerFiles returns the list of files in the outer container matching the +// given pattern. It does this by running `ls -1` in the outer container. +func outerFiles(ctx context.Context, t *testing.T, containerID, pattern string) []string { + t.Helper() + // We need to use /bin/sh -c to avoid the shell interpreting the glob. + out, err := execContainerCmd(ctx, t, containerID, "/bin/sh", "-c", "ls -1 "+pattern) + require.NoError(t, err, "failed to list outer container files") + files := strings.Split(strings.TrimSpace(out), "\n") + slices.Sort(files) + return files +} + +// assertInnerFiles checks that all the files matching the given pattern exist in the +// inner container. +func assertInnerFiles(ctx context.Context, t *testing.T, containerID, pattern string, expected ...string) { + t.Helper() + + // Get the list of files in the inner container. + // We need to use /bin/sh -c to avoid the shell interpreting the glob. + out, err := execContainerCmd(ctx, t, containerID, "docker", "exec", "workspace_cvm", "/bin/sh", "-c", "ls -1 "+pattern) + require.NoError(t, err, "failed to list inner container files") + innerFiles := strings.Split(strings.TrimSpace(out), "\n") + + // Check that the expected files exist in the inner container. + missingFiles := make([]string, 0) + for _, expectedFile := range expected { + if !slices.Contains(innerFiles, expectedFile) { + missingFiles = append(missingFiles, expectedFile) + } + } + require.Empty(t, missingFiles, "missing files in inner container: %s", strings.Join(missingFiles, ", ")) +} + +// assertInnerNvidiaSMI checks that nvidia-smi runs successfully in the inner +// container. +func assertInnerNvidiaSMI(ctx context.Context, t *testing.T, containerID string) { + t.Helper() + // Assert that we can run nvidia-smi in the inner container. + out, err := execContainerCmd(ctx, t, containerID, "docker", "exec", "workspace_cvm", "nvidia-smi") + require.NoError(t, err, "failed to run nvidia-smi in the inner container") + require.Contains(t, out, "NVIDIA-SMI", "nvidia-smi output does not contain NVIDIA-SMI") +} + // startEnvboxCmd starts the envbox container with the given arguments. // Ideally we would use ory/dockertest for this, but it doesn't support // specifying the runtime. We have alternatively used the docker client library, diff --git a/integration/integrationtest/docker.go b/integration/integrationtest/docker.go index 4390ee2..5b42bbe 100644 --- a/integration/integrationtest/docker.go +++ b/integration/integrationtest/docker.go @@ -42,6 +42,9 @@ const ( UbuntuImage = "gcr.io/coder-dev-1/sreya/ubuntu-coder" // Redhat UBI9 image as of 2025-03-05 RedhatImage = "registry.access.redhat.com/ubi9/ubi:9.5" + // CUDASampleImage is a CUDA sample image from NVIDIA's container registry. + // It contains a binary /tmp/vectorAdd which can be run to test the CUDA setup. + CUDASampleImage = "nvcr.io/nvidia/k8s/cuda-sample:vectoradd-cuda10.2" // RegistryImage is used to assert that we add certs // correctly to the docker daemon when pulling an image diff --git a/xunix/gpu.go b/xunix/gpu.go index 0708667..a9129d5 100644 --- a/xunix/gpu.go +++ b/xunix/gpu.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "regexp" + "slices" "sort" "strings" @@ -17,9 +18,9 @@ import ( ) var ( - gpuMountRegex = regexp.MustCompile("(?i)(nvidia|vulkan|cuda)") - gpuExtraRegex = regexp.MustCompile("(?i)(libgl|nvidia|vulkan|cuda)") - gpuEnvRegex = regexp.MustCompile("(?i)nvidia") + gpuMountRegex = regexp.MustCompile(`(?i)(nvidia|vulkan|cuda)`) + gpuExtraRegex = regexp.MustCompile(`(?i)(libgl(e|sx|\.)|nvidia|vulkan|cuda)`) + gpuEnvRegex = regexp.MustCompile(`(?i)nvidia`) sharedObjectRegex = regexp.MustCompile(`\.so(\.[0-9\.]+)?$`) ) @@ -39,6 +40,7 @@ func GPUEnvs(ctx context.Context) []string { func GPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]Device, []mount.MountPoint, error) { var ( + afs = GetFS(ctx) mounter = Mounter(ctx) devices = []Device{} binds = []mount.MountPoint{} @@ -64,6 +66,22 @@ func GPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]Device, []m // If it's not in /dev treat it as a bind mount. binds = append(binds, m) + // We also want to find any symlinks that point to the target. + // This is important for the nvidia driver as it mounts the driver + // files with the driver version appended to the end, and creates + // symlinks that point to the actual files. + links, err := SameDirSymlinks(afs, m.Path) + if err != nil { + log.Error(ctx, "find symlinks", slog.F("path", m.Path), slog.Error(err)) + } else { + for _, link := range links { + log.Debug(ctx, "found symlink", slog.F("link", link), slog.F("target", m.Path)) + binds = append(binds, mount.MountPoint{ + Path: link, + Opts: []string{"ro"}, + }) + } + } } } @@ -104,7 +122,11 @@ func usrLibGPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]mount return nil } - if !sharedObjectRegex.MatchString(path) || !gpuExtraRegex.MatchString(path) { + if !gpuExtraRegex.MatchString(path) { + return nil + } + + if !sharedObjectRegex.MatchString(path) { return nil } @@ -176,6 +198,75 @@ func recursiveSymlinks(afs FS, mountpoint string, path string) ([]string, error) return paths, nil } +// SameDirSymlinks returns all links in the same directory as `target` that +// point to target, either indirectly or directly. Only symlinks in the same +// directory as `target` are considered. +func SameDirSymlinks(afs FS, target string) ([]string, error) { + // Get the list of files in the directory of the target. + fis, err := afero.ReadDir(afs, filepath.Dir(target)) + if err != nil { + return nil, xerrors.Errorf("read dir %q: %w", filepath.Dir(target), err) + } + + // Do an initial pass to map all symlinks to their destinations. + allLinks := make(map[string]string) + for _, fi := range fis { + // Ignore non-symlinks. + if fi.Mode()&os.ModeSymlink == 0 { + continue + } + + absPath := filepath.Join(filepath.Dir(target), fi.Name()) + link, err := afs.Readlink(filepath.Join(filepath.Dir(target), fi.Name())) + if err != nil { + return nil, xerrors.Errorf("readlink %q: %w", fi.Name(), err) + } + + if !filepath.IsAbs(link) { + link = filepath.Join(filepath.Dir(target), link) + } + allLinks[absPath] = link + } + + // Now we can start checking for symlinks that point to the target. + var ( + found = make([]string, 0) + // Set an arbitrary upper limit to prevent infinite loops. + maxIterations = 10 + ) + for range maxIterations { + var foundThisTime bool + for linkName, linkDest := range allLinks { + // Ignore symlinks that point outside of target's directory. + if filepath.Dir(linkName) != filepath.Dir(target) { + continue + } + + // If the symlink points to the target, add it to the list. + if linkDest == target { + if !slices.Contains(found, linkName) { + found = append(found, linkName) + foundThisTime = true + } + } + + // If the symlink points to another symlink that we already determined + // points to the target, add it to the list. + if slices.Contains(found, linkDest) { + if !slices.Contains(found, linkName) { + found = append(found, linkName) + foundThisTime = true + } + } + } + // If we didn't find any new symlinks, we're done. + if !foundThisTime { + break + } + } + return found, nil +} + // TryUnmountProcGPUDrivers unmounts any GPU-related mounts under /proc as it causes // issues when creating any container in some cases. Errors encountered while // unmounting are treated as non-fatal. diff --git a/xunix/gpu_test.go b/xunix/gpu_test.go index f8d8d47..4324fcf 100644 --- a/xunix/gpu_test.go +++ b/xunix/gpu_test.go @@ -2,10 +2,13 @@ package xunix_test import ( "context" + "os" "path/filepath" + "sort" "testing" "github.com/spf13/afero" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/mount-utils" @@ -59,11 +62,17 @@ func TestGPUs(t *testing.T) { filepath.Join(usrLibMountpoint, "nvidia", "libglxserver_nvidia.so.1"), } - // fakeUsrLibFiles are files that should be written to the "mounted" - // /usr/lib directory. It includes files that shouldn't be returned. - fakeUsrLibFiles = append([]string{ + // fakeUsrLibFiles are files that we do not expect to be returned + // bind mounts for. + fakeUsrLibFiles = []string{ filepath.Join(usrLibMountpoint, "libcurl-gnutls.so"), - }, expectedUsrLibFiles...) + filepath.Join(usrLibMountpoint, "libglib.so"), + } + + // allUsrLibFiles are all the files that should be written to the + // "mounted" /usr/lib directory. It includes files that shouldn't + // be returned. + allUsrLibFiles = append(expectedUsrLibFiles, fakeUsrLibFiles...) ) ctx := xunix.WithFS(context.Background(), fs) @@ -90,10 +99,14 @@ func TestGPUs(t *testing.T) { err := fs.MkdirAll(filepath.Join(usrLibMountpoint, "nvidia"), 0o755) require.NoError(t, err) - for _, file := range fakeUsrLibFiles { + for _, file := range allUsrLibFiles { _, err = fs.Create(file) require.NoError(t, err) } + for _, mp := range mounter.MountPoints { + _, err = fs.Create(mp.Path) + require.NoError(t, err) + } devices, binds, err := xunix.GPUs(ctx, log, usrLibMountpoint) require.NoError(t, err) @@ -110,5 +123,111 @@ func TestGPUs(t *testing.T) { Opts: []string{"ro"}, }) } + for _, file := range fakeUsrLibFiles { + require.NotContains(t, binds, mount.MountPoint{ + Path: file, + Opts: []string{"ro"}, + }) + } }) } + +func Test_SameDirSymlinks(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + // We need to test with a real filesystem as the fake one doesn't + // support creating symlinks. + tmpDir = t.TempDir() + // We do test with the interface though! + afs = xunix.GetFS(ctx) + ) + + // Create some files in the temporary directory. + _, err := os.Create(filepath.Join(tmpDir, "file1.real")) + require.NoError(t, err, "create file") + _, err = os.Create(filepath.Join(tmpDir, "file2.real")) + require.NoError(t, err, "create file2") + _, err = os.Create(filepath.Join(tmpDir, "file3.real")) + require.NoError(t, err, "create file3") + _, err = os.Create(filepath.Join(tmpDir, "file4.real")) + require.NoError(t, err, "create file4") + + // Create a symlink to the file in the temporary directory. + // This needs to be done by the real os package. + err = os.Symlink(filepath.Join(tmpDir, "file1.real"), filepath.Join(tmpDir, "file1.link1")) + require.NoError(t, err, "create first symlink") + + // Create another symlink to the previous symlink. + err = os.Symlink(filepath.Join(tmpDir, "file1.link1"), filepath.Join(tmpDir, "file1.link2")) + require.NoError(t, err, "create second symlink") + + // Create a symlink to a file outside of the temporary directory. + err = os.MkdirAll(filepath.Join(tmpDir, "dir"), 0o755) + require.NoError(t, err, "create dir") + // Create a symlink from file2 to inside the dir. + err = os.Symlink(filepath.Join(tmpDir, "file2.real"), filepath.Join(tmpDir, "dir", "file2.link1")) + require.NoError(t, err, "create dir symlink") + + // Create a symlink with a relative path. To do this, we need to + // change the working directory to the temporary directory. + oldWorkingDir, err := os.Getwd() + require.NoError(t, err, "get working dir") + // Change the working directory to the temporary directory. + require.NoError(t, os.Chdir(tmpDir), "change working dir") + err = os.Symlink(filepath.Join(tmpDir, "file4.real"), "file4.link1") + require.NoError(t, err, "create relative symlink") + // Change the working directory back to the original. + require.NoError(t, os.Chdir(oldWorkingDir), "change working dir back") + + for _, tt := range []struct { + name string + expected []string + }{ + { + // Two symlinks to the same file. + name: "file1.real", + expected: []string{ + filepath.Join(tmpDir, "file1.link1"), + filepath.Join(tmpDir, "file1.link2"), + }, + }, + { + // Mid-way in the symlink chain. + name: "file1.link1", + expected: []string{ + filepath.Join(tmpDir, "file1.link2"), + }, + }, + { + // End of the symlink chain. + name: "file1.link2", + expected: []string{}, + }, + { + // Symlink to a file outside of the temporary directory. + name: "file2.real", + expected: []string{}, + }, + { + // No symlinks to this file. + name: "file3.real", + expected: []string{}, + }, + { + // One relative symlink. + name: "file4.real", + expected: []string{filepath.Join(tmpDir, "file4.link1")}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + fullPath := filepath.Join(tmpDir, tt.name) + actual, err := xunix.SameDirSymlinks(afs, fullPath) + require.NoError(t, err, "find symlink") + sort.Strings(actual) + assert.Equal(t, tt.expected, actual, "find symlinks %q", tt.name) + }) + } +}