diff --git a/test/distributed/launcher/test_run.py b/test/distributed/launcher/test_run.py index f71bffd527c1..d271e60954ae 100644 --- a/test/distributed/launcher/test_run.py +++ b/test/distributed/launcher/test_run.py @@ -273,10 +273,29 @@ def test_nproc_launch_unknown_configurations(self): ) @patch("torch.cuda.is_available", return_value=True) @patch("torch.cuda.device_count", return_value=3) - def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): + @patch("torch.accelerator.is_available", return_value=True) + @patch("torch.accelerator.device_count", return_value=3) + @patch("torch.accelerator.current_accelerator", return_value=MagicMock(type="gpu")) + def test_nproc_gpu_launch_configurations( + self, _mock1, _mock2, _mock3, _mock4, _mock5 + ): self._test_nproc_launch_configuration("auto", 3) self._test_nproc_launch_configuration("gpu", 3) + @skip_but_pass_in_sandcastle_if( + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" + ) + @patch("torch.xpu.is_available", return_value=True) + @patch("torch.xpu.device_count", return_value=3) + @patch("torch.accelerator.is_available", return_value=True) + @patch("torch.accelerator.device_count", return_value=3) + @patch("torch.accelerator.current_accelerator", return_value=MagicMock(type="xpu")) + def test_nproc_xpu_launch_configurations( + self, _mock1, _mock2, _mock3, _mock4, _mock5 + ): + self._test_nproc_launch_configuration("auto", 3) + self._test_nproc_launch_configuration("xpu", 3) + @skip_but_pass_in_sandcastle_if( TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index c37ecd8f72d8..bd1dfdb2a02f 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -77,7 +77,9 @@ .. note:: ``--nproc-per-node`` may be ``"gpu"`` (spawn one process per GPU), ``"cpu"`` (spawn one process per CPU), + ``"xpu"`` (spawn one process per XPU), ``"auto"`` (equivalent to ``"gpu"`` if CUDA is available, + else equivalent to ``"xpu"`` if XPU is available, else equivalent to ``"cpu"``), or an integer specifying the number of processes. See `torch.distributed.run.determine_local_world_size @@ -413,7 +415,7 @@ def get_args_parser() -> ArgumentParser: action=env, type=str, default="1", - help="Number of workers per node; supported values: [auto, cpu, gpu, int].", + help="Number of workers per node; supported values: [auto, cpu, gpu, xpu, int].", ) # @@ -694,21 +696,20 @@ def determine_local_world_size(nproc_per_node: str): raise ValueError("Cuda is not available.") from e device_type = "gpu" num_proc = torch.cuda.device_count() + elif nproc_per_node == "xpu": + if not torch.xpu.is_available(): + raise ValueError("Xpu is not available.") from e + device_type = "xpu" + num_proc = torch.xpu.device_count() elif nproc_per_node == torch._C._get_privateuse1_backend_name(): if not _get_custom_mod_func("is_available")(): raise ValueError(f"{nproc_per_node} is not available.") from e device_type = nproc_per_node num_proc = _get_custom_mod_func("device_count")() elif nproc_per_node == "auto": - if torch.cuda.is_available(): - num_proc = torch.cuda.device_count() - device_type = "gpu" - elif ( - hasattr(torch, torch._C._get_privateuse1_backend_name()) - and _get_custom_mod_func("is_available")() - ): - num_proc = _get_custom_mod_func("device_count")() - device_type = torch._C._get_privateuse1_backend_name() + if torch.accelerator.is_available(): + num_proc = torch.accelerator.device_count() + device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr] else: num_proc = os.cpu_count() device_type = "cpu"