|
2 | 2 | import sys
|
3 | 3 | import shlex
|
4 | 4 |
|
| 5 | +import importlib.util |
| 6 | +import pathlib |
| 7 | +import fnmatch |
| 8 | + |
5 | 9 | from setuptools import setup, find_packages
|
6 | 10 | from setuptools.extension import Extension
|
7 | 11 | from setuptools.command.build_ext import build_ext
|
@@ -91,6 +95,30 @@ def build_extensions(self):
|
91 | 95 | # Cuda detection
|
92 | 96 |
|
93 | 97 |
|
| 98 | +# partly taken from JAX |
| 99 | +# https://github.com/google/jax/blob/4cca2335220dcc953edd2ac764b2387e53527495/jax/_src/lib/__init__.py#L129 |
| 100 | +def get_cuda_paths_from_nvidia_pypi(): |
| 101 | + # try to check if nvidia-cuda-nvcc-cu* is installed |
| 102 | + # we need to get the site-packages of this install. to do so we use |
| 103 | + # mpi4py which must be installed |
| 104 | + mpi4py_spec = importlib.util.find_spec("mpi4py") |
| 105 | + depot_path = pathlib.Path(os.path.dirname(mpi4py_spec.origin)).parent |
| 106 | + |
| 107 | + # If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have |
| 108 | + # both of the things XLA looks for in the cuda path, namely bin/ptxas and |
| 109 | + # nvvm/libdevice/libdevice.10.bc |
| 110 | + # |
| 111 | + # The files are split in two sets of directories, so we return both |
| 112 | + maybe_cuda_paths = [ |
| 113 | + depot_path / "nvidia" / "cuda_nvcc", |
| 114 | + depot_path / "nvidia" / "cuda_runtime", |
| 115 | + ] |
| 116 | + if all(p.is_dir() for p in maybe_cuda_paths): |
| 117 | + return [str(p) for p in maybe_cuda_paths] |
| 118 | + else: |
| 119 | + return [] |
| 120 | + |
| 121 | + |
94 | 122 | # Taken from CUPY (MIT License)
|
95 | 123 | def get_cuda_path():
|
96 | 124 | nvcc_path = search_on_path(("nvcc", "nvcc.exe"))
|
@@ -169,22 +197,86 @@ def get_sycl_info():
|
169 | 197 | sycl_info = get_sycl_info()
|
170 | 198 |
|
171 | 199 |
|
| 200 | +def find_files(bases, pattern): |
| 201 | + """Return list of files matching pattern in base folders and subfolders.""" |
| 202 | + if isinstance(bases, (str, pathlib.Path)): |
| 203 | + bases = [bases] |
| 204 | + |
| 205 | + result = [] |
| 206 | + for base in bases: |
| 207 | + for root, dirs, files in os.walk(base): |
| 208 | + for name in files: |
| 209 | + if fnmatch.fnmatch(name, pattern): |
| 210 | + result.append(os.path.join(root, name)) |
| 211 | + return result |
| 212 | + |
| 213 | + |
172 | 214 | def get_cuda_info():
|
173 |
| - cuda_info = {"compile": [], "libdirs": [], "libs": []} |
174 |
| - cuda_path = get_cuda_path() |
175 |
| - if not cuda_path: |
176 |
| - return cuda_info |
| 215 | + cuda_info = {"compile": [], "libdirs": [], "libs": [], "rpaths": []} |
177 | 216 |
|
178 |
| - incdir = os.path.join(cuda_path, "include") |
179 |
| - if os.path.isdir(incdir): |
180 |
| - cuda_info["compile"].append(incdir) |
| 217 | + # First check if the nvidia-cuda-nvcc-cu* package is installed. We ignore CUDA_ROOT |
| 218 | + # because that is the same behaviour of jax. |
| 219 | + cuda_paths = get_cuda_paths_from_nvidia_pypi() |
181 | 220 |
|
182 |
| - for libdir in ("lib64", "lib"): |
183 |
| - full_dir = os.path.join(cuda_path, libdir) |
184 |
| - if os.path.isdir(full_dir): |
185 |
| - cuda_info["libdirs"].append(full_dir) |
| 221 | + # If not, try to find the CUDA_PATH by hand |
| 222 | + if len(cuda_paths) > 0: |
| 223 | + nvidia_pypi_package = True |
| 224 | + else: |
| 225 | + nvidia_pypi_package = False |
| 226 | + _cuda_path = get_cuda_path() |
| 227 | + if _cuda_path is None: |
| 228 | + cuda_paths = [] |
| 229 | + else: |
| 230 | + cuda_paths = [_cuda_path] |
186 | 231 |
|
187 |
| - cuda_info["libs"].append("cudart") |
| 232 | + if len(cuda_paths) == 0: |
| 233 | + return cuda_info |
| 234 | + |
| 235 | + for cuda_path in cuda_paths: |
| 236 | + incdir = os.path.join(cuda_path, "include") |
| 237 | + if os.path.isdir(incdir): |
| 238 | + cuda_info["compile"].append(incdir) |
| 239 | + |
| 240 | + for libdir in ("lib64", "lib"): |
| 241 | + full_dir = os.path.join(cuda_path, libdir) |
| 242 | + if os.path.isdir(full_dir): |
| 243 | + cuda_info["libdirs"].append(full_dir) |
| 244 | + |
| 245 | + # We need to link against libcudart.so |
| 246 | + # - If we are using standard CUDA installations, we simply add a link flag to |
| 247 | + # libcudart.so |
| 248 | + # - If we are using the nvidia-cuda-nvcc-cu* package, we need to find the exact |
| 249 | + # version of libcudart.so to link against because the the package does not provide |
| 250 | + # a generic binding to libcudart.so but only libcudart.so.XX. |
| 251 | + # |
| 252 | + # Moreover, if we are using nvidia-cuda-nvcc we must add @rpath (runtime search paths) |
| 253 | + # because we do not expect the user to set LD_LIBRARY_PATH to the nvidia-cuda-nvcc |
| 254 | + # package. |
| 255 | + if not nvidia_pypi_package: |
| 256 | + cuda_info["libs"].append("cudart") |
| 257 | + else: |
| 258 | + possible_libcudart = find_files(cuda_paths, "libcudart.so*") |
| 259 | + |
| 260 | + if "libcudart.so" in possible_libcudart: |
| 261 | + # If generic symlink is present, use standard linker flag. |
| 262 | + # In theory with nvidia-cuda-nvcc-cu12 we should never reach this point |
| 263 | + # But in the future they might fix it. |
| 264 | + cuda_info["libs"].append("cudart") |
| 265 | + elif len(possible_libcudart) > 0: |
| 266 | + # This should be the standard case for nvidia-cuda-nvcc-cu* |
| 267 | + # where we find a library libcudart.so.XX . The syntax to link to a |
| 268 | + # specific version is -l:libcudart.so.XX |
| 269 | + # We arbitrarily choose the first one |
| 270 | + # and we add the runtime search path accordingly |
| 271 | + lib_to_link = possible_libcudart[0] |
| 272 | + cuda_info["libs"].append(f":{os.path.basename(lib_to_link)}") |
| 273 | + cuda_info["rpaths"].append(os.path.dirname(lib_to_link)) |
| 274 | + else: |
| 275 | + # If we cannot find libcudart.so, we cannot build the extension |
| 276 | + # This should never happen with nvidia-cuda-nvcc-cu* package |
| 277 | + cuda_info["libs"].append("cudart") |
| 278 | + |
| 279 | + print("\n\nCUDA INFO:", cuda_info, "\n\n") |
188 | 280 | return cuda_info
|
189 | 281 |
|
190 | 282 |
|
@@ -237,13 +329,18 @@ def get_extensions():
|
237 | 329 | )
|
238 | 330 |
|
239 | 331 | if cuda_info["compile"] and cuda_info["libdirs"]:
|
| 332 | + extra_extension_args = {} |
| 333 | + if len(cuda_info["rpaths"]) > 0: |
| 334 | + extra_extension_args["runtime_library_dirs"] = cuda_info["rpaths"] |
| 335 | + |
240 | 336 | extensions.append(
|
241 | 337 | Extension(
|
242 | 338 | name=f"{CYTHON_SUBMODULE_NAME}.mpi_xla_bridge_gpu",
|
243 | 339 | sources=[f"{CYTHON_SUBMODULE_PATH}/mpi_xla_bridge_gpu.pyx"],
|
244 | 340 | include_dirs=cuda_info["compile"],
|
245 | 341 | library_dirs=cuda_info["libdirs"],
|
246 | 342 | libraries=cuda_info["libs"],
|
| 343 | + **extra_extension_args, |
247 | 344 | )
|
248 | 345 | )
|
249 | 346 | else:
|
|
0 commit comments