From 3b0eea0262ecd9606d153fa96bd92ae85d62325b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 5 May 2025 00:02:56 -0700 Subject: [PATCH] [invoke_subgraph] Force the output stride to be same as eager [ghstack-poisoned] --- torch/_inductor/ir.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 54f451ad5843d..6f517f23a5a56 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7431,6 +7431,10 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: @ir_dataclass(frozen=False) class InvokeSubgraph(ExternKernel): + """ + Implementation of InvokeSubgraph HOP + """ + subgraph: Optional[Subgraph] = None operands: Optional[list[TensorBox]] = None outputs: Optional[list[MultiOutput]] = None @@ -7515,6 +7519,17 @@ def create_output(output: IRNode, ind: int): skip_size_stride_alignment_checks=True, ) + # Force the output strides to be same as the original strides + new_outputs = [] + fake_outputs = V.graph.current_node.meta["val"] + for idx, output in enumerate(outputs): + if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): + new_outputs.append(output) + else: + example_stride = handle_sym_expr(fake_outputs[idx].stride()) + new_outputs.append(cls.require_exact_strides(output, example_stride)) + outputs = new_outputs + outputs = [create_output(output, i) for i, output in enumerate(outputs)] invoke_subgraph.outputs = outputs return outputs