diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 63e9eb77dd34..2951d167c6f7 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,9 +1,14 @@ #include +#include #include #include #include #include +#ifdef LAE_USE_CUDA +#include +#endif + #include void inline sgd_math( @@ -320,3 +325,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { m.impl("my_zero_", &boxed_my_zero_); } + +// Test functions for torch::stable::accelerator APIs + +#ifdef LAE_USE_CUDA +int64_t test_device_guard(int64_t device_index) { + using torch::stable::accelerator::DeviceGuard; + + STD_TORCH_CHECK( + device_index >= std::numeric_limits::min() && + device_index <= std::numeric_limits::max(), + "Device index is out of range of DeviceIndex (int32_t)."); + + DeviceGuard guard(device_index); + int currentDevice; + cudaError_t err = cudaGetDevice(¤tDevice); + STD_TORCH_CHECK(err == cudaSuccess); + return currentDevice; +} + +void boxed_test_device_guard( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + int res = test_device_guard(static_cast(to(stack[0]))); + stack[0] = from(res); +} + +int64_t test_device_guard_set_index() { + using torch::stable::accelerator::DeviceGuard; + + DeviceGuard guard(1); + guard.set_index(0); + int currentDevice; + cudaError_t err = cudaGetDevice(¤tDevice); + STD_TORCH_CHECK(err == cudaSuccess); + return currentDevice; +} + +void boxed_test_device_guard_set_index( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + int64_t res = test_device_guard_set_index(); + stack[0] = from(res); +} + +int64_t test_stream(int32_t device_index) { + STD_TORCH_CHECK( + device_index >= std::numeric_limits::min() && + device_index <= std::numeric_limits::max(), + "Device index is out of range of DeviceIndex (int32_t)."); + + return torch::stable::accelerator::getCurrentStream(device_index).id(); +} + +void boxed_test_stream( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + int64_t res = test_stream(static_cast(to(stack[0]))); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("test_device_guard(int device_index) -> int"); + m.def("test_device_guard_set_index() -> int"); + m.def("test_stream(int device_index) -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("test_device_guard", &boxed_test_device_guard); + m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index); + m.impl("test_stream", &boxed_test_stream); +} +#endif // LAE_USE_CUDA diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 1694bfa1b396..dbf70d9a976b 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -164,3 +164,37 @@ def fill_infinity(t) -> Tensor: Returns: The modified tensor (same as input) """ return torch.ops.libtorch_agnostic.fill_infinity.default(t) + + +def test_device_guard(device_index) -> int: + """ + Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor. + + Args: + device_index: Device index to set the guard to + + Returns: result of cudaGetDevice() as an integer after using the guard + """ + return torch.ops.libtorch_agnostic.test_device_guard.default(device_index) + + +def test_device_guard_set_index() -> int: + """ + Tests the DeviceGuard set_index functionality by creating a device guard with index 1, + then setting it to index 0, and returning the current device. + + Returns: result of cudaGetDevice() as an integer after using set_index + """ + return torch.ops.libtorch_agnostic.test_device_guard_set_index.default() + + +def test_stream(device_index) -> int: + """ + Tests the Stream functionality by getting the current stream ID for the specified device. + + Args: + device_index: Device index to get the stream for + + Returns: Stream ID as an integer + """ + return torch.ops.libtorch_agnostic.test_stream.default(device_index) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_extension/setup.py index 5cd18f5579f9..b7141a3e6fcd 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/setup.py @@ -4,7 +4,8 @@ from setuptools import find_packages, setup -from torch.utils.cpp_extension import BuildExtension, CppExtension +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension ROOT_DIR = Path(__file__).parent @@ -35,10 +36,16 @@ def get_extension(): "cxx": ["-fdiagnostics-color=always"], } + extension = CppExtension + # allow including + if torch.cuda.is_available(): + extra_compile_args["cxx"].append("-DLAE_USE_CUDA") + extension = CUDAExtension + sources = list(CSRC_DIR.glob("**/*.cpp")) return [ - CppExtension( + extension( "libtorch_agnostic._C", sources=sorted(str(s) for s in sources), py_limited_api=True, diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index bd409a0eb5a6..b5348fb7e319 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -5,6 +5,7 @@ import torch from torch.testing._internal.common_device_type import ( + deviceCountAtLeast, instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -218,6 +219,38 @@ def test_fill_infinity(self, device): expected = torch.full_like(t, math.inf) self.assertEqual(out, expected) + @onlyCUDA + @deviceCountAtLeast(2) + def test_device_guard(self, device): + import libtorch_agnostic + + device_index = 1 + out = libtorch_agnostic.ops.test_device_guard(device_index) + self.assertEqual(out, device_index) + + @onlyCUDA + @deviceCountAtLeast(2) + def test_device_guard_set_index(self, device): + import libtorch_agnostic + + # This test creates a DeviceGuard with index 1, then sets it to index 0 + # and returns the current device (should be 0) + out = libtorch_agnostic.ops.test_device_guard_set_index() + self.assertEqual(out, 0) + + @onlyCUDA + def test_stream(self, device): + import libtorch_agnostic + + stream = torch.cuda.Stream() + device = torch.cuda.current_device() + + with stream: + expected_stream_id = torch.cuda.current_stream(0).stream_id + stream_id = libtorch_agnostic.ops.test_stream(device) + + self.assertEqual(stream_id, expected_stream_id) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index d6f32358cdcc..972234564772 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -493,6 +493,36 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher( const char* overloadName, StableIValue* stack); +// Device-generic guard for managing device context +struct DeviceGuardOpaque; +using DeviceGuardHandle = DeviceGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_device_guard( + int32_t device_index, + DeviceGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_device_guard(DeviceGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_device_guard_set_index( + DeviceGuardHandle guard, + int32_t device_index); + +// Device-generic stream for managing stream objects +struct StreamOpaque; +using StreamHandle = StreamOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_delete_stream(StreamHandle stream); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_stream_id(StreamHandle stream, int64_t* ret_stream_id); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream( + int32_t device_index, + StreamHandle* ret_stream // returns new reference +); + #ifdef USE_CUDA struct CUDAGuardOpaque; diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index eff8276315a2..9b0ca53fad31 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -24,6 +24,10 @@ #include #include +#include +#include +#include + #ifndef AT_PER_OPERATOR_HEADERS #include #else @@ -1612,3 +1616,55 @@ AOTITorchError aoti_torch_call_dispatcher( } }); } + +AOTITorchError aoti_torch_create_device_guard( + int32_t device_index, + DeviceGuardHandle* ret_guard // returns new reference +) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + // checked=true will fail if no accelerator is available + const auto device_type = + at::accelerator::getAccelerator(/*checked=*/true).value(); + c10::Device device(device_type, device_index); + c10::DeviceGuard* guard = new c10::DeviceGuard(device); + *ret_guard = reinterpret_cast(guard); + }); +} + +AOTITorchError aoti_torch_delete_device_guard(DeviceGuardHandle guard) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { delete reinterpret_cast(guard); }); +} + +AOTITorchError aoti_torch_device_guard_set_index( + DeviceGuardHandle guard, + int32_t device_index) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { reinterpret_cast(guard)->set_index(device_index); }); +} + +AOTITorchError aoti_torch_delete_stream(StreamHandle stream) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { delete reinterpret_cast(stream); }); +} + +AOTITorchError aoti_torch_stream_id( + StreamHandle stream, + int64_t* ret_stream_id) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::Stream* stream_ptr = reinterpret_cast(stream); + *ret_stream_id = stream_ptr->id(); + }); +} + +// This function creates a new Stream object and makes StreamHandle point to it. +// The caller is responsible for managing the object's lifecycle. +AOTITorchError aoti_torch_get_current_stream( + int32_t device_index, + StreamHandle* ret_stream) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::Stream stream = at::accelerator::getCurrentStream(device_index); + c10::Stream* stream_ptr = new c10::Stream(stream); + *ret_stream = reinterpret_cast(stream_ptr); + }); +} diff --git a/torch/csrc/stable/accelerator.h b/torch/csrc/stable/accelerator.h new file mode 100644 index 000000000000..0c983407e679 --- /dev/null +++ b/torch/csrc/stable/accelerator.h @@ -0,0 +1,71 @@ +#pragma once + +#include + +#include + +namespace torch::stable::accelerator { + +namespace { +inline void delete_device_guard(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_delete_device_guard(reinterpret_cast(ptr))); +} + +} // namespace + +// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we +// can converge on in this world as DeviceIndex in libtorch is not stable. +using DeviceIndex = int32_t; +using StreamId = int64_t; // this is from c10/core/Stream.h + +class DeviceGuard { + public: + explicit DeviceGuard() = delete; + explicit DeviceGuard(DeviceIndex device_index) + : guard_(nullptr, delete_device_guard) { + DeviceGuardHandle ptr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_create_device_guard(device_index, &ptr)); + guard_.reset(ptr); + } + + void set_index(DeviceIndex device_index) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_device_guard_set_index(guard_.get(), device_index)); + } + + private: + std::unique_ptr guard_; +}; + +class Stream { + public: + explicit Stream() = delete; + + // Construct a stable::Stream from a StreamHandle + // Steals ownership from the StreamHandle + explicit Stream(StreamHandle stream) + : stream_(stream, [](StreamHandle stream) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_stream(stream)); + }) {} + + StreamId id() const { + StreamId stream_id; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_stream_id(stream_.get(), &stream_id)); + return stream_id; + } + + private: + std::shared_ptr stream_; +}; + +Stream getCurrentStream(DeviceIndex device_index) { + StreamHandle stream = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_stream(device_index, &stream)); + return Stream(stream); +} + +} // namespace torch::stable::accelerator