1
1
#include < torch/csrc/inductor/aoti_torch/c/shim.h>
2
+ #include < torch/csrc/stable/accelerator.h>
2
3
#include < torch/csrc/stable/library.h>
3
- #include < torch/csrc/stable/tensor.h>
4
4
#include < torch/csrc/stable/ops.h>
5
+ #include < torch/csrc/stable/tensor.h>
5
6
#include < torch/headeronly/util/Exception.h>
6
7
7
8
#include < optional>
8
9
10
+ #include < cuda_runtime.h>
11
+
9
12
void inline sgd_math (
10
- float * param_ptr,
11
- float * grad_ptr,
12
- float * out_ptr,
13
- const float weight_decay,
14
- const double lr,
15
- const bool maximize,
16
- int64_t size
17
- ){
13
+ float * param_ptr,
14
+ float * grad_ptr,
15
+ float * out_ptr,
16
+ const float weight_decay,
17
+ const double lr,
18
+ const bool maximize,
19
+ int64_t size) {
18
20
int64_t d = 0 ;
19
21
for (; d < size; d++) {
20
22
float grad_val = grad_ptr[d];
21
- if (maximize) grad_val = -grad_val;
22
- if (weight_decay != 0.0 ){
23
+ if (maximize)
24
+ grad_val = -grad_val;
25
+ if (weight_decay != 0.0 ) {
23
26
grad_val += param_ptr[d] * weight_decay;
24
27
}
25
28
out_ptr[d] = param_ptr[d] - grad_val * float (lr);
@@ -36,8 +39,8 @@ Tensor sgd_out_of_place(
36
39
const bool maximize) {
37
40
STD_TORCH_CHECK (param.dim () == 1 , " param must be 1D" );
38
41
39
- int64_t * param_sizes;
40
- int64_t * param_strides;
42
+ int64_t * param_sizes;
43
+ int64_t * param_strides;
41
44
aoti_torch_get_sizes (param.get (), ¶m_sizes);
42
45
aoti_torch_get_strides (param.get (), ¶m_strides);
43
46
@@ -48,35 +51,45 @@ Tensor sgd_out_of_place(
48
51
aoti_torch_get_device_type (param.get (), ¶m_device_type);
49
52
50
53
AtenTensorHandle out_ath;
51
- aoti_torch_empty_strided (param.dim (), param_sizes, param_strides, param_dtype, param_device_type, param.get_device (), &out_ath);
54
+ aoti_torch_empty_strided (
55
+ param.dim (),
56
+ param_sizes,
57
+ param_strides,
58
+ param_dtype,
59
+ param_device_type,
60
+ param.get_device (),
61
+ &out_ath);
52
62
auto out = Tensor (out_ath);
53
63
54
64
sgd_math (
55
- reinterpret_cast <float *>(param.data_ptr ()),
56
- reinterpret_cast <float *>(grad.data_ptr ()),
57
- reinterpret_cast <float *>(out.data_ptr ()),
58
- weight_decay,
59
- lr,
60
- maximize,
61
- param.numel ()
62
- );
65
+ reinterpret_cast <float *>(param.data_ptr ()),
66
+ reinterpret_cast <float *>(grad.data_ptr ()),
67
+ reinterpret_cast <float *>(out.data_ptr ()),
68
+ weight_decay,
69
+ lr,
70
+ maximize,
71
+ param.numel ());
63
72
64
73
return out;
65
74
}
66
75
67
- void boxed_sgd_out_of_place (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
76
+ void boxed_sgd_out_of_place (
77
+ StableIValue* stack,
78
+ uint64_t num_args,
79
+ uint64_t num_outputs) {
68
80
Tensor res = sgd_out_of_place (
69
- to<Tensor>(stack[0 ]),
70
- to<Tensor>(stack[1 ]),
71
- float (to<double >(stack[2 ])),
72
- to<double >(stack[3 ]),
73
- to<bool >(stack[4 ]));
81
+ to<Tensor>(stack[0 ]),
82
+ to<Tensor>(stack[1 ]),
83
+ float (to<double >(stack[2 ])),
84
+ to<double >(stack[3 ]),
85
+ to<bool >(stack[4 ]));
74
86
75
87
stack[0 ] = from (res);
76
88
}
77
89
78
90
STABLE_TORCH_LIBRARY (libtorch_agnostic, m) {
79
- m.def (" sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor" );
91
+ m.def (
92
+ " sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor" );
80
93
}
81
94
82
95
STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CPU, m) {
@@ -87,7 +100,10 @@ Tensor identity(Tensor t) {
87
100
return t;
88
101
}
89
102
90
- void boxed_identity (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
103
+ void boxed_identity (
104
+ StableIValue* stack,
105
+ uint64_t num_args,
106
+ uint64_t num_outputs) {
91
107
Tensor res = identity (to<Tensor>(stack[0 ]));
92
108
stack[0 ] = from (res);
93
109
}
@@ -112,7 +128,10 @@ Tensor my_abs(Tensor t) {
112
128
return to<Tensor>(stack[0 ]);
113
129
}
114
130
115
- void boxed_my_abs (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
131
+ void boxed_my_abs (
132
+ StableIValue* stack,
133
+ uint64_t num_args,
134
+ uint64_t num_outputs) {
116
135
Tensor tensor_res = my_abs (to<Tensor>(stack[0 ]));
117
136
stack[0 ] = from (tensor_res);
118
137
}
@@ -134,18 +153,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
134
153
auto mf = aoti_torch_memory_format_contiguous_format ();
135
154
136
155
stack[0 ] = from (t);
137
- stack[1 ] = from (std::optional (t_dtype)); // dtype
138
- stack[2 ] = from (std::nullopt); // layout
139
- stack[3 ] = from (std::optional (device)); // device
140
- stack[4 ] = from (std::optional (false )); // pin_memory
141
- stack[5 ] = from (std::optional (mf)); // memory_format
156
+ stack[1 ] = from (std::optional (t_dtype)); // dtype
157
+ stack[2 ] = from (std::nullopt); // layout
158
+ stack[3 ] = from (std::optional (device)); // device
159
+ stack[4 ] = from (std::optional (false )); // pin_memory
160
+ stack[5 ] = from (std::optional (mf)); // memory_format
142
161
143
162
aoti_torch_call_dispatcher (" aten::ones_like" , " " , stack);
144
163
145
164
return to<Tensor>(stack[0 ]);
146
165
}
147
166
148
- void boxed_my_ones_like (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
167
+ void boxed_my_ones_like (
168
+ StableIValue* stack,
169
+ uint64_t num_args,
170
+ uint64_t num_outputs) {
149
171
Tensor res = my_ones_like (to<Tensor>(stack[0 ]), stack[1 ]);
150
172
stack[0 ] = from (res);
151
173
}
@@ -158,7 +180,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
158
180
m.impl (" my_ones_like" , &boxed_my_ones_like);
159
181
}
160
182
161
- std::tuple<Tensor, Tensor, bool > exp_neg_is_leaf (Tensor t1, Tensor t2, Tensor t3) {
183
+ std::tuple<Tensor, Tensor, bool > exp_neg_is_leaf (
184
+ Tensor t1,
185
+ Tensor t2,
186
+ Tensor t3) {
162
187
StableIValue stack_exp[1 ];
163
188
stack_exp[0 ] = from (t1);
164
189
aoti_torch_call_dispatcher (" aten::exp" , " " , stack_exp);
@@ -172,20 +197,25 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
172
197
aoti_torch_call_dispatcher (" aten::is_leaf" , " " , stack_is_leaf);
173
198
174
199
return std::make_tuple (
175
- to<Tensor>(stack_exp[0 ]),
176
- to<Tensor>(stack_neg[0 ]),
177
- to<bool >(stack_is_leaf[0 ]));
200
+ to<Tensor>(stack_exp[0 ]),
201
+ to<Tensor>(stack_neg[0 ]),
202
+ to<bool >(stack_is_leaf[0 ]));
178
203
}
179
204
180
- void boxed_exp_neg_is_leaf (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
181
- auto tuple = exp_neg_is_leaf (to<Tensor>(stack[0 ]), to<Tensor>(stack[1 ]), to<Tensor>(stack[2 ]));
205
+ void boxed_exp_neg_is_leaf (
206
+ StableIValue* stack,
207
+ uint64_t num_args,
208
+ uint64_t num_outputs) {
209
+ auto tuple = exp_neg_is_leaf (
210
+ to<Tensor>(stack[0 ]), to<Tensor>(stack[1 ]), to<Tensor>(stack[2 ]));
182
211
stack[0 ] = from (std::get<0 >(tuple));
183
212
stack[1 ] = from (std::get<1 >(tuple));
184
213
stack[2 ] = from (std::get<2 >(tuple));
185
214
}
186
215
187
216
STABLE_TORCH_LIBRARY_FRAGMENT (libtorch_agnostic, m) {
188
- m.def (" exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)" );
217
+ m.def (
218
+ " exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)" );
189
219
}
190
220
191
221
STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CompositeExplicitAutograd, m) {
@@ -200,7 +230,10 @@ Tensor neg_exp(Tensor t) {
200
230
return to<Tensor>(stack[0 ]);
201
231
}
202
232
203
- void boxed_neg_exp (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
233
+ void boxed_neg_exp (
234
+ StableIValue* stack,
235
+ uint64_t num_args,
236
+ uint64_t num_outputs) {
204
237
Tensor res = neg_exp (to<Tensor>(stack[0 ]));
205
238
stack[0 ] = from (res);
206
239
}
@@ -229,7 +262,10 @@ Tensor divide_neg_exp(Tensor t) {
229
262
return to<Tensor>(stack_div[0 ]);
230
263
}
231
264
232
- void boxed_divide_neg_exp (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
265
+ void boxed_divide_neg_exp (
266
+ StableIValue* stack,
267
+ uint64_t num_args,
268
+ uint64_t num_outputs) {
233
269
Tensor res = divide_neg_exp (to<Tensor>(stack[0 ]));
234
270
stack[0 ] = from (res);
235
271
}
@@ -246,7 +282,10 @@ bool is_contiguous(Tensor t) {
246
282
return t.is_contiguous ();
247
283
}
248
284
249
- void boxed_is_contiguous (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
285
+ void boxed_is_contiguous (
286
+ StableIValue* stack,
287
+ uint64_t num_args,
288
+ uint64_t num_outputs) {
250
289
bool res = is_contiguous (to<Tensor>(stack[0 ]));
251
290
stack[0 ] = from (res);
252
291
}
@@ -263,8 +302,12 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
263
302
return transpose (t, dim0, dim1);
264
303
}
265
304
266
- void boxed_my_transpose (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267
- auto res = my_transpose (to<Tensor>(stack[0 ]), to<int64_t >(stack[1 ]), to<int64_t >(stack[2 ]));
305
+ void boxed_my_transpose (
306
+ StableIValue* stack,
307
+ uint64_t num_args,
308
+ uint64_t num_outputs) {
309
+ auto res = my_transpose (
310
+ to<Tensor>(stack[0 ]), to<int64_t >(stack[1 ]), to<int64_t >(stack[2 ]));
268
311
269
312
stack[0 ] = from (res);
270
313
}
@@ -273,7 +316,10 @@ Tensor my_empty_like(Tensor t) {
273
316
return empty_like (t);
274
317
}
275
318
276
- void boxed_empty_like (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
319
+ void boxed_empty_like (
320
+ StableIValue* stack,
321
+ uint64_t num_args,
322
+ uint64_t num_outputs) {
277
323
auto res = my_empty_like (to<Tensor>(stack[0 ]));
278
324
stack[0 ] = from (res);
279
325
}
@@ -308,7 +354,10 @@ Tensor my_zero_(Tensor t) {
308
354
return zero_ (t);
309
355
}
310
356
311
- void boxed_my_zero_ (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
357
+ void boxed_my_zero_ (
358
+ StableIValue* stack,
359
+ uint64_t num_args,
360
+ uint64_t num_outputs) {
312
361
auto res = my_zero_ (to<Tensor>(stack[0 ]));
313
362
stack[0 ] = from (res);
314
363
}
@@ -320,3 +369,46 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
320
369
STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CPU, m) {
321
370
m.impl (" my_zero_" , &boxed_my_zero_);
322
371
}
372
+
373
+ // Test functions for torch::stable::accelerator APIs
374
+
375
+ int test_device_guard (int8_t device_index) {
376
+ using torch::stable::accelerator::DeviceGuard;
377
+
378
+ DeviceGuard guard (device_index);
379
+ int currentDevice;
380
+ cudaError_t err = cudaGetDevice (¤tDevice);
381
+ STD_TORCH_CHECK (err == cudaSuccess);
382
+ return currentDevice;
383
+ }
384
+
385
+ void boxed_test_device_guard (
386
+ StableIValue* stack,
387
+ uint64_t num_args,
388
+ uint64_t num_outputs) {
389
+ int res = test_device_guard (static_cast <int8_t >(to<int64_t >(stack[0 ])));
390
+ stack[0 ] = from (res);
391
+ }
392
+
393
+ int64_t test_stream (int8_t device_index) {
394
+ auto id = torch::stable::accelerator::getCurrentStream (device_index).id ();
395
+ return id;
396
+ }
397
+
398
+ void boxed_test_stream (
399
+ StableIValue* stack,
400
+ uint64_t num_args,
401
+ uint64_t num_outputs) {
402
+ int64_t res = test_stream (static_cast <int8_t >(to<int64_t >(stack[0 ])));
403
+ stack[0 ] = from (res);
404
+ }
405
+
406
+ STABLE_TORCH_LIBRARY_FRAGMENT (libtorch_agnostic, m) {
407
+ m.def (" test_device_guard(int device_index) -> int" );
408
+ m.def (" test_stream(int device_index) -> int" );
409
+ }
410
+
411
+ STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CompositeExplicitAutograd, m) {
412
+ m.impl (" test_device_guard" , &boxed_test_device_guard);
413
+ m.impl (" test_stream" , &boxed_test_stream);
414
+ }
0 commit comments