Skip to content

[C10D] Add check_rng_sync util #160283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: gh/wconstab/440/base
Choose a base branch
from

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented Aug 10, 2025

Stack from ghstack (oldest at bottom):

Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected.

Notes:

  • used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0.

Usage:
check_rng_sync(generator, group)

Prints something like this:

(cuda):

[rank0]:E0808 ] Generator desync detected:
[rank0]:E0808 ] Ranks    (Seed, Offset) values
[rank0]:E0808 ] -------  -----------------------
[rank0]:E0808 ] 0        (456, 0)
[rank0]:E0808 ] 1        (123, 4)
[rank0]:E0808 ] 2-3      (123, 0)

(cpu):

[rank2]:E0810 ] Generator desync detected:
[rank2]:E0810 ] Ranks      Generator State Hash values
[rank2]:E0810 ] -------  -----------------------------
[rank2]:E0810 ] 0                  7633364531954955665
[rank2]:E0810 ] 1                  8807615394212033278
[rank2]:E0810 ] 2-3               -6150027303226666531

Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected.

Notes:
- used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0.

Usage:
`check_rng_sync(generator, group)`

Prints something like this:

(cuda):
```
[rank0]:E0808 ] Generator desync detected:
[rank0]:E0808 ] Ranks    (Seed, Offset) values
[rank0]:E0808 ] -------  -----------------------
[rank0]:E0808 ] 0        (456, 0)
[rank0]:E0808 ] 1        (123, 4)
[rank0]:E0808 ] 2-3      (123, 0)
```

(cpu):
```
[rank2]:E0810 ] Generator desync detected:
[rank2]:E0810 ] Ranks      Generator State Hash values
[rank2]:E0810 ] -------  -----------------------------
[rank2]:E0810 ] 0                  7633364531954955665
[rank2]:E0810 ] 1                  8807615394212033278
[rank2]:E0810 ] 2-3               -6150027303226666531
```

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Aug 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160283

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 4 Pending

As of commit 2e1102a with merge base 556e2a7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Aug 10, 2025
Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected.

Notes:
- used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0.

Usage:
`check_rng_sync(generator, group)`

Prints something like this:

(cuda):
```
[rank0]:E0808 ] Generator desync detected:
[rank0]:E0808 ] Ranks    (Seed, Offset) values
[rank0]:E0808 ] -------  -----------------------
[rank0]:E0808 ] 0        (456, 0)
[rank0]:E0808 ] 1        (123, 4)
[rank0]:E0808 ] 2-3      (123, 0)
```

(cpu):
```
[rank2]:E0810 ] Generator desync detected:
[rank2]:E0810 ] Ranks      Generator State Hash values
[rank2]:E0810 ] -------  -----------------------------
[rank2]:E0810 ] 0                  7633364531954955665
[rank2]:E0810 ] 1                  8807615394212033278
[rank2]:E0810 ] 2-3               -6150027303226666531
```

ghstack-source-id: 3d60739
Pull Request resolved: #160283
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 10, 2025
Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected.

Notes:
- used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0.

Usage:
`check_rng_sync(generator, group)`

Prints something like this:

(cuda):
```
[rank0]:E0808 ] Generator desync detected:
[rank0]:E0808 ] Ranks    (Seed, Offset) values
[rank0]:E0808 ] -------  -----------------------
[rank0]:E0808 ] 0        (456, 0)
[rank0]:E0808 ] 1        (123, 4)
[rank0]:E0808 ] 2-3      (123, 0)
```

(cpu):
```
[rank2]:E0810 ] Generator desync detected:
[rank2]:E0810 ] Ranks      Generator State Hash values
[rank2]:E0810 ] -------  -----------------------------
[rank2]:E0810 ] 0                  7633364531954955665
[rank2]:E0810 ] 1                  8807615394212033278
[rank2]:E0810 ] 2-3               -6150027303226666531
```

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta

[ghstack-poisoned]
Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected.

Notes:
- used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0.

Usage:
`check_rng_sync(generator, group)`

Prints something like this:

(cuda):
```
[rank0]:E0808 ] Generator desync detected:
[rank0]:E0808 ] Ranks    (Seed, Offset) values
[rank0]:E0808 ] -------  -----------------------
[rank0]:E0808 ] 0        (456, 0)
[rank0]:E0808 ] 1        (123, 4)
[rank0]:E0808 ] 2-3      (123, 0)
```

(cpu):
```
[rank2]:E0810 ] Generator desync detected:
[rank2]:E0810 ] Ranks      Generator State Hash values
[rank2]:E0810 ] -------  -----------------------------
[rank2]:E0810 ] 0                  7633364531954955665
[rank2]:E0810 ] 1                  8807615394212033278
[rank2]:E0810 ] 2-3               -6150027303226666531
```

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta

[ghstack-poisoned]
Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected.

Notes:
- used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0.

Usage:
`check_rng_sync(generator, group)`

Prints something like this:

(cuda):
```
[rank0]:E0808 ] Generator desync detected:
[rank0]:E0808 ] Ranks    (Seed, Offset) values
[rank0]:E0808 ] -------  -----------------------
[rank0]:E0808 ] 0        (456, 0)
[rank0]:E0808 ] 1        (123, 4)
[rank0]:E0808 ] 2-3      (123, 0)
```

(cpu):
```
[rank2]:E0810 ] Generator desync detected:
[rank2]:E0810 ] Ranks      Generator State Hash values
[rank2]:E0810 ] -------  -----------------------------
[rank2]:E0810 ] 0                  7633364531954955665
[rank2]:E0810 ] 1                  8807615394212033278
[rank2]:E0810 ] 2-3               -6150027303226666531
```

[ghstack-poisoned]
Debugs RNG desync by checking the current state on each rank in the group and summarizing the differences if any are detected.

Notes:
- used allgather instead of gather since its simpler to do this SPMD rather than add conditional behavior, though I could be convinced we only want to log on rank0.

Usage:
`check_rng_sync(generator, group)`

Prints something like this:

(cuda):
```
[rank0]:E0808 ] Generator desync detected:
[rank0]:E0808 ] Ranks    (Seed, Offset) values
[rank0]:E0808 ] -------  -----------------------
[rank0]:E0808 ] 0        (456, 0)
[rank0]:E0808 ] 1        (123, 4)
[rank0]:E0808 ] 2-3      (123, 0)
```

(cpu):
```
[rank2]:E0810 ] Generator desync detected:
[rank2]:E0810 ] Ranks      Generator State Hash values
[rank2]:E0810 ] -------  -----------------------------
[rank2]:E0810 ] 0                  7633364531954955665
[rank2]:E0810 ] 1                  8807615394212033278
[rank2]:E0810 ] 2-3               -6150027303226666531
```

[ghstack-poisoned]
# [rank0]:E0808 ] ------- -----------------------
# [rank0]:E0808 ] 0 (456, 0)
# [rank0]:E0808 ] 1 (123, 4)
# [rank0]:E0808 ] 2-3 (123, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not have the function also return a str

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I can do that.

all_states = [torch.empty_like(local_state) for _ in range(group.size())]
torch.distributed.all_gather(all_states, local_state)
seeds_offsets = [
(state[:8].view(torch.uint64).item(), state[8:].view(torch.uint64).item())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me what's going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's seed,offset both uint64_t packed together into a 16-byte, iirc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants