-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[invoke_subgraph] Force the output stride to be same as eager #152806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/anijain2305/753/base
Are you sure you want to change the base?
Changes from all commits
3b0eea0
e303edc
1e79f75
dfe1877
b6b482d
ba8920b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7431,6 +7431,10 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: | |||||||||||||||||
|
||||||||||||||||||
@ir_dataclass(frozen=False) | ||||||||||||||||||
class InvokeSubgraph(ExternKernel): | ||||||||||||||||||
""" | ||||||||||||||||||
Implementation of InvokeSubgraph HOP | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
subgraph: Optional[Subgraph] = None | ||||||||||||||||||
operands: Optional[list[TensorBox]] = None | ||||||||||||||||||
outputs: Optional[list[MultiOutput]] = None | ||||||||||||||||||
|
@@ -7515,6 +7519,17 @@ def create_output(output: IRNode, ind: int): | |||||||||||||||||
skip_size_stride_alignment_checks=True, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# Force the output strides to be same as the original strides | ||||||||||||||||||
new_outputs = [] | ||||||||||||||||||
fake_outputs = V.graph.current_node.meta["val"] | ||||||||||||||||||
for idx, output in enumerate(outputs): | ||||||||||||||||||
if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): | ||||||||||||||||||
new_outputs.append(output) | ||||||||||||||||||
else: | ||||||||||||||||||
example_stride = handle_sym_expr(fake_outputs[idx].stride()) | ||||||||||||||||||
new_outputs.append(cls.require_exact_strides(output, example_stride)) | ||||||||||||||||||
Comment on lines
+7529
to
+7530
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is right. Can Inductor passes change the fake_outputs in a way that they differ from eager? If so we need to record the meta vals at the time of tracing, before passes run, and then use the metadata on them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this applies to inputs of the invoke subgraph then as well. Currently, we rely on the meta vals of the inputs of invoke subgraph, which could be different from eager because of graph passes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remind me why we want to force the inputs and output strides to be the same as eager? If we were not doing invoke_subgraph, inductor is allowed to change intermediates in the graph to have whatever strides it wants, with some exceptions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to reduce compile time. We compile a subgraph once and then call the same subgraph output code on second call. Since the input strides can be different for different subgraph calls, we restride the input to a fixed value at the beginning of each subgraph. This allows us to reuse the output code of a subgraph. This is very important for compile time, otherwise the major benefits of invoke subgraph are not realized. It is possible that the restriding is not to eager strides but to some strides after inductor graph passes are run. Nevertheless, it's a fixed and valid input strides. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have some infrastructure to do this already (for inputs), check out pytorch/torch/fx/experimental/proxy_tensor.py Lines 1127 to 1134 in bc11afd
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea - let's use the above mechanism There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can use this for |
||||||||||||||||||
outputs = new_outputs | ||||||||||||||||||
|
||||||||||||||||||
outputs = [create_output(output, i) for i, output in enumerate(outputs)] | ||||||||||||||||||
invoke_subgraph.outputs = outputs | ||||||||||||||||||
return outputs | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a test at the very least. You can add an invoke_subgraph node, then do a graph pass that changes the outputs in the invoke_subgraph subgraph, and then check to make sure the strides are still what you expect.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was not able to get a test working.
I was looking at a regression when I wrap the whole model with the invoke subgraph. When I diffed the output code, I saw an extra kernel after the invoke subgraph call, even though there was no operation outside of the invoke subgraph call. So this PR was my attempt to make the stride of invoke subgraph same as eager output to avoid that extra kernel. This fixed the regression. But after your comment about passes changing meta vals, I am not sure if this is correct (or what should be the solution to avoid the extra kernel)