From 12655fc1c0c7ad3d7f89611043cf951142cf5e2f Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 1 Aug 2025 14:37:18 -0700 Subject: [PATCH 1/9] Add beginning of torch::stable::accelerator [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 194 +++++++++++++----- .../libtorch_agnostic/ops.py | 23 +++ .../libtorch_agnostic_extension/setup.py | 5 +- .../test/test_libtorch_agnostic.py | 23 +++ torch/csrc/inductor/aoti_torch/c/shim.h | 30 +++ .../csrc/inductor/aoti_torch/shim_common.cpp | 54 +++++ torch/csrc/stable/accelerator.h | 68 ++++++ 7 files changed, 344 insertions(+), 53 deletions(-) create mode 100644 torch/csrc/stable/accelerator.h diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 63e9eb77dd34..95e70dc56b62 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,25 +1,28 @@ #include +#include #include -#include #include +#include #include #include +#include + void inline sgd_math( - float* param_ptr, - float* grad_ptr, - float* out_ptr, - const float weight_decay, - const double lr, - const bool maximize, - int64_t size -){ + float* param_ptr, + float* grad_ptr, + float* out_ptr, + const float weight_decay, + const double lr, + const bool maximize, + int64_t size) { int64_t d = 0; for (; d < size; d++) { float grad_val = grad_ptr[d]; - if (maximize) grad_val = -grad_val; - if (weight_decay != 0.0){ + if (maximize) + grad_val = -grad_val; + if (weight_decay != 0.0) { grad_val += param_ptr[d] * weight_decay; } out_ptr[d] = param_ptr[d] - grad_val * float(lr); @@ -36,8 +39,8 @@ Tensor sgd_out_of_place( const bool maximize) { STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); - int64_t *param_sizes; - int64_t *param_strides; + int64_t* param_sizes; + int64_t* param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); aoti_torch_get_strides(param.get(), ¶m_strides); @@ -48,35 +51,45 @@ Tensor sgd_out_of_place( aoti_torch_get_device_type(param.get(), ¶m_device_type); AtenTensorHandle out_ath; - aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); + aoti_torch_empty_strided( + param.dim(), + param_sizes, + param_strides, + param_dtype, + param_device_type, + param.get_device(), + &out_ath); auto out = Tensor(out_ath); sgd_math( - reinterpret_cast(param.data_ptr()), - reinterpret_cast(grad.data_ptr()), - reinterpret_cast(out.data_ptr()), - weight_decay, - lr, - maximize, - param.numel() - ); + reinterpret_cast(param.data_ptr()), + reinterpret_cast(grad.data_ptr()), + reinterpret_cast(out.data_ptr()), + weight_decay, + lr, + maximize, + param.numel()); return out; } -void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_sgd_out_of_place( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = sgd_out_of_place( - to(stack[0]), - to(stack[1]), - float(to(stack[2])), - to(stack[3]), - to(stack[4])); + to(stack[0]), + to(stack[1]), + float(to(stack[2])), + to(stack[3]), + to(stack[4])); stack[0] = from(res); } STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { - m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); + m.def( + "sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { @@ -87,7 +100,10 @@ Tensor identity(Tensor t) { return t; } -void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_identity( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = identity(to(stack[0])); stack[0] = from(res); } @@ -112,7 +128,10 @@ Tensor my_abs(Tensor t) { return to(stack[0]); } -void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_my_abs( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor tensor_res = my_abs(to(stack[0])); stack[0] = from(tensor_res); } @@ -134,18 +153,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) { auto mf = aoti_torch_memory_format_contiguous_format(); stack[0] = from(t); - stack[1] = from(std::optional(t_dtype)); // dtype - stack[2] = from(std::nullopt); // layout - stack[3] = from(std::optional(device)); // device - stack[4] = from(std::optional(false)); // pin_memory - stack[5] = from(std::optional(mf)); // memory_format + stack[1] = from(std::optional(t_dtype)); // dtype + stack[2] = from(std::nullopt); // layout + stack[3] = from(std::optional(device)); // device + stack[4] = from(std::optional(false)); // pin_memory + stack[5] = from(std::optional(mf)); // memory_format aoti_torch_call_dispatcher("aten::ones_like", "", stack); return to(stack[0]); } -void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_my_ones_like( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = my_ones_like(to(stack[0]), stack[1]); stack[0] = from(res); } @@ -158,7 +180,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_ones_like", &boxed_my_ones_like); } -std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { +std::tuple exp_neg_is_leaf( + Tensor t1, + Tensor t2, + Tensor t3) { StableIValue stack_exp[1]; stack_exp[0] = from(t1); aoti_torch_call_dispatcher("aten::exp", "", stack_exp); @@ -172,20 +197,25 @@ std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3 aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); return std::make_tuple( - to(stack_exp[0]), - to(stack_neg[0]), - to(stack_is_leaf[0])); + to(stack_exp[0]), + to(stack_neg[0]), + to(stack_is_leaf[0])); } -void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto tuple = exp_neg_is_leaf(to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_exp_neg_is_leaf( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto tuple = exp_neg_is_leaf( + to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(std::get<0>(tuple)); stack[1] = from(std::get<1>(tuple)); stack[2] = from(std::get<2>(tuple)); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { - m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); + m.def( + "exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { @@ -200,7 +230,10 @@ Tensor neg_exp(Tensor t) { return to(stack[0]); } -void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_neg_exp( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = neg_exp(to(stack[0])); stack[0] = from(res); } @@ -229,7 +262,10 @@ Tensor divide_neg_exp(Tensor t) { return to(stack_div[0]); } -void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_divide_neg_exp( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = divide_neg_exp(to(stack[0])); stack[0] = from(res); } @@ -246,7 +282,10 @@ bool is_contiguous(Tensor t) { return t.is_contiguous(); } -void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_is_contiguous( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { bool res = is_contiguous(to(stack[0])); stack[0] = from(res); } @@ -263,8 +302,12 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { return transpose(t, dim0, dim1); } -void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_transpose(to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_my_transpose( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto res = my_transpose( + to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(res); } @@ -273,7 +316,10 @@ Tensor my_empty_like(Tensor t) { return empty_like(t); } -void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_empty_like( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { auto res = my_empty_like(to(stack[0])); stack[0] = from(res); } @@ -308,7 +354,10 @@ Tensor my_zero_(Tensor t) { return zero_(t); } -void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_my_zero_( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { auto res = my_zero_(to(stack[0])); stack[0] = from(res); } @@ -320,3 +369,46 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { m.impl("my_zero_", &boxed_my_zero_); } + +// Test functions for torch::stable::accelerator APIs + +int test_device_guard(int8_t device_index) { + using torch::stable::accelerator::DeviceGuard; + + DeviceGuard guard(device_index); + int currentDevice; + cudaError_t err = cudaGetDevice(¤tDevice); + STD_TORCH_CHECK(err == cudaSuccess); + return currentDevice; +} + +void boxed_test_device_guard( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + int res = test_device_guard(static_cast(to(stack[0]))); + stack[0] = from(res); +} + +int64_t test_stream(int8_t device_index) { + auto id = torch::stable::accelerator::getCurrentStream(device_index).id(); + return id; +} + +void boxed_test_stream( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + int64_t res = test_stream(static_cast(to(stack[0]))); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("test_device_guard(int device_index) -> int"); + m.def("test_stream(int device_index) -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("test_device_guard", &boxed_test_device_guard); + m.impl("test_stream", &boxed_test_stream); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 1694bfa1b396..1e1b873a7c0f 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -164,3 +164,26 @@ def fill_infinity(t) -> Tensor: Returns: The modified tensor (same as input) """ return torch.ops.libtorch_agnostic.fill_infinity.default(t) + +def test_device_guard(device_index) -> Tensor: + """ + Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor. + + Args: + device_index: Device index to set the guard to + + Returns: A 3x3 empty tensor created on the device specified by device_index + """ + return torch.ops.libtorch_agnostic.test_device_guard.default(device_index) + + +def test_stream(device_index) -> int: + """ + Tests the Stream functionality by getting the current stream ID for the specified device. + + Args: + device_index: Device index to get the stream for + + Returns: Stream ID as an integer + """ + return torch.ops.libtorch_agnostic.test_stream.default(device_index) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_extension/setup.py index 5cd18f5579f9..65d57d2de4dc 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/setup.py @@ -4,7 +4,7 @@ from setuptools import find_packages, setup -from torch.utils.cpp_extension import BuildExtension, CppExtension +from torch.utils.cpp_extension import BuildExtension, CUDAExtension ROOT_DIR = Path(__file__).parent @@ -33,12 +33,13 @@ def run(self): def get_extension(): extra_compile_args = { "cxx": ["-fdiagnostics-color=always"], + "nvcc": ["-O2"], } sources = list(CSRC_DIR.glob("**/*.cpp")) return [ - CppExtension( + CUDAExtension( "libtorch_agnostic._C", sources=sorted(str(s) for s in sources), py_limited_api=True, diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index bd409a0eb5a6..b7720e061a5a 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -5,6 +5,7 @@ import torch from torch.testing._internal.common_device_type import ( + deviceCountAtLeast, instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -218,6 +219,28 @@ def test_fill_infinity(self, device): expected = torch.full_like(t, math.inf) self.assertEqual(out, expected) + @onlyCUDA + @deviceCountAtLeast(2) + def test_device_guard(self, device): + import libtorch_agnostic + + device_index = 1 + out = libtorch_agnostic.ops.test_device_guard(device_index) + self.assertEqual(out, device_index) + + @onlyCUDA + def test_stream(self, device): + import libtorch_agnostic + + stream = torch.cuda.Stream() + device = torch.cuda.current_device() + + with stream: + expected_stream_id = torch.cuda.current_stream(0).stream_id + stream_id = libtorch_agnostic.ops.test_stream(device) + + self.assertEqual(stream_id, expected_stream_id) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 9d512ce1f481..cb808b883bb2 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -483,6 +483,36 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher( const char* overloadName, StableIValue* stack); +// Device-generic guard for managing device context +struct DeviceGuardOpaque; +using DeviceGuardHandle = DeviceGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_device_guard( + int32_t device_index, + DeviceGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_device_guard(DeviceGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_device_guard_set_index( + DeviceGuardHandle guard, + int32_t device_index); + +// Device-generic stream for managing stream objects +struct StreamOpaque; +using StreamHandle = StreamOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_delete_stream(StreamHandle stream); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_stream_id(StreamHandle stream, int64_t* ret_stream_id); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream( + int32_t device_index, + StreamHandle* ret_stream // returns new reference +); + #ifdef USE_CUDA struct CUDAGuardOpaque; diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index a33198fd1ba0..1bea83ea810e 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -24,6 +24,10 @@ #include #include +#include +#include +#include + #ifndef AT_PER_OPERATOR_HEADERS #include #else @@ -1590,3 +1594,53 @@ AOTITorchError aoti_torch_call_dispatcher( } }); } + +AOTITorchError aoti_torch_create_device_guard( + int32_t device_index, + DeviceGuardHandle* ret_guard // returns new reference +) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + // checked=true will fail if no accelerator is available + const auto device_type = + at::accelerator::getAccelerator(/*checked=*/true).value(); + c10::Device device(device_type, device_index); + c10::DeviceGuard* guard = new c10::DeviceGuard(device); + *ret_guard = reinterpret_cast(guard); + }); +} + +AOTITorchError aoti_torch_delete_device_guard(DeviceGuardHandle guard) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { delete reinterpret_cast(guard); }); +} + +AOTITorchError aoti_torch_device_guard_set_index( + DeviceGuardHandle guard, + int32_t device_index) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { reinterpret_cast(guard)->set_index(device_index); }); +} + +AOTITorchError aoti_torch_delete_stream(StreamHandle stream) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { delete reinterpret_cast(stream); }); +} + +AOTITorchError aoti_torch_stream_id( + StreamHandle stream, + int64_t* ret_stream_id) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::Stream* stream_ptr = reinterpret_cast(stream); + *ret_stream_id = stream_ptr->id(); + }); +} + +AOTITorchError aoti_torch_get_current_stream( + int32_t device_index, + StreamHandle* ret_stream) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::Stream stream = at::accelerator::getCurrentStream(device_index); + c10::Stream* stream_ptr = new c10::Stream(stream); + *ret_stream = reinterpret_cast(stream_ptr); + }); +} diff --git a/torch/csrc/stable/accelerator.h b/torch/csrc/stable/accelerator.h new file mode 100644 index 000000000000..29319b9c5f3f --- /dev/null +++ b/torch/csrc/stable/accelerator.h @@ -0,0 +1,68 @@ +#pragma once + +#include + +#include + +namespace torch::stable::accelerator { + +namespace { +inline void delete_device_guard(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_delete_device_guard(reinterpret_cast(ptr))); +} + +} // namespace + +using DeviceIndex = int8_t; +using StreamId = int64_t; +class DeviceGuard { + public: + explicit DeviceGuard() = delete; + explicit DeviceGuard(DeviceIndex device_index) + : guard_(nullptr, delete_device_guard) { + DeviceGuardHandle ptr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_create_device_guard(device_index, &ptr)); + guard_.reset(ptr); + } + + void set_index(int32_t device_index) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_device_guard_set_index(guard_.get(), device_index)); + } + + private: + std::unique_ptr guard_; +}; + +class Stream { + public: + explicit Stream() = delete; + + // Construct a stable::Stream from a StreamHandle + // Steals ownership from the StreamHandle + explicit Stream(StreamHandle stream) + : stream_(stream, [](StreamHandle stream) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_stream(stream)); + }) {} + + StreamId id() const { + StreamId stream_id; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_stream_id(stream_.get(), &stream_id)); + return stream_id; + } + + private: + std::shared_ptr stream_; +}; + +Stream getCurrentStream(DeviceIndex device_index) { + StreamHandle stream = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_stream(device_index, &stream)); + return Stream(stream); +} + +} // namespace torch::stable::accelerator From 117381f913ad876d234b413f4c229fb8bb5ba840 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 1 Aug 2025 14:43:36 -0700 Subject: [PATCH 2/9] Update on "Add beginning of torch::stable::accelerator" [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 152 ++++++------------ .../libtorch_agnostic/ops.py | 1 + 2 files changed, 54 insertions(+), 99 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 95e70dc56b62..a4fe1628468a 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,28 +1,28 @@ #include #include #include -#include #include +#include #include -#include - #include +#include + void inline sgd_math( - float* param_ptr, - float* grad_ptr, - float* out_ptr, - const float weight_decay, - const double lr, - const bool maximize, - int64_t size) { + float* param_ptr, + float* grad_ptr, + float* out_ptr, + const float weight_decay, + const double lr, + const bool maximize, + int64_t size +){ int64_t d = 0; for (; d < size; d++) { float grad_val = grad_ptr[d]; - if (maximize) - grad_val = -grad_val; - if (weight_decay != 0.0) { + if (maximize) grad_val = -grad_val; + if (weight_decay != 0.0){ grad_val += param_ptr[d] * weight_decay; } out_ptr[d] = param_ptr[d] - grad_val * float(lr); @@ -39,8 +39,8 @@ Tensor sgd_out_of_place( const bool maximize) { STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); - int64_t* param_sizes; - int64_t* param_strides; + int64_t *param_sizes; + int64_t *param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); aoti_torch_get_strides(param.get(), ¶m_strides); @@ -51,45 +51,35 @@ Tensor sgd_out_of_place( aoti_torch_get_device_type(param.get(), ¶m_device_type); AtenTensorHandle out_ath; - aoti_torch_empty_strided( - param.dim(), - param_sizes, - param_strides, - param_dtype, - param_device_type, - param.get_device(), - &out_ath); + aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); auto out = Tensor(out_ath); sgd_math( - reinterpret_cast(param.data_ptr()), - reinterpret_cast(grad.data_ptr()), - reinterpret_cast(out.data_ptr()), - weight_decay, - lr, - maximize, - param.numel()); + reinterpret_cast(param.data_ptr()), + reinterpret_cast(grad.data_ptr()), + reinterpret_cast(out.data_ptr()), + weight_decay, + lr, + maximize, + param.numel() + ); return out; } -void boxed_sgd_out_of_place( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = sgd_out_of_place( - to(stack[0]), - to(stack[1]), - float(to(stack[2])), - to(stack[3]), - to(stack[4])); + to(stack[0]), + to(stack[1]), + float(to(stack[2])), + to(stack[3]), + to(stack[4])); stack[0] = from(res); } STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { - m.def( - "sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); + m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { @@ -100,10 +90,7 @@ Tensor identity(Tensor t) { return t; } -void boxed_identity( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = identity(to(stack[0])); stack[0] = from(res); } @@ -128,10 +115,7 @@ Tensor my_abs(Tensor t) { return to(stack[0]); } -void boxed_my_abs( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor tensor_res = my_abs(to(stack[0])); stack[0] = from(tensor_res); } @@ -153,21 +137,18 @@ Tensor my_ones_like(Tensor t, StableIValue device) { auto mf = aoti_torch_memory_format_contiguous_format(); stack[0] = from(t); - stack[1] = from(std::optional(t_dtype)); // dtype - stack[2] = from(std::nullopt); // layout - stack[3] = from(std::optional(device)); // device - stack[4] = from(std::optional(false)); // pin_memory - stack[5] = from(std::optional(mf)); // memory_format + stack[1] = from(std::optional(t_dtype)); // dtype + stack[2] = from(std::nullopt); // layout + stack[3] = from(std::optional(device)); // device + stack[4] = from(std::optional(false)); // pin_memory + stack[5] = from(std::optional(mf)); // memory_format aoti_torch_call_dispatcher("aten::ones_like", "", stack); return to(stack[0]); } -void boxed_my_ones_like( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = my_ones_like(to(stack[0]), stack[1]); stack[0] = from(res); } @@ -180,10 +161,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_ones_like", &boxed_my_ones_like); } -std::tuple exp_neg_is_leaf( - Tensor t1, - Tensor t2, - Tensor t3) { +std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { StableIValue stack_exp[1]; stack_exp[0] = from(t1); aoti_torch_call_dispatcher("aten::exp", "", stack_exp); @@ -197,25 +175,20 @@ std::tuple exp_neg_is_leaf( aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); return std::make_tuple( - to(stack_exp[0]), - to(stack_neg[0]), - to(stack_is_leaf[0])); + to(stack_exp[0]), + to(stack_neg[0]), + to(stack_is_leaf[0])); } -void boxed_exp_neg_is_leaf( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { - auto tuple = exp_neg_is_leaf( - to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto tuple = exp_neg_is_leaf(to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(std::get<0>(tuple)); stack[1] = from(std::get<1>(tuple)); stack[2] = from(std::get<2>(tuple)); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { - m.def( - "exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); + m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { @@ -230,10 +203,7 @@ Tensor neg_exp(Tensor t) { return to(stack[0]); } -void boxed_neg_exp( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = neg_exp(to(stack[0])); stack[0] = from(res); } @@ -262,10 +232,7 @@ Tensor divide_neg_exp(Tensor t) { return to(stack_div[0]); } -void boxed_divide_neg_exp( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = divide_neg_exp(to(stack[0])); stack[0] = from(res); } @@ -282,10 +249,7 @@ bool is_contiguous(Tensor t) { return t.is_contiguous(); } -void boxed_is_contiguous( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { bool res = is_contiguous(to(stack[0])); stack[0] = from(res); } @@ -302,12 +266,8 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { return transpose(t, dim0, dim1); } -void boxed_my_transpose( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { - auto res = my_transpose( - to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_transpose(to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(res); } @@ -316,10 +276,7 @@ Tensor my_empty_like(Tensor t) { return empty_like(t); } -void boxed_empty_like( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { auto res = my_empty_like(to(stack[0])); stack[0] = from(res); } @@ -354,10 +311,7 @@ Tensor my_zero_(Tensor t) { return zero_(t); } -void boxed_my_zero_( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { auto res = my_zero_(to(stack[0])); stack[0] = from(res); } diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 1e1b873a7c0f..cfbd9a290989 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -165,6 +165,7 @@ def fill_infinity(t) -> Tensor: """ return torch.ops.libtorch_agnostic.fill_infinity.default(t) + def test_device_guard(device_index) -> Tensor: """ Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor. From 6f02310c1ef9492255098ad0ef9ff0feeb37de3b Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Aug 2025 12:12:49 -0700 Subject: [PATCH 3/9] Update on "Add beginnings of torch::stable::accelerator" [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 153 ++++++++++++------ .../libtorch_agnostic_extension/setup.py | 12 +- 2 files changed, 110 insertions(+), 55 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index a4fe1628468a..79bbf366136f 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,28 +1,30 @@ #include #include #include -#include #include +#include #include +#ifdef USE_CUDA #include +#endif #include void inline sgd_math( - float* param_ptr, - float* grad_ptr, - float* out_ptr, - const float weight_decay, - const double lr, - const bool maximize, - int64_t size -){ + float* param_ptr, + float* grad_ptr, + float* out_ptr, + const float weight_decay, + const double lr, + const bool maximize, + int64_t size) { int64_t d = 0; for (; d < size; d++) { float grad_val = grad_ptr[d]; - if (maximize) grad_val = -grad_val; - if (weight_decay != 0.0){ + if (maximize) + grad_val = -grad_val; + if (weight_decay != 0.0) { grad_val += param_ptr[d] * weight_decay; } out_ptr[d] = param_ptr[d] - grad_val * float(lr); @@ -39,8 +41,8 @@ Tensor sgd_out_of_place( const bool maximize) { STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); - int64_t *param_sizes; - int64_t *param_strides; + int64_t* param_sizes; + int64_t* param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); aoti_torch_get_strides(param.get(), ¶m_strides); @@ -51,35 +53,45 @@ Tensor sgd_out_of_place( aoti_torch_get_device_type(param.get(), ¶m_device_type); AtenTensorHandle out_ath; - aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); + aoti_torch_empty_strided( + param.dim(), + param_sizes, + param_strides, + param_dtype, + param_device_type, + param.get_device(), + &out_ath); auto out = Tensor(out_ath); sgd_math( - reinterpret_cast(param.data_ptr()), - reinterpret_cast(grad.data_ptr()), - reinterpret_cast(out.data_ptr()), - weight_decay, - lr, - maximize, - param.numel() - ); + reinterpret_cast(param.data_ptr()), + reinterpret_cast(grad.data_ptr()), + reinterpret_cast(out.data_ptr()), + weight_decay, + lr, + maximize, + param.numel()); return out; } -void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_sgd_out_of_place( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = sgd_out_of_place( - to(stack[0]), - to(stack[1]), - float(to(stack[2])), - to(stack[3]), - to(stack[4])); + to(stack[0]), + to(stack[1]), + float(to(stack[2])), + to(stack[3]), + to(stack[4])); stack[0] = from(res); } STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { - m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); + m.def( + "sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { @@ -90,7 +102,10 @@ Tensor identity(Tensor t) { return t; } -void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_identity( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = identity(to(stack[0])); stack[0] = from(res); } @@ -115,7 +130,10 @@ Tensor my_abs(Tensor t) { return to(stack[0]); } -void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_my_abs( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor tensor_res = my_abs(to(stack[0])); stack[0] = from(tensor_res); } @@ -137,18 +155,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) { auto mf = aoti_torch_memory_format_contiguous_format(); stack[0] = from(t); - stack[1] = from(std::optional(t_dtype)); // dtype - stack[2] = from(std::nullopt); // layout - stack[3] = from(std::optional(device)); // device - stack[4] = from(std::optional(false)); // pin_memory - stack[5] = from(std::optional(mf)); // memory_format + stack[1] = from(std::optional(t_dtype)); // dtype + stack[2] = from(std::nullopt); // layout + stack[3] = from(std::optional(device)); // device + stack[4] = from(std::optional(false)); // pin_memory + stack[5] = from(std::optional(mf)); // memory_format aoti_torch_call_dispatcher("aten::ones_like", "", stack); return to(stack[0]); } -void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_my_ones_like( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = my_ones_like(to(stack[0]), stack[1]); stack[0] = from(res); } @@ -161,7 +182,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_ones_like", &boxed_my_ones_like); } -std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { +std::tuple exp_neg_is_leaf( + Tensor t1, + Tensor t2, + Tensor t3) { StableIValue stack_exp[1]; stack_exp[0] = from(t1); aoti_torch_call_dispatcher("aten::exp", "", stack_exp); @@ -175,20 +199,25 @@ std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3 aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); return std::make_tuple( - to(stack_exp[0]), - to(stack_neg[0]), - to(stack_is_leaf[0])); + to(stack_exp[0]), + to(stack_neg[0]), + to(stack_is_leaf[0])); } -void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto tuple = exp_neg_is_leaf(to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_exp_neg_is_leaf( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto tuple = exp_neg_is_leaf( + to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(std::get<0>(tuple)); stack[1] = from(std::get<1>(tuple)); stack[2] = from(std::get<2>(tuple)); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { - m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); + m.def( + "exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { @@ -203,7 +232,10 @@ Tensor neg_exp(Tensor t) { return to(stack[0]); } -void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_neg_exp( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = neg_exp(to(stack[0])); stack[0] = from(res); } @@ -232,7 +264,10 @@ Tensor divide_neg_exp(Tensor t) { return to(stack_div[0]); } -void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_divide_neg_exp( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { Tensor res = divide_neg_exp(to(stack[0])); stack[0] = from(res); } @@ -249,7 +284,10 @@ bool is_contiguous(Tensor t) { return t.is_contiguous(); } -void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_is_contiguous( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { bool res = is_contiguous(to(stack[0])); stack[0] = from(res); } @@ -266,8 +304,12 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { return transpose(t, dim0, dim1); } -void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_transpose(to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_my_transpose( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto res = my_transpose( + to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(res); } @@ -276,7 +318,10 @@ Tensor my_empty_like(Tensor t) { return empty_like(t); } -void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_empty_like( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { auto res = my_empty_like(to(stack[0])); stack[0] = from(res); } @@ -306,12 +351,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("fill_infinity", &boxed_fill_infinity); } - Tensor my_zero_(Tensor t) { return zero_(t); } -void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { +void boxed_my_zero_( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { auto res = my_zero_(to(stack[0])); stack[0] = from(res); } @@ -326,6 +373,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { // Test functions for torch::stable::accelerator APIs +#ifdef USE_CUDA int test_device_guard(int8_t device_index) { using torch::stable::accelerator::DeviceGuard; @@ -366,3 +414,4 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("test_device_guard", &boxed_test_device_guard); m.impl("test_stream", &boxed_test_stream); } +#endif // USE_CUDA diff --git a/test/cpp_extensions/libtorch_agnostic_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_extension/setup.py index 65d57d2de4dc..9593dcd1a5e0 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/setup.py @@ -4,7 +4,8 @@ from setuptools import find_packages, setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension ROOT_DIR = Path(__file__).parent @@ -33,13 +34,18 @@ def run(self): def get_extension(): extra_compile_args = { "cxx": ["-fdiagnostics-color=always"], - "nvcc": ["-O2"], } + extension = CppExtension + # allow including + if torch.cuda.is_available(): + extra_compile_args["cxx"].append("-DUSE_CUDA") + extension = CUDAExtension + sources = list(CSRC_DIR.glob("**/*.cpp")) return [ - CUDAExtension( + extension( "libtorch_agnostic._C", sources=sorted(str(s) for s in sources), py_limited_api=True, From 0ba77fbcd1f792e8b5f1364b0f1dbfdac0e51a52 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Aug 2025 12:14:15 -0700 Subject: [PATCH 4/9] Update on "Add beginnings of torch::stable::accelerator" [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 152 ++++++------------ 1 file changed, 53 insertions(+), 99 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 79bbf366136f..4d128a86ee77 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,8 +1,7 @@ #include -#include #include -#include #include +#include #include #ifdef USE_CUDA @@ -12,19 +11,19 @@ #include void inline sgd_math( - float* param_ptr, - float* grad_ptr, - float* out_ptr, - const float weight_decay, - const double lr, - const bool maximize, - int64_t size) { + float* param_ptr, + float* grad_ptr, + float* out_ptr, + const float weight_decay, + const double lr, + const bool maximize, + int64_t size +){ int64_t d = 0; for (; d < size; d++) { float grad_val = grad_ptr[d]; - if (maximize) - grad_val = -grad_val; - if (weight_decay != 0.0) { + if (maximize) grad_val = -grad_val; + if (weight_decay != 0.0){ grad_val += param_ptr[d] * weight_decay; } out_ptr[d] = param_ptr[d] - grad_val * float(lr); @@ -41,8 +40,8 @@ Tensor sgd_out_of_place( const bool maximize) { STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); - int64_t* param_sizes; - int64_t* param_strides; + int64_t *param_sizes; + int64_t *param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); aoti_torch_get_strides(param.get(), ¶m_strides); @@ -53,45 +52,35 @@ Tensor sgd_out_of_place( aoti_torch_get_device_type(param.get(), ¶m_device_type); AtenTensorHandle out_ath; - aoti_torch_empty_strided( - param.dim(), - param_sizes, - param_strides, - param_dtype, - param_device_type, - param.get_device(), - &out_ath); + aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); auto out = Tensor(out_ath); sgd_math( - reinterpret_cast(param.data_ptr()), - reinterpret_cast(grad.data_ptr()), - reinterpret_cast(out.data_ptr()), - weight_decay, - lr, - maximize, - param.numel()); + reinterpret_cast(param.data_ptr()), + reinterpret_cast(grad.data_ptr()), + reinterpret_cast(out.data_ptr()), + weight_decay, + lr, + maximize, + param.numel() + ); return out; } -void boxed_sgd_out_of_place( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = sgd_out_of_place( - to(stack[0]), - to(stack[1]), - float(to(stack[2])), - to(stack[3]), - to(stack[4])); + to(stack[0]), + to(stack[1]), + float(to(stack[2])), + to(stack[3]), + to(stack[4])); stack[0] = from(res); } STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { - m.def( - "sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); + m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { @@ -102,10 +91,7 @@ Tensor identity(Tensor t) { return t; } -void boxed_identity( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = identity(to(stack[0])); stack[0] = from(res); } @@ -130,10 +116,7 @@ Tensor my_abs(Tensor t) { return to(stack[0]); } -void boxed_my_abs( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor tensor_res = my_abs(to(stack[0])); stack[0] = from(tensor_res); } @@ -155,21 +138,18 @@ Tensor my_ones_like(Tensor t, StableIValue device) { auto mf = aoti_torch_memory_format_contiguous_format(); stack[0] = from(t); - stack[1] = from(std::optional(t_dtype)); // dtype - stack[2] = from(std::nullopt); // layout - stack[3] = from(std::optional(device)); // device - stack[4] = from(std::optional(false)); // pin_memory - stack[5] = from(std::optional(mf)); // memory_format + stack[1] = from(std::optional(t_dtype)); // dtype + stack[2] = from(std::nullopt); // layout + stack[3] = from(std::optional(device)); // device + stack[4] = from(std::optional(false)); // pin_memory + stack[5] = from(std::optional(mf)); // memory_format aoti_torch_call_dispatcher("aten::ones_like", "", stack); return to(stack[0]); } -void boxed_my_ones_like( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = my_ones_like(to(stack[0]), stack[1]); stack[0] = from(res); } @@ -182,10 +162,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_ones_like", &boxed_my_ones_like); } -std::tuple exp_neg_is_leaf( - Tensor t1, - Tensor t2, - Tensor t3) { +std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { StableIValue stack_exp[1]; stack_exp[0] = from(t1); aoti_torch_call_dispatcher("aten::exp", "", stack_exp); @@ -199,25 +176,20 @@ std::tuple exp_neg_is_leaf( aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); return std::make_tuple( - to(stack_exp[0]), - to(stack_neg[0]), - to(stack_is_leaf[0])); + to(stack_exp[0]), + to(stack_neg[0]), + to(stack_is_leaf[0])); } -void boxed_exp_neg_is_leaf( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { - auto tuple = exp_neg_is_leaf( - to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto tuple = exp_neg_is_leaf(to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(std::get<0>(tuple)); stack[1] = from(std::get<1>(tuple)); stack[2] = from(std::get<2>(tuple)); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { - m.def( - "exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); + m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { @@ -232,10 +204,7 @@ Tensor neg_exp(Tensor t) { return to(stack[0]); } -void boxed_neg_exp( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = neg_exp(to(stack[0])); stack[0] = from(res); } @@ -264,10 +233,7 @@ Tensor divide_neg_exp(Tensor t) { return to(stack_div[0]); } -void boxed_divide_neg_exp( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = divide_neg_exp(to(stack[0])); stack[0] = from(res); } @@ -284,10 +250,7 @@ bool is_contiguous(Tensor t) { return t.is_contiguous(); } -void boxed_is_contiguous( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { bool res = is_contiguous(to(stack[0])); stack[0] = from(res); } @@ -304,12 +267,8 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { return transpose(t, dim0, dim1); } -void boxed_my_transpose( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { - auto res = my_transpose( - to(stack[0]), to(stack[1]), to(stack[2])); +void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_transpose(to(stack[0]), to(stack[1]), to(stack[2])); stack[0] = from(res); } @@ -318,10 +277,7 @@ Tensor my_empty_like(Tensor t) { return empty_like(t); } -void boxed_empty_like( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { auto res = my_empty_like(to(stack[0])); stack[0] = from(res); } @@ -351,14 +307,12 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("fill_infinity", &boxed_fill_infinity); } + Tensor my_zero_(Tensor t) { return zero_(t); } -void boxed_my_zero_( - StableIValue* stack, - uint64_t num_args, - uint64_t num_outputs) { +void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { auto res = my_zero_(to(stack[0])); stack[0] = from(res); } @@ -414,4 +368,4 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("test_device_guard", &boxed_test_device_guard); m.impl("test_stream", &boxed_test_stream); } -#endif // USE_CUDA +#endif // USE_CUDA \ No newline at end of file From 6d437516f576d49af2669768b1db3eb941275d68 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Aug 2025 12:25:27 -0700 Subject: [PATCH 5/9] Update on "Add beginnings of torch::stable::accelerator" [ghstack-poisoned] --- torch/csrc/stable/accelerator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/stable/accelerator.h b/torch/csrc/stable/accelerator.h index 29319b9c5f3f..0a237843f92f 100644 --- a/torch/csrc/stable/accelerator.h +++ b/torch/csrc/stable/accelerator.h @@ -27,7 +27,7 @@ class DeviceGuard { guard_.reset(ptr); } - void set_index(int32_t device_index) { + void set_index(DeviceIndex device_index) { AOTI_TORCH_ERROR_CODE_CHECK( aoti_torch_device_guard_set_index(guard_.get(), device_index)); } From fb964942cd06292ba965d3f6273d11d6a34da583 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Aug 2025 12:29:06 -0700 Subject: [PATCH 6/9] Update on "Add beginnings of torch::stable::accelerator" [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 2 +- torch/csrc/stable/accelerator.h | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 4d128a86ee77..92ccef5b93ba 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -368,4 +368,4 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("test_device_guard", &boxed_test_device_guard); m.impl("test_stream", &boxed_test_stream); } -#endif // USE_CUDA \ No newline at end of file +#endif // USE_CUDA diff --git a/torch/csrc/stable/accelerator.h b/torch/csrc/stable/accelerator.h index 0a237843f92f..d223494261c9 100644 --- a/torch/csrc/stable/accelerator.h +++ b/torch/csrc/stable/accelerator.h @@ -14,8 +14,9 @@ inline void delete_device_guard(void* ptr) { } // namespace -using DeviceIndex = int8_t; -using StreamId = int64_t; +using DeviceIndex = int8_t; // this is from c10/core/Device.h +using StreamId = int64_t; // this is from c10/core/Stream.h + class DeviceGuard { public: explicit DeviceGuard() = delete; From 0e247a1447bbd0a97e5d4fd675d5f5d209a23e77 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Aug 2025 14:21:57 -0700 Subject: [PATCH 7/9] Update on "Add beginnings of torch::stable::accelerator" Adds - `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46 - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device - `set_index(DeviceIndex)` - `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque` - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor) - `id() -> StreamId` - `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream` [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 1 + torch/csrc/stable/accelerator.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 92ccef5b93ba..1405166d9126 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include diff --git a/torch/csrc/stable/accelerator.h b/torch/csrc/stable/accelerator.h index d223494261c9..100d16e71a60 100644 --- a/torch/csrc/stable/accelerator.h +++ b/torch/csrc/stable/accelerator.h @@ -14,7 +14,7 @@ inline void delete_device_guard(void* ptr) { } // namespace -using DeviceIndex = int8_t; // this is from c10/core/Device.h +using DeviceIndex = int8_t; // this is from c10/core/Device.h using StreamId = int64_t; // this is from c10/core/Stream.h class DeviceGuard { From aa47d4fd1b4e0d622fc2e3c2e68c209233fe1b19 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Aug 2025 14:43:38 -0700 Subject: [PATCH 8/9] Update on "Add beginnings of torch::stable::accelerator" Adds - `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46 - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device - `set_index(DeviceIndex)` - `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque` - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor) - `id() -> StreamId` - `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream` [ghstack-poisoned] --- torch/csrc/stable/accelerator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/stable/accelerator.h b/torch/csrc/stable/accelerator.h index 100d16e71a60..8ab3f7740046 100644 --- a/torch/csrc/stable/accelerator.h +++ b/torch/csrc/stable/accelerator.h @@ -14,8 +14,8 @@ inline void delete_device_guard(void* ptr) { } // namespace -using DeviceIndex = int8_t; // this is from c10/core/Device.h -using StreamId = int64_t; // this is from c10/core/Stream.h +using DeviceIndex = int8_t; // this is from c10/core/Device.h +using StreamId = int64_t; // this is from c10/core/Stream.h class DeviceGuard { public: From e28e5f677bfd9a508e0bdeafc7d7d9b6f8f87aca Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Aug 2025 15:12:16 -0700 Subject: [PATCH 9/9] Update on "Add beginnings of torch::stable::accelerator" Adds - `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46 - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device - `set_index(DeviceIndex)` - `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque` - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor) - `id() -> StreamId` - `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream` [ghstack-poisoned] --- .../libtorch_agnostic_extension/libtorch_agnostic/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index cfbd9a290989..0f2471a51a11 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -173,7 +173,7 @@ def test_device_guard(device_index) -> Tensor: Args: device_index: Device index to set the guard to - Returns: A 3x3 empty tensor created on the device specified by device_index + Returns: result of cudaGetDevice() as an integer after using the guard """ return torch.ops.libtorch_agnostic.test_device_guard.default(device_index)