-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Implement cuda graphs implementation of torch.cond and torch.while_loop #140979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140979
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ee03af1 with merge base 93316cf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
repros/repro_out_grapharg.py
Outdated
@@ -0,0 +1,26 @@ | |||
### Mini Repo ### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Remove this.
repros/repro_tuple_out.py
Outdated
@@ -0,0 +1,5 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Remove this.
I converted this to a draft because the change that added a backward pass to torch.cond made me realize that there are a few issues in my design. In particular, I assumed that all stream capture would occur on one thread, but in fact it is normal behavior for the backward pass to use one or more new threads. |
b6d9e9f
to
3818a0f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good! a few comments
aten/src/ATen/cuda/CUDAGraph.cpp
Outdated
// capture mode. The easiest solution is to handle stream creation | ||
// and deletion ourselves. | ||
|
||
// Be sure to call cudaStreamDestroy on this when it is finished |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we call cudaStreamDestroy
? I don't see it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is inside end_capture_to_conditional_node
. https://github.com/pytorch/pytorch/pull/140979/files#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR561
Another possibility is to make this return a unique_ptr<cudaStream_t> with a custom destructor. That would be a little bit awkward since the std::stack<at::cuda::CUDAStreamGuard> conditional_node_streams_
would need an additional stack just containing these unique pointers (CUDAStreamGuard lacks a move constructor, so I'm not sure that having a std::pair of CUDAStreamGuard and a unique_ptr would work here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I did try your suggestion: 7cdced5
To be honest, I think it makes things more opaque and confusing in this instance to use smart pointers. In particular, I almost accidentally depended upon the ordering of calculation of the arguments of a function here (std::move(child_stream) means that *child_stream might return NULL): 7cdced5#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR456-R462
I'm thinking of reverting the change personally, but would be happy to keep it if you feel strongly about it.
torch/_dynamo/backends/debugging.py
Outdated
return gm.forward | ||
|
||
class ControlFlowOpWarmupDispatchMode(TorchDispatchMode): | ||
def __init__(self, stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: any reason to pass the stream in here instead of instantiate inside ? it's not used anywhere else
torch/_higher_order_ops/cond.py
Outdated
@@ -371,7 +371,7 @@ def cond_op_dense(pred, true_fn, false_fn, operands): | |||
), f"Dense implementation operands must be a list of tensors and ints {operands}" | |||
mode = _get_current_dispatch_mode() | |||
assert mode is None, "Mode should never be enabled for CPU/CUDA key" | |||
if not torch.cuda.is_current_stream_capturing(): | |||
if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're already making a custom TorchDispatchMode to do the warmup, could we put this logic in that class to avoid cudagraphs leaking here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I think I encountered some weirdness, that is described here: https://github.com/pytorch/pytorch/pull/140979/files#diff-4093d24068320f82997aecfaf4260aa495e3b4322e178cb5625f9555e3f8b0f5R388-R392
and here:
But let me take a second look. I would be happy to make that change, assuming there is nothing standing in the way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If #140322 is blocking we can do as a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did what you suggested. Unfortunately we cannot use code within __torch_dispatch__
right now to dispatch ops in a different way, but we can use py_impl(MyTorchDispatchMode) to do this right now, so I went with that workaround, though it is a little bit ugly because it causes some circular dependencies and separates code that should all sit in one single class. It should probably be redone once #140322 is handled. You can see how I do it for cond here: https://github.com/pytorch/pytorch/pull/140979/files#diff-4093d24068320f82997aecfaf4260aa495e3b4322e178cb5625f9555e3f8b0f5R384-R408
Note that I use two TorchDispatchModes, one for warmup, and one for making sure that I do stream capture appropriately for conditional nodes. I automatically enter CUDAGraphCaptureControlFlowOpDispatchMode in the enter method of torch.cuda.graph. This unfortunately means that users who use CUDAGraph.capture_begin() and CUDAGraph.capture_end() instead of torch.cuda.graph() won't be able to capture conditional ops. I don't think this is a problem to be honest. Do you think it is?
torch/_subclasses/fake_tensor.py
Outdated
flat_args = [to_real_tensor(a) for a in flat_args] | ||
args, kwargs = pytree.tree_unflatten(flat_args, args_spec) | ||
|
||
r = func(*args, **kwargs) | ||
# If one of the inputs is a CUDA tensor, it is possible that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also prevent cuda graph stream logic from leaking into fake tensors by updating the __torch_dispatch__
logic ?
Inside __torch_dispatch__
, instantiate inputs to to_real_tensor
, then map back outputs to Fake ? We can also do the thread_cuda_stream_capture_mode
in it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to admit, I don't understand your suggestion here.
Basically, as I understand FakeTensors don't always know how to do correct shape inference for some operations. One of these is torch.stft. In this situation, real tensors are instantiated, the operation is run, and then fake tensors are constructed using metadata from the output real tensors. run_fallback_kernel(), the function I modified, is the function that is used to do this in FakeTensorMode, so I think I might be already be doing what you're asking for. Let me know if this makes sense at all!
I encountered this issue because I work a lot on audio data. You can see a unit test here at test/functorch/test_control_flow_cuda_initialization.py's test_cond_stft: https://github.com/pytorch/pytorch/pull/140979/files#diff-ce3a6ac7310a0a222e113af989463fe8f5bd7394a13a9cbfa58cfc0a7ceb494eR55
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test there is not using Fake Tensors, so this code is not touched.
I guess taking a step back - I don't think we should do first run support of cuda graph. The general pytorch constraint is you need to have warmed up your model on the first run, then you can capture. There are a number of tests along these lines cublasCreate
etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might try to resolve this in a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test there is not using Fake Tensors, so this code is not touched.
Actually, that test is using fake tensors in order to compute the output shape of the stft operation. But anyway I agree this is a bit sketchy, looking back on this.
torch/cuda/graphs.py
Outdated
@@ -132,6 +135,11 @@ class graph: | |||
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for | |||
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting | |||
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_ | |||
collect_garbage (bool, optional): If True, call torch.cuda.synchronize() followed by gc.collect() to free |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently only set to False in ControlFlowOpWarmupDispatchMode
. So, I guess the use case there is that someone if capturing a cuda graph in eager and calls into a torch.compile wrapped function ? Wouldn't we expect that they should have already warmed up the function once prior to capture ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fundamentally, the challenge is that my current approach to this can cause two distinct cuda stream captures to happen at the same time.
First of all, pytorch never actually intended for more than one stream capture to be possible, from what I can see. This is clearly stated in the documentation of torch.cuda.graph() See here: https://github.com/pytorch/pytorch/pull/140979/files#diff-39e542d87359e7d5381d036cbcea9ec759fbe469578bcdc5693ce6cfab7f1a54R187-R188 If torch.cuda.graph is constructed without a stream, every call to init will use the same single cuda stream. And calling cudaStreamBeginCapture() on an already capturing stream will fail.
As you see, entering the torch.cuda.graph() context manager will call torch.cuda.synchronize(). Unfortunately, torch.cuda.synchronize() will basically call cudaStreamSynchronize() on every stream in the cuda context, and cudaStreamSynchronize() will invalidate stream capture for a capturing stream.
That is why I added this boolean flag, since my "warmup" TorchDispatchMode will always start a second stream capture when another stream capture is already going on.
Unfortunately, there is more to it than that.
The warmup TorchDispatchMode does two things:
- It ensures that all conditionally-executed code paths are "warmed up" by running them at least once.
- It avoids side effects while doing (1) by "stubbing" out all kernel launches via doing relaxed stream capture (side effects are not acceptable because I don't want both sides of an if-else statement to execute. That would never happen during normal execution, and could therefore put the program into an undefined state).
Unfortunately, even though I do relaxed stream capture on the second stream, the first stream caputre may be in global or thread local stream capture mode. There is no way to temporarily "turn off" stream capture on a stream and then turn it back on right now (though in principle, I see no reason why we couldn't do it). This means that unsafe actions done during the warmup will still invalidate the first stream's capture. Thus, I need to change the thread's capture mode to relaxed at various points, which I'm not happy with.
Pytorch has succeeded so far with just using the torch.cuda.graph() context manager for capturing (though someone could use torch.cuda.CUDAGraph()'s APIs directly). This is the first time that what code we want to execute will differ depending upon whether a stream is capturing or not.
I can probably keep running with pytorch's current assumption that only one stream can do capture at a time if the user passed me a function describing the work that they wanted to do. I would then make a wrapper function over that uses torch.cuda.graph() twice, sequentially, once in relaxed mode, and then again in the user-requested mode.
As you can see, there are unfortunately quite a few annoying details here. But I am inclined to agree with you that there is probably an easier, safer way to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is why I added this boolean flag, since my "warmup" TorchDispatchMode will always start a second stream capture when another stream capture is already going on.
I think a lot of these problems stem from trying to do warmup when stream capture is already on.
torch/cuda/graphs.py
Outdated
self.__class__.default_capture_stream = torch.cuda.Stream() | ||
# We use an external stream to prevent the rare situation | ||
# of a stream that is already in used being grabbed from | ||
# the stream pool. This requires destroying the external |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For appropriately rare situations, i would lean towards not supporting until it becomes an actual issue. Or at least have tests for these edge cases to be sure we are correctly handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I'm inclined to just undo this change to be honest. I'm worried about causing an unexpected behavior or error by stuffing too much into this PR. Developers in the know can simply pass in their own cuda stream torch.cuda.graph() to prevent trying to restart stream capture on the same stream.
@pytorchbot label "release notes: cuda" |
@eqy do you know how I can run CI? It's pretty helpful not to have to run the entire test suite locally. |
1097032
to
da3a7bc
Compare
torch/_dynamo/convert_frame.py
Outdated
# unavoidable. Presumably even under stream | ||
# capture, we still want to save and restore the | ||
# rng state. | ||
cuda_rng_state = torch.cuda.default_generators[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Make sure that resetting the cuda rng state works as intended when doing stream capture. Frankly, I'm not even sure what the intended behavior is supposed to be. This is the first time that a code sequence like:
with torch.cuda.graph(g):
torch.compile(f)
is being done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLLM runs torch.compile and captures it afterwards. The recommendation (as with other cudagraphs) is that you should warmup the run first, then compile.
torch.compile(f)
with torch.cuda.graph(g):
torch.compile(f)
Should we revert this ?
This allows torch.cond and torch.while_loop to be captured in a single cuda graph. This is done by manually inserting conditional IF nodes and conditional WHILE nodes during stream capture. This approach is discussed here: https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes Previously, data-depenent control flow would force the usage of cuda graph trees, since data-dependent control flow was done on the CPU. Now, data-dependent control flow is done on the GPU. This work depends upon CUDA 12.4, since cuda graph conditional nodes were introduced in CUDA 12.4. This works only with torch.compile(..., backend="eager") and torch.compile(..., backend="cudagraphs") backends right now. Notably, there is no inductor support at this time! Conditional nodes for cuda graphs were first experimented with in https://arxiv.org/abs/2406.03791 . While this paper showed strong improvements in data-dependent workloads that were very CPU-overhead bound, the next place to look for improvement is horizontal and vertical kernel fusion, which can eventually be enabled automatically once conditional nodes work with backends like inductor. This PR is the first step towards that. I also expect this work to benefit more sophisticated models like autoregressive decoding of LLMs, at least for users that are using static shape kv-caches. This is work done by @tingyangk and me. We have a sophisticated example of RNN-T greedy decoding (the algorithm discussing in the paper) working with this new feature here: tingyangk/NeMo@975a806#diff-2c2a72c9a5392d4a6ea5149fea3ce7900b9fd2c630e460bbab94547379553ceaR376 Some design decisions: 1. We initially implemented this by using the pypi cuda-python package. We changed to modifying CUDAGraph.cpp instead, and wrapping the new methods in pybind11 to expose them in python. This would presumably allow for triton's C++ code generator to call the relevant APIs. It also avoids a new dependency on cuda-python. As a nuisance, the "bodies" of conditional nodes need to be captured to new streams. Therefore, we need access to the stream capture mode of the currently capturing stream to pass to cudaStreamBeginCaptureToGraph(). 2. We expose the currently capturing CUDAGraph via a thread_local variable with the static method CUDAGraph::get_currently_capturing_graph(). Basically, since we don't pass the CUDAGraph at every step of a particular computation, we need to be able to access it to call the newly exposed APIs. There is a long doc string in CUDAGraph.cpp about this, but the TL;DR is that it may be more correct to have a global hash table mapping CUDAStreams to CUDAGraphs instead. It depends on whether or not pytorch code is ever expected to stream capture to a graph from more than one thread. I don't believe that any pytorch code today does this, but I may be mistaken. Enable conditional nodes only for CUDA 12.4 and up. Fixups from PR comments. Add a new TorchDispatchMode to warmup the ops in a program. This ensures we can do stream capture successfully. Change the internal impleemntation to allow for more than one stream capture at a time. Enable relaxed stream capture mode on current thread where appropriate. Always create a new stream in the relevant contexts,t o prevent stream caputre issues. Linting. Lint fixes and CI fixes. Make sure all gradients are 0 rather than none. Do autograd on the same thread as forward prop. That prevents errors from unsafe actions like cublasInitialize() breaking stream capture. Remove note to self. lint fix test simplify stream capture. Don't check for cudaStreamIsCapturing() in the eager mode implementations. Instead use a new TorchDispatchMode inside of torch.cuda.graph(). Fix while loop unit tests lint Use smart pointers to destroy the externally created cudaStream_t Honestly not sure that this is better than what I had before with manually calling cudaStreamDestroy(). Might revert. lint fix support for arbitrary pytrees with cuda graphs for while_loop lint Document thread_cuda_stream_capture_mode. Fix integer conversion resulted in a change of sign. Only nvcc would report this while compiling CUDAGraph.cu. Make sure that populate_builtin_to_tensor_fn_map() uses CPU ops. The following program used to fail: ``` import torch torch.set_default_device("cuda") def f(pred, x): return torch.cond(pred, lambda y: torch.sum(y), lambda y: torch.sum(torch.stft(y, 512, return_complex=False)), [x]) g = torch.cuda.CUDAGraph() pred = torch.tensor(True) x = torch.ones(1024 * 1024) with torch.cuda.graph(g, capture_error_mode="thread_local"): f(pred, x) ``` torch.set_default_device("cuda") would make the kernels used by populate_builtin_to_tensor_fn_map() run on the cuda stream used for stream capture, which would cause a crash for operator.not_, which call cudaStreamSynchronize(). Even if this weren't the case, all of these operations would be unnecessarily captured into the cuda graph, which reduces performance when replaying the cuda graph. The dispatch key does not influence the function name, so I just force dispatching on the CPU to fix this. Simplify graph capture of conditional nodes. The check for whether or not we've already warmed up a function is not helpful because in practice it will never be true, unless we maintained a global cache, rather than an instance variable-local cache. Remove this for now. lint Remove logic to use an external stream rather than a stream from the pytorch pool. It is unlikely that the user will have 32 concurrent cuda stream captures at once. Save before moving around. Give up on eager backend. Enable cudagraphs backend support. Revert materiazling gradients to 0 rather than None. It is handled by main now. lint
Remove is_infra_mode()
while_loop does not support modifying captured variables right now.
Filter out test that calls .item() (causes device sync) as expected to fail.
Fix a few more tests.
81f2d44
to
630af47
Compare
…op (#140979) This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to isaacs/github#361 I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534 Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that. Pull Request resolved: #140979 Approved by: https://github.com/eqy, https://github.com/eellison
…while_loop (#140979)" This reverts commit c7515da. Reverted #140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](#140979 (comment)))
@eellison has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…op (pytorch#140979) This is a new PR for pytorch#130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to isaacs/github#361 I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534 Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that. Pull Request resolved: pytorch#140979 Approved by: https://github.com/eqy, https://github.com/eellison
…while_loop (pytorch#140979)" This reverts commit c7515da. Reverted pytorch#140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](pytorch#140979 (comment)))
Looking forward to trying this out in vLLM once it lands! |
@tlrmchlsmth unfortunately I never succeeded in getting a reproducer from @huydhn Since the failure is internal to Meta and there is no externally working reproducer, unfortunately I've put this on hold. Let me know what you're thinking of doing this with this. If you want simply to minimize the overhead of the GPU synchronizing with the GPU when you check for termination at the end of every iteration, #146145 and #146924 combined will allow for that in an interesting way. However, it will increase memory usage due to unrolling your loop a few iterations, and I know that device memory usage from cuda graph nodes has been a problem in vLLM before. |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to isaacs/github#361
I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534
Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @StrongerXi @ColinPeppler @desertfire