From c2b0c39570ea8e781edc9595e23c6e1403ade011 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 10 Aug 2025 15:23:32 -0700 Subject: [PATCH 1/2] [C10D] Add check_rng_sync util Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected. Notes: - used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0. Usage: `check_rng_sync(generator, group)` Prints something like this: (cuda): ``` [rank0]:E0808 ] Generator desync detected: [rank0]:E0808 ] Ranks (Seed, Offset) values [rank0]:E0808 ] ------- ----------------------- [rank0]:E0808 ] 0 (456, 0) [rank0]:E0808 ] 1 (123, 4) [rank0]:E0808 ] 2-3 (123, 0) ``` (cpu): ``` [rank2]:E0810 ] Generator desync detected: [rank2]:E0810 ] Ranks Generator State Hash values [rank2]:E0810 ] ------- ----------------------------- [rank2]:E0810 ] 0 7633364531954955665 [rank2]:E0810 ] 1 8807615394212033278 [rank2]:E0810 ] 2-3 -6150027303226666531 ``` ghstack-source-id: 3d60739a0a791f3c761dab3af9aef0b5149f2cda Pull Request resolved: https://github.com/pytorch/pytorch/pull/160283 --- test/distributed/test_collective_utils.py | 57 +++++++++++++- torch/distributed/collective_utils.py | 95 ++++++++++++++++++++++- 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/test/distributed/test_collective_utils.py b/test/distributed/test_collective_utils.py index a150a55f77be6..6f8f1dadf5751 100644 --- a/test/distributed/test_collective_utils.py +++ b/test/distributed/test_collective_utils.py @@ -2,10 +2,20 @@ from unittest import mock +import torch import torch.distributed as c10d -from torch.distributed.collective_utils import all_gather, broadcast +from torch.distributed.collective_utils import ( + _check_rng_sync, + all_gather, + broadcast, + check_rng_sync, +) from torch.testing._internal.common_distributed import MultiProcessTestCase -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) class TestCollectiveUtils(MultiProcessTestCase): @@ -116,6 +126,49 @@ def test_all_gather_result_raises_exceptions_from_func( with self.assertRaisesRegex(Exception, expected_exception): all_gather(data_or_fn=func) + @parametrize("device", ["cpu", "cuda"]) + def test_check_rng_sync( + self, + device, + ) -> None: + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="gloo", store=store, rank=self.rank, world_size=self.world_size + ) + group = torch.distributed.distributed_c10d._get_default_group() + generator = torch.Generator(device=device) + generator.manual_seed(123) + value_ranks, _ = _check_rng_sync(generator, group) + self.assertEqual(len(value_ranks), 1, value_ranks) + for actual, expected in zip(value_ranks.values(), [{0, 1, 2, 3}]): + self.assertEqual(actual, expected, actual) + + if torch.distributed.get_rank() == 1: + torch.randn((10,), device=device, generator=generator) + value_ranks, _ = _check_rng_sync(generator, group) + self.assertEqual(len(value_ranks), 2, value_ranks) + for actual, expected in zip(value_ranks.values(), [{0, 2, 3}, {1}]): + self.assertEqual(actual, expected, actual) + + if torch.distributed.get_rank() == 0: + generator.manual_seed(456) + value_ranks, _ = _check_rng_sync(generator, group) + self.assertEqual(len(value_ranks), 3, value_ranks) + for actual, expected in zip(value_ranks.values(), [{0}, {1}, {2, 3}]): + self.assertEqual(actual, expected, actual) + + # Prints something like this, I was too lazy to figure out how to check the log but at least make sure the + # function does not crash + # [rank0]:E0808 ] Generator desync detected: + # [rank0]:E0808 ] Ranks (Seed, Offset) values + # [rank0]:E0808 ] ------- ----------------------- + # [rank0]:E0808 ] 0 (456, 0) + # [rank0]:E0808 ] 1 (123, 4) + # [rank0]:E0808 ] 2-3 (123, 0) + check_rng_sync(generator, group) + + +instantiate_parametrized_tests(TestCollectiveUtils) if __name__ == "__main__": run_tests() diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index b1a7c824c2e3b..b551930b94d32 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -9,12 +9,21 @@ from __future__ import annotations +import logging +from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, cast, Generic, Optional, TypeVar, Union +from typing import Any, Callable, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union + +if TYPE_CHECKING: + from collections.abc import Iterable + +import torch import torch.distributed as dist +logger = logging.getLogger(__name__) + T = TypeVar("T") @@ -215,3 +224,87 @@ def all_gather_object_enforce_type( f"Object type at index {i} is {type(object_list[i])}, " f"while first object type is {type(first_obj)}" ) + + +def summarize_ranks(numbers: Iterable[int]) -> str: + numbers = sorted(numbers) + result = [] + current_range_start = numbers[0] + for i in range(1, len(numbers)): + if numbers[i] == numbers[i - 1] + 1: + pass + else: + if current_range_start == numbers[i - 1]: + result.append(str(current_range_start)) + else: + result.append(f"{current_range_start}-{numbers[i - 1]}") + current_range_start = numbers[i] + if current_range_start == numbers[-1]: + result.append(str(current_range_start)) + else: + result.append(f"{current_range_start}-{numbers[-1]}") + return ", ".join(result) + + +def _check_philox_rng_sync( + generator: torch.Generator, group: dist.ProcessGroup +) -> tuple[dict[Any, set], str]: + local_state = generator.get_state() + all_states = [torch.empty_like(local_state) for _ in range(group.size())] + torch.distributed.all_gather(all_states, local_state) + seeds_offsets = [ + (state[:8].view(torch.uint64).item(), state[8:].view(torch.uint64).item()) + for state in all_states + ] + seed_offset_ranks = defaultdict(set) + for rank, (seed, offset) in enumerate(seeds_offsets): + seed_offset_ranks[(seed, offset)].add(rank) + return seed_offset_ranks, "(Seed, Offset)" + + +def _check_cpu_rng_sync( + generator: torch.Generator, group: dist.ProcessGroup +) -> tuple[dict[Any, set], str]: + # seed is returned as uint64_t from C impl, so may not fit in torch int64 tensor directly. + state_tensor = generator.get_state() + all_state_tensors = [torch.empty_like(state_tensor) for _ in range(group.size())] + torch.distributed.all_gather(all_state_tensors, state_tensor) + state_ranks = defaultdict(set) + for rank, state_tensor in enumerate(all_state_tensors): + # Hacky way to summarize the state vector of the CPU rng. Is there a better way to do this? + # The properties that matter most are (1) its different if there is a state difference, (2) its printable + # (see desync table- not viable to print whole state vector of size 5k) + state_ranks[hash(tuple(state_tensor.tolist()))].add(rank) + return state_ranks, "Generator state hash" + + +def _check_rng_sync( + generator: torch.Generator, group: dist.ProcessGroup +) -> tuple[dict[Any, set], str]: + if generator.device.type == "cuda": + return _check_philox_rng_sync(generator, group) + elif generator.device.type == "cpu": + return _check_cpu_rng_sync(generator, group) + else: + raise NotImplementedError( + f"Unsupported generator device: {generator.device.type}" + ) + + +def _desync_table_str(tag: str, value_ranks: dict[Any, set[int]]) -> str: + headers = ["Ranks", f"{tag} values"] + rank_values = [ + [summarize_ranks(ranks), str(value)] for value, ranks in value_ranks.items() + ] + from tabulate import tabulate + + return tabulate(rank_values, headers=headers) + + +def check_rng_sync(generator: torch.Generator, group: dist.ProcessGroup) -> None: + value_ranks, value_header = _check_rng_sync(generator, group) + if len(value_ranks) > 1: + logger.error( + "Generator desync detected:\n%s", + _desync_table_str(value_header, value_ranks), + ) From 85c6506b0bd6da2a6c2c2f5761f63ad18f60beee Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 10 Aug 2025 15:54:06 -0700 Subject: [PATCH 2/2] WIP summarize ranks ghstack-source-id: 37921cd51fdf8d8561d783195a14dc40c0718b61 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160284 --- test/distributed/test_collective_utils.py | 43 +++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/distributed/test_collective_utils.py b/test/distributed/test_collective_utils.py index 6f8f1dadf5751..1651087366243 100644 --- a/test/distributed/test_collective_utils.py +++ b/test/distributed/test_collective_utils.py @@ -9,13 +9,17 @@ all_gather, broadcast, check_rng_sync, + summarize_ranks, ) +from torch.distributed.device_mesh import init_device_mesh from torch.testing._internal.common_distributed import MultiProcessTestCase from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TestCase, ) +from torch.testing._internal.distributed.fake_pg import FakeStore class TestCollectiveUtils(MultiProcessTestCase): @@ -168,6 +172,45 @@ def test_check_rng_sync( check_rng_sync(generator, group) +class TestUtils(TestCase): + def setUp(self): + super().setUp() + + if not c10d.is_initialized(): + self.rank = 0 + self.world_size = 4096 + + store = FakeStore() + c10d.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + + def tearDown(self): + c10d.destroy_process_group() + + def test_summarize_ranks(self): + mesh_dim_names = ("pp", "dp", "tp") + mesh = init_device_mesh("cpu", (8, 64, 8), mesh_dim_names=mesh_dim_names) + ranks_lists = {name: mesh[name].mesh.tolist() for name in mesh_dim_names} + summaries = { + name: summarize_ranks(ranks_lists[name]) for name in mesh_dim_names + } + self.assertEqual(summaries["pp"], "0, 512, 1024, 1536, 2048, 2560, 3072, 3584") + # TODO: what would be the best format for abbreviating striding? + # self.assertEqual(summaries["pp"], "0, 512, ..., 3584") + # self.assertEqual(summaries["pp"], "0, (stride 512), 3584") + self.assertEqual( + summaries["dp"], + "0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, " + "184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312, 320, 328, 336, " + "344, 352, 360, 368, 376, 384, 392, 400, 408, 416, 424, 432, 440, 448, 456, 464, 472, 480, 488, 496, 504", + ) + self.assertEqual(summaries["tp"], "0-7") + + instantiate_parametrized_tests(TestCollectiveUtils) if __name__ == "__main__":