Skip to content

[inductor] propagate shapes in CSEVariable #152198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: gh/isuruf/141/base
Choose a base branch
from

Conversation

isuruf
Copy link
Collaborator

@isuruf isuruf commented Apr 25, 2025

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152198

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 576d5a5 with merge base 556e2a7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 25, 2025
ghstack-source-id: f2aa06b
Pull Request resolved: #152198
Copy link
Collaborator

@rec rec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, quick unsolicited review while waiting for something else, feel free to ignore!

🙂

@@ -1761,6 +1765,7 @@ def generate(
write: bool = True,
assignment: bool = True,
dtype: Optional[torch.dtype] = None,
shape: Optional[ShapeType] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it have to be Optional here, either? Couldn't the default be ()?

It makes the code simpler if you're not always checking for None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

() means the variable is a scalar. None means we don't know.

[ghstack-poisoned]
isuruf added a commit that referenced this pull request May 2, 2025
ghstack-source-id: 540faf4
Pull Request resolved: #152198
[ghstack-poisoned]
isuruf added a commit that referenced this pull request May 2, 2025
ghstack-source-id: 617d336
Pull Request resolved: #152198
[ghstack-poisoned]
[ghstack-poisoned]
isuruf added a commit that referenced this pull request May 6, 2025
ghstack-source-id: 93c674e
Pull Request resolved: #152198
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jun 6, 2025
ghstack-source-id: 9d6d2eb
Pull Request resolved: #152198
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jun 9, 2025
ghstack-source-id: a0c4ca6
Pull Request resolved: #152198
@isuruf isuruf added the topic: not user facing topic category label Jun 9, 2025
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jun 10, 2025
ghstack-source-id: fe852a4
Pull Request resolved: #152198
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jun 12, 2025
ghstack-source-id: 0ec2b1f
Pull Request resolved: #152198
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jun 13, 2025
ghstack-source-id: 1fcd065
Pull Request resolved: #152198
@isuruf isuruf requested a review from eellison June 13, 2025 15:37
@isuruf isuruf marked this pull request as ready for review June 13, 2025 15:37
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jun 13, 2025
ghstack-source-id: 1308366
Pull Request resolved: #152198
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Excited to get this landed. The one comment that I think should definitely be addressed is to not use this in codegen since it is not 100% comprehensive/needs testing.

Excited to get this working, clean up some code, and also turn on lazy broadcasting (which was an issue with the native matmul codegen)

Comment on lines +2460 to +2469
ndims = self.triton_tensor_ndim()
if ndims == 1:
return f"triton_helpers.promote_to_tensor({value})", shape

nreduce = self.num_reduction_dims
sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce
new_shape = (
(*shape[: (ndims - nreduce)], *[1] * nreduce) if shape is not None else None
)
return f"{value}[{', '.join(sizes)}]", new_shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if this is changing codegen, maybe better to wait until we have shapes everywhere/testing ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See, tests/assertions added here fd553b9, and in the opinfo tests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't change codegen, just adds a new function that returns the shape as well as reduction_resize codegen.

dtype: torch.dtype,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = True,
) -> ShapeType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for my understanding - reason why we need this ? I would assume since single ShapeVar input, we would take that shape.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are non-Shapevar inputs, we return None to say that we are not sure about the shape. This is because there are operations where the shape might be implicitly in one of the inputs. (For eg: index_expr)

@@ -1848,8 +1854,12 @@ def generate(
assert isinstance(expr, str)
cache_key = expr
var = self.try_get(cache_key)
if shape is None and not assignment:
# since there's no assignment to a variable, use any shape here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to make the shape the assigned variables shape ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could if expr was a CSEVariable, but not in other cases.

@@ -79,6 +80,8 @@
# causes typing errors in subclasses (defined in other files).
OpVarT = str

ShapeType = Optional[Sequence[Union[int, str]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this non-Optional ? Scalars have a shape of (). As with dtypes in codegen, its much more useful if we can 100% rely on these instead of being sometimes present.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually yes, but we don't have 100% coverage to make it so. (For eg: with cpp backend)

@nullplay
Copy link
Collaborator

Hi, very excited to see this feature landing — I was actually looking into tracking variable shapes and came across this PR.
Just to clarify, this looks like it's for generic CSEVariable, not specifically for TritonCSEVariable, right?

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Jul 28, 2025
ghstack-source-id: 31e608f
Pull Request resolved: #152198
@isuruf
Copy link
Collaborator Author

isuruf commented Aug 8, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/isuruf/141/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/152198)

pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
ghstack-source-id: 1088b3b
Pull Request resolved: #152198
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Aug 11, 2025
ghstack-source-id: 5f115d1
Pull Request resolved: #152198
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Aug 11, 2025
ghstack-source-id: 5dcbe75
Pull Request resolved: #152198
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants