Skip to content

[Inductor] Fix remove_noop_ops pass where the types for the same_meta would differ #154460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Gab-Menezes
Copy link

@Gab-Menezes Gab-Menezes commented May 27, 2025

Summary: Fixes a bug where the type of val1 and val2 for the same_meta function would differ. Leading to a compiler crash, since types other than Tensor and FakeTensor don't have the following attributes.

The problem has been spotted by me in the wild, where val2 would be of type SymInt, while the expected type is FakeTensor. The setup is kinda convoluted to reproduce, but here are some other examples of people encountering this problem.

This problem only showed up when I tried to compile a large model using dynamic=True, so it's related to that.

chengzeyi/Comfy-WaveSpeed#18
openvinotoolkit/openvino#22412
pytorch/TensorRT#2356

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Copy link

pytorch-bot bot commented May 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154460

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a63a768 with merge base 53affa2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented May 27, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: Gab-Menezes / name: Gabriel Jorge Menezes (fa4a629, a63a768)

@Gab-Menezes
Copy link
Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 27, 2025
@Gab-Menezes
Copy link
Author

Gab-Menezes commented May 27, 2025

Backtrace that lead me to this fix:

  ...
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1544, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1519, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/__init__.py", line 2347, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2089, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 101, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1160, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 775, in load
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1145, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 219, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 479, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1882, in fw_compiler_base
    _recursive_joint_graph_passes(gm)
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 367, in _recursive_joint_graph_passes
    joint_graph_passes(gm)
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 544, in joint_graph_passes
    GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_transform_observer.py", line 85, in apply_graph_pass
    return pass_fn(self.gm.graph)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py", line 739, in remove_noop_ops
    if same_meta(node, src) and cond(*args, **kwargs):
       ^^^^^^^^^^^^^^^^^^^^
  File "/REDACTED/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py", line 587, in same_meta
    and statically_known_true(sym_eq(val1.size(), val2.size()))
                                                  ^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'SymInt' object has no attribute 'size'

@colesbury colesbury requested a review from eellison May 29, 2025 16:30
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 29, 2025
Comment on lines +894 to +895
issubclass(type(val1), torch.Tensor)
and issubclass(type(val2), torch.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between this and isinstance(val1, torch.Tensor)

@Gab-Menezes
Copy link
Author

@eellison we good to merge this ?

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jul 29, 2025
@Gab-Menezes
Copy link
Author

@pytorchmergebot merge

Copy link

pytorch-bot bot commented Aug 6, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor open source Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants