1
1
#include < tensor-array/core/tensor.hh>
2
+ #include < tensor-array/core/data_type_wrapper.hh>
2
3
#include < pybind11/pybind11.h>
3
4
#include < pybind11/numpy.h>
4
5
#include < pybind11/operators.h>
5
6
6
7
using namespace tensor_array ::value;
8
+ using namespace tensor_array ::datatype;
7
9
8
10
template <typename T>
9
11
TensorBase convert_numpy_to_tensor_base (pybind11::array_t <T> py_buf)
@@ -20,12 +22,22 @@ TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
20
22
return static_cast <unsigned int >(dim);
21
23
}
22
24
);
25
+ warp_type (warp_type (typeid (T)));
23
26
return TensorBase (typeid (T), shape_vec, info.ptr );
24
27
}
25
28
29
+ pybind11::dtype get_py_type (const std::type_info& info)
30
+ {
31
+ if (info == typeid (bool ))
32
+ return pybind11::dtype::of<bool >();
33
+ if (info == typeid (float ))
34
+ return pybind11::dtype::of<float >();
35
+ throw std::exception ();
36
+ }
37
+
26
38
pybind11::array convert_tensor_to_numpy (const Tensor& tensor)
27
39
{
28
- const TensorBase& base_tensor = tensor.get_buffer ();
40
+ const TensorBase& base_tensor = tensor.get_buffer (). change_device ({tensor_array::devices::CPU, 0 }) ;
29
41
std::vector<pybind11::size_t > shape_vec (base_tensor.shape ().size ());
30
42
std::transform
31
43
(
@@ -37,8 +49,9 @@ pybind11::array convert_tensor_to_numpy(const Tensor& tensor)
37
49
return static_cast <pybind11::size_t >(dim);
38
50
}
39
51
);
40
- pybind11::array arr = pybind11::array ();
41
- return arr;
52
+ auto ty0 = pybind11::detail::get_type_info (base_tensor.type ());
53
+ pybind11::dtype ty1 = get_py_type (base_tensor.type ());
54
+ return pybind11::array (ty1, shape_vec, base_tensor.data ());
42
55
}
43
56
44
57
Tensor python_tuple_slice (const Tensor& t, pybind11::tuple tuple_slice)
@@ -107,15 +120,34 @@ std::size_t python_len(const Tensor& t)
107
120
return shape_list.size () != 0 ? shape_list.begin ()[0 ]: 1U ;
108
121
}
109
122
110
- std::string tensor_to_string (const Tensor& t)
123
+ pybind11::str tensor_to_string (const Tensor& t)
111
124
{
112
- std::ostringstream osstream;
113
- osstream << t;
114
- return osstream.str ();
125
+ return pybind11::repr (convert_tensor_to_numpy (t));
126
+ }
127
+
128
+ Tensor tensor_cast_1 (const Tensor& t, DataType dtype)
129
+ {
130
+ return t.tensor_cast (warp_type (dtype));
115
131
}
116
132
117
133
PYBIND11_MODULE (tensor2, m)
118
134
{
135
+ pybind11::enum_<DataType>(m, " DataType" )
136
+ .value (" BOOL" , BOOL_DTYPE)
137
+ .value (" S_INT_8" , S_INT_8)
138
+ .value (" S_INT_16" , S_INT_16)
139
+ .value (" S_INT_32" , S_INT_32)
140
+ .value (" S_INT_64" , S_INT_64)
141
+ .value (" FLOAT" , FLOAT_DTYPE)
142
+ .value (" DOUBLE" , DOUBLE_DTYPE)
143
+ .value (" HALF" , HALF_DTYPE)
144
+ .value (" BFLOAT16" , BF16_DTYPE)
145
+ .value (" U_INT_8" , U_INT_8)
146
+ .value (" U_INT_16" , U_INT_16)
147
+ .value (" U_INT_32" , U_INT_32)
148
+ .value (" U_INT_64" , U_INT_64)
149
+ .export_values ();
150
+
119
151
pybind11::class_<Tensor>(m, " Tensor" )
120
152
.def (pybind11::init ())
121
153
.def (pybind11::init (&convert_numpy_to_tensor_base<float >))
@@ -138,11 +170,22 @@ PYBIND11_MODULE(tensor2, m)
138
170
.def (hash (pybind11::self))
139
171
.def (" transpose" , &Tensor::transpose)
140
172
.def (" calc_grad" , &Tensor::calc_grad)
173
+ .def (" sin" , &Tensor::sin)
174
+ .def (" sin" , &Tensor::sin)
175
+ .def (" cos" , &Tensor::cos)
176
+ .def (" tan" , &Tensor::tan)
177
+ .def (" sinh" , &Tensor::sinh)
178
+ .def (" cosh" , &Tensor::cosh)
179
+ .def (" tanh" , &Tensor::tanh)
180
+ .def (" log" , &Tensor::log)
181
+ .def (" clone" , &Tensor::clone)
182
+ .def (" cast" , &tensor_cast_1)
141
183
.def (" add" , &add)
142
184
.def (" multiply" , &multiply)
143
185
.def (" divide" , ÷)
144
186
.def (" matmul" , &matmul)
145
187
.def (" condition" , &condition)
188
+ .def (" numpy" , &convert_tensor_to_numpy)
146
189
.def (" __getitem__" , &python_index)
147
190
.def (" __getitem__" , &python_slice)
148
191
.def (" __getitem__" , &python_tuple_slice)
0 commit comments