Skip to content

Remove accidental host synchronization in autograd cpu offload #159698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,7 @@ def run_ci_sanity_check(test: ShardedTest, test_directory, options):
"test_cuda_primary_ctx": run_test_with_subprocess,
"test_cuda_nvml_based_avail": run_test_with_subprocess,
"test_cuda_trace": run_test_with_subprocess,
"test_cuda_sync_guard": run_test_with_subprocess,
"test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja,
"test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja,
"distributed/test_distributed_spawn": test_distributed,
Expand Down
42 changes: 42 additions & 0 deletions test/test_cuda_sync_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Owner(s): ["module: cuda"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move these checks to test_autograd no need for a special file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I believe there is a need. CudaSyncGuard sets a global variable, not a thread-local variable. If another thread in test_autograd.py happens to be running a test concurrently where host synchronization is okay, it will error out if the CudaSyncGuard context manager is active. Therefore, I need to run this test in its own proces (check our run_test.py to see how that is enabled). Does that make sense?

I suppose my concern is not really a concern if every test runs serially. Can I assume that this is the case? If so, I am okay making your proposed change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move that to test/autograd/* folder then and make sure to set the appropriate flag in run_test.py to avoid parallelist.

@ngimel I guess it is expected for the CudaSyncGuard to be global and not thread local?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe @serialTest() would be a way to force the single test to run serially

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, CudaSyncGuard is global


import sys
import unittest

import torch
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
CudaSyncGuard,
NoTest,
run_tests,
TestCase,
)


# NOTE: this needs to be run in a brand new process

if not TEST_CUDA:
print("CUDA not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811


@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
class Test(TestCase):
def test_autograd_save_on_cpu_does_not_synchronize(self):
a = torch.randn(5, requires_grad=True, device="cuda")
b = torch.randn(5, requires_grad=True, device="cuda")
c = torch.randn(5, requires_grad=True, device="cuda")

def f(a, b, c):
prod_1 = a * b
prod_2 = prod_1 * c
y = prod_2 * a
return y

with CudaSyncGuard("error"), torch.autograd.graph.save_on_cpu(pin_memory=True):
y = f(a, b, c)
y.sum().backward()


if __name__ == "__main__":
run_tests()
5 changes: 3 additions & 2 deletions torch/autograd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,14 @@ def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None:
def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]:
if not pin_memory:
return (tensor.device, tensor.cpu())
actually_pin_memory = device_module.is_available() and not tensor.is_sparse
packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(device_module.is_available() and not tensor.is_sparse),
pin_memory=actually_pin_memory,
)
packed.copy_(tensor)
packed.copy_(tensor, non_blocking=actually_pin_memory)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD, @soulitzer what stream would this copy run on? Is it guaranteed that offloading and onloading will run on the same stream? Otherwise we do need a sync between offloading and onloading.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can empirically check with CUPTI.

I ran the test case I made with cudaProfilerStart() and cudaProfilerStop() added in just before and after the CudasyncGuard context manager to remove noise.

report18.nsys-rep.zip

If you go to "stats system view", and then "CUDA GPU Trace", and then sort by the "Strm" column, you will see that all GPU operations run on stream 7.

However, I did realize due to your comment that this CPU memory offloading hook is honestly pretty poor. It uses the same stream being used by the rest of the computations. An ideal implementation would use a different stream for these copies, so that you don't pay a time cost in order to use less memory.

If I were to add this capability to use a separate stream to this existing API, though, it could break users who rely on device memory being immediately recycled (if anyone is even using this at all).

Both Megatron-LM and I think torchao have their own memory offloading context handlers that handle this. I'm happy to give up on this PR, or just merge this without using a separate stream for copying.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I suspected that this hook is not doing anything smart (and that's why torchao and Megatron need something better), but for workloads that are using different streams for different autograd nodes it might be running on the stream that the node has run on, or some other stream, and I didn't dig too deep in the implementation to know what happens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this hook is more of a simple showcase than anything extra fancy.

Is it guaranteed that offloading and onloading will run on the same stream?

Yes it is as the autograd engine will set the backward stream to be the forward one. So both will happen on the ambiant stream during the forward op call.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, thanks for the confirmation @albanD . I will get this PR mergeable when I have some down time in the next few days.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly do you plan on doing in the next iteration?
In particular, we would tradeoff the realibility to enable performance (as I mentioned below) ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC the reason we did this one was to ensure that the CPU side object is always correct, whichever way the user pokes at it. Since it is a showcase, we prioritized reliability over performance.

return (tensor.device, packed)

def unpack_from_cpu(packed: tuple[torch.device, torch.Tensor]) -> torch.Tensor:
Expand Down
Loading