Skip to content

Commit 77cf4a1

Browse files
test linear
1 parent e65644a commit 77cf4a1

File tree

7 files changed

+43
-13
lines changed

7 files changed

+43
-13
lines changed

include/tensor_bind.cc

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,34 @@ TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
2323
return static_cast<unsigned int>(dim);
2424
}
2525
);
26-
warp_type(warp_type(typeid(T)));
27-
return TensorBase(typeid(T), shape_vec, info.ptr);
26+
return TensorBase(warp_type(warp_type(typeid(T))), shape_vec, info.ptr);
2827
}
2928

3029
pybind11::dtype get_py_type(const std::type_info& info)
3130
{
31+
if (info == typeid(std::int8_t))
32+
return pybind11::dtype::of<std::int8_t>();
33+
if (info == typeid(std::int16_t))
34+
return pybind11::dtype::of<std::int16_t>();
35+
if (info == typeid(std::int32_t))
36+
return pybind11::dtype::of<std::int32_t>();
37+
if (info == typeid(std::int64_t))
38+
return pybind11::dtype::of<std::int64_t>();
39+
if (info == typeid(std::uint8_t))
40+
return pybind11::dtype::of<std::uint8_t>();
41+
if (info == typeid(std::uint16_t))
42+
return pybind11::dtype::of<std::uint16_t>();
43+
if (info == typeid(std::uint32_t))
44+
return pybind11::dtype::of<std::uint32_t>();
45+
if (info == typeid(std::uint64_t))
46+
return pybind11::dtype::of<std::uint64_t>();
3247
if (info == typeid(bool))
3348
return pybind11::dtype::of<bool>();
3449
if (info == typeid(float))
3550
return pybind11::dtype::of<float>();
36-
throw std::exception();
51+
if (info == typeid(double))
52+
return pybind11::dtype::of<double>();
53+
throw std::runtime_error("no dtype");
3754
}
3855

3956
pybind11::array convert_tensor_to_numpy(const Tensor& self)
@@ -125,6 +142,11 @@ pybind11::tuple tensor_shape(const Tensor& self)
125142
return pybind11::cast(std::vector(self.get_buffer().shape()));
126143
}
127144

145+
DataType tensor_type(const Tensor& self)
146+
{
147+
return warp_type(self.get_buffer().type());
148+
}
149+
128150
Tensor tensor_copying(const Tensor& self)
129151
{
130152
return self;
@@ -205,6 +227,7 @@ PYBIND11_MODULE(tensor2, m)
205227
.def("condition", &condition)
206228
.def("numpy", &convert_tensor_to_numpy)
207229
.def("shape", &tensor_shape)
230+
.def("dtype", &tensor_type)
208231
.def("__getitem__", &python_index)
209232
.def("__getitem__", &python_slice)
210233
.def("__getitem__", &python_tuple_slice)

src/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222

2323
layer1 = Linear(3)
2424

25-
print(layer1(Tensor([[1, 2, 3], [4, 5, 6]])))
25+
result_layer_1 = layer1(Tensor([[1, 2, 3], [4, 5, 6]]))
26+
print(result_layer_1)

src/tensor_array/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from tensor_array.core.tensor2 import Tensor
2-
from tensor_array.core.tensor2 import zeros
2+
from tensor_array.core.tensor2 import zeros
3+
from tensor_array.core.tensor2 import DataType
935 KB
Binary file not shown.

src/tensor_array/core/tensor2.so

22.4 KB
Binary file not shown.

src/tensor_array/layers/layer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,15 @@ def __init__(self) -> None:
2121

2222
def __call__(self, *args: Any, **kwds: Any) -> Any:
2323
if not self.__dict__['is_running']:
24-
self.init_value(*args, **kwds)
24+
list_arg = ((t.shape(), t.dtype()) for t in args if isinstance(t, Tensor))
25+
dict_kwargs = {
26+
key: (val.shape(), val.dtype())
27+
for key, val in kwds
28+
if isinstance(val, Tensor)
29+
}
30+
self.init_value(*list_arg, **dict_kwargs)
2531
super().__setattr__('is_running', True)
26-
self.calculate(*args, **kwds)
32+
return self.calculate(*args, **kwds)
2733

2834
def init_value(self, *args: Any, **kwds: Any) -> Any:
2935
pass

src/tensor_array/layers/util/linear.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,20 @@
22
from .. import Parameter
33
from tensor_array.core import Tensor
44
from tensor_array.core import zeros
5+
from tensor_array.core import DataType
56
from typing import Any
67

78

89
class Linear(Layer):
910
def __init__(self, bias) -> None:
10-
super(Linear, self).__init__()
11+
super().__init__()
1112
self.bias_shape = bias
12-
self.b = Parameter(zeros(shape = (bias,)))
13+
self.b = Parameter(zeros(shape = (bias,), dtype = DataType.FLOAT))
1314

1415
def init_value(self, t):
15-
self.w = Parameter(zeros(shape = (t.shape()[-1], self.bias_shape)))
16+
shape, dtype = t
17+
self.w = Parameter(zeros(shape = (shape[-1], self.bias_shape), dtype = dtype))
1618

1719
def calculate(self, t):
18-
print("t", t)
19-
print("w", self.w)
20-
print("b", self.b)
2120
return t @ self.w + self.b
2221

0 commit comments

Comments
 (0)