Skip to content

Commit b922dbd

Browse files
Add beginning of torch::stable::accelerator
ghstack-source-id: fae00fc Pull Request resolved: #159679
1 parent 08ea8fc commit b922dbd

File tree

7 files changed

+355
-54
lines changed

7 files changed

+355
-54
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 147 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
2+
#include <torch/csrc/stable/accelerator.h>
23
#include <torch/csrc/stable/library.h>
3-
#include <torch/csrc/stable/tensor.h>
44
#include <torch/csrc/stable/ops.h>
5+
#include <torch/csrc/stable/tensor.h>
56
#include <torch/headeronly/util/Exception.h>
67

8+
#ifdef USE_CUDA
9+
#include <cuda_runtime.h>
10+
#endif
11+
712
#include <optional>
813

914
void inline sgd_math(
10-
float* param_ptr,
11-
float* grad_ptr,
12-
float* out_ptr,
13-
const float weight_decay,
14-
const double lr,
15-
const bool maximize,
16-
int64_t size
17-
){
15+
float* param_ptr,
16+
float* grad_ptr,
17+
float* out_ptr,
18+
const float weight_decay,
19+
const double lr,
20+
const bool maximize,
21+
int64_t size) {
1822
int64_t d = 0;
1923
for (; d < size; d++) {
2024
float grad_val = grad_ptr[d];
21-
if (maximize) grad_val = -grad_val;
22-
if (weight_decay != 0.0){
25+
if (maximize)
26+
grad_val = -grad_val;
27+
if (weight_decay != 0.0) {
2328
grad_val += param_ptr[d] * weight_decay;
2429
}
2530
out_ptr[d] = param_ptr[d] - grad_val * float(lr);
@@ -36,8 +41,8 @@ Tensor sgd_out_of_place(
3641
const bool maximize) {
3742
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");
3843

39-
int64_t *param_sizes;
40-
int64_t *param_strides;
44+
int64_t* param_sizes;
45+
int64_t* param_strides;
4146
aoti_torch_get_sizes(param.get(), &param_sizes);
4247
aoti_torch_get_strides(param.get(), &param_strides);
4348

@@ -48,35 +53,45 @@ Tensor sgd_out_of_place(
4853
aoti_torch_get_device_type(param.get(), &param_device_type);
4954

5055
AtenTensorHandle out_ath;
51-
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
56+
aoti_torch_empty_strided(
57+
param.dim(),
58+
param_sizes,
59+
param_strides,
60+
param_dtype,
61+
param_device_type,
62+
param.get_device(),
63+
&out_ath);
5264
auto out = Tensor(out_ath);
5365

5466
sgd_math(
55-
reinterpret_cast<float*>(param.data_ptr()),
56-
reinterpret_cast<float*>(grad.data_ptr()),
57-
reinterpret_cast<float*>(out.data_ptr()),
58-
weight_decay,
59-
lr,
60-
maximize,
61-
param.numel()
62-
);
67+
reinterpret_cast<float*>(param.data_ptr()),
68+
reinterpret_cast<float*>(grad.data_ptr()),
69+
reinterpret_cast<float*>(out.data_ptr()),
70+
weight_decay,
71+
lr,
72+
maximize,
73+
param.numel());
6374

6475
return out;
6576
}
6677

67-
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
78+
void boxed_sgd_out_of_place(
79+
StableIValue* stack,
80+
uint64_t num_args,
81+
uint64_t num_outputs) {
6882
Tensor res = sgd_out_of_place(
69-
to<Tensor>(stack[0]),
70-
to<Tensor>(stack[1]),
71-
float(to<double>(stack[2])),
72-
to<double>(stack[3]),
73-
to<bool>(stack[4]));
83+
to<Tensor>(stack[0]),
84+
to<Tensor>(stack[1]),
85+
float(to<double>(stack[2])),
86+
to<double>(stack[3]),
87+
to<bool>(stack[4]));
7488

7589
stack[0] = from(res);
7690
}
7791

7892
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
79-
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
93+
m.def(
94+
"sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
8095
}
8196

8297
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
@@ -87,7 +102,10 @@ Tensor identity(Tensor t) {
87102
return t;
88103
}
89104

90-
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
105+
void boxed_identity(
106+
StableIValue* stack,
107+
uint64_t num_args,
108+
uint64_t num_outputs) {
91109
Tensor res = identity(to<Tensor>(stack[0]));
92110
stack[0] = from(res);
93111
}
@@ -112,7 +130,10 @@ Tensor my_abs(Tensor t) {
112130
return to<Tensor>(stack[0]);
113131
}
114132

115-
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
133+
void boxed_my_abs(
134+
StableIValue* stack,
135+
uint64_t num_args,
136+
uint64_t num_outputs) {
116137
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
117138
stack[0] = from(tensor_res);
118139
}
@@ -134,18 +155,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
134155
auto mf = aoti_torch_memory_format_contiguous_format();
135156

136157
stack[0] = from(t);
137-
stack[1] = from(std::optional(t_dtype)); // dtype
138-
stack[2] = from(std::nullopt); // layout
139-
stack[3] = from(std::optional(device)); // device
140-
stack[4] = from(std::optional(false)); // pin_memory
141-
stack[5] = from(std::optional(mf)); // memory_format
158+
stack[1] = from(std::optional(t_dtype)); // dtype
159+
stack[2] = from(std::nullopt); // layout
160+
stack[3] = from(std::optional(device)); // device
161+
stack[4] = from(std::optional(false)); // pin_memory
162+
stack[5] = from(std::optional(mf)); // memory_format
142163

143164
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
144165

145166
return to<Tensor>(stack[0]);
146167
}
147168

148-
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
169+
void boxed_my_ones_like(
170+
StableIValue* stack,
171+
uint64_t num_args,
172+
uint64_t num_outputs) {
149173
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
150174
stack[0] = from(res);
151175
}
@@ -158,7 +182,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
158182
m.impl("my_ones_like", &boxed_my_ones_like);
159183
}
160184

161-
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
185+
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(
186+
Tensor t1,
187+
Tensor t2,
188+
Tensor t3) {
162189
StableIValue stack_exp[1];
163190
stack_exp[0] = from(t1);
164191
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
@@ -172,20 +199,25 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
172199
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
173200

174201
return std::make_tuple(
175-
to<Tensor>(stack_exp[0]),
176-
to<Tensor>(stack_neg[0]),
177-
to<bool>(stack_is_leaf[0]));
202+
to<Tensor>(stack_exp[0]),
203+
to<Tensor>(stack_neg[0]),
204+
to<bool>(stack_is_leaf[0]));
178205
}
179206

180-
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
181-
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
207+
void boxed_exp_neg_is_leaf(
208+
StableIValue* stack,
209+
uint64_t num_args,
210+
uint64_t num_outputs) {
211+
auto tuple = exp_neg_is_leaf(
212+
to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
182213
stack[0] = from(std::get<0>(tuple));
183214
stack[1] = from(std::get<1>(tuple));
184215
stack[2] = from(std::get<2>(tuple));
185216
}
186217

187218
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
188-
m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
219+
m.def(
220+
"exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
189221
}
190222

191223
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
@@ -200,7 +232,10 @@ Tensor neg_exp(Tensor t) {
200232
return to<Tensor>(stack[0]);
201233
}
202234

203-
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
235+
void boxed_neg_exp(
236+
StableIValue* stack,
237+
uint64_t num_args,
238+
uint64_t num_outputs) {
204239
Tensor res = neg_exp(to<Tensor>(stack[0]));
205240
stack[0] = from(res);
206241
}
@@ -229,7 +264,10 @@ Tensor divide_neg_exp(Tensor t) {
229264
return to<Tensor>(stack_div[0]);
230265
}
231266

232-
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267+
void boxed_divide_neg_exp(
268+
StableIValue* stack,
269+
uint64_t num_args,
270+
uint64_t num_outputs) {
233271
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
234272
stack[0] = from(res);
235273
}
@@ -246,7 +284,10 @@ bool is_contiguous(Tensor t) {
246284
return t.is_contiguous();
247285
}
248286

249-
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
287+
void boxed_is_contiguous(
288+
StableIValue* stack,
289+
uint64_t num_args,
290+
uint64_t num_outputs) {
250291
bool res = is_contiguous(to<Tensor>(stack[0]));
251292
stack[0] = from(res);
252293
}
@@ -263,8 +304,12 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
263304
return transpose(t, dim0, dim1);
264305
}
265306

266-
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267-
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
307+
void boxed_my_transpose(
308+
StableIValue* stack,
309+
uint64_t num_args,
310+
uint64_t num_outputs) {
311+
auto res = my_transpose(
312+
to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
268313

269314
stack[0] = from(res);
270315
}
@@ -273,7 +318,10 @@ Tensor my_empty_like(Tensor t) {
273318
return empty_like(t);
274319
}
275320

276-
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
321+
void boxed_empty_like(
322+
StableIValue* stack,
323+
uint64_t num_args,
324+
uint64_t num_outputs) {
277325
auto res = my_empty_like(to<Tensor>(stack[0]));
278326
stack[0] = from(res);
279327
}
@@ -303,12 +351,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
303351
m.impl("fill_infinity", &boxed_fill_infinity);
304352
}
305353

306-
307354
Tensor my_zero_(Tensor t) {
308355
return zero_(t);
309356
}
310357

311-
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
358+
void boxed_my_zero_(
359+
StableIValue* stack,
360+
uint64_t num_args,
361+
uint64_t num_outputs) {
312362
auto res = my_zero_(to<Tensor>(stack[0]));
313363
stack[0] = from(res);
314364
}
@@ -320,3 +370,48 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
320370
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
321371
m.impl("my_zero_", &boxed_my_zero_);
322372
}
373+
374+
// Test functions for torch::stable::accelerator APIs
375+
376+
#ifdef USE_CUDA
377+
int test_device_guard(int8_t device_index) {
378+
using torch::stable::accelerator::DeviceGuard;
379+
380+
DeviceGuard guard(device_index);
381+
int currentDevice;
382+
cudaError_t err = cudaGetDevice(&currentDevice);
383+
STD_TORCH_CHECK(err == cudaSuccess);
384+
return currentDevice;
385+
}
386+
387+
void boxed_test_device_guard(
388+
StableIValue* stack,
389+
uint64_t num_args,
390+
uint64_t num_outputs) {
391+
int res = test_device_guard(static_cast<int8_t>(to<int64_t>(stack[0])));
392+
stack[0] = from(res);
393+
}
394+
395+
int64_t test_stream(int8_t device_index) {
396+
auto id = torch::stable::accelerator::getCurrentStream(device_index).id();
397+
return id;
398+
}
399+
400+
void boxed_test_stream(
401+
StableIValue* stack,
402+
uint64_t num_args,
403+
uint64_t num_outputs) {
404+
int64_t res = test_stream(static_cast<int8_t>(to<int64_t>(stack[0])));
405+
stack[0] = from(res);
406+
}
407+
408+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
409+
m.def("test_device_guard(int device_index) -> int");
410+
m.def("test_stream(int device_index) -> int");
411+
}
412+
413+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
414+
m.impl("test_device_guard", &boxed_test_device_guard);
415+
m.impl("test_stream", &boxed_test_stream);
416+
}
417+
#endif // USE_CUDA

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,27 @@ def fill_infinity(t) -> Tensor:
164164
Returns: The modified tensor (same as input)
165165
"""
166166
return torch.ops.libtorch_agnostic.fill_infinity.default(t)
167+
168+
169+
def test_device_guard(device_index) -> Tensor:
170+
"""
171+
Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor.
172+
173+
Args:
174+
device_index: Device index to set the guard to
175+
176+
Returns: A 3x3 empty tensor created on the device specified by device_index
177+
"""
178+
return torch.ops.libtorch_agnostic.test_device_guard.default(device_index)
179+
180+
181+
def test_stream(device_index) -> int:
182+
"""
183+
Tests the Stream functionality by getting the current stream ID for the specified device.
184+
185+
Args:
186+
device_index: Device index to get the stream for
187+
188+
Returns: Stream ID as an integer
189+
"""
190+
return torch.ops.libtorch_agnostic.test_stream.default(device_index)

test/cpp_extensions/libtorch_agnostic_extension/setup.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from setuptools import find_packages, setup
66

7-
from torch.utils.cpp_extension import BuildExtension, CppExtension
7+
import torch
8+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
89

910

1011
ROOT_DIR = Path(__file__).parent
@@ -35,10 +36,16 @@ def get_extension():
3536
"cxx": ["-fdiagnostics-color=always"],
3637
}
3738

39+
extension = CppExtension
40+
# allow including <cuda_runtime.h>
41+
if torch.cuda.is_available():
42+
extra_compile_args["cxx"].append("-DUSE_CUDA")
43+
extension = CUDAExtension
44+
3845
sources = list(CSRC_DIR.glob("**/*.cpp"))
3946

4047
return [
41-
CppExtension(
48+
extension(
4249
"libtorch_agnostic._C",
4350
sources=sorted(str(s) for s in sources),
4451
py_limited_api=True,

0 commit comments

Comments
 (0)