# Owner(s): ["oncall: distributed"] import operator import os import sys import threading from functools import reduce from unittest import skip, SkipTest import torch import torch.autograd import torch.distributed as dist from torch._C._distributed_c10d import ReduceOp if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) from torch.testing._internal.common_distributed import ( MultiThreadedTestCase, skip_if_lt_x_gpu, spawn_threads_and_init_comms, ) from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase DEFAULT_WORLD_SIZE = 4 class TestCollectivesWithWrapper(TestCase): @spawn_threads_and_init_comms(world_size=4) def test_broadcast_object_list(self): val = 99 if dist.get_rank() == 0 else None object_list = [val] * dist.get_world_size() dist.broadcast_object_list(object_list=object_list) self.assertEqual(99, object_list[0]) def test_collective_error_on_rank_zero(self): @spawn_threads_and_init_comms(world_size=4) def _test_method(self): input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather output_tensors = [ torch.empty_like(input_tensor) for _ in range(dist.get_world_size()) ] dist.all_gather(output_tensors, input_tensor) if dist.get_rank() == 0: raise AssertionError("Mimic real test failure.") # fail on rank 0 dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather with self.assertRaises(RuntimeError): _test_method(self) def test_collective_error_on_rank_non_zero(self): @spawn_threads_and_init_comms(world_size=4) def _test_method(self): input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather output_tensors = [ torch.empty_like(input_tensor) for _ in range(dist.get_world_size()) ] dist.all_gather(output_tensors, input_tensor) if dist.get_rank() == 1: raise AssertionError("Mimic real test failure.") # fail on rank 1 dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather with self.assertRaises(RuntimeError): _test_method(self) def test_collective_error_on_rank_non_zero_all(self): @spawn_threads_and_init_comms(world_size=4) def _test_method(self): input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather output_tensors = [ torch.empty_like(input_tensor) for _ in range(dist.get_world_size()) ] dist.all_gather(output_tensors, input_tensor) if dist.get_rank() > 0: raise AssertionError( "Mimic real test failure." ) # fail on all non-zero rank dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather with self.assertRaises(RuntimeError): _test_method(self) def test_skip(self): @spawn_threads_and_init_comms(world_size=4) @skip("check if skip exception can be captured correctly.") def _test_method(self): pass if not IS_SANDCASTLE: with self.assertRaises(SkipTest): _test_method(self) @spawn_threads_and_init_comms(world_size=4) def test_all_to_all_single_tensor(self): rank = dist.get_rank() world_size = dist.get_world_size() send = torch.full((world_size, 2), rank) sizes = torch.ones(world_size, dtype=torch.int64) out = torch.zeros(world_size, 2, dtype=send.dtype) dist.all_to_all_single(out, send, sizes, sizes) self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size)))) @spawn_threads_and_init_comms(world_size=4) def test_all_to_all_single_list(self): rank = dist.get_rank() world_size = dist.get_world_size() send = torch.full((world_size, 2), rank) sizes = [1] * world_size out = torch.zeros(world_size, 2, dtype=send.dtype) dist.all_to_all_single(out, send, sizes, sizes) self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size)))) @spawn_threads_and_init_comms(world_size=4) def test_all_to_all_single_none(self): rank = dist.get_rank() world_size = dist.get_world_size() send = torch.full((world_size, 2), rank) out = torch.zeros(world_size, 2, dtype=send.dtype) dist.all_to_all_single(out, send) self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size)))) class TestCollectivesWithBaseClass(MultiThreadedTestCase): @property def world_size(self): return 4 def setUp(self): os.environ["TORCH_DIST_INIT_BARRIER"] = "1" super().setUp() self._spawn_threads() def tearDown(self): super().tearDown() os.environ["TORCH_DIST_INIT_BARRIER"] = "0" def test_allgather(self): input_tensor = torch.ones(3, 3) * dist.get_rank() output_tensors = [ torch.empty_like(input_tensor) for _ in range(self.world_size) ] dist.all_gather(output_tensors, input_tensor) for rank, out_tensor in enumerate(output_tensors): self.assertEqual(out_tensor, torch.ones(3, 3) * rank) def test_broadcast(self): input_tensor = torch.ones(3, 3) * dist.get_rank() for rank in range(self.world_size): cloned_input = input_tensor.clone() dist.broadcast(cloned_input, src=rank) self.assertEqual(cloned_input, torch.ones(3, 3) * rank) def test_scatter(self): if dist.get_rank() == 0: scatter_list = [torch.ones(3, 3) * rank for rank in range(self.world_size)] else: scatter_list = None output_tensor = torch.empty(3, 3) dist.scatter(output_tensor, scatter_list) self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank()) def test_reduce_scatter(self): to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(self.world_size)] output_tensor = torch.empty(3, 3) dist.reduce_scatter(output_tensor, to_reduce_scatter) expected_tensor = torch.ones(3, 3) * dist.get_rank() * self.world_size self.assertEqual(output_tensor, expected_tensor) output_tensor = torch.empty(3, 3) dist.reduce_scatter(output_tensor, to_reduce_scatter, op=dist.ReduceOp.AVG) expected_tensor = torch.ones(3, 3) * dist.get_rank() self.assertEqual(output_tensor, expected_tensor) def test_broadcast_object_list(self): val = 99 if dist.get_rank() == 0 else None object_list = [val] * dist.get_world_size() print(f"{dist.get_rank()} -> {dist.get_world_size()}") dist.broadcast_object_list(object_list=object_list) self.assertEqual(99, object_list[0]) def test_all_reduce(self): output = torch.ones(3, 3) * dist.get_rank() dist.all_reduce(output) res_num = ((0 + self.world_size - 1) * self.world_size) / 2 self.assertEqual(output, torch.ones(3, 3) * res_num) def test_all_to_all(self): rank = self.rank world_size = self.world_size input_tensor_list = [ torch.ones(3, 3) * x for x in range(rank * world_size, (rank + 1) * world_size) ] output_tensor_list = [torch.empty_like(tensor) for tensor in input_tensor_list] dist.all_to_all(output_tensor_list, input_tensor_list) expected_tensor_list = [ torch.ones(3, 3) * x for x in range(rank, world_size * world_size, world_size) ] self.assertEqual(expected_tensor_list, output_tensor_list) def test_all_reduce_ops(self): tensor = torch.tensor([dist.get_rank() + 1]) dist.all_reduce(tensor, op=ReduceOp.PRODUCT) expected = reduce(operator.mul, range(1, self.world_size + 1)) self.assertEqual(expected, tensor.item()) tensor = torch.tensor([dist.get_rank() + 1]) dist.all_reduce(tensor, op=ReduceOp.MIN) self.assertEqual(1, tensor.item()) tensor = torch.tensor([dist.get_rank() + 1]) dist.all_reduce(tensor, op=ReduceOp.MAX) self.assertEqual(self.world_size, tensor.item()) tensor = torch.tensor([dist.get_rank() + 1]) dist.all_reduce(tensor, op=ReduceOp.BAND) expected = reduce(operator.and_, range(1, self.world_size + 1)) self.assertEqual(expected, tensor.item()) tensor = torch.tensor([dist.get_rank() + 1]) dist.all_reduce(tensor, op=ReduceOp.BOR) expected = reduce(operator.or_, range(1, self.world_size + 1)) self.assertEqual(expected, tensor.item()) tensor = torch.tensor([dist.get_rank() + 1]) dist.all_reduce(tensor, op=ReduceOp.BXOR) expected = reduce(operator.xor, range(1, self.world_size + 1)) self.assertEqual(expected, tensor.item()) def test_assert_equal_on_rank(self): # RNG is shared across threads. So instead of asserting on all threads # we only assert on rank 0 self_tensor = torch.rand(3, 3) rank_0_tensor = self_tensor.clone() dist.broadcast(rank_0_tensor, src=0) self.assertEqualOnRank(rank_0_tensor, self_tensor, rank=0) self.assertNotEqualOnRank(rank_0_tensor, self_tensor, rank=1) def test_subpg(self): subpg0 = dist.new_group([0, 1]) subpg1 = dist.new_group([2, 3]) current_rank = dist.get_rank() output = torch.ones(3, 3) * current_rank # call all_reduce on subpg0 and subpg1 concurrently if current_rank in [0, 1]: dist.all_reduce(output, group=subpg0) else: dist.all_reduce(output, group=subpg1) if current_rank in [0, 1]: self.assertEqual(output, torch.ones(3, 3) * 1) else: self.assertEqual(output, torch.ones(3, 3) * 5) def test_using_pg_from_another_thread(self): def stuff_in_other_thread(pg): x = torch.rand(4, requires_grad=True) dist.all_reduce(x, group=pg) t = threading.Thread(target=stuff_in_other_thread, args=(dist.group.WORLD,)) t.start() t.join() def test_gather(self): if dist.get_rank() == 0: gather_list = [torch.empty(3, 3) for _ in range(self.world_size)] else: gather_list = None input_tensor = torch.ones(3, 3) * dist.get_rank() dist.gather(input_tensor, gather_list) if dist.get_rank() == 0: for i in range(self.world_size): self.assertEqual(gather_list[i], torch.ones(3, 3) * i) def test_all_reduce_coalesced(self): t0 = torch.ones(3, 3) * dist.get_rank() t1 = torch.ones(3, 3) * dist.get_rank() * 2 dist.all_reduce_coalesced([t0, t1]) res_num = ((0 + self.world_size - 1) * self.world_size) / 2 self.assertEqual(t0, torch.ones(3, 3) * res_num) self.assertEqual(t1, torch.ones(3, 3) * (res_num * 2)) @skip_if_lt_x_gpu(1) def test_bwd_sees_fwd_pg(self): fwd_tid = threading.current_thread().ident class MyFunc(torch.autograd.Function): @staticmethod def forward(ctx, rank): result = rank * 2 ctx.save_for_backward(result, rank) assert int(rank.item()) == dist.get_rank() return result @staticmethod def backward(ctx, grad_output): result, rank = ctx.saved_tensors bwd_tid = threading.current_thread().ident self.assertEqual( fwd_tid, bwd_tid, f"bwd not running in the same thread a fwd for rank {rank.item()}", ) self.assertTrue(dist.is_initialized()) self.assertEqual(int(rank.item()), dist.get_rank()) dist.all_reduce(result) self.assertEqual(int(result.item()), 12) # (0 + 1 + 2 + 3) * 2 return grad_output * result x = torch.tensor( [dist.get_rank()], dtype=torch.float, device="cuda", requires_grad=True ) x = MyFunc.apply(x) x.sum().backward() if __name__ == "__main__": run_tests()