Skip to content

Commit 4e0545c

Browse files
test
1 parent 24a46e1 commit 4e0545c

File tree

3 files changed

+159
-56
lines changed

3 files changed

+159
-56
lines changed

include/tensor_bind.cc

Lines changed: 151 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,152 @@
1-
#include <tensor-array/core/tensor.hh>
2-
#include <pybind11/pybind11.h>
3-
#include <pybind11/numpy.h>
4-
#include <pybind11/operators.h>
5-
6-
using namespace tensor_array::value;
7-
8-
template <typename T>
9-
TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
10-
{
11-
pybind11::buffer_info info = py_buf.request();
12-
std::vector<unsigned int> shape_vec(info.ndim);
13-
std::transform
14-
(
15-
info.shape.cbegin(),
16-
info.shape.cend(),
17-
shape_vec.begin(),
18-
[](pybind11::size_t dim)
19-
{
20-
return static_cast<unsigned int>(dim);
21-
}
22-
);
23-
return TensorBase(typeid(T), shape_vec, info.ptr);
24-
}
25-
26-
std::string tensor_to_string(const Tensor t)
27-
{
28-
std::ostringstream osstream;
29-
osstream << t;
30-
return osstream.str();
31-
}
32-
33-
PYBIND11_MODULE(tensor2, m)
34-
{
35-
pybind11::class_<Tensor>(m, "Tensor")
36-
.def(pybind11::init())
37-
.def(pybind11::init(&convert_numpy_to_tensor_base<float>))
38-
.def(pybind11::self + pybind11::self)
39-
.def(pybind11::self - pybind11::self)
40-
.def(pybind11::self * pybind11::self)
41-
.def(pybind11::self / pybind11::self)
42-
.def(pybind11::self += pybind11::self)
43-
.def(pybind11::self -= pybind11::self)
44-
.def(pybind11::self *= pybind11::self)
45-
.def(pybind11::self /= pybind11::self)
46-
.def(pybind11::self == pybind11::self)
47-
.def(pybind11::self != pybind11::self)
48-
.def(pybind11::self >= pybind11::self)
49-
.def(pybind11::self <= pybind11::self)
50-
.def(pybind11::self > pybind11::self)
51-
.def(pybind11::self < pybind11::self)
52-
.def(+pybind11::self)
53-
.def(-pybind11::self)
54-
.def("__matmul__", &matmul)
55-
.def("__repr__", &tensor_to_string);
1+
#include <tensor-array/core/tensor.hh>
2+
#include <pybind11/pybind11.h>
3+
#include <pybind11/numpy.h>
4+
#include <pybind11/operators.h>
5+
6+
using namespace tensor_array::value;
7+
8+
template <typename T>
9+
TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
10+
{
11+
pybind11::buffer_info info = py_buf.request();
12+
std::vector<unsigned int> shape_vec(info.ndim);
13+
std::transform
14+
(
15+
info.shape.cbegin(),
16+
info.shape.cend(),
17+
shape_vec.begin(),
18+
[](pybind11::size_t dim)
19+
{
20+
return static_cast<unsigned int>(dim);
21+
}
22+
);
23+
return TensorBase(typeid(T), shape_vec, info.ptr);
24+
}
25+
26+
pybind11::array convert_tensor_to_numpy(const Tensor& tensor)
27+
{
28+
const TensorBase& base_tensor = tensor.get_buffer();
29+
std::vector<pybind11::size_t> shape_vec(base_tensor.shape().size());
30+
std::transform
31+
(
32+
base_tensor.shape().begin(),
33+
base_tensor.shape().end(),
34+
shape_vec.begin(),
35+
[](unsigned int dim)
36+
{
37+
return static_cast<pybind11::size_t>(dim);
38+
}
39+
);
40+
pybind11::array arr = pybind11::array();
41+
return arr;
42+
}
43+
44+
Tensor python_tuple_slice(const Tensor& t, pybind11::tuple tuple_slice)
45+
{
46+
std::vector<Tensor::Slice> t_slices;
47+
for (size_t i = 0; i < tuple_slice.size(); i++)
48+
{
49+
ssize_t start, stop, step;
50+
ssize_t length;
51+
pybind11::slice py_slice = tuple_slice[i].cast<pybind11::slice>();
52+
if (!py_slice.compute(t.get_buffer().shape().begin()[i], &start, &stop, &step, &length))
53+
throw std::runtime_error("Invalid slice");
54+
t_slices.insert
55+
(
56+
t_slices.begin() + i,
57+
Tensor::Slice
58+
{
59+
static_cast<int>(start),
60+
static_cast<int>(stop),
61+
static_cast<int>(step)
62+
}
63+
);
64+
}
65+
66+
#ifdef __GNUC__
67+
struct
68+
{
69+
const Tensor::Slice* it;
70+
std::size_t sz;
71+
} test;
72+
test.it = t_slices.data();
73+
test.sz = t_slices.size();
74+
std::initializer_list<Tensor::Slice>& t_slice_list = reinterpret_cast<std::initializer_list<Tensor::Slice>&>(test);
75+
#endif
76+
return t[t_slice_list];
77+
}
78+
79+
Tensor python_slice(const Tensor& t, pybind11::slice py_slice)
80+
{
81+
std::vector<Tensor::Slice> t_slices;
82+
ssize_t start, stop, step;
83+
ssize_t length;
84+
if (!py_slice.compute(t.get_buffer().shape().begin()[0], &start, &stop, &step, &length))
85+
throw std::runtime_error("Invalid slice");
86+
return t
87+
[
88+
{
89+
Tensor::Slice
90+
{
91+
static_cast<int>(start),
92+
static_cast<int>(stop),
93+
static_cast<int>(step)
94+
}
95+
}
96+
];
97+
}
98+
99+
Tensor python_index(const Tensor& t, unsigned int i)
100+
{
101+
return t[i];
102+
}
103+
104+
std::size_t python_len(const Tensor& t)
105+
{
106+
std::initializer_list<unsigned int> shape_list = t.get_buffer().shape();
107+
return shape_list.size() != 0 ? shape_list.begin()[0]: 1U;
108+
}
109+
110+
std::string tensor_to_string(const Tensor& t)
111+
{
112+
std::ostringstream osstream;
113+
osstream << t;
114+
return osstream.str();
115+
}
116+
117+
PYBIND11_MODULE(tensor2, m)
118+
{
119+
pybind11::class_<Tensor>(m, "Tensor")
120+
.def(pybind11::init())
121+
.def(pybind11::init(&convert_numpy_to_tensor_base<float>))
122+
.def(pybind11::self + pybind11::self)
123+
.def(pybind11::self - pybind11::self)
124+
.def(pybind11::self * pybind11::self)
125+
.def(pybind11::self / pybind11::self)
126+
.def(pybind11::self += pybind11::self)
127+
.def(pybind11::self -= pybind11::self)
128+
.def(pybind11::self *= pybind11::self)
129+
.def(pybind11::self /= pybind11::self)
130+
.def(pybind11::self == pybind11::self)
131+
.def(pybind11::self != pybind11::self)
132+
.def(pybind11::self >= pybind11::self)
133+
.def(pybind11::self <= pybind11::self)
134+
.def(pybind11::self > pybind11::self)
135+
.def(pybind11::self < pybind11::self)
136+
.def(+pybind11::self)
137+
.def(-pybind11::self)
138+
.def(hash(pybind11::self))
139+
.def("transpose", &Tensor::transpose)
140+
.def("calc_grad", &Tensor::calc_grad)
141+
.def("add", &add)
142+
.def("multiply", &multiply)
143+
.def("divide", &divide)
144+
.def("matmul", &matmul)
145+
.def("condition", &condition)
146+
.def("__getitem__", &python_index)
147+
.def("__getitem__", &python_slice)
148+
.def("__getitem__", &python_tuple_slice)
149+
.def("__len__", &python_len)
150+
.def("__matmul__", &matmul)
151+
.def("__repr__", &tensor_to_string);
56152
}

lib/tensor2.so

358 KB
Binary file not shown.

main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33

44
print("hello")
55

6-
t1 = t.Tensor([[1], [2]])
6+
t1 = t.Tensor([[1, 2, 3], [4, 5, 6]])
7+
print("tensor len", t1.__len__())
8+
t1 = t1[::, ::2]
9+
print(t1)
10+
t1 = t1.transpose(0, 1)
11+
print("tensor len", t1.__len__())
12+
print(t1)
713
t2 = t1 + t1
814
print(t2)
15+
print(t2 > t1)

0 commit comments

Comments
 (0)