# Owner(s): ["oncall: distributed"] import contextlib import functools import io from collections import OrderedDict from copy import deepcopy from itertools import product import torch import torch.nn.functional as F import torch.nn.parallel as dp from torch import nn from torch.cuda.amp import autocast from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, onlyCUDA, skipMeta, ) from torch.testing._internal.common_utils import ( _assertGradAndGradgradChecks, dtype2prec_DONTUSE, gradcheck, run_tests, skip_but_pass_in_sandcastle_if, TestCase, ) NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL") # batched grad doesn't support data parallel gradcheck = functools.partial(gradcheck, check_batched_grad=False) _assertGradAndGradgradChecks = functools.partial( _assertGradAndGradgradChecks, check_batched_grad=False ) class TestDataParallel(TestCase): @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_buffers_requiring_grad(self): class TestModule(nn.Module): def __init__(self, t): super().__init__() self.t_rg = nn.Buffer(t) self.t_not_rg = nn.Buffer(t.detach().clone()) def forward(self, x): return x * self.t_rg + self.t_not_rg m = TestModule( torch.randn(100, device="cuda", requires_grad=True, dtype=torch.double) ) self.assertTrue(m.t_rg.requires_grad) dpm = nn.DataParallel(m, [0, 1]) inp = torch.randn(2, 100, device="cuda", dtype=torch.double) def fn(t): return dpm(inp) gradcheck(fn, (m.t_rg,)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_rnn(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.rnn = torch.nn.LSTM( 300, 1024, 1, batch_first=True, bidirectional=True ) def forward(self, x): self.rnn.flatten_parameters() return self.rnn(x) def step(model): opt = torch.optim.SGD(model.parameters(), lr=10) input = torch.ones(4, 4, 300).to(0) output = model(input) loss = F.mse_loss(output[0], torch.zeros_like(output[0])) loss.backward() opt.step() with torch.no_grad(): model = TestModule().to(0) model_dp = torch.nn.DataParallel(deepcopy(model)) # make sure DP does not crash when grad is disabled. # See #21108 model_dp(torch.rand(2, 4, 300).to(0)) step(model) step(model_dp) for p1, p2 in zip(model.parameters(), model_dp.parameters()): self.assertTrue(p1.allclose(p2)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_lazy_linear(self): with self.assertRaisesRegex( ValueError, "Attempted to use an uninitialized parameter" ): model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0)) model_dp(torch.rand(10, 10).to(0)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_parallel_apply(self): l1 = nn.Linear(10, 5).to("cuda:0", torch.float) l2 = nn.Linear(10, 5).to("cuda:1", torch.float) i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float) i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float) expected1 = l1(i1) expected2 = l2(i2) modules = (l1, l2) expected_outputs = (expected1, expected2) # each input can be either a collection of positional arguments # or an object representing the single argument for inputs in [((i1,), (i2,)), (i1, i2)]: outputs = dp.parallel_apply(modules, inputs, None) for out, expected in zip(outputs, expected_outputs): self.assertEqual(out, expected) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_parallel_apply_autocast(self): l1 = nn.Linear(10, 5).to("cuda:0", torch.float) l2 = nn.Linear(10, 5).to("cuda:1", torch.float) i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float) i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float) with autocast(): expected1 = l1(i1) expected2 = l2(i2) modules = (l1, l2) expected_outputs = (expected1, expected2) # each input can be either a collection of positional arguments # or an object representing the single argument for inputs in [((i1,), (i2,)), (i1, i2)]: with autocast(): outputs = dp.parallel_apply(modules, inputs, None) for out, expected in zip(outputs, expected_outputs): self.assertEqual(out, expected) @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "CUDA unavailable") def test_parallel_apply_passes_exception(self): # we define and instantiate a module that will throw a KeyError class TestModule(nn.Module): def forward(self, *args): return {}["wonderful"] l1 = TestModule().to("cuda", torch.float) # and check that parallel_apply passes on the exception # (we can use a single device twice for this test) with self.assertRaisesRegex( KeyError, "Caught KeyError in replica \\d " "on device 0.\nOriginal Traceback" "[\\s\\S]+wonderful", ): dp.parallel_apply(modules=(l1, l1), inputs=(None, None)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_multiple_input(self): class TestModule(nn.Module): def forward(self, var1, var2, float1, var3=None): if var3 is None: return float1 * (var1 * var2) else: return float1 * (var1 * var2 + var3) m = TestModule() var1 = torch.randn(5, 5, dtype=torch.float, requires_grad=True) var2 = torch.randn(5, 5, dtype=torch.float, requires_grad=True) var3 = torch.randn(5, 5, dtype=torch.float, requires_grad=False) float1 = torch.randn(1).item() expected = m(var1, var2, float1) loss = expected.sum() loss.backward() gvar1_exp = var1.grad.clone() gvar2_exp = var2.grad.clone() def local_test(out): with torch.no_grad(): var1.grad.fill_(0.0) var2.grad.fill_(0.0) loss = out.sum() loss.backward() self.assertEqual(out, expected) self.assertEqual(gvar1_exp, var1.grad) self.assertEqual(gvar2_exp, var2.grad) out = dp.data_parallel(m, (var1, var2, float1), (0, 1)) local_test(out) out = dp.data_parallel(m, (var1, var2, float1), (1, 0)) local_test(out) out = dp.data_parallel(m, (var1, var2, float1), (0,)) local_test(out) with torch.no_grad(): var1.grad.fill_(0.0) var2.grad.fill_(0.0) expected = m(var1, var2, float1, var3=var3) loss = expected.sum() loss.backward() gvar1_exp = var1.grad.clone() gvar2_exp = var2.grad.clone() dpm = nn.DataParallel(TestModule()) out = dpm(var1, var2, float1, var3=var3) local_test(out) dpm = nn.DataParallel(TestModule(), device_ids=[0]) out = dpm(var1, var2, float1, var3=var3) local_test(out) kwarg_wrap = {"var3": var3} out = dp.data_parallel( m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap ) local_test(out) out = dp.data_parallel(m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap) local_test(out) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_small_back(self): l = nn.Linear(10, 5).float().cuda() i = torch.randn(20, 10, dtype=torch.float, device="cuda") out = dp.data_parallel(l, i, (0, 1)) self.assertEqual(out, l(i)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_model_device(self): r"""Test device[0] check at forward time.""" l = nn.Linear(2, 2) inp = torch.randn(2, 2) inp_cuda0 = inp.cuda(0) inp_cuda1 = inp.cuda(1) error_msg = "module must have its parameters and buffers on device {}" @contextlib.contextmanager def dummy_ctx_manager(): yield def test(inner_m, dp_device, inp, device_ids, should_fail): if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if isinstance(device_ids[0], torch.device): expect_device = device_ids[0] else: expect_device = torch.device(f"cuda:{device_ids[0]}") if should_fail: def assert_correct(): return self.assertRaisesRegex( RuntimeError, error_msg.format(expect_device) ) else: assert_correct = dummy_ctx_manager # test DataParallel module dpm = nn.DataParallel(inner_m, device_ids) if dp_device is not None: dpm = dpm.to(dp_device) with assert_correct(): dpm(inp) # test functional with assert_correct(): nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids) test(l.to("cpu"), None, inp, None, should_fail=True) test(l.cuda(1), None, inp_cuda0, None, should_fail=True) test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True) test(l.cuda(), None, inp_cuda0, None, should_fail=False) test(l.cpu(), "cuda", inp_cuda0, None, should_fail=False) test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False) test(l.cpu(), "cuda:1", inp_cuda1, [1, 0], should_fail=False) s = nn.Sequential(l.cpu()) test(s, None, inp, None, should_fail=True) test(s, None, inp, [0, 1], should_fail=True) test(s, None, inp, [1, 0], should_fail=True) s = nn.Sequential(deepcopy(l).cpu(), l.cuda()) test(s, None, inp, None, should_fail=True) test(s, None, inp, [0, 1], should_fail=True) test(s, None, inp, [1, 0], should_fail=True) s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1)) test(s, None, inp, None, should_fail=True) test(s, None, inp, [0, 1], should_fail=True) test(s, None, inp, [1, 0], should_fail=True) s = nn.Sequential(l.cuda(), deepcopy(l).cuda()) test(s, None, inp, None, should_fail=False) test(s, None, inp, [0, 1], should_fail=False) test(s, None, inp, [1, 0], should_fail=True) test(s.cpu(), None, inp, [1, 0], should_fail=True) test(s.cuda(1), None, inp, [1, 0], should_fail=False) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_model_no_refcycles(self): # Python 2.7 will create reference cycles with the following # Module on multiple GPUs, but Python 3 shouldn't unless # there are refcycles on the PyTorch side (or the defined module) import gc class Model(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(x) gc.collect() model = nn.DataParallel(Model().cuda()) data = torch.randn(1, device="cuda") model(data) refcycles = gc.collect() self.assertEqual(refcycles, 0) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_no_grad(self): test = self class Layer(nn.Module): def forward(self, x): test.assertFalse(torch.is_grad_enabled()) return x l = Layer() i = torch.randn(20, 10, dtype=torch.float, device="cuda") with torch.no_grad(): dp.data_parallel(l, i, (0, 1)) self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1))) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel(self): l = nn.Linear(10, 5).float().cuda() i = torch.randn(20, 10, dtype=torch.float, device="cuda:1") l.cuda(1) expected_out = l(i) loss = expected_out.sum() loss.backward() expected_grads = [] for param in l.parameters(): expected_grads.append(param.grad.clone()) dev_ids_list = [(0, 1), (1, 0)] for dev_id in dev_ids_list: with torch.cuda.device(dev_id[0]): l.cuda() l.zero_grad() out = dp.data_parallel(l, i, dev_id) loss = out.sum() loss.backward() self.assertEqual(out.get_device(), dev_id[0]) self.assertEqual(out, expected_out) for expected, param in zip(expected_grads, l.parameters()): self.assertEqual(param.grad, expected) # Check for None device_ids l = l.cuda() out = dp.data_parallel(l, i) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_sparse(self): l = nn.Embedding(10, 5, sparse=True).to("cuda:1") i = torch.randint(10, (20, 5), device="cuda:1", dtype=torch.long) expected_out = l(i) loss = expected_out.sum() loss.backward() expected_grads = [] for param in l.parameters(): expected_grads.append(param.grad.clone()) dev_ids_list = [(0, 1), (1, 0)] for dev_id in dev_ids_list: with torch.cuda.device(dev_id[0]): l.cuda() l.zero_grad() out = dp.data_parallel(l, i, dev_id) loss = out.sum() loss.backward() self.assertEqual(out.get_device(), dev_id[0]) self.assertEqual(out, expected_out) for expected, param in zip(expected_grads, l.parameters()): self.assertEqual(param.grad.coalesce(), expected.coalesce()) # Check for None device_ids l = l.cuda() out = dp.data_parallel(l, i) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_nested_output(self): def fn(input): return [ input, (input.sin(), input.cos(), [input.add(1)]), input, OrderedDict(a=input, b=[input.sin()]), ] class Net(nn.Module): def forward(self, input): return fn(input) i = torch.randn(2, 2).float().cuda(1) gpus = range(torch.cuda.device_count()) output = dp.data_parallel(Net(), i, gpus) self.assertEqual(output, fn(i)) self.assertIsInstance(output[0], torch.Tensor) self.assertIsInstance(output[1], tuple) self.assertIsInstance(output[1][0], torch.Tensor) self.assertIsInstance(output[1][1], torch.Tensor) self.assertIsInstance(output[1][2], list) self.assertIsInstance(output[1][2][0], torch.Tensor) self.assertIsInstance(output[2], torch.Tensor) self.assertIsInstance(output[3], dict) self.assertEqual(len(output[3]), 2) self.assertIn("a", output[3]) self.assertIn("b", output[3]) self.assertIsInstance(output[3]["a"], torch.Tensor) self.assertIsInstance(output[3]["b"], list) self.assertIsInstance(output[3]["b"][0], torch.Tensor) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_nested_input(self): def fn(input): return input[1][0] class Net(nn.Module): def forward(self, *input): return fn(input) i = torch.randn(20, 3, dtype=torch.float, device="cuda:1") input = (i.cos(), (i.sin(), i), i.sin()) gpus = range(torch.cuda.device_count()) output = dp.data_parallel(Net(), input, gpus) self.assertEqual(output, fn(input)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_module_zero_inputs(self): class TestModule(nn.Module): def forward(self): t = torch.eye(2, 3, device="cuda:0") return t + (1 - t) def test_helper(output, expected): self.assertEqual(output.get_device(), 0) self.assertEqual(output, expected) expected = torch.ones(2, 3, device="cuda:0") model = TestModule() test_helper(nn.DataParallel(model, [0])(), expected) test_helper(nn.DataParallel(model, [0, 1])(), expected) test_helper(dp.data_parallel(model, None, [0]), expected) test_helper(dp.data_parallel(model, (), [0, 1]), expected) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_device_args(self): cuda0 = torch.device("cuda:0") cuda1 = torch.device("cuda:1") # test output_device l = nn.Linear(10, 5).to(cuda0, torch.float) i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True) out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0) self.assertEqual(out, l(i)) # test device_ids l = nn.Linear(10, 5).to(cuda0, torch.float) i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True) out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0) self.assertEqual(out, l(i)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_function_deletion(self): # this test case is originated from #16532 def gradient_penalty(net, x): output = net(x) loss = torch.autograd.grad( outputs=output, inputs=x, grad_outputs=x.new_ones(output.size()), create_graph=True, retain_graph=True, )[0].mean() return loss net = nn.Linear(4, 1).cuda() dpn = nn.DataParallel(net, [0, 1]) x = torch.ones(2, 4, requires_grad=True).cuda() dpn.zero_grad() loss = gradient_penalty(dpn, x) loss.backward() grads = [p.grad for p in net.parameters()] self.assertEqual(2, len(grads)) self.assertEqual( torch.tensor([[0.25, 0.25, 0.25, 0.25]], device="cuda:0"), grads[0] ) self.assertEqual(torch.tensor([0.0], device="cuda:0"), grads[1]) def _test_scatter(self, tensor): x = tensor.detach().requires_grad_() result = dp.scatter(x, (0, 1)) self.assertEqual(len(result), 2) self.assertEqual(result[0], x[:2]) self.assertEqual(result[0].get_device(), 0) self.assertEqual(result[1], x[2:]) self.assertEqual(result[1].get_device(), 1) grad = result[0].detach().clone().fill_(2) result[0].backward(grad) self.assertEqual(x.grad[:2], grad) self.assertEqual(x.grad[2:], grad.clone().zero_()) _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_scatter_cpu(self): self._test_scatter(torch.randn((4, 4), dtype=torch.double)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_scatter_gpu(self): self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda()) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed") def test_data_parallel_complex(self): # We expect complex parameters to be broadcast by view_as_real, e.g. move from C to R^2 class Cplx(torch.nn.Module): def __init__(self) -> None: super().__init__() self.cplx = torch.nn.Parameter( torch.zeros(1, 10, dtype=torch.cfloat).cuda() ) def forward(self, x): return x + self.cplx cplx = torch.nn.DataParallel(Cplx().cuda()) input = torch.rand(1, 10, dtype=torch.cfloat).cuda() result = cplx(input) # 2 is the extra real view dimension here self.assertEqual(result.size(), torch.Size([1, 10, 2])) self.assertEqual(result, torch.view_as_real(input)) def _test_gather(self, output_device): inputs = ( torch.randn(2, 4, device="cuda:0", requires_grad=True, dtype=torch.double), torch.randn(2, 4, device="cuda:1", requires_grad=True, dtype=torch.double), ) result = dp.gather(inputs, output_device) self.assertEqual(result.size(), torch.Size([4, 4])) self.assertEqual(result[:2], inputs[0]) self.assertEqual(result[2:], inputs[1]) if output_device != -1: self.assertEqual(result.get_device(), output_device) else: self.assertFalse(result.is_cuda) grad = torch.randn((4, 4), dtype=torch.double) if output_device != -1: grad = grad.cuda(output_device) result.backward(grad) self.assertEqual(inputs[0].grad, grad[:2]) self.assertEqual(inputs[1].grad, grad[2:]) _assertGradAndGradgradChecks( self, lambda x, y: dp.gather((x, y), output_device), inputs ) # test scalar inputs, should stack into a vector in this case inputs = ( torch.randn((), device="cuda:0", requires_grad=True, dtype=torch.double), torch.randn((), device="cuda:1", requires_grad=True, dtype=torch.double), ) result = dp.gather(inputs, output_device) self.assertEqual(result.size(), torch.Size([2])) self.assertEqual(result[0], inputs[0]) self.assertEqual(result[1], inputs[1]) if output_device != -1: self.assertEqual(result.get_device(), output_device) else: self.assertFalse(result.is_cuda) grad = torch.randn(2, dtype=torch.double) if output_device != -1: grad = grad.cuda(output_device) result.backward(grad) self.assertEqual(inputs[0].grad, grad[0]) self.assertEqual(inputs[1].grad, grad[1]) _assertGradAndGradgradChecks( self, lambda x, y: dp.gather((x, y), output_device), inputs ) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_gather_cpu(self): self._test_gather(-1) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_gather_gpu(self): self._test_gather(0) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_gather_different_len_dicts(self): inputs = ( {"a": torch.randn(1, 2, requires_grad=True, device="cuda:0")}, { "b": torch.randn(1, 2, requires_grad=True, device="cuda:1"), "a": torch.randn(1, 2, requires_grad=True, device="cuda:1"), }, ) with self.assertRaises(ValueError): _ = dp.gather(inputs, target_device=0) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_replicate(self): module = nn.Linear(10, 5).float().cuda() input = torch.randn(2, 10, dtype=torch.float, device="cuda") expected_output = module(input) for devices in [(0, 1), [0, 1]]: replicas = dp.replicate(module, devices) for i, replica in enumerate(replicas): for p in replica.parameters(): self.assertEqual(p.get_device(), i) replica_input = input.cuda(i) self.assertEqual(replica(replica_input), expected_output) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_replicate_buffers(self): net = nn.Module() net.bn = nn.BatchNorm2d(10) net.cuda() for devices in [(0, 1), [0, 1]]: replicas = dp.replicate(net, devices) for i, replica in enumerate(replicas): self.assertEqual( replica.bn.running_mean.get_device(), i, msg="buffer on wrong device", ) self.assertEqual( replica.bn.running_var.get_device(), i, msg="buffer on wrong device" ) self.assertEqual( replica.bn.num_batches_tracked.get_device(), i, msg="buffer on wrong device", ) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_zero_grad(self): # zero_grad should warn about using gradients inside forward class Net(torch.nn.Module): def __init__(self, testcase): super().__init__() self._testcase = testcase def forward(self, x): with self._testcase.assertWarnsRegex( UserWarning, r"Calling \.zero_grad\(\) from a module created with nn\.DataParallel\(\) has no effect.", ): self.zero_grad() return x module = Net(self).cuda() dpm = dp.DataParallel(module) dpm(torch.rand(4, 3, 6, 5)) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_autocast(self): class Model(torch.nn.Linear): def __init__(self) -> None: super().__init__(8, 8) @torch.autocast(device_type="cuda") def forward(self, input): return super().forward(input) model = dp.DataParallel(Model().cuda().to(dtype=torch.float32)) input = torch.randn((8, 8), dtype=torch.float32, device="cuda") self.assertTrue(model(input).dtype is torch.float16) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_save_replica_module(self): # DataParallel replicas can be saved (gh-37182) module = torch.nn.Linear(8, 8).cuda() dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=False) data = io.BytesIO() torch.save(dpm, data) dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=True) torch.save(dpm, data) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_strided_grad_layout(self): class ConvNet(nn.Module): def __init__(self, layouts, dtype_list): super().__init__() self.dtypes = dtype_list self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to( memory_format=layouts[0], dtype=dtype_list[0] ) self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to( memory_format=layouts[1], dtype=dtype_list[1] ) self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to( memory_format=layouts[2], dtype=dtype_list[2] ) self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to( memory_format=layouts[3], dtype=dtype_list[3] ) def forward(self, x): x = x.to(self.dtypes[0]) x = self.conv0(x).to(self.dtypes[1]) x = self.conv1(x).to(self.dtypes[2]) x = self.conv2(x).to(self.dtypes[3]) x = self.conv3(x) return x layer_formats = ( [torch.contiguous_format] * 4, [torch.channels_last] * 2 + [torch.contiguous_format] * 2, [torch.channels_last] * 4, ) layer_dtypes = ( [torch.float] * 4, [torch.float] * 2 + [torch.half] * 2, [torch.half] * 4, ) ndevs = torch.cuda.device_count() input = torch.randn(ndevs * 8, 8, 8, 8, device="cuda:0", dtype=torch.float) target = torch.randn(ndevs * 8, 8, 4, 4, device="cuda:0", dtype=torch.float) device_ids = list(range(ndevs)) with torch.backends.cudnn.flags( enabled=True, deterministic=True, benchmark=False ): for formats, dtype_list in product(layer_formats, layer_dtypes): model_msg = f"formats = {formats} dtypes = {dtypes}" try: m = ConvNet(formats, dtype_list).cuda(device="cuda:0") m_dp = dp.DataParallel(deepcopy(m), device_ids=device_ids) opt = torch.optim.SGD(m.parameters(), lr=0.1) opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1) has_half = any(p.dtype is torch.half for p in m.parameters()) tol = 1.0e-3 if has_half else 1.0e-5 except BaseException: # Prints case-specific debugging info to narrow down failing case. print( "Caught exception during model creation for " + model_msg, flush=True, ) raise # 2 iters: First iter creates grads, second iter tries zeroed grads. for it in range(2): iter_msg = f"iter = {it} " + model_msg named_msg = iter_msg try: F.mse_loss(m(input).float(), target).backward() F.mse_loss(m_dp(input).float(), target).backward() for i, ((layer_name, m_child), m_dp_child) in enumerate( zip(m.named_children(), m_dp.module.children()) ): named_msg = layer_name + ".weight " + iter_msg self.assertTrue( m_child.weight.grad.is_contiguous( memory_format=formats[i] ), named_msg, ) self.assertTrue( m_dp_child.weight.grad.is_contiguous( memory_format=formats[i] ), named_msg, ) for (param_name, p), p_dp in zip( m_child.named_parameters(), m_dp_child.parameters() ): named_msg = ( layer_name + "." + param_name + " " + iter_msg ) self.assertEqual(p.grad, p_dp.grad, rtol=tol, atol=tol) opt.step() opt_dp.step() opt.zero_grad() opt_dp.zero_grad() except BaseException: # Makes sure we still get info if an error occurred somewhere other than the asserts. print( "Caught exception during iterations at " + named_msg, flush=True, ) raise @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_parameter_list_dict_replica(self): class MyMod(torch.nn.Module): def __init__(self, data, check_fn): super().__init__() self.data = data self.check_fn = check_fn def forward(self, inp): self.check_fn(self) return inp p1 = torch.nn.Parameter(torch.rand(10)) p2 = torch.nn.Parameter(torch.rand(10)) key0 = 0 key1 = 1 def check_fn(self_): self.assertEqual(p1, self_.data[key0]) self.assertEqual(p2, self_.data[key1]) self.assertTrue(self_.data[key0].requires_grad) self.assertTrue(self_.data[key1].requires_grad) self.assertIsNotNone(self_.data[key0].grad_fn) self.assertIsNotNone(self_.data[key1].grad_fn) module = MyMod(torch.nn.ParameterList([p1, p2]), check_fn).cuda() model = dp.DataParallel(module) input = torch.randn((8, 8), device="cuda") # Runs the check_fn model(input) key0 = "0" key1 = "1" module = MyMod(torch.nn.ParameterDict({"0": p1, "1": p2}), check_fn).cuda() model = dp.DataParallel(module) input = torch.randn((8, 8), device="cuda") # Runs the check_fn model(input) class TestDataParallelDeviceType(TestCase): @onlyCUDA @skipMeta @dtypes(torch.float, torch.double, torch.half) def test_data_parallel_module(self, device, dtype): l = nn.Linear(10, 5).to(device, dtype) i = torch.randn(20, 10, device=device, dtype=dtype) expected_out = l(i) net = nn.DataParallel(l) out = net(i) self.assertEqual(out.get_device(), 0) self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) @onlyCUDA @skipMeta @dtypes(torch.float, torch.double, torch.half) def test_data_parallel_module_kwargs_only(self, device, dtype): class Net(nn.Module): def __init__(self) -> None: super().__init__() self.l = l def forward(self, input): return self.l(input) l = nn.Linear(10, 5).to(device, dtype) i = torch.randn(20, 10, device=device, dtype=dtype) expected_out = l(i) n = nn.DataParallel(Net()) out = n(input=i) self.assertEqual(out.get_device(), 0) self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) @onlyCUDA @skipMeta @dtypes(torch.float, torch.double, torch.half) def test_data_parallel_module_kwargs_only_empty_list(self, device, dtype): class Net(nn.Module): def __init__(self) -> None: super().__init__() self.l = l def forward(self, input): return self.l(input["data"]) l = nn.Linear(10, 5).to(device, dtype) i = torch.randn(20, 10, device=device, dtype=dtype) expected_out = l(i) n = nn.DataParallel(Net()) out = n(input={"data": i, "unused": []}) self.assertEqual(out.get_device(), 0) self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) @onlyCUDA @skipMeta @dtypes(torch.float, torch.double, torch.half) def test_data_parallel_module_kwargs_only_empty_dict(self, device, dtype): class Net(nn.Module): def __init__(self) -> None: super().__init__() self.l = l def forward(self, input): return self.l(input["data"]) l = nn.Linear(10, 5).to(device, dtype) i = torch.randn(20, 10, device=device, dtype=dtype) expected_out = l(i) n = nn.DataParallel(Net()) out = n(input={"data": i, "unused": {}}) self.assertEqual(out.get_device(), 0) self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) @onlyCUDA @skipMeta @dtypes(torch.float, torch.double, torch.half) def test_data_parallel_module_kwargs_only_empty_tuple(self, device, dtype): class Net(nn.Module): def __init__(self) -> None: super().__init__() self.l = l def forward(self, input): return self.l(input["data"]) l = nn.Linear(10, 5).to(device, dtype) i = torch.randn(20, 10, device=device, dtype=dtype) expected_out = l(i) n = nn.DataParallel(Net()) out = n(input={"data": i, "unused": ()}) self.assertEqual(out.get_device(), 0) self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) instantiate_device_type_tests(TestDataParallelDeviceType, globals()) if __name__ == "__main__": TestCase._default_dtype_check_enabled = True run_tests()