Skip to content

Add beginnings of torch::stable::accelerator #159679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: gh/mikaylagawarecki/332/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/util/Exception.h>

#ifdef LAE_USE_CUDA
#include <cuda_runtime.h>
#endif

#include <optional>

void inline sgd_math(
Expand Down Expand Up @@ -320,3 +325,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("my_zero_", &boxed_my_zero_);
}

// Test functions for torch::stable::accelerator APIs

#ifdef LAE_USE_CUDA
int64_t test_device_guard(int64_t device_index) {
using torch::stable::accelerator::DeviceGuard;

STD_TORCH_CHECK(
device_index >= std::numeric_limits<int32_t>::min() &&
device_index <= std::numeric_limits<int32_t>::max(),
"Device index is out of range of DeviceIndex (int32_t).");

DeviceGuard guard(device_index);
int currentDevice;
cudaError_t err = cudaGetDevice(&currentDevice);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that before the guard you have another device than this one.
Because if there is only 1 gpu, I'm not sure this will ever do anything

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test that exercises this from python is decorated with @deviceCountAtLeast(2) so this situation won't happen I think :)

STD_TORCH_CHECK(err == cudaSuccess);
return currentDevice;
}

void boxed_test_device_guard(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
}

int64_t test_device_guard_set_index() {
using torch::stable::accelerator::DeviceGuard;

DeviceGuard guard(1);
guard.set_index(0);
int currentDevice;
cudaError_t err = cudaGetDevice(&currentDevice);
STD_TORCH_CHECK(err == cudaSuccess);
return currentDevice;
}

void boxed_test_device_guard_set_index(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_device_guard_set_index();
stack[0] = from(res);
}

int64_t test_stream(int32_t device_index) {
STD_TORCH_CHECK(
device_index >= std::numeric_limits<int32_t>::min() &&
device_index <= std::numeric_limits<int32_t>::max(),
"Device index is out of range of DeviceIndex (int32_t).");

return torch::stable::accelerator::getCurrentStream(device_index).id();
}

void boxed_test_stream(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_device_guard(int device_index) -> int");
m.def("test_device_guard_set_index() -> int");
m.def("test_stream(int device_index) -> int");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_device_guard", &boxed_test_device_guard);
m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index);
m.impl("test_stream", &boxed_test_stream);
}
#endif // LAE_USE_CUDA
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,37 @@ def fill_infinity(t) -> Tensor:
Returns: The modified tensor (same as input)
"""
return torch.ops.libtorch_agnostic.fill_infinity.default(t)


def test_device_guard(device_index) -> int:
"""
Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor.

Args:
device_index: Device index to set the guard to

Returns: result of cudaGetDevice() as an integer after using the guard
"""
return torch.ops.libtorch_agnostic.test_device_guard.default(device_index)


def test_device_guard_set_index() -> int:
"""
Tests the DeviceGuard set_index functionality by creating a device guard with index 1,
then setting it to index 0, and returning the current device.

Returns: result of cudaGetDevice() as an integer after using set_index
"""
return torch.ops.libtorch_agnostic.test_device_guard_set_index.default()


def test_stream(device_index) -> int:
"""
Tests the Stream functionality by getting the current stream ID for the specified device.

Args:
device_index: Device index to get the stream for

Returns: Stream ID as an integer
"""
return torch.ops.libtorch_agnostic.test_stream.default(device_index)
11 changes: 9 additions & 2 deletions test/cpp_extensions/libtorch_agnostic_extension/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from setuptools import find_packages, setup

from torch.utils.cpp_extension import BuildExtension, CppExtension
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension


ROOT_DIR = Path(__file__).parent
Expand Down Expand Up @@ -35,10 +36,16 @@ def get_extension():
"cxx": ["-fdiagnostics-color=always"],
}

extension = CppExtension
# allow including <cuda_runtime.h>
if torch.cuda.is_available():
extra_compile_args["cxx"].append("-DLAE_USE_CUDA")
extension = CUDAExtension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm maybe this would be a good call to move the CUDA stuff into its own C++ file and build them separately..

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I don't see the issue with the current approach, so gonna keep it as is unless there's something specific you're concerned about! :)

sources = list(CSRC_DIR.glob("**/*.cpp"))

return [
CppExtension(
extension(
"libtorch_agnostic._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch.testing._internal.common_device_type import (
deviceCountAtLeast,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
Expand Down Expand Up @@ -218,6 +219,38 @@ def test_fill_infinity(self, device):
expected = torch.full_like(t, math.inf)
self.assertEqual(out, expected)

@onlyCUDA
@deviceCountAtLeast(2)
def test_device_guard(self, device):
import libtorch_agnostic

device_index = 1
out = libtorch_agnostic.ops.test_device_guard(device_index)
self.assertEqual(out, device_index)

@onlyCUDA
@deviceCountAtLeast(2)
def test_device_guard_set_index(self, device):
import libtorch_agnostic

# This test creates a DeviceGuard with index 1, then sets it to index 0
# and returns the current device (should be 0)
out = libtorch_agnostic.ops.test_device_guard_set_index()
self.assertEqual(out, 0)

@onlyCUDA
def test_stream(self, device):
import libtorch_agnostic

stream = torch.cuda.Stream()
device = torch.cuda.current_device()

with stream:
expected_stream_id = torch.cuda.current_stream(0).stream_id
stream_id = libtorch_agnostic.ops.test_stream(device)

self.assertEqual(stream_id, expected_stream_id)

instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)

if __name__ == "__main__":
Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,36 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher(
const char* overloadName,
StableIValue* stack);

// Device-generic guard for managing device context
struct DeviceGuardOpaque;
using DeviceGuardHandle = DeviceGuardOpaque*;

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_device_guard(
int32_t device_index,
DeviceGuardHandle* ret_guard // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_delete_device_guard(DeviceGuardHandle guard);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_device_guard_set_index(
DeviceGuardHandle guard,
int32_t device_index);

// Device-generic stream for managing stream objects
struct StreamOpaque;
using StreamHandle = StreamOpaque*;

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_delete_stream(StreamHandle stream);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_stream_id(StreamHandle stream, int64_t* ret_stream_id);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream(
int32_t device_index,
StreamHandle* ret_stream // returns new reference
);

#ifdef USE_CUDA

struct CUDAGuardOpaque;
Expand Down
56 changes: 56 additions & 0 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#include <iostream>
#include <vector>

#include <c10/core/Device.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
Expand Down Expand Up @@ -1612,3 +1616,55 @@ AOTITorchError aoti_torch_call_dispatcher(
}
});
}

AOTITorchError aoti_torch_create_device_guard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32_t device_index,
DeviceGuardHandle* ret_guard // returns new reference
) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
// checked=true will fail if no accelerator is available
const auto device_type =
at::accelerator::getAccelerator(/*checked=*/true).value();
c10::Device device(device_type, device_index);
c10::DeviceGuard* guard = new c10::DeviceGuard(device);
*ret_guard = reinterpret_cast<DeviceGuardHandle>(guard);
});
}

AOTITorchError aoti_torch_delete_device_guard(DeviceGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<c10::DeviceGuard*>(guard); });
}

AOTITorchError aoti_torch_device_guard_set_index(
DeviceGuardHandle guard,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ reinterpret_cast<c10::DeviceGuard*>(guard)->set_index(device_index); });
}

AOTITorchError aoti_torch_delete_stream(StreamHandle stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<c10::Stream*>(stream); });
}

AOTITorchError aoti_torch_stream_id(
StreamHandle stream,
int64_t* ret_stream_id) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
c10::Stream* stream_ptr = reinterpret_cast<c10::Stream*>(stream);
*ret_stream_id = stream_ptr->id();
});
}

// This function creates a new Stream object and makes StreamHandle point to it.
// The caller is responsible for managing the object's lifecycle.
AOTITorchError aoti_torch_get_current_stream(
int32_t device_index,
StreamHandle* ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
c10::Stream stream = at::accelerator::getCurrentStream(device_index);
c10::Stream* stream_ptr = new c10::Stream(stream);
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr);
Comment on lines +1667 to +1668
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new stream on the heap is gonna leak memory, we should be able to assign the stream in line 1642 into the pointer (tho idk the cast semantics).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh after reading the rest of the code, I see what the user flow is intended to be. That said, I'm not sure if we want the semantics of get_current_stream to create a new stream on the heap (this is likely not expected from the user POV), and if we do end up sticking with this, we need to loudly document this as.

I think we do not want to mess with the memory of the stream at all...cc @albanD for thoughts

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm looks like @albanD had no thoughts :b

I'm not sure what to do here, this one looks different from

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
});
}

As that is just directly returning the id. However, the actual getCurrentStream returns a c10::Stream, so I'm trying to match the semantic in the stable ABI (so in future users can add other stream methods on a stream object)

c10::Stream getCurrentStream(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return impl.getStream({device_type, device_index});
}

I think at::acceerator::getCurrentStream is returning c10::Stream by value which is why we will need to create a new object on the heap if we want to return it to the caller and transfer ownership to the stable::Stream

I can add a comment for sure, but would be curious how else we can tweak this in a way that's more user/memory friendly, would it make more sense to just return the .id() directly?

});
}
71 changes: 71 additions & 0 deletions torch/csrc/stable/accelerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include <torch/csrc/inductor/aoti_runtime/utils.h>

#include <torch/csrc/inductor/aoti_torch/c/shim.h>

namespace torch::stable::accelerator {

namespace {
inline void delete_device_guard(void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_delete_device_guard(reinterpret_cast<DeviceGuardHandle>(ptr)));
}

} // namespace

// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we
// can converge on in this world as DeviceIndex in libtorch is not stable.
using DeviceIndex = int32_t;
using StreamId = int64_t; // this is from c10/core/Stream.h

class DeviceGuard {
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something that I'm not sure about -- do the copy / move semantics need to match the real device guard although this just a wrapper that is holding a std::unique_ptr to DeviceGuardOpaque?

~DeviceGuard() = default;
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
DeviceGuard& operator=(const DeviceGuard&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
DeviceGuard(DeviceGuard&& other) = delete;
DeviceGuard& operator=(DeviceGuard&& other) = delete;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to copy the existing DeviceGuard. For example, I agree with deleting the default constructor for this API. We can disallow copy and move if it's easiest to maintain anyway.

public:
explicit DeviceGuard() = delete;
explicit DeviceGuard(DeviceIndex device_index)
: guard_(nullptr, delete_device_guard) {
DeviceGuardHandle ptr = nullptr;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_create_device_guard(device_index, &ptr));
guard_.reset(ptr);
}

void set_index(DeviceIndex device_index) {
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_device_guard_set_index(guard_.get(), device_index));
}

private:
std::unique_ptr<DeviceGuardOpaque, aot_inductor::DeleterFnPtr> guard_;
};

class Stream {
public:
explicit Stream() = delete;

// Construct a stable::Stream from a StreamHandle
// Steals ownership from the StreamHandle
explicit Stream(StreamHandle stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, so the expected use case is:

  • user calls getCurrentStream which creates a Stream
  • then they can call id on it

: stream_(stream, [](StreamHandle stream) {
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_stream(stream));
}) {}

StreamId id() const {
StreamId stream_id;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_stream_id(stream_.get(), &stream_id));
return stream_id;
}

private:
std::shared_ptr<StreamOpaque> stream_;
};

Stream getCurrentStream(DeviceIndex device_index) {
StreamHandle stream = nullptr;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_stream(device_index, &stream));
return Stream(stream);
}

} // namespace torch::stable::accelerator
Loading