Skip to content

Implement cuda graphs implementation of torch.cond and torch.while_loop #140979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from

Conversation

galv
Copy link
Collaborator

@galv galv commented Nov 18, 2024

This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to isaacs/github#361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @StrongerXi @ColinPeppler @desertfire

@galv galv requested review from eqy, syed-ahmed, zou3519 and a team as code owners November 18, 2024 21:05
Copy link

pytorch-bot bot commented Nov 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140979

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ee03af1 with merge base 93316cf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -0,0 +1,26 @@
### Mini Repo ###
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Remove this.

@@ -0,0 +1,5 @@
import torch
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Remove this.

@eellison eellison self-requested a review November 19, 2024 01:45
@galv galv marked this pull request as draft November 20, 2024 08:04
@galv
Copy link
Collaborator Author

galv commented Nov 20, 2024

I converted this to a draft because the change that added a backward pass to torch.cond made me realize that there are a few issues in my design. In particular, I assumed that all stream capture would occur on one thread, but in fact it is normal behavior for the backward pass to use one or more new threads.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good! a few comments

// capture mode. The easiest solution is to handle stream creation
// and deletion ourselves.

// Be sure to call cudaStreamDestroy on this when it is finished
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we call cudaStreamDestroy ? I don't see it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is inside end_capture_to_conditional_node. https://github.com/pytorch/pytorch/pull/140979/files#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR561

Another possibility is to make this return a unique_ptr<cudaStream_t> with a custom destructor. That would be a little bit awkward since the std::stack<at::cuda::CUDAStreamGuard> conditional_node_streams_ would need an additional stack just containing these unique pointers (CUDAStreamGuard lacks a move constructor, so I'm not sure that having a std::pair of CUDAStreamGuard and a unique_ptr would work here).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I did try your suggestion: 7cdced5

To be honest, I think it makes things more opaque and confusing in this instance to use smart pointers. In particular, I almost accidentally depended upon the ordering of calculation of the arguments of a function here (std::move(child_stream) means that *child_stream might return NULL): 7cdced5#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR456-R462

I'm thinking of reverting the change personally, but would be happy to keep it if you feel strongly about it.

return gm.forward

class ControlFlowOpWarmupDispatchMode(TorchDispatchMode):
def __init__(self, stream):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any reason to pass the stream in here instead of instantiate inside ? it's not used anywhere else

@@ -371,7 +371,7 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
), f"Dense implementation operands must be a list of tensors and ints {operands}"
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
if not torch.cuda.is_current_stream_capturing():
if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're already making a custom TorchDispatchMode to do the warmup, could we put this logic in that class to avoid cudagraphs leaking here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I think I encountered some weirdness, that is described here: https://github.com/pytorch/pytorch/pull/140979/files#diff-4093d24068320f82997aecfaf4260aa495e3b4322e178cb5625f9555e3f8b0f5R388-R392

and here:

#140322

But let me take a second look. I would be happy to make that change, assuming there is nothing standing in the way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If #140322 is blocking we can do as a follow up

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did what you suggested. Unfortunately we cannot use code within __torch_dispatch__ right now to dispatch ops in a different way, but we can use py_impl(MyTorchDispatchMode) to do this right now, so I went with that workaround, though it is a little bit ugly because it causes some circular dependencies and separates code that should all sit in one single class. It should probably be redone once #140322 is handled. You can see how I do it for cond here: https://github.com/pytorch/pytorch/pull/140979/files#diff-4093d24068320f82997aecfaf4260aa495e3b4322e178cb5625f9555e3f8b0f5R384-R408

Note that I use two TorchDispatchModes, one for warmup, and one for making sure that I do stream capture appropriately for conditional nodes. I automatically enter CUDAGraphCaptureControlFlowOpDispatchMode in the enter method of torch.cuda.graph. This unfortunately means that users who use CUDAGraph.capture_begin() and CUDAGraph.capture_end() instead of torch.cuda.graph() won't be able to capture conditional ops. I don't think this is a problem to be honest. Do you think it is?

flat_args = [to_real_tensor(a) for a in flat_args]
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)

r = func(*args, **kwargs)
# If one of the inputs is a CUDA tensor, it is possible that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also prevent cuda graph stream logic from leaking into fake tensors by updating the __torch_dispatch__ logic ?

Inside __torch_dispatch__, instantiate inputs to to_real_tensor, then map back outputs to Fake ? We can also do the thread_cuda_stream_capture_mode in it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit, I don't understand your suggestion here.

Basically, as I understand FakeTensors don't always know how to do correct shape inference for some operations. One of these is torch.stft. In this situation, real tensors are instantiated, the operation is run, and then fake tensors are constructed using metadata from the output real tensors. run_fallback_kernel(), the function I modified, is the function that is used to do this in FakeTensorMode, so I think I might be already be doing what you're asking for. Let me know if this makes sense at all!

I encountered this issue because I work a lot on audio data. You can see a unit test here at test/functorch/test_control_flow_cuda_initialization.py's test_cond_stft: https://github.com/pytorch/pytorch/pull/140979/files#diff-ce3a6ac7310a0a222e113af989463fe8f5bd7394a13a9cbfa58cfc0a7ceb494eR55

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test there is not using Fake Tensors, so this code is not touched.

I guess taking a step back - I don't think we should do first run support of cuda graph. The general pytorch constraint is you need to have warmed up your model on the first run, then you can capture. There are a number of tests along these lines cublasCreate etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might try to resolve this in a follow up

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test there is not using Fake Tensors, so this code is not touched.

Actually, that test is using fake tensors in order to compute the output shape of the stft operation. But anyway I agree this is a bit sketchy, looking back on this.

@@ -132,6 +135,11 @@ class graph:
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
collect_garbage (bool, optional): If True, call torch.cuda.synchronize() followed by gc.collect() to free
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently only set to False in ControlFlowOpWarmupDispatchMode. So, I guess the use case there is that someone if capturing a cuda graph in eager and calls into a torch.compile wrapped function ? Wouldn't we expect that they should have already warmed up the function once prior to capture ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally, the challenge is that my current approach to this can cause two distinct cuda stream captures to happen at the same time.

First of all, pytorch never actually intended for more than one stream capture to be possible, from what I can see. This is clearly stated in the documentation of torch.cuda.graph() See here: https://github.com/pytorch/pytorch/pull/140979/files#diff-39e542d87359e7d5381d036cbcea9ec759fbe469578bcdc5693ce6cfab7f1a54R187-R188 If torch.cuda.graph is constructed without a stream, every call to init will use the same single cuda stream. And calling cudaStreamBeginCapture() on an already capturing stream will fail.

As you see, entering the torch.cuda.graph() context manager will call torch.cuda.synchronize(). Unfortunately, torch.cuda.synchronize() will basically call cudaStreamSynchronize() on every stream in the cuda context, and cudaStreamSynchronize() will invalidate stream capture for a capturing stream.

That is why I added this boolean flag, since my "warmup" TorchDispatchMode will always start a second stream capture when another stream capture is already going on.

Unfortunately, there is more to it than that.

The warmup TorchDispatchMode does two things:

  1. It ensures that all conditionally-executed code paths are "warmed up" by running them at least once.
  2. It avoids side effects while doing (1) by "stubbing" out all kernel launches via doing relaxed stream capture (side effects are not acceptable because I don't want both sides of an if-else statement to execute. That would never happen during normal execution, and could therefore put the program into an undefined state).

Unfortunately, even though I do relaxed stream capture on the second stream, the first stream caputre may be in global or thread local stream capture mode. There is no way to temporarily "turn off" stream capture on a stream and then turn it back on right now (though in principle, I see no reason why we couldn't do it). This means that unsafe actions done during the warmup will still invalidate the first stream's capture. Thus, I need to change the thread's capture mode to relaxed at various points, which I'm not happy with.

Pytorch has succeeded so far with just using the torch.cuda.graph() context manager for capturing (though someone could use torch.cuda.CUDAGraph()'s APIs directly). This is the first time that what code we want to execute will differ depending upon whether a stream is capturing or not.

I can probably keep running with pytorch's current assumption that only one stream can do capture at a time if the user passed me a function describing the work that they wanted to do. I would then make a wrapper function over that uses torch.cuda.graph() twice, sequentially, once in relaxed mode, and then again in the user-requested mode.

As you can see, there are unfortunately quite a few annoying details here. But I am inclined to agree with you that there is probably an easier, safer way to do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is why I added this boolean flag, since my "warmup" TorchDispatchMode will always start a second stream capture when another stream capture is already going on.

I think a lot of these problems stem from trying to do warmup when stream capture is already on.

self.__class__.default_capture_stream = torch.cuda.Stream()
# We use an external stream to prevent the rare situation
# of a stream that is already in used being grabbed from
# the stream pool. This requires destroying the external
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For appropriately rare situations, i would lean towards not supporting until it becomes an actual issue. Or at least have tests for these edge cases to be sure we are correctly handling.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'm inclined to just undo this change to be honest. I'm worried about causing an unexpected behavior or error by stuffing too much into this PR. Developers in the know can simply pass in their own cuda stream torch.cuda.graph() to prevent trying to restart stream capture on the same stream.

@galv galv requested a review from bdhirsh as a code owner December 2, 2024 06:40
@galv
Copy link
Collaborator Author

galv commented Dec 3, 2024

@pytorchbot label "release notes: cuda"

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Dec 3, 2024
@galv
Copy link
Collaborator Author

galv commented Dec 3, 2024

@eqy do you know how I can run CI? It's pretty helpful not to have to run the entire test suite locally.

@galv galv force-pushed the galv/cudagraphs-conditional-nodes-3 branch from 1097032 to da3a7bc Compare December 3, 2024 19:09
# unavoidable. Presumably even under stream
# capture, we still want to save and restore the
# rng state.
cuda_rng_state = torch.cuda.default_generators[
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Make sure that resetting the cuda rng state works as intended when doing stream capture. Frankly, I'm not even sure what the intended behavior is supposed to be. This is the first time that a code sequence like:

with torch.cuda.graph(g):
    torch.compile(f)

is being done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM runs torch.compile and captures it afterwards. The recommendation (as with other cudagraphs) is that you should warmup the run first, then compile.

torch.compile(f)
with torch.cuda.graph(g):
    torch.compile(f)

Should we revert this ?

galv and others added 13 commits February 20, 2025 06:11
This allows torch.cond and torch.while_loop to be captured in a single
cuda graph. This is done by manually inserting conditional IF nodes
and conditional WHILE nodes during stream capture. This approach is
discussed here:

https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes

Previously, data-depenent control flow would force the
usage of cuda graph trees, since data-dependent control flow was done
on the CPU. Now, data-dependent control flow is done on the GPU.

This work depends upon CUDA 12.4, since cuda graph conditional nodes
were introduced in CUDA 12.4.

This works only with torch.compile(..., backend="eager") and
torch.compile(..., backend="cudagraphs") backends right now. Notably,
there is no inductor support at this time!

Conditional nodes for cuda graphs were first experimented with in
https://arxiv.org/abs/2406.03791 . While this paper showed strong
improvements in data-dependent workloads that were very CPU-overhead
bound, the next place to look for improvement is horizontal and
vertical kernel fusion, which can eventually be enabled automatically
once conditional nodes work with backends like inductor. This PR is
the first step towards that. I also expect this work to benefit more
sophisticated models like autoregressive decoding of LLMs, at least
for users that are using static shape kv-caches.

This is work done by @tingyangk and me.

We have a sophisticated example of RNN-T greedy decoding (the
algorithm discussing in the paper) working with this new feature here:
tingyangk/NeMo@975a806#diff-2c2a72c9a5392d4a6ea5149fea3ce7900b9fd2c630e460bbab94547379553ceaR376

Some design decisions:

1. We initially implemented this by using the pypi cuda-python
package. We changed to modifying CUDAGraph.cpp instead, and wrapping
the new methods in pybind11 to expose them in python. This would
presumably allow for triton's C++ code generator to call the relevant
APIs. It also avoids a new dependency on cuda-python. As a nuisance,
the "bodies" of conditional nodes need to be captured to new streams.
Therefore, we need access to the stream capture mode of the currently
capturing stream to pass to cudaStreamBeginCaptureToGraph().

2. We expose the currently capturing CUDAGraph via a thread_local
variable with the static method
CUDAGraph::get_currently_capturing_graph(). Basically, since we don't
pass the CUDAGraph at every step of a particular computation, we need
to be able to access it to call the newly exposed APIs. There is a
long doc string in CUDAGraph.cpp about this, but the TL;DR is that it
may be more correct to have a global hash table mapping CUDAStreams to
CUDAGraphs instead. It depends on whether or not pytorch code is ever
expected to stream capture to a graph from more than one thread. I
don't believe that any pytorch code today does this, but I may be
mistaken.

Enable conditional nodes only for CUDA 12.4 and up.

Fixups from PR comments.

Add a new TorchDispatchMode to warmup the ops in a program.

This ensures we can do stream capture successfully.

Change the internal impleemntation to allow for more than one stream
capture at a time.

Enable relaxed stream capture mode on current thread where appropriate.

Always create a new stream in the relevant contexts,t o prevent stream
caputre issues.

Linting.

Lint fixes and CI fixes.

Make sure all gradients are 0 rather than none.

Do autograd on the same thread as forward prop. That prevents errors
from unsafe actions like cublasInitialize() breaking stream capture.

Remove note to self.

lint

fix test

simplify stream capture.

Don't check for cudaStreamIsCapturing() in the eager mode
implementations. Instead use a new TorchDispatchMode inside of torch.cuda.graph().

Fix while loop unit tests

lint

Use smart pointers to destroy the externally created cudaStream_t

Honestly not sure that this is better than what I had before with
manually calling cudaStreamDestroy(). Might revert.

lint

fix support for arbitrary pytrees with cuda graphs for while_loop

lint

Document thread_cuda_stream_capture_mode.

Fix integer conversion resulted in a change of sign.

Only nvcc would report this while compiling CUDAGraph.cu.

Make sure that populate_builtin_to_tensor_fn_map() uses CPU ops.

The following program used to fail:

```
import torch

torch.set_default_device("cuda")

def f(pred, x):
    return torch.cond(pred, lambda y: torch.sum(y), lambda y: torch.sum(torch.stft(y, 512, return_complex=False)), [x])

g = torch.cuda.CUDAGraph()

pred = torch.tensor(True)
x = torch.ones(1024 * 1024)

with torch.cuda.graph(g, capture_error_mode="thread_local"):
    f(pred, x)
```

torch.set_default_device("cuda") would make the kernels used by
populate_builtin_to_tensor_fn_map() run on the cuda stream used for
stream capture, which would cause a crash for operator.not_, which
call cudaStreamSynchronize(). Even if this weren't the case, all of
these operations would be unnecessarily captured into the cuda graph,
which reduces performance when replaying the cuda graph.

The dispatch key does not influence the function name, so I just force
dispatching on the CPU to fix this.

Simplify graph capture of conditional nodes.

The check for whether or not we've already warmed up a function is not
helpful because in practice it will never be true, unless we
maintained a global cache, rather than an instance variable-local
cache. Remove this for now.

lint

Remove logic to use an external stream rather than a stream from the
pytorch pool.

It is unlikely that the user will have 32 concurrent cuda stream
captures at once.

Save before moving around.

Give up on eager backend. Enable cudagraphs backend support.

Revert materiazling gradients to 0 rather than None.

It is handled by main now.

lint
while_loop does not support modifying captured variables right now.
Filter out test that calls .item() (causes device sync) as expected to fail.
@galv galv force-pushed the galv/cudagraphs-conditional-nodes-3 branch from 81f2d44 to 630af47 Compare February 20, 2025 06:16
Raymo111 pushed a commit that referenced this pull request Feb 20, 2025
…op (#140979)

This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to isaacs/github#361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: #140979
Approved by: https://github.com/eqy, https://github.com/eellison
Raymo111 pushed a commit that referenced this pull request Feb 20, 2025
…while_loop (#140979)"

This reverts commit c7515da.

Reverted #140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](#140979 (comment)))
@facebook-github-bot
Copy link
Contributor

@eellison has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…op (pytorch#140979)

This is a new PR for pytorch#130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to isaacs/github#361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: pytorch#140979
Approved by: https://github.com/eqy, https://github.com/eellison
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…while_loop (pytorch#140979)"

This reverts commit c7515da.

Reverted pytorch#140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](pytorch#140979 (comment)))
@tlrmchlsmth
Copy link

Looking forward to trying this out in vLLM once it lands!

@galv
Copy link
Collaborator Author

galv commented Apr 7, 2025

@tlrmchlsmth unfortunately I never succeeded in getting a reproducer from @huydhn Since the failure is internal to Meta and there is no externally working reproducer, unfortunately I've put this on hold.

Let me know what you're thinking of doing this with this. If you want simply to minimize the overhead of the GPU synchronizing with the GPU when you check for termination at the end of every iteration, #146145 and #146924 combined will allow for that in an interesting way. However, it will increase memory usage due to unrolling your loop a few iterations, and I know that device memory usage from cuda graph nodes has been a problem in vLLM before.

@BoyuanFeng BoyuanFeng self-requested a review April 7, 2025 22:54
Copy link
Contributor

github-actions bot commented Jun 6, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 6, 2025
@github-actions github-actions bot closed this Jul 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants