Skip to content

Add beginnings of torch::stable::accelerator #159679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: gh/mikaylagawarecki/332/base
Choose a base branch
from

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Aug 1, 2025

Adds

  • torch::stable::accelerator::DeviceGuard: std::unique_ptr to DeviceGuardOpauqe mostly copied from the below (but made generic)

    class AOTICudaGuard {
    public:
    AOTICudaGuard(int32_t device_index) : guard_(nullptr, delete_cuda_guard) {
    CUDAGuardHandle ptr = nullptr;
    AOTI_TORCH_ERROR_CODE_CHECK(
    aoti_torch_create_cuda_guard(device_index, &ptr));
    guard_.reset(ptr);
    }
    void set_index(int32_t device_index) {
    AOTI_TORCH_ERROR_CODE_CHECK(
    aoti_torch_cuda_guard_set_index(guard_.get(), device_index));
    }
    private:
    std::unique_ptr<CUDAGuardOpaque, DeleterFnPtr> guard_;
    };

    • constructor DeviceGuard(DeviceIndex) (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device)
    • set_index(DeviceIndex)
  • torch::stable::accelerator::Stream: std::shared_ptr to StreamOpaque

    • constructor Stream(StreamHandle stream) (similar to torch::stable::Tensor)
    • id() -> StreamId
  • getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream

Stack from ghstack (oldest at bottom):

Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159679

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit 43994c1 with merge base 556e2a7 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

mikaylagawarecki added a commit that referenced this pull request Aug 1, 2025
ghstack-source-id: 1530b24
Pull Request resolved: #159679
Copy link
Contributor

github-actions bot commented Aug 1, 2025

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

mikaylagawarecki added a commit that referenced this pull request Aug 1, 2025
ghstack-source-id: 2f2fa7b
Pull Request resolved: #159679
@mikaylagawarecki mikaylagawarecki changed the title Add beginning of torch::stable::accelerator Add beginnings of torch::stable::accelerator Aug 1, 2025
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: fae00fc
Pull Request resolved: #159679
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 083b8bc
Pull Request resolved: #159679

using DeviceIndex = int8_t;
using StreamId = int64_t;
class DeviceGuard {
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something that I'm not sure about -- do the copy / move semantics need to match the real device guard although this just a wrapper that is holding a std::unique_ptr to DeviceGuardOpaque?

~DeviceGuard() = default;
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
DeviceGuard& operator=(const DeviceGuard&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
DeviceGuard(DeviceGuard&& other) = delete;
DeviceGuard& operator=(DeviceGuard&& other) = delete;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to copy the existing DeviceGuard. For example, I agree with deleting the default constructor for this API. We can disallow copy and move if it's easiest to maintain anyway.

mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: afc07a8
Pull Request resolved: #159679
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 23b8be1
Pull Request resolved: #159679
Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from 

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 655868b
Pull Request resolved: #159679
Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from 

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 94eaf15
Pull Request resolved: #159679
@albanD
Copy link
Collaborator

albanD commented Aug 4, 2025

FYI @guangyey @EikanWang in case you have some feedback here.

Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from 

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 2ede272
Pull Request resolved: #159679
Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review August 5, 2025 17:31
if torch.cuda.is_available():
extra_compile_args["cxx"].append("-DUSE_CUDA")
extension = CUDAExtension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm maybe this would be a good call to move the CUDA stuff into its own C++ file and build them separately..

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I don't see the issue with the current approach, so gonna keep it as is unless there's something specific you're concerned about! :)

@@ -1590,3 +1594,53 @@ AOTITorchError aoti_torch_call_dispatcher(
}
});
}

AOTITorchError aoti_torch_create_device_guard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +1643 to +1644
c10::Stream* stream_ptr = new c10::Stream(stream);
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new stream on the heap is gonna leak memory, we should be able to assign the stream in line 1642 into the pointer (tho idk the cast semantics).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh after reading the rest of the code, I see what the user flow is intended to be. That said, I'm not sure if we want the semantics of get_current_stream to create a new stream on the heap (this is likely not expected from the user POV), and if we do end up sticking with this, we need to loudly document this as.

I think we do not want to mess with the memory of the stream at all...cc @albanD for thoughts

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm looks like @albanD had no thoughts :b

I'm not sure what to do here, this one looks different from

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
});
}

As that is just directly returning the id. However, the actual getCurrentStream returns a c10::Stream, so I'm trying to match the semantic in the stable ABI (so in future users can add other stream methods on a stream object)

c10::Stream getCurrentStream(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return impl.getStream({device_type, device_index});
}

I think at::acceerator::getCurrentStream is returning c10::Stream by value which is why we will need to create a new object on the heap if we want to return it to the caller and transfer ownership to the stable::Stream

I can add a comment for sure, but would be curious how else we can tweak this in a way that's more user/memory friendly, would it make more sense to just return the .id() directly?


using DeviceIndex = int8_t;
using StreamId = int64_t;
class DeviceGuard {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to copy the existing DeviceGuard. For example, I agree with deleting the default constructor for this API. We can disallow copy and move if it's easiest to maintain anyway.


// Construct a stable::Stream from a StreamHandle
// Steals ownership from the StreamHandle
explicit Stream(StreamHandle stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, so the expected use case is:

  • user calls getCurrentStream which creates a Stream
  • then they can call id on it

@janeyx99
Copy link
Contributor

janeyx99 commented Aug 6, 2025

Thanks for taking this on! I'm left comments, i think my two biggest questions are:

  1. How are we sure DeviceGuard is working as expected haha (add test for set_index and I'm not sure how to add something to ensure we're not leaking memory)
  2. How do we want to represent Stream and what is the story of its memory handling

stack[0] = from(res);
}

int64_t test_stream(int8_t device_index) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schema being int means the arg is int64_t


} // namespace

using DeviceIndex = int8_t; // this is from c10/core/Device.h
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should NOT rely on this one at the shim level.
For this part of the world, let's make it an int32_t for consistency with the existing shim.
We will have to change the tensor.h version.


DeviceGuard guard(device_index);
int currentDevice;
cudaError_t err = cudaGetDevice(&currentDevice);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that before the guard you have another device than this one.
Because if there is only 1 gpu, I'm not sure this will ever do anything

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test that exercises this from python is decorated with @deviceCountAtLeast(2) so this situation won't happen I think :)

Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 11, 2025
ghstack-source-id: 02220eb
Pull Request resolved: #159679
mikaylagawarecki added a commit to mikaylagawarecki/pytorch that referenced this pull request Aug 11, 2025
ghstack-source-id: 02220eb
Pull Request resolved: pytorch#159679
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants