Skip to content

Fix torch.export.export() GPU failure with RNN modules. #155734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

penknife6153
Copy link
Contributor

@penknife6153 penknife6153 commented Jun 11, 2025

torch.export.export() uses fake tensors during graph tracing, which don't have actual memory storage. The RNN module's flatten_parameters() method was calling p.data_ptr() for aliasing detection, causing "Cannot access data pointer of Tensor" errors on GPU.

Fixes #155309

Changes:

  • Replace p.data_ptr() with StorageWeakRef(p.untyped_storage()) for PT2 compatibility
  • Add graceful fallback for cases where StorageWeakRef fails
  • Works with fake tensors during exportFix torch.export.export() GPU failure with RNN modules

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @tugsbayasgalan

Copy link

pytorch-bot bot commented Jun 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155734

Note: Links to docs will display an error until the docs builds have been completed.

❌ 11 New Failures

As of commit 2a5c538 with merge base 3a56237 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@penknife6153
Copy link
Contributor Author

@pytorchbot label "release notes: nn"

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Jun 11, 2025
@janeyx99 janeyx99 requested review from angelayi and tugsbayasgalan and removed request for albanD, mikaylagawarecki and jbschlosser June 13, 2025 01:32
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 13, 2025
if len(unique_data_ptrs) != len(self._flat_weights):
return
try:
from torch.multiprocessing.reductions import StorageWeakRef
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StorageWeakRef implementation actually doesn't look like it is multiprocessing specific, so you should move it out of there. It is also not good practice to do local imports like this. @bdhirsh What do you think about defining in torch/init.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here would be the right place

class TensorWeakRef:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tugsbayasgalan @albanD I've moved StorageWeakRef to torch.utils.weak as suggested, and updated all relevant imports across the codebase. Let me know if there’s anything else you'd like me to adjust.

}
if len(unique_storage_refs) != len(self._flat_weights):
return
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a test case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on it today.

except Exception:
# Fallback for cases where StorageWeakRef is not available or fails
# This maintains PT2 compatibility by skipping aliasing check
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it ok to just pass here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we should at least narrow down the exception.

@@ -21,43 +21,6 @@
pass


class StorageWeakRef:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is NOT ok to remove public APIs without appropriate care.
In this case, I don't think it is worth breaking existing user code by removing this altogether.
You can import StorageWeakRef from torch.utils.weak here to preserve the current public API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imported, thanks for pointing this out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some tests are still failing.

@@ -336,3 +337,40 @@ def __call__(self):
# TODO, add _fix_weakref type binding
out._fix_weakref() # type: ignore[attr-defined]
return out


class StorageWeakRef:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @ezyang I forgot we had this in serialization. We should migrate this like the above when we have a minute now that we have pyobject preservation for Storage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes. Sounds LLM amenable, file an issue maybe?

@ezyang
Copy link
Contributor

ezyang commented Jun 30, 2025

@albanD @tugsbayasgalan Can we make sure this lands? This has definitely been a huge irritation in the past

@albanD
Copy link
Collaborator

albanD commented Jul 9, 2025

CI Needs fixing, happy to review again after that.

@tugsbayasgalan
Copy link
Contributor

@penknife6153 just wondering if you plan to work on this further?

@penknife6153
Copy link
Contributor Author

@penknife6153 just wondering if you plan to work on this further?

@tugsbayasgalan been sick for the past few weeks, but I'm currently working on this.

@tugsbayasgalan
Copy link
Contributor

tugsbayasgalan commented Aug 4, 2025

@penknife6153 Just checking in to see if you made any progress on this PR, since ONNX team plans to switch to export based graph capture very soon, this item needs to be closed sooner i think (cc: @titaiwangms)

@titaiwangms
Copy link
Collaborator

We need this by 2.9

@penknife6153
Copy link
Contributor Author

@penknife6153 Just checking in to see if you made any progress on this PR, since ONNX team plans to switch to export based graph capture very soon, this item needs to be closed sooner i think (cc: @titaiwangms)

Hi @tugsbayasgalan and @titaiwangms!

I’ve drafted some tests for RNN, LSTM, and GRU, added error handling at RNNBase, and resolved the merge conflicts. Aiming to have everything finalized within this weekend.

}
except Exception as e:
if isinstance(e, RuntimeError) and "share storage" in str(e):
raise # Re-raise actual aliasing errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sus

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: inductor open source release notes: nn release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.export.export() fails on GPU with LSTM model: "Cannot access data pointer of Tensor"
7 participants