Skip to content

Commit 428ed03

Browse files
committed
Adding tests & minor bugfix
1 parent e21ecb0 commit 428ed03

File tree

3 files changed

+77
-10
lines changed

3 files changed

+77
-10
lines changed

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace pybind11
3737
{
3838
static auto get(handle src)
3939
{
40-
return array_t<T, array::f_style>::ensure(src);
40+
return array_t<T, array::f_style | array::forcecast>::ensure(src);
4141
}
4242
};
4343

@@ -51,8 +51,7 @@ namespace pybind11
5151
{
5252
static auto get(handle src)
5353
{
54-
auto buf = xtensor_get_buffer<T, L>::get(src);
55-
return buf;
54+
return xtensor_get_buffer<T, L>::get(src);
5655
}
5756
};
5857

@@ -61,11 +60,7 @@ namespace pybind11
6160
{
6261
static auto get(handle src)
6362
{
64-
auto buf = xtensor_get_buffer<T, L>::get(src);
65-
if (buf.ndim() != N) {
66-
return false;
67-
}
68-
return buf;
63+
return xtensor_get_buffer<T, L>::get(src);
6964
}
7065
};
7166

@@ -98,6 +93,27 @@ namespace pybind11
9893
};
9994

10095

96+
template <class T>
97+
struct xtensor_verify
98+
{
99+
template <class B>
100+
static bool get(const B& buf)
101+
{
102+
return true;
103+
}
104+
};
105+
106+
template <class T, std::size_t N, xt::layout_type L>
107+
struct xtensor_verify<xt::xtensor<T, N, L>>
108+
{
109+
template <class B>
110+
static bool get(const B& buf)
111+
{
112+
return buf.ndim() == N;
113+
}
114+
};
115+
116+
101117
// Casts a strided expression type to numpy array.If given a base,
102118
// the numpy array references the src data, otherwise it'll make a copy.
103119
// The writeable attributes lets you specify writeable flag for the array.
@@ -192,11 +208,14 @@ namespace pybind11
192208
if (!buf) {
193209
return false;
194210
}
211+
if (!xtensor_verify<Type>::get(buf)) {
212+
return false;
213+
}
195214

196215
std::vector<size_t> shape(buf.ndim());
197216
std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin());
198-
value = Type(shape);
199-
std::copy(buf.data(), buf.data() + buf.size(), value.begin());
217+
value = Type::from_shape(shape);
218+
std::copy(buf.data(), buf.data() + buf.size(), value.data());
200219

201220
return true;
202221
}

test_python/main.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,33 @@ xt::pyarray<double> example2(xt::pyarray<double>& m)
3333
return m + 2;
3434
}
3535

36+
xt::xarray<int> example3_xarray(const xt::xarray<int>& m)
37+
{
38+
return xt::transpose(m) + 2;
39+
}
40+
41+
xt::xarray<int, xt::layout_type::column_major> example3_xarray_colmajor(
42+
const xt::xarray<int, xt::layout_type::column_major>& m)
43+
{
44+
return xt::transpose(m) + 2;
45+
}
46+
47+
xt::xtensor<int, 3> example3_xtensor3(const xt::xtensor<int, 3>& m)
48+
{
49+
return xt::transpose(m) + 2;
50+
}
51+
52+
xt::xtensor<int, 2> example3_xtensor2(const xt::xtensor<int, 2>& m)
53+
{
54+
return xt::transpose(m) + 2;
55+
}
56+
57+
xt::xtensor<int, 2, xt::layout_type::column_major> example3_xtensor2_colmajor(
58+
const xt::xtensor<int, 2, xt::layout_type::column_major>& m)
59+
{
60+
return xt::transpose(m) + 2;
61+
}
62+
3663
// Readme Examples
3764

3865
double readme_example1(xt::pyarray<double>& m)
@@ -249,6 +276,11 @@ PYBIND11_MODULE(xtensor_python_test, m)
249276

250277
m.def("example1", example1);
251278
m.def("example2", example2);
279+
m.def("example3_xarray", example3_xarray);
280+
m.def("example3_xarray_colmajor", example3_xarray_colmajor);
281+
m.def("example3_xtensor3", example3_xtensor3);
282+
m.def("example3_xtensor2", example3_xtensor2);
283+
m.def("example3_xtensor2_colmajor", example3_xtensor2_colmajor);
252284

253285
m.def("complex_overload", no_complex_overload);
254286
m.def("complex_overload", complex_overload);

test_python/test_pyarray.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ def test_example2(self):
3636
y = xt.example2(x)
3737
np.testing.assert_allclose(y, res, 1e-12)
3838

39+
def test_example3(self):
40+
x = np.arange(2 * 3).reshape(2, 3)
41+
xc = np.asfortranarray(x)
42+
y = np.arange(2 * 3 * 4).reshape(2, 3, 4)
43+
v = y[1:, 1:, 0]
44+
z = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
45+
np.testing.assert_array_equal(xt.example3_xarray(x), x.T + 2)
46+
np.testing.assert_array_equal(xt.example3_xarray_colmajor(xc), xc.T + 2)
47+
np.testing.assert_array_equal(xt.example3_xtensor3(y), y.T + 2)
48+
np.testing.assert_array_equal(xt.example3_xtensor2(x), x.T + 2)
49+
np.testing.assert_array_equal(xt.example3_xtensor2(y[1:, 1:, 0]), v.T + 2)
50+
np.testing.assert_array_equal(xt.example3_xtensor2_colmajor(xc), xc.T + 2)
51+
52+
with self.assertRaises(TypeError):
53+
xt.example3_xtensor3(x)
54+
3955
def test_vectorize(self):
4056
x1 = np.array([[0, 1], [2, 3]])
4157
x2 = np.array([0, 1])

0 commit comments

Comments
 (0)