Skip to content

MPS varying seq len SDPA memory leak #152550

@SalmanMohammadi

Description

@SalmanMohammadi

🐛 Describe the bug

After trying the fix from #152371 (thanks so much for landing this so quickly) However, I was still seeing memory leaks. I found another issue where memory usage on MPS explodes when the sequence length sufficiently varies for SDPA - this does not occur with CUDA.

Image

Reproduction script:

import torch
import torch.nn.functional as F
import sys


def get_memory_stats(device: torch.device):
    if device.type == "mps":
        peak_active = torch.mps.current_allocated_memory()
        peak_alloc = torch.mps.driver_allocated_memory()
        return peak_active, peak_alloc
    elif device.type == "cuda":
        peak_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0)
        peak_alloc = torch.cuda.max_memory_allocated()
        return peak_active, peak_alloc


def format_bytes(size_bytes):
    """Converts bytes to a readable string (KB, MB, GB)."""
    if size_bytes < 1024:
        return f"{size_bytes} B"
    elif size_bytes < 1024**2:
        return f"{size_bytes / 1024:.2f} KB"
    elif size_bytes < 1024**3:
        return f"{size_bytes / 1024**2:.2f} MB"
    else:
        return f"{size_bytes / 1024**3:.2f} GB"


def run_sdpa_test_single_bs(batch_size, num_iterations, num_heads, head_dim, min_seq_len, max_seq_len, device, dtype):
    actual_max_seq_len = max(max_seq_len, min_seq_len + 1)
    peak_active, peak_alloc = get_memory_stats(device)
    print(f"  Initial Memory: Active={format_bytes(peak_active)}, Alloc={format_bytes(peak_alloc)}")

    for i in range(num_iterations):
        seq_len = torch.randint(min_seq_len, actual_max_seq_len, (1,)).item()

        query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
        key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
        value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
        with torch.no_grad():
            F.scaled_dot_product_attention(query, key, value)  

        peak_active, peak_alloc = get_memory_stats(device)

        if (i + 1) % (num_iterations // 10 or 1) == 0:
            print(f"  Step {i + 1}/{num_iterations}: Active={format_bytes(peak_active)}, Alloc={format_bytes(peak_alloc)}")

    final_peak_active, final_peak_alloc = get_memory_stats(device)
    print(f"  Final Memory: Active={format_bytes(final_peak_active)}, Alloc={format_bytes(final_peak_alloc)}")
    print(f"--- Finished SDPA Test for BS={batch_size}, SeqLen Range=({min_seq_len}-{actual_max_seq_len - 1}) ---")


if __name__ == "__main__":
    batch_size = 4
    num_iterations = 400
    num_heads = 8
    head_dim = 128
    min_seq_len = 128
    max_seq_len = min_seq_len + int(sys.argv[1])
    device = torch.device(sys.argv[2])
    dtype = torch.bfloat16
    run_sdpa_test_single_bs(batch_size, num_iterations, num_heads, head_dim, min_seq_len, max_seq_len, device, dtype)

CUDA results:

root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 128 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=8.71 MB, Alloc=8.71 MB
  Step 80/400: Active=8.71 MB, Alloc=8.71 MB
  Step 120/400: Active=9.66 MB, Alloc=9.66 MB
  Step 160/400: Active=9.66 MB, Alloc=9.66 MB
  Step 200/400: Active=9.66 MB, Alloc=9.66 MB
  Step 240/400: Active=9.66 MB, Alloc=9.66 MB
  Step 280/400: Active=9.66 MB, Alloc=9.66 MB
  Step 320/400: Active=9.66 MB, Alloc=9.66 MB
  Step 360/400: Active=9.66 MB, Alloc=9.66 MB
  Step 400/400: Active=9.66 MB, Alloc=9.66 MB
  Final Memory: Active=9.66 MB, Alloc=9.66 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-255) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 256 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=12.00 MB, Alloc=12.00 MB
  Step 80/400: Active=12.00 MB, Alloc=12.00 MB
  Step 120/400: Active=13.17 MB, Alloc=13.17 MB
  Step 160/400: Active=13.17 MB, Alloc=13.17 MB
  Step 200/400: Active=13.17 MB, Alloc=13.17 MB
  Step 240/400: Active=13.17 MB, Alloc=13.17 MB
  Step 280/400: Active=13.17 MB, Alloc=13.17 MB
  Step 320/400: Active=13.17 MB, Alloc=13.17 MB
  Step 360/400: Active=13.17 MB, Alloc=13.17 MB
  Step 400/400: Active=13.17 MB, Alloc=13.17 MB
  Final Memory: Active=13.17 MB, Alloc=13.17 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-383) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 512 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=20.78 MB, Alloc=20.78 MB
  Step 80/400: Active=20.78 MB, Alloc=20.78 MB
  Step 120/400: Active=20.78 MB, Alloc=20.78 MB
  Step 160/400: Active=20.78 MB, Alloc=20.78 MB
  Step 200/400: Active=20.78 MB, Alloc=20.78 MB
  Step 240/400: Active=20.78 MB, Alloc=20.78 MB
  Step 280/400: Active=20.78 MB, Alloc=20.78 MB
  Step 320/400: Active=20.78 MB, Alloc=20.78 MB
  Step 360/400: Active=20.78 MB, Alloc=20.78 MB
  Step 400/400: Active=20.78 MB, Alloc=20.78 MB
  Final Memory: Active=20.78 MB, Alloc=20.78 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-639) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 2048 cuda
  Initial Memory: Active=0 B, Alloc=0 B
  Step 40/400: Active=67.58 MB, Alloc=67.58 MB
  Step 80/400: Active=67.58 MB, Alloc=67.58 MB
  Step 120/400: Active=67.58 MB, Alloc=67.58 MB
  Step 160/400: Active=67.58 MB, Alloc=67.58 MB
  Step 200/400: Active=67.58 MB, Alloc=67.58 MB
  Step 240/400: Active=67.58 MB, Alloc=67.58 MB
  Step 280/400: Active=67.58 MB, Alloc=67.58 MB
  Step 320/400: Active=67.89 MB, Alloc=67.89 MB
  Step 360/400: Active=67.89 MB, Alloc=67.89 MB
  Step 400/400: Active=68.14 MB, Alloc=68.14 MB
  Final Memory: Active=68.14 MB, Alloc=68.14 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-2175) ---

MPS Results:

> python minimal_test.py 128 mps
  Initial Memory: Active=0 B, Alloc=384.00 KB
  Step 40/400: Active=5.86 MB, Alloc=77.17 MB
  Step 80/400: Active=5.86 MB, Alloc=85.52 MB
  Step 120/400: Active=5.83 MB, Alloc=117.83 MB
  Step 160/400: Active=5.86 MB, Alloc=118.02 MB
  Step 200/400: Active=4.17 MB, Alloc=118.28 MB
  Step 240/400: Active=5.83 MB, Alloc=118.41 MB
  Step 280/400: Active=5.84 MB, Alloc=118.47 MB
  Step 320/400: Active=5.84 MB, Alloc=118.48 MB
  Step 360/400: Active=5.83 MB, Alloc=118.56 MB
  Step 400/400: Active=5.83 MB, Alloc=118.61 MB
  Final Memory: Active=5.83 MB, Alloc=118.61 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-255) ---
> python minimal_test.py 256 mps
  Initial Memory: Active=0 B, Alloc=384.00 KB
  Step 40/400: Active=7.81 MB, Alloc=143.22 MB
  Step 80/400: Active=7.81 MB, Alloc=151.73 MB
  Step 120/400: Active=7.81 MB, Alloc=184.08 MB
  Step 160/400: Active=7.81 MB, Alloc=184.47 MB
  Step 200/400: Active=7.81 MB, Alloc=184.77 MB
  Step 240/400: Active=7.81 MB, Alloc=185.03 MB
  Step 280/400: Active=8.11 MB, Alloc=185.28 MB
  Step 320/400: Active=7.81 MB, Alloc=185.50 MB
  Step 360/400: Active=7.81 MB, Alloc=185.78 MB
  Step 400/400: Active=17.01 MB, Alloc=185.88 MB
  Final Memory: Active=17.01 MB, Alloc=185.88 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-383) ---
> python minimal_test.py 512 mps
  Initial Memory: Active=0 B, Alloc=384.00 KB
  Step 40/400: Active=5.06 MB, Alloc=1.13 GB
  Step 80/400: Active=17.57 MB, Alloc=1.13 GB
  Step 120/400: Active=15.55 MB, Alloc=1.13 GB
  Step 160/400: Active=10.97 MB, Alloc=1.13 GB
  Step 200/400: Active=7.15 MB, Alloc=1.13 GB
  Step 240/400: Active=15.55 MB, Alloc=1.13 GB
  Step 280/400: Active=10.97 MB, Alloc=1.13 GB
  Step 320/400: Active=17.57 MB, Alloc=1.13 GB
  Step 360/400: Active=10.97 MB, Alloc=1.13 GB
  Step 400/400: Active=17.57 MB, Alloc=1.13 GB
  Final Memory: Active=17.57 MB, Alloc=1.13 GB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-639) ---

Versions

On MPS:

Collecting environment information...
PyTorch version: 2.8.0.dev20250430
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.1.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.8 (main, Jan  5 2025, 06:55:30) [Clang 19.1.6 ] (64-bit runtime)
Python platform: macOS-15.1.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==2.2.3
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] torch==2.8.0.dev20250430
[pip3] torchao==0.10.0+cpu
[pip3] torchaudio==2.6.0.dev20250430
[pip3] torchdata==0.11.0
[pip3] torchtune==0.0.0
[pip3] torchvision==0.22.0.dev20250430
[conda] No relevant packages

On CUDA:

Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.0.0
Libc version: glibc-2.35

Python version: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-196-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000

Nvidia driver version: 550.127.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             128
On-line CPU(s) list:                0-127
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7543 32-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          2
Stepping:                           1
Frequency boost:                    enabled
CPU max MHz:                        2800.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           5599.84
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization:                     AMD-V
L1d cache:                          2 MiB (64 instances)
L1i cache:                          2 MiB (64 instances)
L2 cache:                           32 MiB (64 instances)
L3 cache:                           512 MiB (16 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-31,64-95
NUMA node1 CPU(s):                  32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] apollo-torch==1.0.3
[pip3] galore-torch==1.0
[pip3] numpy==2.0.1
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0+cu124
[pip3] torch-optimi==0.2.1
[pip3] torchao==0.9.0
[pip3] torchvision==0.21.0+cu124
[pip3] triton==3.2.0
[conda] No relevant packages

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: memory usagePyTorch is using more memory than it should, or it is leaking memorymodule: mpsRelated to Apple Metal Performance Shaders frameworkmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions