1
- #include < tensor-array/core/tensor.hh>
2
- #include < pybind11/pybind11.h>
3
- #include < pybind11/numpy.h>
4
- #include < pybind11/operators.h>
5
-
6
- using namespace tensor_array ::value;
7
-
8
- template <typename T>
9
- TensorBase convert_numpy_to_tensor_base (pybind11::array_t <T> py_buf)
10
- {
11
- pybind11::buffer_info info = py_buf.request ();
12
- std::vector<unsigned int > shape_vec (info.ndim );
13
- std::transform
14
- (
15
- info.shape .cbegin (),
16
- info.shape .cend (),
17
- shape_vec.begin (),
18
- [](pybind11::size_t dim)
19
- {
20
- return static_cast <unsigned int >(dim);
21
- }
22
- );
23
- return TensorBase (typeid (T), shape_vec, info.ptr );
24
- }
25
-
26
- std::string tensor_to_string (const Tensor t)
27
- {
28
- std::ostringstream osstream;
29
- osstream << t;
30
- return osstream.str ();
31
- }
32
-
33
- PYBIND11_MODULE (tensor2, m)
34
- {
35
- pybind11::class_<Tensor>(m, " Tensor" )
36
- .def (pybind11::init ())
37
- .def (pybind11::init (&convert_numpy_to_tensor_base<float >))
38
- .def (pybind11::self + pybind11::self)
39
- .def (pybind11::self - pybind11::self)
40
- .def (pybind11::self * pybind11::self)
41
- .def (pybind11::self / pybind11::self)
42
- .def (pybind11::self += pybind11::self)
43
- .def (pybind11::self -= pybind11::self)
44
- .def (pybind11::self *= pybind11::self)
45
- .def (pybind11::self /= pybind11::self)
46
- .def (pybind11::self == pybind11::self)
47
- .def (pybind11::self != pybind11::self)
48
- .def (pybind11::self >= pybind11::self)
49
- .def (pybind11::self <= pybind11::self)
50
- .def (pybind11::self > pybind11::self)
51
- .def (pybind11::self < pybind11::self)
52
- .def (+pybind11::self)
53
- .def (-pybind11::self)
54
- .def (" __matmul__" , &matmul)
55
- .def (" __repr__" , &tensor_to_string);
1
+ #include < tensor-array/core/tensor.hh>
2
+ #include < pybind11/pybind11.h>
3
+ #include < pybind11/numpy.h>
4
+ #include < pybind11/operators.h>
5
+
6
+ using namespace tensor_array ::value;
7
+
8
+ template <typename T>
9
+ TensorBase convert_numpy_to_tensor_base (pybind11::array_t <T> py_buf)
10
+ {
11
+ pybind11::buffer_info info = py_buf.request ();
12
+ std::vector<unsigned int > shape_vec (info.ndim );
13
+ std::transform
14
+ (
15
+ info.shape .cbegin (),
16
+ info.shape .cend (),
17
+ shape_vec.begin (),
18
+ [](pybind11::size_t dim)
19
+ {
20
+ return static_cast <unsigned int >(dim);
21
+ }
22
+ );
23
+ return TensorBase (typeid (T), shape_vec, info.ptr );
24
+ }
25
+
26
+ pybind11::array convert_tensor_to_numpy (const Tensor& tensor)
27
+ {
28
+ const TensorBase& base_tensor = tensor.get_buffer ();
29
+ std::vector<pybind11::size_t > shape_vec (base_tensor.shape ().size ());
30
+ std::transform
31
+ (
32
+ base_tensor.shape ().begin (),
33
+ base_tensor.shape ().end (),
34
+ shape_vec.begin (),
35
+ [](unsigned int dim)
36
+ {
37
+ return static_cast <pybind11::size_t >(dim);
38
+ }
39
+ );
40
+ pybind11::array arr = pybind11::array ();
41
+ return arr;
42
+ }
43
+
44
+ Tensor python_tuple_slice (const Tensor& t, pybind11::tuple tuple_slice)
45
+ {
46
+ std::vector<Tensor::Slice> t_slices;
47
+ for (size_t i = 0 ; i < tuple_slice.size (); i++)
48
+ {
49
+ ssize_t start, stop, step;
50
+ ssize_t length;
51
+ pybind11::slice py_slice = tuple_slice[i].cast <pybind11::slice>();
52
+ if (!py_slice.compute (t.get_buffer ().shape ().begin ()[i], &start, &stop, &step, &length))
53
+ throw std::runtime_error (" Invalid slice" );
54
+ t_slices.insert
55
+ (
56
+ t_slices.begin () + i,
57
+ Tensor::Slice
58
+ {
59
+ static_cast <int >(start),
60
+ static_cast <int >(stop),
61
+ static_cast <int >(step)
62
+ }
63
+ );
64
+ }
65
+
66
+ #ifdef __GNUC__
67
+ struct
68
+ {
69
+ const Tensor::Slice* it;
70
+ std::size_t sz;
71
+ } test;
72
+ test.it = t_slices.data ();
73
+ test.sz = t_slices.size ();
74
+ std::initializer_list<Tensor::Slice>& t_slice_list = reinterpret_cast <std::initializer_list<Tensor::Slice>&>(test);
75
+ #endif
76
+ return t[t_slice_list];
77
+ }
78
+
79
+ Tensor python_slice (const Tensor& t, pybind11::slice py_slice)
80
+ {
81
+ std::vector<Tensor::Slice> t_slices;
82
+ ssize_t start, stop, step;
83
+ ssize_t length;
84
+ if (!py_slice.compute (t.get_buffer ().shape ().begin ()[0 ], &start, &stop, &step, &length))
85
+ throw std::runtime_error (" Invalid slice" );
86
+ return t
87
+ [
88
+ {
89
+ Tensor::Slice
90
+ {
91
+ static_cast <int >(start),
92
+ static_cast <int >(stop),
93
+ static_cast <int >(step)
94
+ }
95
+ }
96
+ ];
97
+ }
98
+
99
+ Tensor python_index (const Tensor& t, unsigned int i)
100
+ {
101
+ return t[i];
102
+ }
103
+
104
+ std::size_t python_len (const Tensor& t)
105
+ {
106
+ std::initializer_list<unsigned int > shape_list = t.get_buffer ().shape ();
107
+ return shape_list.size () != 0 ? shape_list.begin ()[0 ]: 1U ;
108
+ }
109
+
110
+ std::string tensor_to_string (const Tensor& t)
111
+ {
112
+ std::ostringstream osstream;
113
+ osstream << t;
114
+ return osstream.str ();
115
+ }
116
+
117
+ PYBIND11_MODULE (tensor2, m)
118
+ {
119
+ pybind11::class_<Tensor>(m, " Tensor" )
120
+ .def (pybind11::init ())
121
+ .def (pybind11::init (&convert_numpy_to_tensor_base<float >))
122
+ .def (pybind11::self + pybind11::self)
123
+ .def (pybind11::self - pybind11::self)
124
+ .def (pybind11::self * pybind11::self)
125
+ .def (pybind11::self / pybind11::self)
126
+ .def (pybind11::self += pybind11::self)
127
+ .def (pybind11::self -= pybind11::self)
128
+ .def (pybind11::self *= pybind11::self)
129
+ .def (pybind11::self /= pybind11::self)
130
+ .def (pybind11::self == pybind11::self)
131
+ .def (pybind11::self != pybind11::self)
132
+ .def (pybind11::self >= pybind11::self)
133
+ .def (pybind11::self <= pybind11::self)
134
+ .def (pybind11::self > pybind11::self)
135
+ .def (pybind11::self < pybind11::self)
136
+ .def (+pybind11::self)
137
+ .def (-pybind11::self)
138
+ .def (hash (pybind11::self))
139
+ .def (" transpose" , &Tensor::transpose)
140
+ .def (" calc_grad" , &Tensor::calc_grad)
141
+ .def (" add" , &add)
142
+ .def (" multiply" , &multiply)
143
+ .def (" divide" , ÷)
144
+ .def (" matmul" , &matmul)
145
+ .def (" condition" , &condition)
146
+ .def (" __getitem__" , &python_index)
147
+ .def (" __getitem__" , &python_slice)
148
+ .def (" __getitem__" , &python_tuple_slice)
149
+ .def (" __len__" , &python_len)
150
+ .def (" __matmul__" , &matmul)
151
+ .def (" __repr__" , &tensor_to_string);
56
152
}
0 commit comments