-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Fix torch.export.export() GPU failure with RNN modules. #155734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155734
Note: Links to docs will display an error until the docs builds have been completed. ❌ 11 New FailuresAs of commit 2a5c538 with merge base 3a56237 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "release notes: nn" |
torch/nn/modules/rnn.py
Outdated
if len(unique_data_ptrs) != len(self._flat_weights): | ||
return | ||
try: | ||
from torch.multiprocessing.reductions import StorageWeakRef |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StorageWeakRef implementation actually doesn't look like it is multiprocessing specific, so you should move it out of there. It is also not good practice to do local imports like this. @bdhirsh What do you think about defining in torch/init.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here would be the right place
Line 322 in 43a0918
class TensorWeakRef: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tugsbayasgalan @albanD I've moved StorageWeakRef to torch.utils.weak as suggested, and updated all relevant imports across the codebase. Let me know if there’s anything else you'd like me to adjust.
torch/nn/modules/rnn.py
Outdated
} | ||
if len(unique_storage_refs) != len(self._flat_weights): | ||
return | ||
except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need a test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Working on it today.
torch/nn/modules/rnn.py
Outdated
except Exception: | ||
# Fallback for cases where StorageWeakRef is not available or fails | ||
# This maintains PT2 compatibility by skipping aliasing check | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it ok to just pass here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep we should at least narrow down the exception.
@@ -21,43 +21,6 @@ | |||
pass | |||
|
|||
|
|||
class StorageWeakRef: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is NOT ok to remove public APIs without appropriate care.
In this case, I don't think it is worth breaking existing user code by removing this altogether.
You can import StorageWeakRef from torch.utils.weak here to preserve the current public API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imported, thanks for pointing this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like some tests are still failing.
@@ -336,3 +337,40 @@ def __call__(self): | |||
# TODO, add _fix_weakref type binding | |||
out._fix_weakref() # type: ignore[attr-defined] | |||
return out | |||
|
|||
|
|||
class StorageWeakRef: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @ezyang I forgot we had this in serialization. We should migrate this like the above when we have a minute now that we have pyobject preservation for Storage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes. Sounds LLM amenable, file an issue maybe?
@albanD @tugsbayasgalan Can we make sure this lands? This has definitely been a huge irritation in the past |
CI Needs fixing, happy to review again after that. |
@penknife6153 just wondering if you plan to work on this further? |
@tugsbayasgalan been sick for the past few weeks, but I'm currently working on this. |
@penknife6153 Just checking in to see if you made any progress on this PR, since ONNX team plans to switch to export based graph capture very soon, this item needs to be closed sooner i think (cc: @titaiwangms) |
We need this by 2.9 |
Hi @tugsbayasgalan and @titaiwangms! I’ve drafted some tests for RNN, LSTM, and GRU, added error handling at RNNBase, and resolved the merge conflicts. Aiming to have everything finalized within this weekend. |
} | ||
except Exception as e: | ||
if isinstance(e, RuntimeError) and "share storage" in str(e): | ||
raise # Re-raise actual aliasing errors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sus
torch.export.export()
uses fake tensors during graph tracing, which don't have actual memory storage. The RNN module'sflatten_parameters()
method was callingp.data_ptr()
for aliasing detection, causing "Cannot access data pointer of Tensor" errors on GPU.Fixes #155309
Changes:
p.data_ptr()
withStorageWeakRef(p.untyped_storage())
for PT2 compatibilitycc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @tugsbayasgalan