Skip to content

Commit 9d14e4b

Browse files
committed
port 5 distributed test to Intel GPU
1 parent 2dccff7 commit 9d14e4b

File tree

6 files changed

+124
-75
lines changed

6 files changed

+124
-75
lines changed

test/distributed/_composable/test_checkpoint.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
import torch.nn as nn
1111
from torch.distributed._composable import checkpoint
1212
from torch.testing._internal.common_cuda import TEST_CUDA
13-
from torch.testing._internal.common_utils import run_tests, TestCase
13+
from torch.testing._internal.common_fsdp import get_devtype
14+
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
1415
from torch.utils.checkpoint import CheckpointError
1516

1617

18+
device_type = torch.device(get_devtype())
19+
device_module = torch.get_device_module(device_type)
20+
21+
1722
class MemoryDelta(ContextDecorator):
1823
def __init__(self, device: torch.device):
1924
self.device: torch.device = device
@@ -22,16 +27,16 @@ def __init__(self, device: torch.device):
2227

2328
def __enter__(self):
2429
self.active_memory_enter = (
25-
torch.cuda.memory_stats()["active_bytes.all.current"]
26-
if self.device.type == "cuda"
30+
device_module.memory_stats()["active_bytes.all.current"]
31+
if self.device.type == "cuda" or self.device.type == "xpu"
2732
else 0
2833
)
2934
return self
3035

3136
def __exit__(self, *exc):
3237
self.active_memory_exit = (
33-
torch.cuda.memory_stats()["active_bytes.all.current"]
34-
if self.device.type == "cuda"
38+
device_module.memory_stats()["active_bytes.all.current"]
39+
if self.device.type == "cuda" or self.device.type == "xpu"
3540
else 0
3641
)
3742

@@ -126,7 +131,7 @@ def _test_tensor_only(
126131
loss2 = net2(x2).sum()
127132
loss2.backward()
128133

129-
if x.is_cuda:
134+
if x.is_cuda or x.is_xpu:
130135
self.assertTrue(mem2.delta() < mem1.delta())
131136

132137
for p1, p2 in zip(net1.parameters(), net2.parameters()):
@@ -137,10 +142,10 @@ def test_tensor_only_cpu(self):
137142
net = ToyModel()
138143
self._test_tensor_only(net, x)
139144

140-
@unittest.skipIf(not TEST_CUDA, "no cuda")
145+
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
141146
def test_tensor_only_gpu(self):
142-
x = torch.randn(20, 100, device="cuda:0")
143-
net = ToyModel().to("cuda:0")
147+
x = torch.randn(20, 100, device=f"{device_type.type}:0")
148+
net = ToyModel().to(f"{device_type.type}:0")
144149
self._test_tensor_only(net, x)
145150

146151
def test_random_cpu(self):

0 commit comments

Comments
 (0)