-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[scan] cloned aliased input when lowering scan to while_loop #158168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/ydwu4/279/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158168
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4cddc7d with merge base fc25c68 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Fixes #153679. The CSE pass accidently de-dup the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among thei nputs. CSE is after dynamo and aot so went undetected by our front end safety checks. The fix follows the same spirit of skipping cse aten.empty, we just don't cse aten.full. The alternataive considered: 1. figure out a way that tell cse don't de-dup in this case. This seems tedious to do... 2. Detect the aliasing and do a clone when we decompose scan to while_loop. This is OK but in the spirit of fixing something wrong instead of fixing the original problem. [ghstack-poisoned]
How does this cause aliasing among the inputs? |
grad_init1 = torch.full(s1, s2, ...)
grad_init2 = torch.full(s1, s2, ...)
torch.ops.higher_order.scan(backward_combine_graph, (grad_init1, grad_init2), ...) The two torch.full happen to have the same argument because the forward_combine_graph's output shape is the same. CSE ffind the two torch.full have the same argument therefore are subject to cse. So it changes the graph to be: grad_init = torch.full(s1, s2, ...)
torch.ops.higher_order.scan(backward_combine_graph, (grad_init, grad_init), ...) |
Fixes #153679. The CSE pass accidently de-dups the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among the inputs. CSE is after dynamo and aot so went undetected by our front-end safety checks. And it's not wrapped by auto_functionalized because there're no aliasing in auto_grad and functioanlization key becaues the zeros_like are not de-duped yet. The fix follows the same spirit of skipping cse aten.empty: we just don't cse aten.full. The alternatives include: 1. figure out a way that tell cse don't de-dup in this case. This seems tedious to do and complicates the cse function's interface. 2. Detect the aliasing and do a clone when we decompose scan to while_loop. This is OK but seems ad-hoc. Also it breaks the contract that there shouldn't be aliasing among inputs for hops if not wrapped with auto_functionalized. [ghstack-poisoned]
Fixes #153679. The CSE pass accidently de-dups the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among the inputs. CSE is after dynamo and aot so went undetected by our front-end safety checks. And it's not wrapped by auto_functionalized because there're no aliasing in auto_grad and functioanlization key becaues the zeros_like are not de-duped yet. The alternatives include: 1. figure out a way that tell cse don't de-dup in this case. This seems tedious to do and complicates the cse function's interface. 2. Detect the aliasing and do a clone when we decompose scan to while_loop. This is OK but seems ad-hoc. Also it breaks the contract that there shouldn't be aliasing among inputs for hops if not wrapped with auto_functionalized. The final implemented fix after discussion with zou3519 is 2: aliasing is ok as long as there're no in place mutations, so we can fix it right before we start introducing buffers etc. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Fixes #153679.
The CSE pass accidently de-dups the zero buffer that are allocated to store the gradients of init to scan causing an aliasing among the inputs. CSE is after dynamo and aot so went undetected by our front-end safety checks. And it's not wrapped by auto_functionalized because there're no aliasing in auto_grad and functioanlization key becaues the zeros_like are not de-duped yet.
The alternatives include:
The final implemented fix after discussion with @zou3519 is 2: aliasing is ok as long as there're no in place mutations, so we can fix it right before we start introducing buffers etc.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben