-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathtest_mem_tracker.py
241 lines (214 loc) · 9.22 KB
/
test_mem_tracker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Owner(s): ["module: unknown"]
import gc
import unittest
import torch
import torch.nn as nn
from torch.distributed._tools.mem_tracker import MemTracker
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TestCase,
)
from torch.utils.checkpoint import checkpoint
class TestMemTracker(TestCase):
def _init_cublas_workspace(self, dev: torch.device):
lin = torch.nn.Linear(768, 768, device=dev)
inp = torch.randn(1, 768, device=dev)
lin(inp).sum().backward()
del lin
del inp
def _reset_mem_stats(self, dev: torch.device):
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats(dev)
torch.cuda.reset_peak_memory_stats(dev)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@skipIfRocm()
def test_cuda_tracker_equivalence(
self,
):
"""
Tests that the tracker correctly calculates the peak memory.
"""
dev = torch.device(torch.cuda.current_device())
self._init_cublas_workspace(dev)
gc.collect(1)
self._reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"]
bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat16
class DummyModel(nn.Module):
def __init__(self, n_layers: int, dim: int, dtype: torch.dtype):
super().__init__()
self.linears = nn.ModuleList()
for _ in range(n_layers):
self.linears.append(nn.Linear(dim, dim, dtype=dtype))
self.linears.append(nn.ReLU())
def forward(self, x):
for layer in self.linears:
x = layer(x)
return x
with torch.device(dev):
model = DummyModel(n_layers, dim, dtype=dtype)
optim = torch.optim.Adam(model.parameters(), foreach=True)
input_batch = torch.randn(bsz, dim, device=dev, dtype=dtype)
mem_tracker = MemTracker()
mem_tracker.track_external(model, optim, input_batch)
with mem_tracker as mt:
for iter_idx in range(2):
model(input_batch).sum().backward()
optim.step()
optim.zero_grad()
if iter_idx == 0:
mt.reset_mod_stats()
# Check for accuracy of peak memory
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
mem_stats = torch.cuda.memory_stats(dev)
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
accuracy = tracker_max / cuda_max
self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_tracker_with_activation_checkpointing(
self,
):
"""
Tests that the tracker correctly computes the peak memory during activation checkpointing.
"""
dev = torch.device(torch.cuda.current_device())
self._init_cublas_workspace(dev)
gc.collect(1)
self._reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"]
bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float16
class MLPBlock(nn.Module):
def __init__(self, dim: int, dtype: torch.dtype):
super().__init__()
self.mlp_block = nn.Sequential(
nn.Linear(dim, 2 * dim, dtype=dtype),
nn.ReLU(),
nn.Linear(2 * dim, dim, dtype=dtype),
)
def forward(self, x):
return self.mlp_block(x)
class MyModule(nn.Module):
def __init__(
self, n_layers: int, dim: int, dtype: torch.dtype, use_ac: bool = False
):
super().__init__()
self.mlp_blocks = nn.ModuleList()
self.use_ac = use_ac
for _ in range(n_layers):
self.mlp_blocks.append(MLPBlock(dim, dtype=dtype))
def forward(self, x):
for i, block in enumerate(self.mlp_blocks):
if i >= 1 and self.use_ac:
x = checkpoint(
block, x, preserve_rng_state=True, use_reentrant=False
)
else:
x = block(x)
return x
with torch.device(dev):
model = MyModule(n_layers, dim, dtype, True)
optim = torch.optim.Adam(model.parameters(), foreach=True)
mem_tracker = MemTracker()
mem_tracker.track_external(model, optim)
with mem_tracker as mt:
input_batch = torch.randn(bsz, dim, dim, device=dev, dtype=dtype)
for iter_idx in range(2):
model(input_batch).sum().backward()
optim.step()
optim.zero_grad()
if iter_idx == 0:
mt.reset_mod_stats()
# Check for accuracy of peak memory
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
mem_stats = torch.cuda.memory_stats(dev)
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
accuracy = tracker_max / cuda_max
self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
def test_tracker_attribution(self):
"""
Tests that the tracker correctly categorizes params, gradients, and optimizer states.
"""
dev = torch.device(torch.get_default_device())
gc.collect(1)
bsz, n_layers, dim, dtype = 16, 3, 128, torch.float32
def get_param_grad_optstate_actual_bytes(
model: nn.Module, opt: torch.optim.Optimizer
) -> tuple[int, int, int]:
param_bytes = 0
grad_bytes = 0
opt_state_bytes = 0
for param in model.parameters():
if param.device == dev:
param_bytes += param.numel() * param.element_size()
if param.grad is not None and param.grad.device == dev:
grad_bytes += param.grad.numel() * param.grad.element_size()
for state in opt.state.values():
for v in state.values():
if isinstance(v, torch.Tensor) and v.device == dev:
opt_state_bytes += v.numel() * v.element_size()
return param_bytes, grad_bytes, opt_state_bytes
def get_param_grad_optstate_bytes_from_tracker(
tracker: MemTracker,
) -> tuple[int, int, int]:
snapshot = tracker.get_tracker_snapshot()
param_bytes = snapshot[dev]["Parameter"]
grad_bytes = snapshot[dev]["Gradient"]
opt_state_bytes = snapshot[dev]["Optstate"]
return param_bytes, grad_bytes, opt_state_bytes
def test_attribution_equivalence(
mt: MemTracker,
model: nn.Module,
opt: torch.optim.Optimizer,
) -> None:
actual = get_param_grad_optstate_actual_bytes(model, opt)
tracker = get_param_grad_optstate_bytes_from_tracker(mt)
for a, b in zip(actual, tracker):
if a == 0:
self.assertEqual(b, 0)
else:
self.assertAlmostEqual(b / a, 1.0, delta=0.1)
class DummyModel(nn.Module):
def __init__(self, n_layers: int, dim: int, dtype: torch.dtype):
super().__init__()
self.MLP_layers = nn.ModuleList()
for _ in range(n_layers):
self.MLP_layers.extend(
[nn.Linear(dim, 2 * dim, dtype=dtype), nn.GELU()]
)
self.MLP_layers.extend(
[nn.Linear(2 * dim, dim, dtype=dtype), nn.GELU()]
)
def forward(self, x):
for layer in self.MLP_layers:
x = layer(x)
return x
with torch.device(dev):
model = DummyModel(n_layers, dim, dtype=dtype)
optim = torch.optim.Adam(model.parameters(), foreach=True)
mem_tracker = MemTracker()
mem_tracker.track_external(model, optim)
with mem_tracker as mt:
input_batch = torch.randn(bsz, dim, device=dev, dtype=dtype)
# Before forward: Only parameters and input are allocated
test_attribution_equivalence(mt, model, optim)
output = model(input_batch)
output.sum().backward()
# After backward: Gradients are allocated
test_attribution_equivalence(mt, model, optim)
output = None
optim.step()
# After step: Optimizer state is allocated
test_attribution_equivalence(mt, model, optim)
optim.zero_grad()
# After zero_grad: Gradients are deallocated
test_attribution_equivalence(mt, model, optim)
if __name__ == "__main__":
run_tests()