Skip to content

[invoke_subgraph] Force the output stride to be same as eager #152806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: gh/anijain2305/753/base
Choose a base branch
from
15 changes: 15 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7431,6 +7431,10 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:

@ir_dataclass(frozen=False)
class InvokeSubgraph(ExternKernel):
"""
Implementation of InvokeSubgraph HOP
"""

subgraph: Optional[Subgraph] = None
operands: Optional[list[TensorBox]] = None
outputs: Optional[list[MultiOutput]] = None
Expand Down Expand Up @@ -7515,6 +7519,17 @@ def create_output(output: IRNode, ind: int):
skip_size_stride_alignment_checks=True,
)

# Force the output strides to be same as the original strides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a test at the very least. You can add an invoke_subgraph node, then do a graph pass that changes the outputs in the invoke_subgraph subgraph, and then check to make sure the strides are still what you expect.

Copy link
Contributor Author

@anijain2305 anijain2305 May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was not able to get a test working.

I was looking at a regression when I wrap the whole model with the invoke subgraph. When I diffed the output code, I saw an extra kernel after the invoke subgraph call, even though there was no operation outside of the invoke subgraph call. So this PR was my attempt to make the stride of invoke subgraph same as eager output to avoid that extra kernel. This fixed the regression. But after your comment about passes changing meta vals, I am not sure if this is correct (or what should be the solution to avoid the extra kernel)

new_outputs = []
fake_outputs = V.graph.current_node.meta["val"]
for idx, output in enumerate(outputs):
if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
new_outputs.append(output)
else:
example_stride = handle_sym_expr(fake_outputs[idx].stride())
new_outputs.append(cls.require_exact_strides(output, example_stride))
Comment on lines +7529 to +7530
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is right. Can Inductor passes change the fake_outputs in a way that they differ from eager?

If so we need to record the meta vals at the time of tracing, before passes run, and then use the metadata on them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this applies to inputs of the invoke subgraph then as well. Currently, we rely on the meta vals of the inputs of invoke subgraph, which could be different from eager because of graph passes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remind me why we want to force the inputs and output strides to be the same as eager? If we were not doing invoke_subgraph, inductor is allowed to change intermediates in the graph to have whatever strides it wants, with some exceptions.

Copy link
Contributor Author

@anijain2305 anijain2305 May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to reduce compile time. We compile a subgraph once and then call the same subgraph output code on second call. Since the input strides can be different for different subgraph calls, we restride the input to a fixed value at the beginning of each subgraph.

This allows us to reuse the output code of a subgraph. This is very important for compile time, otherwise the major benefits of invoke subgraph are not realized.

It is possible that the restriding is not to eager strides but to some strides after inductor graph passes are run. Nevertheless, it's a fixed and valid input strides.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some infrastructure to do this already (for inputs), check out

if _should_save_eager_input_vals(target, (args, kwargs)):
# NOTE "eager_input_vals"
# We save the original (args, kwargs) FakeTensor values for nodes
# that have exact stride requirements. This is useful downstream.
# We use this information inside Inductor to ensure that inputs to
# stride-sensitive operators have the correct strides.
arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type]
node.meta["eager_input_vals"] = (arg_inp, kwarg_inp)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea - let's use the above mechanism

Copy link
Contributor Author

@anijain2305 anijain2305 Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can use this for input. Is there anything for the output strides? The pointer is only for the inputs, but I also want to constrain the outputs.

outputs = new_outputs

outputs = [create_output(output, i) for i, output in enumerate(outputs)]
invoke_subgraph.outputs = outputs
return outputs
Expand Down
Loading