Skip to content

torch compile produces nans with GQA  #159469

@dabeschte

Description

@dabeschte

🐛 Describe the bug

Torch compiler can create tensors with a wrong shape / stride when using GQA SDPA.
I get nans in the gradients in the backward pass for all parameters and the inputs

This bug was very difficult to pin down to a minimal code example that reproduces it - turns out that a lot of conditions have to be met to trigger this rare error:

  • odd sequence lengths (S % 2 == 1)
  • GQA (I used 32 heads for Q, 8 for K and V)
  • math backend (will automatically be used when GQA is enabled and attn_mask is used
  • this import before the compilation happens from __future__ import annotations - looks like this is not necessary? maybe it just triggered a recompile. But now I get nans even without this.

Maybe some of them are not needed or just misleading, but it already took me multiple days to reduce our code to this minimal reproducable example. I hope you can reproduce the error with it :)

example:

from __future__ import annotations
import torch


class MaskedMHA(torch.nn.Module):
    def __init__(self, H_q, H_kv, D):
        super().__init__()
        self.H_kv = H_kv
        num_heads_total = H_q + 2 * H_kv
        self.qkv_proj_vid = torch.nn.Linear(H_q*D, num_heads_total*D)
        self.qkv_proj_txt = torch.nn.Linear(H_q*D, num_heads_total*D)
        self.out_proj = torch.nn.Linear(H_q*D, H_q*D)
        self.H_q = H_q
        self.D = D

        
    def forward(self, x_vid, x_txt, attn_mask):
        qkv_vid = self.qkv_proj_vid(x_vid)
        qkv_txt = self.qkv_proj_txt(x_txt)
        qkv_vid = qkv_vid.reshape((*qkv_vid.shape[:-1], -1, self.D))
        qkv_txt = qkv_txt.reshape((*qkv_txt.shape[:-1], -1, self.D))

        
        q_vid = qkv_vid[..., :self.H_q, :]
        k_vid = qkv_vid[..., self.H_q:self.H_q + self.H_kv, :]
        v_vid = qkv_vid[..., self.H_q + self.H_kv:, :]

        q_txt = qkv_txt[..., :self.H_q, :]
        k_txt = qkv_txt[..., self.H_q:self.H_q + self.H_kv, :]
        v_txt = qkv_txt[..., self.H_q + self.H_kv:, :]

        q = torch.cat([q_vid, q_txt], dim=-3)
        k = torch.cat([k_vid, k_txt], dim=-3)
        v = torch.cat([v_vid, v_txt], dim=-3)

        out = torch.nn.functional.scaled_dot_product_attention(q.transpose(-2,-3), k.transpose(-2,-3), v.transpose(-2,-3), attn_mask=attn_mask, enable_gqa=True)
        out = out.transpose(-2,-3)

        return out

def test_masked_mha():
    S_vid = 300
    S_txt = S - S_vid
    x1 = torch.randn(B, S_vid, H*D, requires_grad=True, device=device)
    x2 = torch.randn(B, S_txt, H*D, requires_grad=True, device=device)
    attn_mask = torch.ones(B, 1, S, S, dtype=torch.bool, device=device)

    H_kv = H // 4
    mha = MaskedMHA(H, H_kv, D)
    mha = mha.to(device)
    mha = torch.compile(mha, fullgraph=True)
    with torch.autocast(device_type="cuda", dtype=dtype, cache_enabled=False):
        out_vid = mha(x1, x2, attn_mask)
        target_vid = torch.randn_like(out_vid)

        loss_vid = (out_vid - target_vid).mean()
        loss = loss_vid
    loss.backward()
    torch.cuda.synchronize()
    print(f"x1 grad any nan={torch.any(x1.grad.isnan()).item()}")
    print(f"x2 grad any nan={torch.any(x2.grad.isnan()).item()}")
    print(f"loss={loss.item()}")
    for param_idx, param in enumerate(mha.parameters()):
        print(f"{param_idx=} {param.grad.max() if param.grad is not None else None}")


B, H, S, D = 64, 32, 555, 128
device = "cuda"
dtype = torch.bfloat16
torch.compiler.reset()

test_masked_mha()

example output

x1 grad any nan=True
x2 grad any nan=True
loss=0.0004730224609375
param_idx=0 nan
param_idx=1 nan
param_idx=2 nan
param_idx=3 nan
param_idx=4 None
param_idx=5 None

compiled output code

output_code_bw.txt
output_code_fw.txt

Versions

Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.31

Python version: 3.11.10 (main, Oct 16 2024, 04:38:48) [Clang 18.1.8 ] (64-bit runtime)
Python platform: Linux-5.15.0-1062-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.183.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             96
On-line CPU(s) list:                0-95
Thread(s) per core:                 1
Core(s) per socket:                 48
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              1
Model name:                         AMD EPYC 7R13 Processor
Stepping:                           1
CPU MHz:                            2650.000
BogoMIPS:                           5300.00
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          3 MiB
L1i cache:                          3 MiB
L2 cache:                           48 MiB
L3 cache:                           384 MiB
NUMA node0 CPU(s):                  0-47
NUMA node1 CPU(s):                  48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @coconutruben

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2pt2: ubn"unbreak now" hi-pri, only applies to the PyTorch Compiler Team.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions