diff --git a/test/run_test.py b/test/run_test.py index e0bde4e6d52d..13008e5df665 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1231,6 +1231,7 @@ def run_ci_sanity_check(test: ShardedTest, test_directory, options): "test_cuda_primary_ctx": run_test_with_subprocess, "test_cuda_nvml_based_avail": run_test_with_subprocess, "test_cuda_trace": run_test_with_subprocess, + "test_cuda_sync_guard": run_test_with_subprocess, "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja, "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja, "distributed/test_distributed_spawn": test_distributed, diff --git a/test/test_cuda_sync_guard.py b/test/test_cuda_sync_guard.py new file mode 100644 index 000000000000..541c9beb1269 --- /dev/null +++ b/test/test_cuda_sync_guard.py @@ -0,0 +1,42 @@ +# Owner(s): ["module: cuda"] + +import sys +import unittest + +import torch +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import ( + CudaSyncGuard, + NoTest, + run_tests, + TestCase, +) + + +# NOTE: this needs to be run in a brand new process + +if not TEST_CUDA: + print("CUDA not available, skipping tests", file=sys.stderr) + TestCase = NoTest # noqa: F811 + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") +class Test(TestCase): + def test_autograd_save_on_cpu_does_not_synchronize(self): + a = torch.randn(5, requires_grad=True, device="cuda") + b = torch.randn(5, requires_grad=True, device="cuda") + c = torch.randn(5, requires_grad=True, device="cuda") + + def f(a, b, c): + prod_1 = a * b + prod_2 = prod_1 * c + y = prod_2 * a + return y + + with CudaSyncGuard("error"), torch.autograd.graph.save_on_cpu(pin_memory=True): + y = f(a, b, c) + y.sum().backward() + + +if __name__ == "__main__": + run_tests() diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index bf643a97f60f..83fc9a7c798a 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -366,13 +366,14 @@ def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None: def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]: if not pin_memory: return (tensor.device, tensor.cpu()) + actually_pin_memory = device_module.is_available() and not tensor.is_sparse packed = torch.empty( tensor.size(), dtype=tensor.dtype, layout=tensor.layout, - pin_memory=(device_module.is_available() and not tensor.is_sparse), + pin_memory=actually_pin_memory, ) - packed.copy_(tensor) + packed.copy_(tensor, non_blocking=actually_pin_memory) return (tensor.device, packed) def unpack_from_cpu(packed: tuple[torch.device, torch.Tensor]) -> torch.Tensor: