10
10
import torch .nn as nn
11
11
from torch .distributed ._composable import checkpoint
12
12
from torch .testing ._internal .common_cuda import TEST_CUDA
13
- from torch .testing ._internal .common_utils import run_tests , TestCase
13
+ from torch .testing ._internal .common_fsdp import get_devtype
14
+ from torch .testing ._internal .common_utils import run_tests , TEST_XPU , TestCase
14
15
from torch .utils .checkpoint import CheckpointError
15
16
16
17
18
+ device_type = torch .device (get_devtype ())
19
+ device_module = torch .get_device_module (device_type )
20
+
21
+
17
22
class MemoryDelta (ContextDecorator ):
18
23
def __init__ (self , device : torch .device ):
19
24
self .device : torch .device = device
@@ -22,16 +27,16 @@ def __init__(self, device: torch.device):
22
27
23
28
def __enter__ (self ):
24
29
self .active_memory_enter = (
25
- torch . cuda .memory_stats ()["active_bytes.all.current" ]
26
- if self .device .type == "cuda"
30
+ device_module .memory_stats ()["active_bytes.all.current" ]
31
+ if self .device .type == "cuda" or self . device . type == "xpu"
27
32
else 0
28
33
)
29
34
return self
30
35
31
36
def __exit__ (self , * exc ):
32
37
self .active_memory_exit = (
33
- torch . cuda .memory_stats ()["active_bytes.all.current" ]
34
- if self .device .type == "cuda"
38
+ device_module .memory_stats ()["active_bytes.all.current" ]
39
+ if self .device .type == "cuda" or self . device . type == "xpu"
35
40
else 0
36
41
)
37
42
@@ -126,7 +131,7 @@ def _test_tensor_only(
126
131
loss2 = net2 (x2 ).sum ()
127
132
loss2 .backward ()
128
133
129
- if x .is_cuda :
134
+ if x .is_cuda or x . is_xpu :
130
135
self .assertTrue (mem2 .delta () < mem1 .delta ())
131
136
132
137
for p1 , p2 in zip (net1 .parameters (), net2 .parameters ()):
@@ -137,10 +142,10 @@ def test_tensor_only_cpu(self):
137
142
net = ToyModel ()
138
143
self ._test_tensor_only (net , x )
139
144
140
- @unittest .skipIf (not TEST_CUDA , "no cuda" )
145
+ @unittest .skipIf (not TEST_CUDA and not TEST_XPU , "no cuda/xpu " )
141
146
def test_tensor_only_gpu (self ):
142
- x = torch .randn (20 , 100 , device = "cuda :0" )
143
- net = ToyModel ().to ("cuda :0" )
147
+ x = torch .randn (20 , 100 , device = f" { device_type . type } :0" )
148
+ net = ToyModel ().to (f" { device_type . type } :0" )
144
149
self ._test_tensor_only (net , x )
145
150
146
151
def test_random_cpu (self ):
0 commit comments