Skip to content

Commit 86dae2c

Browse files
Add beginning of torch::stable::accelerator
ghstack-source-id: 1530b24 Pull Request resolved: #159679
1 parent 08ea8fc commit 86dae2c

File tree

7 files changed

+344
-53
lines changed

7 files changed

+344
-53
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 143 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
2+
#include <torch/csrc/stable/accelerator.h>
23
#include <torch/csrc/stable/library.h>
3-
#include <torch/csrc/stable/tensor.h>
44
#include <torch/csrc/stable/ops.h>
5+
#include <torch/csrc/stable/tensor.h>
56
#include <torch/headeronly/util/Exception.h>
67

78
#include <optional>
89

10+
#include <cuda_runtime.h>
11+
912
void inline sgd_math(
10-
float* param_ptr,
11-
float* grad_ptr,
12-
float* out_ptr,
13-
const float weight_decay,
14-
const double lr,
15-
const bool maximize,
16-
int64_t size
17-
){
13+
float* param_ptr,
14+
float* grad_ptr,
15+
float* out_ptr,
16+
const float weight_decay,
17+
const double lr,
18+
const bool maximize,
19+
int64_t size) {
1820
int64_t d = 0;
1921
for (; d < size; d++) {
2022
float grad_val = grad_ptr[d];
21-
if (maximize) grad_val = -grad_val;
22-
if (weight_decay != 0.0){
23+
if (maximize)
24+
grad_val = -grad_val;
25+
if (weight_decay != 0.0) {
2326
grad_val += param_ptr[d] * weight_decay;
2427
}
2528
out_ptr[d] = param_ptr[d] - grad_val * float(lr);
@@ -36,8 +39,8 @@ Tensor sgd_out_of_place(
3639
const bool maximize) {
3740
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");
3841

39-
int64_t *param_sizes;
40-
int64_t *param_strides;
42+
int64_t* param_sizes;
43+
int64_t* param_strides;
4144
aoti_torch_get_sizes(param.get(), &param_sizes);
4245
aoti_torch_get_strides(param.get(), &param_strides);
4346

@@ -48,35 +51,45 @@ Tensor sgd_out_of_place(
4851
aoti_torch_get_device_type(param.get(), &param_device_type);
4952

5053
AtenTensorHandle out_ath;
51-
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
54+
aoti_torch_empty_strided(
55+
param.dim(),
56+
param_sizes,
57+
param_strides,
58+
param_dtype,
59+
param_device_type,
60+
param.get_device(),
61+
&out_ath);
5262
auto out = Tensor(out_ath);
5363

5464
sgd_math(
55-
reinterpret_cast<float*>(param.data_ptr()),
56-
reinterpret_cast<float*>(grad.data_ptr()),
57-
reinterpret_cast<float*>(out.data_ptr()),
58-
weight_decay,
59-
lr,
60-
maximize,
61-
param.numel()
62-
);
65+
reinterpret_cast<float*>(param.data_ptr()),
66+
reinterpret_cast<float*>(grad.data_ptr()),
67+
reinterpret_cast<float*>(out.data_ptr()),
68+
weight_decay,
69+
lr,
70+
maximize,
71+
param.numel());
6372

6473
return out;
6574
}
6675

67-
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
76+
void boxed_sgd_out_of_place(
77+
StableIValue* stack,
78+
uint64_t num_args,
79+
uint64_t num_outputs) {
6880
Tensor res = sgd_out_of_place(
69-
to<Tensor>(stack[0]),
70-
to<Tensor>(stack[1]),
71-
float(to<double>(stack[2])),
72-
to<double>(stack[3]),
73-
to<bool>(stack[4]));
81+
to<Tensor>(stack[0]),
82+
to<Tensor>(stack[1]),
83+
float(to<double>(stack[2])),
84+
to<double>(stack[3]),
85+
to<bool>(stack[4]));
7486

7587
stack[0] = from(res);
7688
}
7789

7890
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
79-
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
91+
m.def(
92+
"sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
8093
}
8194

8295
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
@@ -87,7 +100,10 @@ Tensor identity(Tensor t) {
87100
return t;
88101
}
89102

90-
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
103+
void boxed_identity(
104+
StableIValue* stack,
105+
uint64_t num_args,
106+
uint64_t num_outputs) {
91107
Tensor res = identity(to<Tensor>(stack[0]));
92108
stack[0] = from(res);
93109
}
@@ -112,7 +128,10 @@ Tensor my_abs(Tensor t) {
112128
return to<Tensor>(stack[0]);
113129
}
114130

115-
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
131+
void boxed_my_abs(
132+
StableIValue* stack,
133+
uint64_t num_args,
134+
uint64_t num_outputs) {
116135
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
117136
stack[0] = from(tensor_res);
118137
}
@@ -134,18 +153,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
134153
auto mf = aoti_torch_memory_format_contiguous_format();
135154

136155
stack[0] = from(t);
137-
stack[1] = from(std::optional(t_dtype)); // dtype
138-
stack[2] = from(std::nullopt); // layout
139-
stack[3] = from(std::optional(device)); // device
140-
stack[4] = from(std::optional(false)); // pin_memory
141-
stack[5] = from(std::optional(mf)); // memory_format
156+
stack[1] = from(std::optional(t_dtype)); // dtype
157+
stack[2] = from(std::nullopt); // layout
158+
stack[3] = from(std::optional(device)); // device
159+
stack[4] = from(std::optional(false)); // pin_memory
160+
stack[5] = from(std::optional(mf)); // memory_format
142161

143162
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
144163

145164
return to<Tensor>(stack[0]);
146165
}
147166

148-
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
167+
void boxed_my_ones_like(
168+
StableIValue* stack,
169+
uint64_t num_args,
170+
uint64_t num_outputs) {
149171
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
150172
stack[0] = from(res);
151173
}
@@ -158,7 +180,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
158180
m.impl("my_ones_like", &boxed_my_ones_like);
159181
}
160182

161-
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
183+
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(
184+
Tensor t1,
185+
Tensor t2,
186+
Tensor t3) {
162187
StableIValue stack_exp[1];
163188
stack_exp[0] = from(t1);
164189
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
@@ -172,20 +197,25 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
172197
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
173198

174199
return std::make_tuple(
175-
to<Tensor>(stack_exp[0]),
176-
to<Tensor>(stack_neg[0]),
177-
to<bool>(stack_is_leaf[0]));
200+
to<Tensor>(stack_exp[0]),
201+
to<Tensor>(stack_neg[0]),
202+
to<bool>(stack_is_leaf[0]));
178203
}
179204

180-
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
181-
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
205+
void boxed_exp_neg_is_leaf(
206+
StableIValue* stack,
207+
uint64_t num_args,
208+
uint64_t num_outputs) {
209+
auto tuple = exp_neg_is_leaf(
210+
to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
182211
stack[0] = from(std::get<0>(tuple));
183212
stack[1] = from(std::get<1>(tuple));
184213
stack[2] = from(std::get<2>(tuple));
185214
}
186215

187216
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
188-
m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
217+
m.def(
218+
"exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
189219
}
190220

191221
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
@@ -200,7 +230,10 @@ Tensor neg_exp(Tensor t) {
200230
return to<Tensor>(stack[0]);
201231
}
202232

203-
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
233+
void boxed_neg_exp(
234+
StableIValue* stack,
235+
uint64_t num_args,
236+
uint64_t num_outputs) {
204237
Tensor res = neg_exp(to<Tensor>(stack[0]));
205238
stack[0] = from(res);
206239
}
@@ -229,7 +262,10 @@ Tensor divide_neg_exp(Tensor t) {
229262
return to<Tensor>(stack_div[0]);
230263
}
231264

232-
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
265+
void boxed_divide_neg_exp(
266+
StableIValue* stack,
267+
uint64_t num_args,
268+
uint64_t num_outputs) {
233269
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
234270
stack[0] = from(res);
235271
}
@@ -246,7 +282,10 @@ bool is_contiguous(Tensor t) {
246282
return t.is_contiguous();
247283
}
248284

249-
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
285+
void boxed_is_contiguous(
286+
StableIValue* stack,
287+
uint64_t num_args,
288+
uint64_t num_outputs) {
250289
bool res = is_contiguous(to<Tensor>(stack[0]));
251290
stack[0] = from(res);
252291
}
@@ -263,8 +302,12 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
263302
return transpose(t, dim0, dim1);
264303
}
265304

266-
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267-
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
305+
void boxed_my_transpose(
306+
StableIValue* stack,
307+
uint64_t num_args,
308+
uint64_t num_outputs) {
309+
auto res = my_transpose(
310+
to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
268311

269312
stack[0] = from(res);
270313
}
@@ -273,7 +316,10 @@ Tensor my_empty_like(Tensor t) {
273316
return empty_like(t);
274317
}
275318

276-
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
319+
void boxed_empty_like(
320+
StableIValue* stack,
321+
uint64_t num_args,
322+
uint64_t num_outputs) {
277323
auto res = my_empty_like(to<Tensor>(stack[0]));
278324
stack[0] = from(res);
279325
}
@@ -308,7 +354,10 @@ Tensor my_zero_(Tensor t) {
308354
return zero_(t);
309355
}
310356

311-
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
357+
void boxed_my_zero_(
358+
StableIValue* stack,
359+
uint64_t num_args,
360+
uint64_t num_outputs) {
312361
auto res = my_zero_(to<Tensor>(stack[0]));
313362
stack[0] = from(res);
314363
}
@@ -320,3 +369,46 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
320369
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
321370
m.impl("my_zero_", &boxed_my_zero_);
322371
}
372+
373+
// Test functions for torch::stable::accelerator APIs
374+
375+
int test_device_guard(int8_t device_index) {
376+
using torch::stable::accelerator::DeviceGuard;
377+
378+
DeviceGuard guard(device_index);
379+
int currentDevice;
380+
cudaError_t err = cudaGetDevice(&currentDevice);
381+
STD_TORCH_CHECK(err == cudaSuccess);
382+
return currentDevice;
383+
}
384+
385+
void boxed_test_device_guard(
386+
StableIValue* stack,
387+
uint64_t num_args,
388+
uint64_t num_outputs) {
389+
int res = test_device_guard(static_cast<int8_t>(to<int64_t>(stack[0])));
390+
stack[0] = from(res);
391+
}
392+
393+
int64_t test_stream(int8_t device_index) {
394+
auto id = torch::stable::accelerator::getCurrentStream(device_index).id();
395+
return id;
396+
}
397+
398+
void boxed_test_stream(
399+
StableIValue* stack,
400+
uint64_t num_args,
401+
uint64_t num_outputs) {
402+
int64_t res = test_stream(static_cast<int8_t>(to<int64_t>(stack[0])));
403+
stack[0] = from(res);
404+
}
405+
406+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
407+
m.def("test_device_guard(int device_index) -> int");
408+
m.def("test_stream(int device_index) -> int");
409+
}
410+
411+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
412+
m.impl("test_device_guard", &boxed_test_device_guard);
413+
m.impl("test_stream", &boxed_test_stream);
414+
}

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,26 @@ def fill_infinity(t) -> Tensor:
164164
Returns: The modified tensor (same as input)
165165
"""
166166
return torch.ops.libtorch_agnostic.fill_infinity.default(t)
167+
168+
def test_device_guard(device_index) -> Tensor:
169+
"""
170+
Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor.
171+
172+
Args:
173+
device_index: Device index to set the guard to
174+
175+
Returns: A 3x3 empty tensor created on the device specified by device_index
176+
"""
177+
return torch.ops.libtorch_agnostic.test_device_guard.default(device_index)
178+
179+
180+
def test_stream(device_index) -> int:
181+
"""
182+
Tests the Stream functionality by getting the current stream ID for the specified device.
183+
184+
Args:
185+
device_index: Device index to get the stream for
186+
187+
Returns: Stream ID as an integer
188+
"""
189+
return torch.ops.libtorch_agnostic.test_stream.default(device_index)

test/cpp_extensions/libtorch_agnostic_extension/setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from setuptools import find_packages, setup
66

7-
from torch.utils.cpp_extension import BuildExtension, CppExtension
7+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
88

99

1010
ROOT_DIR = Path(__file__).parent
@@ -33,12 +33,13 @@ def run(self):
3333
def get_extension():
3434
extra_compile_args = {
3535
"cxx": ["-fdiagnostics-color=always"],
36+
"nvcc": ["-O2"],
3637
}
3738

3839
sources = list(CSRC_DIR.glob("**/*.cpp"))
3940

4041
return [
41-
CppExtension(
42+
CUDAExtension(
4243
"libtorch_agnostic._C",
4344
sources=sorted(str(s) for s in sources),
4445
py_limited_api=True,

0 commit comments

Comments
 (0)