-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[inductor] propagate shapes in CSEVariable #152198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/isuruf/141/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152198
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 576d5a5 with merge base 556e2a7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, quick unsolicited review while waiting for something else, feel free to ignore!
🙂
torch/_inductor/codegen/common.py
Outdated
@@ -1761,6 +1765,7 @@ def generate( | |||
write: bool = True, | |||
assignment: bool = True, | |||
dtype: Optional[torch.dtype] = None, | |||
shape: Optional[ShapeType] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it have to be Optional
here, either? Couldn't the default be ()
?
It makes the code simpler if you're not always checking for None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
()
means the variable is a scalar. None
means we don't know.
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Excited to get this landed. The one comment that I think should definitely be addressed is to not use this in codegen since it is not 100% comprehensive/needs testing.
Excited to get this working, clean up some code, and also turn on lazy broadcasting (which was an issue with the native matmul codegen)
ndims = self.triton_tensor_ndim() | ||
if ndims == 1: | ||
return f"triton_helpers.promote_to_tensor({value})", shape | ||
|
||
nreduce = self.num_reduction_dims | ||
sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce | ||
new_shape = ( | ||
(*shape[: (ndims - nreduce)], *[1] * nreduce) if shape is not None else None | ||
) | ||
return f"{value}[{', '.join(sizes)}]", new_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if this is changing codegen, maybe better to wait until we have shapes everywhere/testing ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See, tests/assertions added here fd553b9, and in the opinfo tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't change codegen, just adds a new function that returns the shape as well as reduction_resize
codegen.
torch/_inductor/shape_propagation.py
Outdated
dtype: torch.dtype, | ||
src_dtype: Optional[torch.dtype] = None, | ||
use_compute_types: bool = True, | ||
) -> ShapeType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for my understanding - reason why we need this ? I would assume since single ShapeVar input, we would take that shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are non-Shapevar inputs, we return None to say that we are not sure about the shape. This is because there are operations where the shape might be implicitly in one of the inputs. (For eg: index_expr)
@@ -1848,8 +1854,12 @@ def generate( | |||
assert isinstance(expr, str) | |||
cache_key = expr | |||
var = self.try_get(cache_key) | |||
if shape is None and not assignment: | |||
# since there's no assignment to a variable, use any shape here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to make the shape the assigned variables shape ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could if expr
was a CSEVariable
, but not in other cases.
torch/_inductor/codegen/common.py
Outdated
@@ -79,6 +80,8 @@ | |||
# causes typing errors in subclasses (defined in other files). | |||
OpVarT = str | |||
|
|||
ShapeType = Optional[Sequence[Union[int, str]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this non-Optional ? Scalars have a shape of (). As with dtypes in codegen, its much more useful if we can 100% rely on these instead of being sometimes present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually yes, but we don't have 100% coverage to make it so. (For eg: with cpp backend)
Hi, very excited to see this feature landing — I was actually looking into tracking variable shapes and came across this PR. |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Fixes #149905
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov