diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 54f451ad5843d..6f517f23a5a56 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7431,6 +7431,10 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: @ir_dataclass(frozen=False) class InvokeSubgraph(ExternKernel): + """ + Implementation of InvokeSubgraph HOP + """ + subgraph: Optional[Subgraph] = None operands: Optional[list[TensorBox]] = None outputs: Optional[list[MultiOutput]] = None @@ -7515,6 +7519,17 @@ def create_output(output: IRNode, ind: int): skip_size_stride_alignment_checks=True, ) + # Force the output strides to be same as the original strides + new_outputs = [] + fake_outputs = V.graph.current_node.meta["val"] + for idx, output in enumerate(outputs): + if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): + new_outputs.append(output) + else: + example_stride = handle_sym_expr(fake_outputs[idx].stride()) + new_outputs.append(cls.require_exact_strides(output, example_stride)) + outputs = new_outputs + outputs = [create_output(output, i) for i, output in enumerate(outputs)] invoke_subgraph.outputs = outputs return outputs