Skip to content

Commit cdd46aa

Browse files
Add beginning of torch::stable::accelerator
ghstack-source-id: 655868b Pull Request resolved: #159679
1 parent 08ea8fc commit cdd46aa

File tree

7 files changed

+259
-2
lines changed

7 files changed

+259
-2
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
2+
#include <torch/csrc/stable/accelerator.h>
23
#include <torch/csrc/stable/library.h>
34
#include <torch/csrc/stable/tensor.h>
45
#include <torch/csrc/stable/ops.h>
56
#include <torch/headeronly/util/Exception.h>
67

8+
#ifdef USE_CUDA
9+
#include <cuda_runtime.h>
10+
#endif
11+
712
#include <optional>
813

914
void inline sgd_math(
@@ -320,3 +325,48 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
320325
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
321326
m.impl("my_zero_", &boxed_my_zero_);
322327
}
328+
329+
// Test functions for torch::stable::accelerator APIs
330+
331+
#ifdef USE_CUDA
332+
int test_device_guard(int8_t device_index) {
333+
using torch::stable::accelerator::DeviceGuard;
334+
335+
DeviceGuard guard(device_index);
336+
int currentDevice;
337+
cudaError_t err = cudaGetDevice(&currentDevice);
338+
STD_TORCH_CHECK(err == cudaSuccess);
339+
return currentDevice;
340+
}
341+
342+
void boxed_test_device_guard(
343+
StableIValue* stack,
344+
uint64_t num_args,
345+
uint64_t num_outputs) {
346+
int res = test_device_guard(static_cast<int8_t>(to<int64_t>(stack[0])));
347+
stack[0] = from(res);
348+
}
349+
350+
int64_t test_stream(int8_t device_index) {
351+
auto id = torch::stable::accelerator::getCurrentStream(device_index).id();
352+
return id;
353+
}
354+
355+
void boxed_test_stream(
356+
StableIValue* stack,
357+
uint64_t num_args,
358+
uint64_t num_outputs) {
359+
int64_t res = test_stream(static_cast<int8_t>(to<int64_t>(stack[0])));
360+
stack[0] = from(res);
361+
}
362+
363+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
364+
m.def("test_device_guard(int device_index) -> int");
365+
m.def("test_stream(int device_index) -> int");
366+
}
367+
368+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
369+
m.impl("test_device_guard", &boxed_test_device_guard);
370+
m.impl("test_stream", &boxed_test_stream);
371+
}
372+
#endif // USE_CUDA

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,27 @@ def fill_infinity(t) -> Tensor:
164164
Returns: The modified tensor (same as input)
165165
"""
166166
return torch.ops.libtorch_agnostic.fill_infinity.default(t)
167+
168+
169+
def test_device_guard(device_index) -> Tensor:
170+
"""
171+
Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor.
172+
173+
Args:
174+
device_index: Device index to set the guard to
175+
176+
Returns: A 3x3 empty tensor created on the device specified by device_index
177+
"""
178+
return torch.ops.libtorch_agnostic.test_device_guard.default(device_index)
179+
180+
181+
def test_stream(device_index) -> int:
182+
"""
183+
Tests the Stream functionality by getting the current stream ID for the specified device.
184+
185+
Args:
186+
device_index: Device index to get the stream for
187+
188+
Returns: Stream ID as an integer
189+
"""
190+
return torch.ops.libtorch_agnostic.test_stream.default(device_index)

test/cpp_extensions/libtorch_agnostic_extension/setup.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from setuptools import find_packages, setup
66

7-
from torch.utils.cpp_extension import BuildExtension, CppExtension
7+
import torch
8+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
89

910

1011
ROOT_DIR = Path(__file__).parent
@@ -35,10 +36,16 @@ def get_extension():
3536
"cxx": ["-fdiagnostics-color=always"],
3637
}
3738

39+
extension = CppExtension
40+
# allow including <cuda_runtime.h>
41+
if torch.cuda.is_available():
42+
extra_compile_args["cxx"].append("-DUSE_CUDA")
43+
extension = CUDAExtension
44+
3845
sources = list(CSRC_DIR.glob("**/*.cpp"))
3946

4047
return [
41-
CppExtension(
48+
extension(
4249
"libtorch_agnostic._C",
4350
sources=sorted(str(s) for s in sources),
4451
py_limited_api=True,

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch.testing._internal.common_device_type import (
8+
deviceCountAtLeast,
89
instantiate_device_type_tests,
910
onlyCPU,
1011
onlyCUDA,
@@ -218,6 +219,28 @@ def test_fill_infinity(self, device):
218219
expected = torch.full_like(t, math.inf)
219220
self.assertEqual(out, expected)
220221

222+
@onlyCUDA
223+
@deviceCountAtLeast(2)
224+
def test_device_guard(self, device):
225+
import libtorch_agnostic
226+
227+
device_index = 1
228+
out = libtorch_agnostic.ops.test_device_guard(device_index)
229+
self.assertEqual(out, device_index)
230+
231+
@onlyCUDA
232+
def test_stream(self, device):
233+
import libtorch_agnostic
234+
235+
stream = torch.cuda.Stream()
236+
device = torch.cuda.current_device()
237+
238+
with stream:
239+
expected_stream_id = torch.cuda.current_stream(0).stream_id
240+
stream_id = libtorch_agnostic.ops.test_stream(device)
241+
242+
self.assertEqual(stream_id, expected_stream_id)
243+
221244
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
222245

223246
if __name__ == "__main__":

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,36 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher(
483483
const char* overloadName,
484484
StableIValue* stack);
485485

486+
// Device-generic guard for managing device context
487+
struct DeviceGuardOpaque;
488+
using DeviceGuardHandle = DeviceGuardOpaque*;
489+
490+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_device_guard(
491+
int32_t device_index,
492+
DeviceGuardHandle* ret_guard // returns new reference
493+
);
494+
495+
AOTI_TORCH_EXPORT AOTITorchError
496+
aoti_torch_delete_device_guard(DeviceGuardHandle guard);
497+
498+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_device_guard_set_index(
499+
DeviceGuardHandle guard,
500+
int32_t device_index);
501+
502+
// Device-generic stream for managing stream objects
503+
struct StreamOpaque;
504+
using StreamHandle = StreamOpaque*;
505+
506+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_delete_stream(StreamHandle stream);
507+
508+
AOTI_TORCH_EXPORT AOTITorchError
509+
aoti_torch_stream_id(StreamHandle stream, int64_t* ret_stream_id);
510+
511+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream(
512+
int32_t device_index,
513+
StreamHandle* ret_stream // returns new reference
514+
);
515+
486516
#ifdef USE_CUDA
487517

488518
struct CUDAGuardOpaque;

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
#include <iostream>
2525
#include <vector>
2626

27+
#include <c10/core/Device.h>
28+
#include <c10/core/DeviceGuard.h>
29+
#include <c10/core/Stream.h>
30+
2731
#ifndef AT_PER_OPERATOR_HEADERS
2832
#include <ATen/Functions.h>
2933
#else
@@ -1590,3 +1594,53 @@ AOTITorchError aoti_torch_call_dispatcher(
15901594
}
15911595
});
15921596
}
1597+
1598+
AOTITorchError aoti_torch_create_device_guard(
1599+
int32_t device_index,
1600+
DeviceGuardHandle* ret_guard // returns new reference
1601+
) {
1602+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1603+
// checked=true will fail if no accelerator is available
1604+
const auto device_type =
1605+
at::accelerator::getAccelerator(/*checked=*/true).value();
1606+
c10::Device device(device_type, device_index);
1607+
c10::DeviceGuard* guard = new c10::DeviceGuard(device);
1608+
*ret_guard = reinterpret_cast<DeviceGuardHandle>(guard);
1609+
});
1610+
}
1611+
1612+
AOTITorchError aoti_torch_delete_device_guard(DeviceGuardHandle guard) {
1613+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
1614+
{ delete reinterpret_cast<c10::DeviceGuard*>(guard); });
1615+
}
1616+
1617+
AOTITorchError aoti_torch_device_guard_set_index(
1618+
DeviceGuardHandle guard,
1619+
int32_t device_index) {
1620+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
1621+
{ reinterpret_cast<c10::DeviceGuard*>(guard)->set_index(device_index); });
1622+
}
1623+
1624+
AOTITorchError aoti_torch_delete_stream(StreamHandle stream) {
1625+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
1626+
{ delete reinterpret_cast<c10::Stream*>(stream); });
1627+
}
1628+
1629+
AOTITorchError aoti_torch_stream_id(
1630+
StreamHandle stream,
1631+
int64_t* ret_stream_id) {
1632+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1633+
c10::Stream* stream_ptr = reinterpret_cast<c10::Stream*>(stream);
1634+
*ret_stream_id = stream_ptr->id();
1635+
});
1636+
}
1637+
1638+
AOTITorchError aoti_torch_get_current_stream(
1639+
int32_t device_index,
1640+
StreamHandle* ret_stream) {
1641+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1642+
c10::Stream stream = at::accelerator::getCurrentStream(device_index);
1643+
c10::Stream* stream_ptr = new c10::Stream(stream);
1644+
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr);
1645+
});
1646+
}

torch/csrc/stable/accelerator.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#pragma once
2+
3+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
4+
5+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
6+
7+
namespace torch::stable::accelerator {
8+
9+
namespace {
10+
inline void delete_device_guard(void* ptr) {
11+
AOTI_TORCH_ERROR_CODE_CHECK(
12+
aoti_torch_delete_device_guard(reinterpret_cast<DeviceGuardHandle>(ptr)));
13+
}
14+
15+
} // namespace
16+
17+
using DeviceIndex = int8_t; // this is from c10/core/Device.h
18+
using StreamId = int64_t; // this is from c10/core/Stream.h
19+
20+
class DeviceGuard {
21+
public:
22+
explicit DeviceGuard() = delete;
23+
explicit DeviceGuard(DeviceIndex device_index)
24+
: guard_(nullptr, delete_device_guard) {
25+
DeviceGuardHandle ptr = nullptr;
26+
AOTI_TORCH_ERROR_CODE_CHECK(
27+
aoti_torch_create_device_guard(device_index, &ptr));
28+
guard_.reset(ptr);
29+
}
30+
31+
void set_index(DeviceIndex device_index) {
32+
AOTI_TORCH_ERROR_CODE_CHECK(
33+
aoti_torch_device_guard_set_index(guard_.get(), device_index));
34+
}
35+
36+
private:
37+
std::unique_ptr<DeviceGuardOpaque, aot_inductor::DeleterFnPtr> guard_;
38+
};
39+
40+
class Stream {
41+
public:
42+
explicit Stream() = delete;
43+
44+
// Construct a stable::Stream from a StreamHandle
45+
// Steals ownership from the StreamHandle
46+
explicit Stream(StreamHandle stream)
47+
: stream_(stream, [](StreamHandle stream) {
48+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_stream(stream));
49+
}) {}
50+
51+
StreamId id() const {
52+
StreamId stream_id;
53+
AOTI_TORCH_ERROR_CODE_CHECK(
54+
aoti_torch_stream_id(stream_.get(), &stream_id));
55+
return stream_id;
56+
}
57+
58+
private:
59+
std::shared_ptr<StreamOpaque> stream_;
60+
};
61+
62+
Stream getCurrentStream(DeviceIndex device_index) {
63+
StreamHandle stream = nullptr;
64+
AOTI_TORCH_ERROR_CODE_CHECK(
65+
aoti_torch_get_current_stream(device_index, &stream));
66+
return Stream(stream);
67+
}
68+
69+
} // namespace torch::stable::accelerator

0 commit comments

Comments
 (0)