-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathtest_sac_estimator.py
90 lines (77 loc) · 3.06 KB
/
test_sac_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Owner(s): ["module: unknown"]
import unittest
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.sac_estimator import SACEstimator
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
class TestSACEstimator(TestCase):
def _sac_estimation(
self,
estimate_mode: str,
model: torch.nn.Module,
inp: torch.Tensor,
):
sace = SACEstimator()
with sace(estimate_mode_type=estimate_mode):
loss = model(inp).sum()
loss.backward()
sace.pwlf_sac_tradeoff_curve(n_segments=2, save_tradeoff_graphs=False)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_transformer_sac_estimation(self):
"""Runs a basic GPT-2 model"""
dev = torch.cuda.current_device()
vocab_size = 8192
bsz, seq_len = 8, 1024
model_args = ModelArgs(
n_layers=4,
n_heads=12,
vocab_size=vocab_size,
max_seq_len=seq_len,
dim=768,
dropout_p=0.1,
)
with FakeTensorMode():
with torch.device(dev):
model = Transformer(model_args)
inp = torch.randint(
0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev
)
self._sac_estimation("operator-level-benchmark", model, inp)
self._sac_estimation("operator-level-cost-model", model, inp)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_simple_model_sac_estimation(self):
"""This test checks the correctness of view_ops, random_ops and inplace_ops"""
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 10)
self.relu1 = torch.nn.ReLU(inplace=True)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = torch.cos_(x)
x = torch.sin_(x)
return x
dev = torch.cuda.current_device()
with FakeTensorMode():
with torch.device(dev):
model = Foo()
x = torch.rand((10, 5), device=dev)
sac_estimator = SACEstimator()
with sac_estimator(estimate_mode_type="operator-level-benchmark"):
loss = model(x).sum()
loss.backward()
self.assertEqual(sac_estimator.sac_mod_stats["Foo"].view_like_ops, [0])
self.assertEqual(sac_estimator.sac_mod_stats["Foo"].rand_ops, [])
self.assertEqual(
sac_estimator.sac_mod_stats["Foo"].inplace_ops, [(2, 1), (3, 1), (4, 1)]
)
if __name__ == "__main__":
run_tests()