-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🐛 Describe the bug
When compiling a model containing permute
followed by Conv1d
operations, the Inductor backend fails with a stride validation error:
failed on inductor expected size 64==64, stride 1==100 at dim=1; expected size 100==100, stride 64==1 at dim=2
Error in op: torch.ops.aten.convolution.default
The identical model works with eager
/aot_eager
backends.
Minimal Reproducer
import torch
import torch.nn as nn
class ConvModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(1, 64, kernel_size=3, padding=1)
def forward(self, x):
x = x.permute(0, 2, 1)
return self.conv(x)
model = ConvModel()
x = torch.randn(32, 100, 1)
def run_test(model, input, backend):
try:
model = torch.compile(model, backend=backend)
output = model(*input)
print(f"succeed on {backend}")
except Exception as e:
print(f"failed on {backend}", str(e))
run_test(model, [x], "eager")
run_test(model, [x], "aot_eager")
run_test(model, [x], "inductor")
Error logs
succeed on eager
succeed on aot_eager
failed on inductor: expected size 64==64, stride 1==100 at dim=1; expected size 100==100, stride 64==1 at dim=2
Error in op: torch.ops.aten.convolution.default
This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
Use torch.library.opcheck to test your custom op.
See https://pytorch.org/docs/stable/library.html#torch.library.opcheck
Versions
Collecting environment information...
PyTorch version: 2.9.0.dev20250729+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 4.0.2
Libc version: N/A
Python version: 3.10.10 (tags/v3.10.10:aad5f6a, Feb 7 2023, 17:20:36) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @chauhang @penguinwu @eellison @zou3519 @bdhirsh