-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Fix GroupNorm(num_groups=1) to match LayerNorm behavior #159736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix GroupNorm(num_groups=1) to match LayerNorm behavior #159736
Conversation
- Special-case num_groups=1 to use LayerNorm computation for float32/float64 - Add comprehensive tests for equivalence - Skip mixed dtype tests for num_groups=1 to avoid numerical precision issues - No performance impact for other cases Fixes pytorch#75862
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159736
Note: Links to docs will display an error until the docs builds have been completed. ❌ 34 New Failures, 2 Unrelated FailuresAs of commit 1fdc50d with merge base a626dc8 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is a correct change.
Looking at the discussion on #75862 and https://discuss.pytorch.org/t/groupnorm-num-groups-1-and-layernorm-are-not-equivalent/145468, it does not seem that we expect LayerNorm(num_groups=1)
and GroupNorm
to be equivalent
Thank you for the review @mikaylagawarecki. I now understand that this behavior difference is intentional, not a bug. Would it be valuable if I pivoted this PR to add documentation clarifying that GroupNorm(num_groups=1) and LayerNorm are intentionally different? This could help future users avoid the same confusion. |
Fixes #75862
Summary
This PR fixes a bug where
GroupNorm(num_groups=1)
does not produce the same output asLayerNorm
, despite the documentation claiming they should be equivalent.The Problem
Currently,
GroupNorm
withnum_groups=1
normalizes over all dimensions (C×H×W), computing a single mean and variance per sample. However,LayerNorm
normalizes over the C dimension for each spatial position, computing H×W different mean/var values.The Fix
This PR special-cases
num_groups=1
inF.group_norm
to use the LayerNorm computation path for float32/float64 tensors, ensuring mathematical equivalence.Implementation Details
num_groups=1
casesnum_groups > 1
Testing
num_groups > 1
Backward Compatibility
This change fixes incorrect behavior to match documented behavior. While it changes outputs for
num_groups=1
, this aligns with user expectations based on the documentation.