-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Open
Labels
module: memory usagePyTorch is using more memory than it should, or it is leaking memoryPyTorch is using more memory than it should, or it is leaking memorymodule: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworkmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
After trying the fix from #152371 (thanks so much for landing this so quickly) However, I was still seeing memory leaks. I found another issue where memory usage on MPS explodes when the sequence length sufficiently varies for SDPA - this does not occur with CUDA.
Reproduction script:
import torch
import torch.nn.functional as F
import sys
def get_memory_stats(device: torch.device):
if device.type == "mps":
peak_active = torch.mps.current_allocated_memory()
peak_alloc = torch.mps.driver_allocated_memory()
return peak_active, peak_alloc
elif device.type == "cuda":
peak_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0)
peak_alloc = torch.cuda.max_memory_allocated()
return peak_active, peak_alloc
def format_bytes(size_bytes):
"""Converts bytes to a readable string (KB, MB, GB)."""
if size_bytes < 1024:
return f"{size_bytes} B"
elif size_bytes < 1024**2:
return f"{size_bytes / 1024:.2f} KB"
elif size_bytes < 1024**3:
return f"{size_bytes / 1024**2:.2f} MB"
else:
return f"{size_bytes / 1024**3:.2f} GB"
def run_sdpa_test_single_bs(batch_size, num_iterations, num_heads, head_dim, min_seq_len, max_seq_len, device, dtype):
actual_max_seq_len = max(max_seq_len, min_seq_len + 1)
peak_active, peak_alloc = get_memory_stats(device)
print(f" Initial Memory: Active={format_bytes(peak_active)}, Alloc={format_bytes(peak_alloc)}")
for i in range(num_iterations):
seq_len = torch.randint(min_seq_len, actual_max_seq_len, (1,)).item()
query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
with torch.no_grad():
F.scaled_dot_product_attention(query, key, value)
peak_active, peak_alloc = get_memory_stats(device)
if (i + 1) % (num_iterations // 10 or 1) == 0:
print(f" Step {i + 1}/{num_iterations}: Active={format_bytes(peak_active)}, Alloc={format_bytes(peak_alloc)}")
final_peak_active, final_peak_alloc = get_memory_stats(device)
print(f" Final Memory: Active={format_bytes(final_peak_active)}, Alloc={format_bytes(final_peak_alloc)}")
print(f"--- Finished SDPA Test for BS={batch_size}, SeqLen Range=({min_seq_len}-{actual_max_seq_len - 1}) ---")
if __name__ == "__main__":
batch_size = 4
num_iterations = 400
num_heads = 8
head_dim = 128
min_seq_len = 128
max_seq_len = min_seq_len + int(sys.argv[1])
device = torch.device(sys.argv[2])
dtype = torch.bfloat16
run_sdpa_test_single_bs(batch_size, num_iterations, num_heads, head_dim, min_seq_len, max_seq_len, device, dtype)
CUDA results:
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 128 cuda
Initial Memory: Active=0 B, Alloc=0 B
Step 40/400: Active=8.71 MB, Alloc=8.71 MB
Step 80/400: Active=8.71 MB, Alloc=8.71 MB
Step 120/400: Active=9.66 MB, Alloc=9.66 MB
Step 160/400: Active=9.66 MB, Alloc=9.66 MB
Step 200/400: Active=9.66 MB, Alloc=9.66 MB
Step 240/400: Active=9.66 MB, Alloc=9.66 MB
Step 280/400: Active=9.66 MB, Alloc=9.66 MB
Step 320/400: Active=9.66 MB, Alloc=9.66 MB
Step 360/400: Active=9.66 MB, Alloc=9.66 MB
Step 400/400: Active=9.66 MB, Alloc=9.66 MB
Final Memory: Active=9.66 MB, Alloc=9.66 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-255) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 256 cuda
Initial Memory: Active=0 B, Alloc=0 B
Step 40/400: Active=12.00 MB, Alloc=12.00 MB
Step 80/400: Active=12.00 MB, Alloc=12.00 MB
Step 120/400: Active=13.17 MB, Alloc=13.17 MB
Step 160/400: Active=13.17 MB, Alloc=13.17 MB
Step 200/400: Active=13.17 MB, Alloc=13.17 MB
Step 240/400: Active=13.17 MB, Alloc=13.17 MB
Step 280/400: Active=13.17 MB, Alloc=13.17 MB
Step 320/400: Active=13.17 MB, Alloc=13.17 MB
Step 360/400: Active=13.17 MB, Alloc=13.17 MB
Step 400/400: Active=13.17 MB, Alloc=13.17 MB
Final Memory: Active=13.17 MB, Alloc=13.17 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-383) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 512 cuda
Initial Memory: Active=0 B, Alloc=0 B
Step 40/400: Active=20.78 MB, Alloc=20.78 MB
Step 80/400: Active=20.78 MB, Alloc=20.78 MB
Step 120/400: Active=20.78 MB, Alloc=20.78 MB
Step 160/400: Active=20.78 MB, Alloc=20.78 MB
Step 200/400: Active=20.78 MB, Alloc=20.78 MB
Step 240/400: Active=20.78 MB, Alloc=20.78 MB
Step 280/400: Active=20.78 MB, Alloc=20.78 MB
Step 320/400: Active=20.78 MB, Alloc=20.78 MB
Step 360/400: Active=20.78 MB, Alloc=20.78 MB
Step 400/400: Active=20.78 MB, Alloc=20.78 MB
Final Memory: Active=20.78 MB, Alloc=20.78 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-639) ---
root@cb0c541d80f5:/workspace/axolotl# python ../mem_test.py 2048 cuda
Initial Memory: Active=0 B, Alloc=0 B
Step 40/400: Active=67.58 MB, Alloc=67.58 MB
Step 80/400: Active=67.58 MB, Alloc=67.58 MB
Step 120/400: Active=67.58 MB, Alloc=67.58 MB
Step 160/400: Active=67.58 MB, Alloc=67.58 MB
Step 200/400: Active=67.58 MB, Alloc=67.58 MB
Step 240/400: Active=67.58 MB, Alloc=67.58 MB
Step 280/400: Active=67.58 MB, Alloc=67.58 MB
Step 320/400: Active=67.89 MB, Alloc=67.89 MB
Step 360/400: Active=67.89 MB, Alloc=67.89 MB
Step 400/400: Active=68.14 MB, Alloc=68.14 MB
Final Memory: Active=68.14 MB, Alloc=68.14 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-2175) ---
MPS Results:
> python minimal_test.py 128 mps
Initial Memory: Active=0 B, Alloc=384.00 KB
Step 40/400: Active=5.86 MB, Alloc=77.17 MB
Step 80/400: Active=5.86 MB, Alloc=85.52 MB
Step 120/400: Active=5.83 MB, Alloc=117.83 MB
Step 160/400: Active=5.86 MB, Alloc=118.02 MB
Step 200/400: Active=4.17 MB, Alloc=118.28 MB
Step 240/400: Active=5.83 MB, Alloc=118.41 MB
Step 280/400: Active=5.84 MB, Alloc=118.47 MB
Step 320/400: Active=5.84 MB, Alloc=118.48 MB
Step 360/400: Active=5.83 MB, Alloc=118.56 MB
Step 400/400: Active=5.83 MB, Alloc=118.61 MB
Final Memory: Active=5.83 MB, Alloc=118.61 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-255) ---
> python minimal_test.py 256 mps
Initial Memory: Active=0 B, Alloc=384.00 KB
Step 40/400: Active=7.81 MB, Alloc=143.22 MB
Step 80/400: Active=7.81 MB, Alloc=151.73 MB
Step 120/400: Active=7.81 MB, Alloc=184.08 MB
Step 160/400: Active=7.81 MB, Alloc=184.47 MB
Step 200/400: Active=7.81 MB, Alloc=184.77 MB
Step 240/400: Active=7.81 MB, Alloc=185.03 MB
Step 280/400: Active=8.11 MB, Alloc=185.28 MB
Step 320/400: Active=7.81 MB, Alloc=185.50 MB
Step 360/400: Active=7.81 MB, Alloc=185.78 MB
Step 400/400: Active=17.01 MB, Alloc=185.88 MB
Final Memory: Active=17.01 MB, Alloc=185.88 MB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-383) ---
> python minimal_test.py 512 mps
Initial Memory: Active=0 B, Alloc=384.00 KB
Step 40/400: Active=5.06 MB, Alloc=1.13 GB
Step 80/400: Active=17.57 MB, Alloc=1.13 GB
Step 120/400: Active=15.55 MB, Alloc=1.13 GB
Step 160/400: Active=10.97 MB, Alloc=1.13 GB
Step 200/400: Active=7.15 MB, Alloc=1.13 GB
Step 240/400: Active=15.55 MB, Alloc=1.13 GB
Step 280/400: Active=10.97 MB, Alloc=1.13 GB
Step 320/400: Active=17.57 MB, Alloc=1.13 GB
Step 360/400: Active=10.97 MB, Alloc=1.13 GB
Step 400/400: Active=17.57 MB, Alloc=1.13 GB
Final Memory: Active=17.57 MB, Alloc=1.13 GB
--- Finished SDPA Test for BS=4, SeqLen Range=(128-639) ---
Versions
On MPS:
Collecting environment information...
PyTorch version: 2.8.0.dev20250430
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.1.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.8 (main, Jan 5 2025, 06:55:30) [Clang 19.1.6 ] (64-bit runtime)
Python platform: macOS-15.1.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Max
Versions of relevant libraries:
[pip3] numpy==2.2.3
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] torch==2.8.0.dev20250430
[pip3] torchao==0.10.0+cpu
[pip3] torchaudio==2.6.0.dev20250430
[pip3] torchdata==0.11.0
[pip3] torchtune==0.0.0
[pip3] torchvision==0.22.0.dev20250430
[conda] No relevant packages
On CUDA:
Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.0.0
Libc version: glibc-2.35
Python version: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-196-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000
Nvidia driver version: 550.127.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7543 32-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 2800.0000
CPU min MHz: 1500.0000
BogoMIPS: 5599.84
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization: AMD-V
L1d cache: 2 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 32 MiB (64 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] apollo-torch==1.0.3
[pip3] galore-torch==1.0
[pip3] numpy==2.0.1
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0+cu124
[pip3] torch-optimi==0.2.1
[pip3] torchao==0.9.0
[pip3] torchvision==0.21.0+cu124
[pip3] triton==3.2.0
[conda] No relevant packages
Metadata
Metadata
Assignees
Labels
module: memory usagePyTorch is using more memory than it should, or it is leaking memoryPyTorch is using more memory than it should, or it is leaking memorymodule: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworkmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module