# Owner(s): ["oncall: distributed"] import sys import unittest from functools import partial, wraps import torch import torch.distributed as dist import torch.distributed._functional_collectives as ft_c import torch.distributed._tensor as dt import torch.distributed.distributed_c10d as c10d from functorch import make_fx from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.inductor_utils import HAS_GPU if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) from torch.testing._internal.common_distributed import ( DistributedTestBase, MultiThreadedTestCase, requires_nccl, TEST_SKIPS, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, skipIfHpu, TEST_CUDA, TEST_HPU, TestCase, ) # NOTE: Instructions for adding new device types to this test file # # This test file contains two types of tests: # 1. Tests that run on both CPUs and accelerators # 2. Tests that run only on accelerators # # We use two variables to manage device types: # - `devices`: A list containing device types for both CPU and accelerator tests # - `DEVICE`: A string containing only the accelerator type for accelerator-only tests # # To add a new device type: # 1. Add a new `elif` statement in the if-else ladder below # 2. Check for the presence of your device (e.g., TEST_NEW_DEVICE) # 3. Append your device type to the `devices` list # 4. Assign your device type string to `DEVICE` # # Example: # elif TEST_NEW_DEVICE: # devices.append("new_device") # DEVICE = "new_device" DEVICE = "cuda" devices = ["cpu"] if TEST_HPU: devices.append("hpu") DEVICE = "hpu" elif TEST_CUDA: devices.append("cuda") def new_subgroups(group_size: int, pg_tag=None): world_size = dist.get_world_size() subgroups = [] cur_subgroup = None for subgroup_id in range(world_size // group_size): start_rank = subgroup_id * group_size end_rank = start_rank + group_size ranks_in_subgroup = list(range(start_rank, end_rank)) subgroup = c10d._new_group_with_tag( ranks=ranks_in_subgroup, pg_tag=pg_tag, ) subgroups.append(subgroup) rank = dist.get_rank() if rank in ranks_in_subgroup: cur_subgroup = subgroup return cur_subgroup, subgroups @skipIfHpu class TestExpand(MultiThreadedTestCase): @property def world_size(self): return 4 def setUp(self): super().setUp() self._spawn_threads() def test_expand_1d_rank_list(self): tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3]) self.assertEqual("", tag) self.assertEqual([0, 1, 2, 3], rankset) self.assertEqual(4, group_size) tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla") self.assertEqual("bla", tag) def test_expand_2d_rank_list(self): tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]]) self.assertEqual("", tag) self.assertEqual([0, 1, 2, 3], rankset) self.assertEqual(2, group_size) tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu") self.assertEqual("blu", tag) with self.assertRaisesRegex(ValueError, "group sizes must be identical"): ft_c._expand_group([[0], [1, 2, 3]]) def test_expand_process_group(self): tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD) self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag) self.assertEqual([0, 1, 2, 3], rankset) self.assertEqual(4, group_size) tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla") self.assertEqual("bla", tag) my_pg, _ = new_subgroups(group_size=2) tag, rankset, group_size = ft_c._expand_group(my_pg) self.assertEqual(c10d._get_group_tag(my_pg), tag) self.assertEqual(dist.get_process_group_ranks(my_pg), rankset) self.assertEqual(2, group_size) my_pg = None for i in range(dist.get_world_size()): group = c10d._new_group_with_tag([i], pg_tag="my_pg") if i == dist.get_rank(): my_pg = group tag, rankset, group_size = ft_c._expand_group(my_pg) self.assertEqual("my_pg", tag) self.assertEqual([dist.get_rank()], rankset) self.assertEqual(1, group_size) tag, rankset, group_size = ft_c._expand_group(my_pg, "bla") self.assertEqual("bla", tag) def test_expand_device_mesh(self): mesh = dt.DeviceMesh("cpu", torch.arange(4)) tag, rankset, group_size = ft_c._expand_group(mesh) self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag) self.assertEqual([0, 1, 2, 3], rankset) self.assertEqual(4, group_size) mesh = dt.DeviceMesh("cpu", torch.arange(4)) tag, rankset, group_size = ft_c._expand_group(mesh) self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag) self.assertEqual([0, 1, 2, 3], rankset) self.assertEqual(4, group_size) def test_expand_device_mesh_tuple(self): mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2)) with self.assertRaisesRegex(AssertionError, "Only 1D mesh"): tag, rankset, group_size = ft_c._expand_group(mesh) tag, rankset, group_size = ft_c._expand_group((mesh, 0)) self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag) expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3] self.assertEqual(expected_rankset, rankset) self.assertEqual(2, group_size) tag, rankset, group_size = ft_c._expand_group((mesh, 1)) expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3] self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag) self.assertEqual(expected_rankset, rankset) self.assertEqual(2, group_size) @skipIfHpu class TestPgTag(MultiThreadedTestCase): @property def world_size(self): return 4 def setUp(self): super().setUp() self._spawn_threads() """ The behavior we want is as follow: - rankset+tag will always result in the same PG. Do we enforce this by failing creation of new PGs or returning existing ones? Return existing one. - default tag gives existing behavior. This means we should create duplicates. - _expand_group on _default-tagged pg should always resolve to it This mean we can't depend on empty tag + rankset. """ def test_pg_creation_with_tag(self): my_group, _ = new_subgroups(group_size=2, pg_tag="blu") my_group2, _ = new_subgroups(group_size=2, pg_tag="blu") self.assertEqual(my_group, my_group2) my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2") self.assertNotEqual(my_group, my_group3) my_group4, _ = new_subgroups(group_size=2) self.assertNotEqual(my_group, my_group4) my_group5, _ = new_subgroups(group_size=2) self.assertNotEqual(my_group4, my_group5) def test_pg_lookup_roundtrip(self): pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu") pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2") pg_notag0, _ = new_subgroups(group_size=2) pg_notag1, _ = new_subgroups(group_size=2) def roundtrip(pg): tag, rankset, _ = ft_c._expand_group(pg) return c10d._find_pg_by_ranks_and_tag(tag, rankset) self.assertEqual(pg_tag0, roundtrip(pg_tag0)) self.assertEqual(pg_tag1, roundtrip(pg_tag1)) self.assertEqual(pg_notag0, roundtrip(pg_notag0)) self.assertEqual(pg_notag1, roundtrip(pg_notag1)) def test_pg_lookup_with_tag(self): pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu") pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla") pg_notag0, _ = new_subgroups(group_size=2) def roundtrip(pg, pg_tag): tag, rankset, _ = ft_c._expand_group(pg, pg_tag) return c10d._find_pg_by_ranks_and_tag(tag, rankset) self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu")) self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu")) # Cannot erase the tag of a PG self.assertEqual(pg_tag0, roundtrip(pg_tag0, "")) def test_find_or_create_pg(self): pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2) pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu") self.assertEqual(pg, pg_tag0) def test_find_root_pg(self): pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3]) self.assertEqual(dist.group.WORLD, pg) @instantiate_parametrized_tests @skipIfHpu class TestTraceableCollectives(MultiThreadedTestCase): @property def world_size(self): return 4 def setUp(self): super().setUp() self._spawn_threads() @parametrize("device", devices) def test_broadcast(self, device): if device == "cuda": if torch.cuda.device_count() < self.world_size: self.skipTest("Not enough CUDA devices") torch.cuda.set_device(dist.get_rank()) if dist.get_rank() == 0: tensor = torch.ones([4], device=device) else: tensor = torch.zeros([4], device=device) mesh = dt.DeviceMesh(device, torch.arange(4)) res = ft_c.broadcast(tensor, 0, mesh) self.assertEqual(res, torch.ones([4], device=device)) @parametrize("device", devices) def test_all_reduce_eager(self, device): if device == "cuda": if torch.cuda.device_count() < self.world_size: self.skipTest("Not enough CUDA devices") torch.cuda.set_device(dist.get_rank()) tensor = torch.ones([4], device=device) mesh = dt.DeviceMesh(device, torch.arange(4)) res = ft_c.all_reduce(tensor, "sum", mesh) self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float)) mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2)) res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1)) self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float)) @parametrize("device", devices) def test_all_reduce_coalesced_eager(self, device): if device == "cuda": if torch.cuda.device_count() < self.world_size: self.skipTest("Not enough CUDA devices") torch.cuda.set_device(dist.get_rank()) t0 = torch.ones([4], device=device) t1 = torch.ones([6], device=device) + 2 mesh = dt.DeviceMesh(device, torch.arange(4)) res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh) self.assertEqual(res[0], t0 * 4) self.assertEqual(res[1], t1 * 4) @parametrize("device", devices) def test_all_gather_tensor(self, device): if device == "cuda": if torch.cuda.device_count() < self.world_size: self.skipTest("Not enough CUDA devices") torch.cuda.set_device(dist.get_rank()) # testing 1d/2d mesh mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2)) for mesh in [mesh_1d, mesh_2d]: dims_to_gather = [0, 1, 2] for dim in dims_to_gather: output_size = [3, 3, 3] output_size[dim] *= mesh.size(0) # each rank have its own tensor, all_gather gives a bigger tensor local_tensor = torch.ones([3, 3, 3], device=device) gathered_tensor = ft_c.all_gather_tensor( local_tensor, gather_dim=dim, group=(mesh, 0) ) self.assertEqual(gathered_tensor, torch.ones(output_size)) @parametrize("device", devices) def test_all_gather_into_tensor_coalesced(self, device): if device == "cuda": if torch.cuda.device_count() < self.world_size: self.skipTest("Not enough CUDA devices") torch.cuda.set_device(dist.get_rank()) tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1] mesh = dt.DeviceMesh(device, torch.arange(4)) res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh) self.assertEqual(2, len(res)) self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0]) self.assertEqual( torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1] ) @parametrize("device", devices) def test_reduce_scatter_tensor(self, device): if device == "cuda": if torch.cuda.device_count() < self.world_size: self.skipTest("Not enough CUDA devices") torch.cuda.set_device(dist.get_rank()) # testing 1d/2d mesh mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2)) for mesh in [mesh_1d, mesh_2d]: dims_to_scatter = [0, 1] for dim in dims_to_scatter: group_size = mesh.size(0) input_size = [3, 3] output_size = [3, 3] output_size[dim] *= group_size input_tensor = torch.ones(output_size, device=device) res_num = 1 * group_size rs_tensor = ft_c.reduce_scatter_tensor( input_tensor, "sum", scatter_dim=dim, group=(mesh, 0) ) self.assertEqual(rs_tensor, torch.ones(input_size) * res_num) @parametrize("device", devices) def test_reduce_scatter_into_tensor_coalesced(self, device): if device == "cuda": if torch.cuda.device_count() < self.world_size: self.skipTest("Not enough CUDA devices") torch.cuda.set_device(dist.get_rank()) tensors = [ torch.ones([4], dtype=torch.int64, device=device), torch.ones([4], dtype=torch.int64, device=device) + 1, ] mesh = dt.DeviceMesh(device, torch.arange(4)) res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh) self.assertEqual(2, len(res)) self.assertEqual(torch.tensor([4], device=device), res[0]) self.assertEqual(torch.tensor([8], device=device), res[1]) class TestMetaCollectives(TestCase): def test_all_reduce(self): x = torch.rand((2, 3, 4), device="meta") out = ft_c.all_reduce(x, "sum", "0") self.assertEqual(x.size(), out.size()) @skipIfHpu class TestGradCollectives(MultiThreadedTestCase): @property def world_size(self): return 2 def setUp(self): super().setUp() self._spawn_threads() def test_all_reduce(self): x = torch.rand([4], requires_grad=True) y = torch.rand([4], requires_grad=True) out = ft_c.all_reduce(x, "sum", dist.group.WORLD) (out + y).sum().backward() self.assertIsNone(x.grad) class TestMakeFx(TestCase): def setUp(self): # make_fx is not thread-safe due to patching nd mutating global states # so create a fake_pg. self.rank = 0 self.world_size = 2 store = FakeStore() dist.init_process_group( backend="fake", world_size=self.world_size, rank=self.rank, store=store, ) def tearDown(self): super().tearDown() self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing()) def test_all_reduce_tracing(self): def allred(input): return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1 graph = make_fx(allred)(torch.rand(4)) FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph)) mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size)) def allred_mesh(input): return ft_c.all_reduce(input, "sum", mesh) + 1 mesh_graph = make_fx(allred_mesh)(torch.rand(4)) FileCheck().check_not("get_attr").check("wait_tensor").run( str(mesh_graph.graph) ) def allred_mesh_dim(input): return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1 mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4)) FileCheck().check_not("get_attr").check("wait_tensor").run( str(mesh_dim_graph.graph) ) BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # Adding support for HCCL backend # To add a different backend # add an elif to the same chain with a conditional checking for the device type (along the lines of TEST_HPU or TEST_CUDA) # And then set the BACKEND variable appropriately. if TEST_HPU: BACKEND = dist.Backend.HCCL # allows you to check for multiple accelerator irrespective of device type # to add new device types to this check simply follow the same format # and append an elif with the conditional and appropriate device count function for your new device def exit_if_lt_x_accelerators(x): if TEST_CUDA: if torch.cuda.device_count() < x: sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) elif TEST_HPU: if torch.hpu.device_count() < x: sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code) def with_comms(func=None): if func is None: return partial(with_comms) @wraps(func) def wrapper(self, *args, **kwargs): if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) kwargs["device"] = DEVICE self.pg = self.create_pg(device=DEVICE) try: return func(self, *args, **kwargs) finally: torch.distributed.destroy_process_group() return wrapper class TestCollectivesWithDistributedBackend(DistributedTestBase): @with_comms() def test_all_gather_into_tensor_coalesced(self, device): exit_if_lt_x_accelerators(self.world_size) tensors = [ torch.ones([4], device=device), torch.ones([4], device=device) + 1, ] mesh = dt.DeviceMesh(device, torch.arange(self.world_size)) res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh) self.assertEqual(2, len(res)) self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0]) self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1]) @with_comms() def test_all_to_all_single(self, device): mesh = dt.DeviceMesh(device, torch.arange(self.world_size)) rank = dist.get_rank() row = self.world_size * (rank + 1) * (self.world_size + 1) / 2 x = torch.ones(int(row), 5, device=device) * (rank + 1) split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)] y = ft_c.all_to_all_single( x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh ) expected = [] for idx, tensor in enumerate(torch.split(x, split_sizes)): expected.append(torch.full_like(tensor, (idx + 1))) expected = torch.cat(expected) self.assertEqual(y, expected) @with_comms() def test_all_to_all_single_1d_input(self, device): mesh = dt.DeviceMesh(device, torch.arange(self.world_size)) rank = dist.get_rank() row = self.world_size * (rank + 1) * (self.world_size + 1) / 2 x = torch.ones(int(row), device=device) * (rank + 1) split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)] y = ft_c.all_to_all_single( x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh ) expected = [] for idx, tensor in enumerate(torch.split(x, split_sizes)): expected.append(torch.full_like(tensor, (idx + 1))) expected = torch.cat(expected) self.assertEqual(y, expected) @with_comms() def test_all_to_all_single_split_sizes_none(self, device): mesh = dt.DeviceMesh(device, torch.arange(self.world_size)) rank = dist.get_rank() x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1) y = ft_c.all_to_all_single( x, output_split_sizes=None, input_split_sizes=None, group=mesh ) expected = [] for idx, tensor in enumerate(torch.chunk(x, self.world_size)): expected.append(torch.full_like(tensor, (idx + 1))) expected = torch.cat(expected) self.assertEqual(y, expected) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @requires_nccl() @with_comms() def test_tracing(self, device): def allreduce(t, pg): return ft_c.all_reduce(t, "sum", pg) compiled_allreduce = torch.compile(allreduce, fullgraph=True) compiled_allreduce(torch.randn(8, device=device), self.pg) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_tracing_with_fakepg(self, device=DEVICE): exit_if_lt_x_accelerators(self.world_size) def allreduce(t, pg): return ft_c.all_reduce(t, "sum", pg) compiled_allreduce = torch.compile(allreduce, fullgraph=True) # noqa: F841 dist.init_process_group( backend="fake", rank=0, world_size=8, store=FakeStore(), ) allreduce(torch.randn(8, device=device), pg=dist.group.WORLD) dist.destroy_process_group() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @requires_nccl() @with_comms() def test_tracing_with_dce_code(self, device): if self.world_size > 2: return def func(batch, group, rank): ret = ft_c.permute_tensor(batch, [1, 0], group) if hasattr(ret, "wait"): ret = ret.wait() if rank == 0: return ret else: return batch * 5 compiled_func = torch.compile(func) compiled_func(torch.ones((100,), device=device), self.process_group, self.rank) dist.barrier() class TestDistributedBackendCollectivesWithWorldSize4( TestCollectivesWithDistributedBackend ): @property def world_size(self): return 4 @with_comms() def test_permute_tensor_with_sub_group(self, device): exit_if_lt_x_accelerators(self.world_size) mesh_dim_names = ["dp", "tp"] mesh_2d = dt.init_device_mesh( device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names ) for mesh_name in mesh_dim_names: mesh = mesh_2d[mesh_name] rank = mesh.get_local_rank() # rank0: [0., 1.], rank1: [2., 3.] send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh) # rank0: [2., 3.], rank1: [0., 1.] expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * ( (rank - 1 + 2) % 2 ) self.assertEqual( recvd_tensor, expected, msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), " f"but received {recvd_tensor} instead.", ) @instantiate_parametrized_tests @skipIfHpu class TestFunctionalAutograd(MultiThreadedTestCase): def setUp(self): super().setUp() self._spawn_threads() @property def world_size(self): return 2 @parametrize("compile", [True, False]) def test_all_to_all_single(self, compile: bool = True) -> None: group = dist.group.WORLD.group_name t = torch.ones((self.world_size, 2), requires_grad=True) def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor: sizes = [1] * world_size t = t * 2 assert t.requires_grad out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group) out = out + 0 return out if compile: compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager") else: compiled = my_func out = compiled(t, self.world_size) self.assertEqual(out.shape, t.shape) self.assertEqual(out, torch.full_like(t, 2.0)) self.assertIsNotNone(out.grad_fn) self.assertTrue(out.requires_grad) loss = out.sum() loss.backward() self.assertEqual(t.grad, torch.full_like(t, 2.0)) def test_all_to_all_single_inductor(self) -> None: group = dist.group.WORLD.group_name t = torch.rand((self.world_size, 2), requires_grad=True) def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor: sizes = [1] * world_size t = t * 10 assert t.requires_grad out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group) out = out + 2 return out.sum() compiled = torch.compile(my_func, fullgraph=True) def run_with_backward(): out = compiled(t, self.world_size) out.backward() _, codes = run_and_get_code(run_with_backward) for code in codes: FileCheck().check_count( "_c10d_functional.all_to_all_single.default", 1, exactly=True ).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run( code ) self.assertIsNotNone(t.grad) @parametrize("compile", [True, False]) def test_all_gather_tensor(self, compile: bool) -> None: group = dist.group.WORLD.group_name def my_func(t: torch.Tensor, dim: int) -> torch.Tensor: assert t.requires_grad out = ft_c.all_gather_tensor_autograd( t * 1.0, gather_dim=dim, group=group, ) out = out * 1.0 return out if compile: compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager") else: compiled = my_func dims_to_gather = [0, 1, 2] for dim in dims_to_gather: output_size = [3, 3, 3] output_size[dim] *= self.world_size # each rank have its own tensor, all_gather gives a bigger tensor local_tensor = torch.ones([3, 3, 3], requires_grad=True) gathered_tensor = compiled(local_tensor, dim) self.assertEqual(gathered_tensor, torch.ones(output_size)) gathered_tensor.sum().backward() self.assertEqual( local_tensor.grad, torch.full((3, 3, 3), fill_value=float(self.world_size)), ) @parametrize("compile", [True, False]) def test_reduce_scatter_tensor(self, compile: bool) -> None: group = dist.group.WORLD.group_name def my_func(t: torch.Tensor, dim: int) -> torch.Tensor: assert t.requires_grad rs_tensor = ( ft_c.reduce_scatter_tensor_autograd( input_tensor * 1.0, "sum", scatter_dim=dim, group=group ) * 1.0 ) return rs_tensor if compile: compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager") else: compiled = my_func dims_to_scatter = [0, 1] for dim in dims_to_scatter: group_size = self.world_size input_size = [3, 3] output_size = [3, 3] output_size[dim] *= group_size input_tensor = torch.ones(output_size, requires_grad=True) rs_tensor = compiled(input_tensor, dim) res_num = 1 * group_size self.assertEqual(rs_tensor, torch.ones(input_size) * res_num) rs_tensor.sum().backward() self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0)) class TestFunctionalAutogradWithDistributedBackend(DistributedTestBase): @with_comms() def test_all_to_all_single(self, device) -> None: group = self.pg t = torch.ones((self.world_size, 2), requires_grad=True, device=device) sizes = [1] * self.world_size assert t.requires_grad out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0 self.assertEqual(out.shape, t.shape) self.assertEqual(out, torch.full_like(t, 2.0)) self.assertIsNotNone(out.grad_fn) self.assertTrue(out.requires_grad) loss = out.sum() loss.backward() self.assertEqual(t.grad, torch.full_like(t, 2.0)) # Update the supported devices in DEVICE instantiate_device_type_tests( TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE ) instantiate_device_type_tests( TestDistributedBackendCollectivesWithWorldSize4, globals(), only_for=DEVICE ) instantiate_device_type_tests( TestFunctionalAutogradWithDistributedBackend, globals(), only_for=DEVICE ) if __name__ == "__main__": run_tests()