-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🐛 Describe the bug
Torch compiler can create tensors with a wrong shape / stride when using GQA SDPA.
I get nans in the gradients in the backward pass for all parameters and the inputs
This bug was very difficult to pin down to a minimal code example that reproduces it - turns out that a lot of conditions have to be met to trigger this rare error:
- odd sequence lengths (S % 2 == 1)
- GQA (I used 32 heads for Q, 8 for K and V)
- math backend (will automatically be used when GQA is enabled and attn_mask is used
this import before the compilation happens- looks like this is not necessary? maybe it just triggered a recompile. But now I get nans even without this.from __future__ import annotations
Maybe some of them are not needed or just misleading, but it already took me multiple days to reduce our code to this minimal reproducable example. I hope you can reproduce the error with it :)
example:
from __future__ import annotations
import torch
class MaskedMHA(torch.nn.Module):
def __init__(self, H_q, H_kv, D):
super().__init__()
self.H_kv = H_kv
num_heads_total = H_q + 2 * H_kv
self.qkv_proj_vid = torch.nn.Linear(H_q*D, num_heads_total*D)
self.qkv_proj_txt = torch.nn.Linear(H_q*D, num_heads_total*D)
self.out_proj = torch.nn.Linear(H_q*D, H_q*D)
self.H_q = H_q
self.D = D
def forward(self, x_vid, x_txt, attn_mask):
qkv_vid = self.qkv_proj_vid(x_vid)
qkv_txt = self.qkv_proj_txt(x_txt)
qkv_vid = qkv_vid.reshape((*qkv_vid.shape[:-1], -1, self.D))
qkv_txt = qkv_txt.reshape((*qkv_txt.shape[:-1], -1, self.D))
q_vid = qkv_vid[..., :self.H_q, :]
k_vid = qkv_vid[..., self.H_q:self.H_q + self.H_kv, :]
v_vid = qkv_vid[..., self.H_q + self.H_kv:, :]
q_txt = qkv_txt[..., :self.H_q, :]
k_txt = qkv_txt[..., self.H_q:self.H_q + self.H_kv, :]
v_txt = qkv_txt[..., self.H_q + self.H_kv:, :]
q = torch.cat([q_vid, q_txt], dim=-3)
k = torch.cat([k_vid, k_txt], dim=-3)
v = torch.cat([v_vid, v_txt], dim=-3)
out = torch.nn.functional.scaled_dot_product_attention(q.transpose(-2,-3), k.transpose(-2,-3), v.transpose(-2,-3), attn_mask=attn_mask, enable_gqa=True)
out = out.transpose(-2,-3)
return out
def test_masked_mha():
S_vid = 300
S_txt = S - S_vid
x1 = torch.randn(B, S_vid, H*D, requires_grad=True, device=device)
x2 = torch.randn(B, S_txt, H*D, requires_grad=True, device=device)
attn_mask = torch.ones(B, 1, S, S, dtype=torch.bool, device=device)
H_kv = H // 4
mha = MaskedMHA(H, H_kv, D)
mha = mha.to(device)
mha = torch.compile(mha, fullgraph=True)
with torch.autocast(device_type="cuda", dtype=dtype, cache_enabled=False):
out_vid = mha(x1, x2, attn_mask)
target_vid = torch.randn_like(out_vid)
loss_vid = (out_vid - target_vid).mean()
loss = loss_vid
loss.backward()
torch.cuda.synchronize()
print(f"x1 grad any nan={torch.any(x1.grad.isnan()).item()}")
print(f"x2 grad any nan={torch.any(x2.grad.isnan()).item()}")
print(f"loss={loss.item()}")
for param_idx, param in enumerate(mha.parameters()):
print(f"{param_idx=} {param.grad.max() if param.grad is not None else None}")
B, H, S, D = 64, 32, 555, 128
device = "cuda"
dtype = torch.bfloat16
torch.compiler.reset()
test_masked_mha()
example output
x1 grad any nan=True
x2 grad any nan=True
loss=0.0004730224609375
param_idx=0 nan
param_idx=1 nan
param_idx=2 nan
param_idx=3 nan
param_idx=4 None
param_idx=5 None
compiled output code
output_code_bw.txt
output_code_fw.txt
Versions
Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.31
Python version: 3.11.10 (main, Oct 16 2024, 04:38:48) [Clang 18.1.8 ] (64-bit runtime)
Python platform: Linux-5.15.0-1062-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7R13 Processor
Stepping: 1
CPU MHz: 2650.000
BogoMIPS: 5300.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3 MiB
L1i cache: 3 MiB
L2 cache: 48 MiB
L3 cache: 384 MiB
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @coconutruben