Skip to content

Commit b12c477

Browse files
test
1 parent 4e0545c commit b12c477

File tree

7 files changed

+72
-14
lines changed

7 files changed

+72
-14
lines changed

include/tensor_bind.cc

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#include <tensor-array/core/tensor.hh>
2+
#include <tensor-array/core/data_type_wrapper.hh>
23
#include <pybind11/pybind11.h>
34
#include <pybind11/numpy.h>
45
#include <pybind11/operators.h>
56

67
using namespace tensor_array::value;
8+
using namespace tensor_array::datatype;
79

810
template <typename T>
911
TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
@@ -20,12 +22,22 @@ TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
2022
return static_cast<unsigned int>(dim);
2123
}
2224
);
25+
warp_type(warp_type(typeid(T)));
2326
return TensorBase(typeid(T), shape_vec, info.ptr);
2427
}
2528

29+
pybind11::dtype get_py_type(const std::type_info& info)
30+
{
31+
if (info == typeid(bool))
32+
return pybind11::dtype::of<bool>();
33+
if (info == typeid(float))
34+
return pybind11::dtype::of<float>();
35+
throw std::exception();
36+
}
37+
2638
pybind11::array convert_tensor_to_numpy(const Tensor& tensor)
2739
{
28-
const TensorBase& base_tensor = tensor.get_buffer();
40+
const TensorBase& base_tensor = tensor.get_buffer().change_device({tensor_array::devices::CPU, 0});
2941
std::vector<pybind11::size_t> shape_vec(base_tensor.shape().size());
3042
std::transform
3143
(
@@ -37,8 +49,9 @@ pybind11::array convert_tensor_to_numpy(const Tensor& tensor)
3749
return static_cast<pybind11::size_t>(dim);
3850
}
3951
);
40-
pybind11::array arr = pybind11::array();
41-
return arr;
52+
auto ty0 = pybind11::detail::get_type_info(base_tensor.type());
53+
pybind11::dtype ty1 = get_py_type(base_tensor.type());
54+
return pybind11::array(ty1, shape_vec, base_tensor.data());
4255
}
4356

4457
Tensor python_tuple_slice(const Tensor& t, pybind11::tuple tuple_slice)
@@ -107,15 +120,34 @@ std::size_t python_len(const Tensor& t)
107120
return shape_list.size() != 0 ? shape_list.begin()[0]: 1U;
108121
}
109122

110-
std::string tensor_to_string(const Tensor& t)
123+
pybind11::str tensor_to_string(const Tensor& t)
111124
{
112-
std::ostringstream osstream;
113-
osstream << t;
114-
return osstream.str();
125+
return pybind11::repr(convert_tensor_to_numpy(t));
126+
}
127+
128+
Tensor tensor_cast_1(const Tensor& t, DataType dtype)
129+
{
130+
return t.tensor_cast(warp_type(dtype));
115131
}
116132

117133
PYBIND11_MODULE(tensor2, m)
118134
{
135+
pybind11::enum_<DataType>(m, "DataType")
136+
.value("BOOL", BOOL_DTYPE)
137+
.value("S_INT_8", S_INT_8)
138+
.value("S_INT_16", S_INT_16)
139+
.value("S_INT_32", S_INT_32)
140+
.value("S_INT_64", S_INT_64)
141+
.value("FLOAT", FLOAT_DTYPE)
142+
.value("DOUBLE", DOUBLE_DTYPE)
143+
.value("HALF", HALF_DTYPE)
144+
.value("BFLOAT16", BF16_DTYPE)
145+
.value("U_INT_8", U_INT_8)
146+
.value("U_INT_16", U_INT_16)
147+
.value("U_INT_32", U_INT_32)
148+
.value("U_INT_64", U_INT_64)
149+
.export_values();
150+
119151
pybind11::class_<Tensor>(m, "Tensor")
120152
.def(pybind11::init())
121153
.def(pybind11::init(&convert_numpy_to_tensor_base<float>))
@@ -138,11 +170,22 @@ PYBIND11_MODULE(tensor2, m)
138170
.def(hash(pybind11::self))
139171
.def("transpose", &Tensor::transpose)
140172
.def("calc_grad", &Tensor::calc_grad)
173+
.def("sin", &Tensor::sin)
174+
.def("sin", &Tensor::sin)
175+
.def("cos", &Tensor::cos)
176+
.def("tan", &Tensor::tan)
177+
.def("sinh", &Tensor::sinh)
178+
.def("cosh", &Tensor::cosh)
179+
.def("tanh", &Tensor::tanh)
180+
.def("log", &Tensor::log)
181+
.def("clone", &Tensor::clone)
182+
.def("cast", &tensor_cast_1)
141183
.def("add", &add)
142184
.def("multiply", &multiply)
143185
.def("divide", &divide)
144186
.def("matmul", &matmul)
145187
.def("condition", &condition)
188+
.def("numpy", &convert_tensor_to_numpy)
146189
.def("__getitem__", &python_index)
147190
.def("__getitem__", &python_slice)
148191
.def("__getitem__", &python_tuple_slice)

lib/tensor2.so

-3.32 MB
Binary file not shown.
12.8 MB
Binary file not shown.

src/TensorArray/core/tensor2.so

4 MB
Binary file not shown.

src/TensorArray/layers/Linear.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
from typing import Any
3+
4+
5+
class Linear:
6+
def __init__(self) -> None:
7+
self.w
8+
self.b
9+
pass
10+
11+
def __call__(self, input) -> Any:
12+
return input @ self.w + self.b

src/TensorArray/layers/__init__.py

Whitespace-only changes.

main.py renamed to src/main.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1-
from lib import tensor2 as t
2-
import numpy as np
1+
from TensorArray.core import tensor2 as t
32

43
print("hello")
54

65
t1 = t.Tensor([[1, 2, 3], [4, 5, 6]])
6+
t2 = t1.clone()
77
print("tensor len", t1.__len__())
8-
t1 = t1[::, ::2]
8+
print(t1)
9+
t1 = t1[::]
910
print(t1)
1011
t1 = t1.transpose(0, 1)
1112
print("tensor len", t1.__len__())
12-
print(t1)
13-
t2 = t1 + t1
14-
print(t2)
15-
print(t2 > t1)
13+
t4 = t1 @ t2
14+
t3 = t1 * t1
15+
print(t4)
16+
print(t3)
17+
18+
print(t1 != t3)

0 commit comments

Comments
 (0)