1
1
#include < torch/csrc/inductor/aoti_torch/c/shim.h>
2
+ #include < torch/csrc/stable/accelerator.h>
2
3
#include < torch/csrc/stable/library.h>
3
- #include < torch/csrc/stable/tensor.h>
4
4
#include < torch/csrc/stable/ops.h>
5
+ #include < torch/csrc/stable/tensor.h>
5
6
#include < torch/headeronly/util/Exception.h>
6
7
8
+ #ifdef USE_CUDA
9
+ #include < cuda_runtime.h>
10
+ #endif
11
+
7
12
#include < optional>
8
13
9
14
void inline sgd_math (
10
- float * param_ptr,
11
- float * grad_ptr,
12
- float * out_ptr,
13
- const float weight_decay,
14
- const double lr,
15
- const bool maximize,
16
- int64_t size
17
- ){
15
+ float * param_ptr,
16
+ float * grad_ptr,
17
+ float * out_ptr,
18
+ const float weight_decay,
19
+ const double lr,
20
+ const bool maximize,
21
+ int64_t size) {
18
22
int64_t d = 0 ;
19
23
for (; d < size; d++) {
20
24
float grad_val = grad_ptr[d];
21
- if (maximize) grad_val = -grad_val;
22
- if (weight_decay != 0.0 ){
25
+ if (maximize)
26
+ grad_val = -grad_val;
27
+ if (weight_decay != 0.0 ) {
23
28
grad_val += param_ptr[d] * weight_decay;
24
29
}
25
30
out_ptr[d] = param_ptr[d] - grad_val * float (lr);
@@ -36,8 +41,8 @@ Tensor sgd_out_of_place(
36
41
const bool maximize) {
37
42
STD_TORCH_CHECK (param.dim () == 1 , " param must be 1D" );
38
43
39
- int64_t * param_sizes;
40
- int64_t * param_strides;
44
+ int64_t * param_sizes;
45
+ int64_t * param_strides;
41
46
aoti_torch_get_sizes (param.get (), ¶m_sizes);
42
47
aoti_torch_get_strides (param.get (), ¶m_strides);
43
48
@@ -48,35 +53,45 @@ Tensor sgd_out_of_place(
48
53
aoti_torch_get_device_type (param.get (), ¶m_device_type);
49
54
50
55
AtenTensorHandle out_ath;
51
- aoti_torch_empty_strided (param.dim (), param_sizes, param_strides, param_dtype, param_device_type, param.get_device (), &out_ath);
56
+ aoti_torch_empty_strided (
57
+ param.dim (),
58
+ param_sizes,
59
+ param_strides,
60
+ param_dtype,
61
+ param_device_type,
62
+ param.get_device (),
63
+ &out_ath);
52
64
auto out = Tensor (out_ath);
53
65
54
66
sgd_math (
55
- reinterpret_cast <float *>(param.data_ptr ()),
56
- reinterpret_cast <float *>(grad.data_ptr ()),
57
- reinterpret_cast <float *>(out.data_ptr ()),
58
- weight_decay,
59
- lr,
60
- maximize,
61
- param.numel ()
62
- );
67
+ reinterpret_cast <float *>(param.data_ptr ()),
68
+ reinterpret_cast <float *>(grad.data_ptr ()),
69
+ reinterpret_cast <float *>(out.data_ptr ()),
70
+ weight_decay,
71
+ lr,
72
+ maximize,
73
+ param.numel ());
63
74
64
75
return out;
65
76
}
66
77
67
- void boxed_sgd_out_of_place (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
78
+ void boxed_sgd_out_of_place (
79
+ StableIValue* stack,
80
+ uint64_t num_args,
81
+ uint64_t num_outputs) {
68
82
Tensor res = sgd_out_of_place (
69
- to<Tensor>(stack[0 ]),
70
- to<Tensor>(stack[1 ]),
71
- float (to<double >(stack[2 ])),
72
- to<double >(stack[3 ]),
73
- to<bool >(stack[4 ]));
83
+ to<Tensor>(stack[0 ]),
84
+ to<Tensor>(stack[1 ]),
85
+ float (to<double >(stack[2 ])),
86
+ to<double >(stack[3 ]),
87
+ to<bool >(stack[4 ]));
74
88
75
89
stack[0 ] = from (res);
76
90
}
77
91
78
92
STABLE_TORCH_LIBRARY (libtorch_agnostic, m) {
79
- m.def (" sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor" );
93
+ m.def (
94
+ " sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor" );
80
95
}
81
96
82
97
STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CPU, m) {
@@ -87,7 +102,10 @@ Tensor identity(Tensor t) {
87
102
return t;
88
103
}
89
104
90
- void boxed_identity (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
105
+ void boxed_identity (
106
+ StableIValue* stack,
107
+ uint64_t num_args,
108
+ uint64_t num_outputs) {
91
109
Tensor res = identity (to<Tensor>(stack[0 ]));
92
110
stack[0 ] = from (res);
93
111
}
@@ -112,7 +130,10 @@ Tensor my_abs(Tensor t) {
112
130
return to<Tensor>(stack[0 ]);
113
131
}
114
132
115
- void boxed_my_abs (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
133
+ void boxed_my_abs (
134
+ StableIValue* stack,
135
+ uint64_t num_args,
136
+ uint64_t num_outputs) {
116
137
Tensor tensor_res = my_abs (to<Tensor>(stack[0 ]));
117
138
stack[0 ] = from (tensor_res);
118
139
}
@@ -134,18 +155,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
134
155
auto mf = aoti_torch_memory_format_contiguous_format ();
135
156
136
157
stack[0 ] = from (t);
137
- stack[1 ] = from (std::optional (t_dtype)); // dtype
138
- stack[2 ] = from (std::nullopt); // layout
139
- stack[3 ] = from (std::optional (device)); // device
140
- stack[4 ] = from (std::optional (false )); // pin_memory
141
- stack[5 ] = from (std::optional (mf)); // memory_format
158
+ stack[1 ] = from (std::optional (t_dtype)); // dtype
159
+ stack[2 ] = from (std::nullopt); // layout
160
+ stack[3 ] = from (std::optional (device)); // device
161
+ stack[4 ] = from (std::optional (false )); // pin_memory
162
+ stack[5 ] = from (std::optional (mf)); // memory_format
142
163
143
164
aoti_torch_call_dispatcher (" aten::ones_like" , " " , stack);
144
165
145
166
return to<Tensor>(stack[0 ]);
146
167
}
147
168
148
- void boxed_my_ones_like (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
169
+ void boxed_my_ones_like (
170
+ StableIValue* stack,
171
+ uint64_t num_args,
172
+ uint64_t num_outputs) {
149
173
Tensor res = my_ones_like (to<Tensor>(stack[0 ]), stack[1 ]);
150
174
stack[0 ] = from (res);
151
175
}
@@ -158,7 +182,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
158
182
m.impl (" my_ones_like" , &boxed_my_ones_like);
159
183
}
160
184
161
- std::tuple<Tensor, Tensor, bool > exp_neg_is_leaf (Tensor t1, Tensor t2, Tensor t3) {
185
+ std::tuple<Tensor, Tensor, bool > exp_neg_is_leaf (
186
+ Tensor t1,
187
+ Tensor t2,
188
+ Tensor t3) {
162
189
StableIValue stack_exp[1 ];
163
190
stack_exp[0 ] = from (t1);
164
191
aoti_torch_call_dispatcher (" aten::exp" , " " , stack_exp);
@@ -172,20 +199,25 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
172
199
aoti_torch_call_dispatcher (" aten::is_leaf" , " " , stack_is_leaf);
173
200
174
201
return std::make_tuple (
175
- to<Tensor>(stack_exp[0 ]),
176
- to<Tensor>(stack_neg[0 ]),
177
- to<bool >(stack_is_leaf[0 ]));
202
+ to<Tensor>(stack_exp[0 ]),
203
+ to<Tensor>(stack_neg[0 ]),
204
+ to<bool >(stack_is_leaf[0 ]));
178
205
}
179
206
180
- void boxed_exp_neg_is_leaf (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
181
- auto tuple = exp_neg_is_leaf (to<Tensor>(stack[0 ]), to<Tensor>(stack[1 ]), to<Tensor>(stack[2 ]));
207
+ void boxed_exp_neg_is_leaf (
208
+ StableIValue* stack,
209
+ uint64_t num_args,
210
+ uint64_t num_outputs) {
211
+ auto tuple = exp_neg_is_leaf (
212
+ to<Tensor>(stack[0 ]), to<Tensor>(stack[1 ]), to<Tensor>(stack[2 ]));
182
213
stack[0 ] = from (std::get<0 >(tuple));
183
214
stack[1 ] = from (std::get<1 >(tuple));
184
215
stack[2 ] = from (std::get<2 >(tuple));
185
216
}
186
217
187
218
STABLE_TORCH_LIBRARY_FRAGMENT (libtorch_agnostic, m) {
188
- m.def (" exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)" );
219
+ m.def (
220
+ " exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)" );
189
221
}
190
222
191
223
STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CompositeExplicitAutograd, m) {
@@ -200,7 +232,10 @@ Tensor neg_exp(Tensor t) {
200
232
return to<Tensor>(stack[0 ]);
201
233
}
202
234
203
- void boxed_neg_exp (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
235
+ void boxed_neg_exp (
236
+ StableIValue* stack,
237
+ uint64_t num_args,
238
+ uint64_t num_outputs) {
204
239
Tensor res = neg_exp (to<Tensor>(stack[0 ]));
205
240
stack[0 ] = from (res);
206
241
}
@@ -229,7 +264,10 @@ Tensor divide_neg_exp(Tensor t) {
229
264
return to<Tensor>(stack_div[0 ]);
230
265
}
231
266
232
- void boxed_divide_neg_exp (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267
+ void boxed_divide_neg_exp (
268
+ StableIValue* stack,
269
+ uint64_t num_args,
270
+ uint64_t num_outputs) {
233
271
Tensor res = divide_neg_exp (to<Tensor>(stack[0 ]));
234
272
stack[0 ] = from (res);
235
273
}
@@ -246,7 +284,10 @@ bool is_contiguous(Tensor t) {
246
284
return t.is_contiguous ();
247
285
}
248
286
249
- void boxed_is_contiguous (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
287
+ void boxed_is_contiguous (
288
+ StableIValue* stack,
289
+ uint64_t num_args,
290
+ uint64_t num_outputs) {
250
291
bool res = is_contiguous (to<Tensor>(stack[0 ]));
251
292
stack[0 ] = from (res);
252
293
}
@@ -263,8 +304,12 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
263
304
return transpose (t, dim0, dim1);
264
305
}
265
306
266
- void boxed_my_transpose (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267
- auto res = my_transpose (to<Tensor>(stack[0 ]), to<int64_t >(stack[1 ]), to<int64_t >(stack[2 ]));
307
+ void boxed_my_transpose (
308
+ StableIValue* stack,
309
+ uint64_t num_args,
310
+ uint64_t num_outputs) {
311
+ auto res = my_transpose (
312
+ to<Tensor>(stack[0 ]), to<int64_t >(stack[1 ]), to<int64_t >(stack[2 ]));
268
313
269
314
stack[0 ] = from (res);
270
315
}
@@ -273,7 +318,10 @@ Tensor my_empty_like(Tensor t) {
273
318
return empty_like (t);
274
319
}
275
320
276
- void boxed_empty_like (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
321
+ void boxed_empty_like (
322
+ StableIValue* stack,
323
+ uint64_t num_args,
324
+ uint64_t num_outputs) {
277
325
auto res = my_empty_like (to<Tensor>(stack[0 ]));
278
326
stack[0 ] = from (res);
279
327
}
@@ -303,12 +351,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
303
351
m.impl (" fill_infinity" , &boxed_fill_infinity);
304
352
}
305
353
306
-
307
354
Tensor my_zero_ (Tensor t) {
308
355
return zero_ (t);
309
356
}
310
357
311
- void boxed_my_zero_ (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
358
+ void boxed_my_zero_ (
359
+ StableIValue* stack,
360
+ uint64_t num_args,
361
+ uint64_t num_outputs) {
312
362
auto res = my_zero_ (to<Tensor>(stack[0 ]));
313
363
stack[0 ] = from (res);
314
364
}
@@ -320,3 +370,48 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
320
370
STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CPU, m) {
321
371
m.impl (" my_zero_" , &boxed_my_zero_);
322
372
}
373
+
374
+ // Test functions for torch::stable::accelerator APIs
375
+
376
+ #ifdef USE_CUDA
377
+ int test_device_guard (int8_t device_index) {
378
+ using torch::stable::accelerator::DeviceGuard;
379
+
380
+ DeviceGuard guard (device_index);
381
+ int currentDevice;
382
+ cudaError_t err = cudaGetDevice (¤tDevice);
383
+ STD_TORCH_CHECK (err == cudaSuccess);
384
+ return currentDevice;
385
+ }
386
+
387
+ void boxed_test_device_guard (
388
+ StableIValue* stack,
389
+ uint64_t num_args,
390
+ uint64_t num_outputs) {
391
+ int res = test_device_guard (static_cast <int8_t >(to<int64_t >(stack[0 ])));
392
+ stack[0 ] = from (res);
393
+ }
394
+
395
+ int64_t test_stream (int8_t device_index) {
396
+ auto id = torch::stable::accelerator::getCurrentStream (device_index).id ();
397
+ return id;
398
+ }
399
+
400
+ void boxed_test_stream (
401
+ StableIValue* stack,
402
+ uint64_t num_args,
403
+ uint64_t num_outputs) {
404
+ int64_t res = test_stream (static_cast <int8_t >(to<int64_t >(stack[0 ])));
405
+ stack[0 ] = from (res);
406
+ }
407
+
408
+ STABLE_TORCH_LIBRARY_FRAGMENT (libtorch_agnostic, m) {
409
+ m.def (" test_device_guard(int device_index) -> int" );
410
+ m.def (" test_stream(int device_index) -> int" );
411
+ }
412
+
413
+ STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CompositeExplicitAutograd, m) {
414
+ m.impl (" test_device_guard" , &boxed_test_device_guard);
415
+ m.impl (" test_stream" , &boxed_test_stream);
416
+ }
417
+ #endif // USE_CUDA
0 commit comments