Skip to content

Add option to assert if kernel is not fully fused in foreach_map #159213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

haveheartt
Copy link

@haveheartt haveheartt commented Jul 26, 2025

@haveheartt haveheartt requested a review from zou3519 as a code owner July 26, 2025 22:02
Copy link

pytorch-bot bot commented Jul 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159213

Note: Links to docs will display an error until the docs builds have been completed.

❌ 11 New Failures, 1 Unrelated Failure

As of commit 93485fc with merge base 3ced107 (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@HDCharles HDCharles requested a review from eellison July 28, 2025 21:26
@HDCharles HDCharles added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 28, 2025
@eellison eellison requested review from mlazos and removed request for eellison July 29, 2025 14:48
@zou3519 zou3519 removed their request for review July 31, 2025 14:49
@zou3519
Copy link
Contributor

zou3519 commented Jul 31, 2025

Going to defer to @mlazos

@@ -9,15 +9,15 @@ class ForeachMap(BaseHOP):
def __init__(self):
super().__init__("foreach_map")

def __call__(self, fn, *operands, **kwargs): # type: ignore[override]
def __call__(self, fn, *operands, assert_fused=False, **kwargs): # type: ignore[override]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've discussed things that "check if something is fused" before and the conclusion was always that "Inductor makes no promises". If so, this should be named something like a debug API. Thoughts @eellison ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, especially internally, we do not want this to get used, something changed, and then we get a huge number of failures.

Copy link
Contributor

@mlazos mlazos Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking this as debug is fine. Will that work for you guys? I think that's the main ask from users anyway, they just want some way to test that it's working as intended.

Copy link
Contributor

@zou3519 zou3519 Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark as debug and then document that we will break this would be good

@@ -1104,6 +1104,23 @@ def ref_fn(xs):

self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)

def test_foreach_map_assert_fused_passes(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do one more test where you have a slightly more complex sequence of pointwise functions (maybe multiply the result as well? I want to make sure vertical fusion is tested as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, working on it!

return a + b
x = torch.randn(10)
y = torch.randn(10)
torch.compile(foreach_map, fullgraph=True)(fn, (x,), (y,), assert_fused=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also count the number of kernels and verify the accuracy like in the other tests? Other than that this looks good!

@mlazos
Copy link
Contributor

mlazos commented Aug 6, 2025

@haveheartt looks good, last thing can you preface assert_fused with debug to make it clear that it's a debug API ?

@mlazos
Copy link
Contributor

mlazos commented Aug 6, 2025

Looks good, please prefix the arg with debug as I mentioned above before merging, and mention that we don't have BC guarantees on fusion behvaior. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor open source release notes: foreach_frontend release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add option to assert if kernel is not fully fused in foreach_map
6 participants