Skip to content

Commit dea83d3

Browse files
committed
add pybind casters for strided_views, array_adaptor, and tensor_adaptor
1 parent d6f87cf commit dea83d3

File tree

6 files changed

+112
-11
lines changed

6 files changed

+112
-11
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "pyarray_backstrides.hpp"
2222
#include "pycontainer.hpp"
2323
#include "pystrides_adaptor.hpp"
24+
#include "pynative_casters.hpp"
2425
#include "xtensor_type_caster_base.hpp"
2526

2627
namespace xt
@@ -91,11 +92,6 @@ namespace pybind11
9192
}
9293
};
9394

94-
// Type caster for casting xarray to ndarray
95-
template <class T, xt::layout_type L>
96-
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
97-
{
98-
};
9995
}
10096
}
10197

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/***************************************************************************
2+
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#ifndef PYNATIVE_CASTER_HPP
11+
#define PYNATIVE_CASTER_HPP
12+
13+
#include "xtensor_type_caster_base.hpp"
14+
15+
16+
namespace pybind11
17+
{
18+
namespace detail
19+
{
20+
// Type caster for casting xarray to ndarray
21+
template <class T, xt::layout_type L>
22+
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
23+
{
24+
};
25+
26+
// Type caster for casting xt::xtensor to ndarray
27+
template <class T, std::size_t N, xt::layout_type L>
28+
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
29+
{
30+
};
31+
32+
// Type caster for casting xt::xstrided_view to ndarray
33+
template <class CT, class S, xt::layout_type L, class FST>
34+
struct type_caster<xt::xstrided_view<CT, S, L, FST>> : xtensor_type_caster_base<xt::xstrided_view<CT, S, L, FST>>
35+
{
36+
};
37+
38+
// Type caster for casting xt::xarray_adaptor to ndarray
39+
template <class EC, xt::layout_type L, class SC, class Tag>
40+
struct type_caster<xt::xarray_adaptor<EC, L, SC, Tag>> : xtensor_type_caster_base<xt::xarray_adaptor<EC, L, SC, Tag>>
41+
{
42+
};
43+
44+
// Type caster for casting xt::xtensor_adaptor to ndarray
45+
template <class EC, std::size_t N, xt::layout_type L, class Tag>
46+
struct type_caster<xt::xtensor_adaptor<EC, N, L, Tag>> : xtensor_type_caster_base<xt::xtensor_adaptor<EC, N, L, Tag>>
47+
{
48+
};
49+
}
50+
}
51+
52+
#endif

include/xtensor-python/pytensor.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "pycontainer.hpp"
2323
#include "pystrides_adaptor.hpp"
24+
#include "pynative_casters.hpp"
2425
#include "xtensor_type_caster_base.hpp"
2526

2627
namespace xt
@@ -99,11 +100,6 @@ namespace pybind11
99100
}
100101
};
101102

102-
// Type caster for casting xt::xtensor to ndarray
103-
template <class T, std::size_t N, xt::layout_type L>
104-
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
105-
{
106-
};
107103
}
108104
}
109105

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace pybind11
3939
std::vector<std::size_t> python_shape(src.shape().size());
4040
std::copy(src.shape().begin(), src.shape().end(), python_shape.begin());
4141

42-
array a(python_shape, python_strides, src.begin(), base);
42+
array a(python_shape, python_strides, src.data(), base);
4343

4444
if (!writeable)
4545
{

test_python/main.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "xtensor-python/pyarray.hpp"
1616
#include "xtensor-python/pytensor.hpp"
1717
#include "xtensor-python/pyvectorize.hpp"
18+
#include "xtensor/xadapt.hpp"
19+
#include "xtensor/xstrided_view.hpp"
1820

1921
namespace py = pybind11;
2022
using complex_t = std::complex<double>;
@@ -133,6 +135,34 @@ class C
133135
array_type m_array;
134136
};
135137

138+
struct test_native_casters
139+
{
140+
using array_type = xt::xarray<double>;
141+
array_type a{{0, 1, 2},{3, 4, 5}};
142+
143+
const auto & get_array(){
144+
return a;
145+
}
146+
147+
auto get_strided_view(){
148+
return xt::strided_view(a, {xt::range(0, 1), xt::range(0, 3, 2)});
149+
}
150+
151+
auto get_array_adapter(){
152+
using shape_type = std::vector<size_t>;
153+
shape_type shape = {2, 2};
154+
shape_type stride = {3, 2};
155+
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
156+
}
157+
158+
auto get_tensor_adapter(){
159+
using shape_type = std::array<size_t, 2>;
160+
shape_type shape = {2, 2};
161+
shape_type stride = {3, 2};
162+
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
163+
}
164+
};
165+
136166
xt::pyarray<A> dtype_to_python()
137167
{
138168
A a1{123, 321, 'a', {1, 2, 3}};
@@ -257,4 +287,12 @@ PYBIND11_MODULE(xtensor_python_test, m)
257287

258288
m.def("diff_shape_overload", [](xt::pytensor<int, 1> a) { return 1; });
259289
m.def("diff_shape_overload", [](xt::pytensor<int, 2> a) { return 2; });
290+
291+
py::class_<test_native_casters>(m, "test_native_casters")
292+
.def(py::init<>())
293+
.def("get_array", &test_native_casters::get_array, py::return_value_policy::reference_internal)
294+
.def("get_strided_view", &test_native_casters::get_strided_view)
295+
.def("get_array_adapter", &test_native_casters::get_array_adapter)
296+
.def("get_tensor_adapter", &test_native_casters::get_tensor_adapter);
297+
260298
}

test_python/test_pyarray.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,25 @@ def test_diff_shape_overload(self):
166166
# FIXME: the TypeError information is not informative
167167
xt.diff_shape_overload(np.ones((2, 2, 2)))
168168

169+
def test_native_casters(self):
170+
obj = xt.test_native_casters()
171+
arr = obj.get_array()
172+
173+
strided_view = obj.get_strided_view()
174+
strided_view[0, 1] = -1
175+
self.assertEqual(strided_view.shape, (1, 2))
176+
self.assertEqual(arr[0, 2], -1)
177+
178+
adapter = obj.get_array_adapter()
179+
self.assertEqual(adapter.shape, (2, 2))
180+
adapter[1, 1] = -2
181+
self.assertEqual(arr[1, 2], -2)
182+
183+
adapter = obj.get_tensor_adapter()
184+
self.assertEqual(adapter.shape, (2, 2))
185+
adapter[1, 1] = -3
186+
self.assertEqual(arr[1, 2], -3)
187+
169188

170189
class AttributeTest(TestCase):
171190

0 commit comments

Comments
 (0)