Skip to content

[scan] cloned aliased input when lowering scan to while_loop #158168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/ydwu4/279/base
Choose a base branch
from

Conversation

ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Jul 12, 2025

Stack from ghstack (oldest at bottom):

Fixes #153679.

The CSE pass accidently de-dups the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among the inputs. CSE is after dynamo and aot so went undetected by our front-end safety checks. And it's not wrapped by auto_functionalized because there're no aliasing in auto_grad and functioanlization key becaues the zeros_like are not de-duped yet.

The alternatives include:

  1. figure out a way that tell cse don't de-dup in this case. This seems tedious to do and complicates the cse function's interface.
  2. Detect the aliasing and do a clone when we decompose scan to while_loop. This is OK but seems ad-hoc. Also it breaks the contract that there shouldn't be aliasing among inputs for hops if not wrapped with auto_functionalized.

The final implemented fix after discussion with @zou3519 is 2: aliasing is ok as long as there're no in place mutations, so we can fix it right before we start introducing buffers etc.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Copy link

pytorch-bot bot commented Jul 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158168

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4cddc7d with merge base fc25c68 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ydwu4 added a commit that referenced this pull request Jul 12, 2025
ghstack-source-id: 4ee4d47
Pull Request resolved: #158168
@ydwu4 ydwu4 added the topic: not user facing topic category label Jul 12, 2025
Fixes #153679.

The CSE pass accidently de-dup the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among thei nputs. CSE is after dynamo and aot so went undetected by our front end safety checks. The fix follows the same spirit of skipping cse aten.empty, we just don't cse aten.full.

The alternataive considered:
1. figure out a way that tell cse don't de-dup in this case. This seems tedious to do...
2. Detect the aliasing and do a clone when we decompose scan to while_loop. This is OK but in the spirit of fixing something wrong instead of fixing the original problem.


[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jul 12, 2025
ghstack-source-id: 6657b7a
Pull Request resolved: #158168
@ydwu4 ydwu4 requested a review from zou3519 July 14, 2025 17:38
@zou3519
Copy link
Contributor

zou3519 commented Jul 18, 2025

The CSE pass accidently de-dups the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among the inputs.

How does this cause aliasing among the inputs?

@ydwu4
Copy link
Contributor Author

ydwu4 commented Jul 18, 2025

How does this cause aliasing among the inputs?
It's like before cse the graph looks like:

grad_init1 = torch.full(s1, s2, ...) 
grad_init2 = torch.full(s1, s2, ...)
torch.ops.higher_order.scan(backward_combine_graph, (grad_init1, grad_init2), ...)

The two torch.full happen to have the same argument because the forward_combine_graph's output shape is the same. CSE ffind the two torch.full have the same argument therefore are subject to cse. So it changes the graph to be:

grad_init = torch.full(s1, s2, ...) 
torch.ops.higher_order.scan(backward_combine_graph, (grad_init, grad_init), ...)

Fixes #153679.

The CSE pass accidently de-dups the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among the inputs. CSE is after dynamo and aot so went undetected by our front-end safety checks.  And it's not wrapped by auto_functionalized because there're no aliasing in auto_grad and functioanlization key becaues the zeros_like are not de-duped yet. The fix follows the same spirit of skipping cse aten.empty: we just don't cse aten.full.

The alternatives include:
1. figure out a way that tell cse don't de-dup in this case. This seems tedious to do and complicates the cse function's interface.
2. Detect the aliasing and do a clone when we decompose scan to while_loop. This is OK but seems ad-hoc. Also it breaks the contract that there shouldn't be aliasing among inputs for hops if not wrapped with auto_functionalized.

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jul 24, 2025
@ydwu4 ydwu4 changed the title [scan][cse] avoid cse zeros like gradient buffers [scan] cloned aliased input when lowering scan to while_loop Jul 24, 2025
Fixes #153679.

The CSE pass accidently de-dups the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among the inputs. CSE is after dynamo and aot so went undetected by our front-end safety checks.  And it's not wrapped by auto_functionalized because there're no aliasing in auto_grad and functioanlization key becaues the zeros_like are not de-duped yet.

The alternatives include:
1. figure out a way that tell cse don't de-dup in this case. This seems tedious to do and complicates the cse function's interface.
2. Detect the aliasing and do a clone when we decompose scan to while_loop. This is OK but seems ad-hoc. Also it breaks the contract that there shouldn't be aliasing among inputs for hops if not wrapped with auto_functionalized.

The final implemented fix after discussion with zou3519  is 2: aliasing is ok as long as there're no in place mutations, so we can fix it right before we start introducing buffers etc.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants