Skip to content

Commit fa97cf0

Browse files
committed
add pybind casters for strided_views, array_adaptor, and tensor_adaptor
1 parent d6f87cf commit fa97cf0

File tree

7 files changed

+193
-25
lines changed

7 files changed

+193
-25
lines changed

CMakeLists.txt

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,16 @@ message(STATUS "Found numpy: ${NUMPY_INCLUDE_DIRS}")
6262
# =====
6363

6464
set(XTENSOR_PYTHON_HEADERS
65-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray.hpp
66-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray_backstrides.hpp
67-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pycontainer.hpp
68-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pystrides_adaptor.hpp
69-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
70-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
71-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
72-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
73-
)
65+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray.hpp
66+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray_backstrides.hpp
67+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pycontainer.hpp
68+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pynative_casters.hpp
69+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pystrides_adaptor.hpp
70+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
71+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
72+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
73+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
74+
)
7475

7576
add_library(xtensor-python INTERFACE)
7677
target_include_directories(xtensor-python INTERFACE

include/xtensor-python/pyarray.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "pyarray_backstrides.hpp"
2222
#include "pycontainer.hpp"
2323
#include "pystrides_adaptor.hpp"
24+
#include "pynative_casters.hpp"
2425
#include "xtensor_type_caster_base.hpp"
2526

2627
namespace xt
@@ -91,11 +92,6 @@ namespace pybind11
9192
}
9293
};
9394

94-
// Type caster for casting xarray to ndarray
95-
template <class T, xt::layout_type L>
96-
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
97-
{
98-
};
9995
}
10096
}
10197

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/***************************************************************************
2+
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#ifndef PYNATIVE_CASTERS_HPP
11+
#define PYNATIVE_CASTERS_HPP
12+
13+
#include "xtensor_type_caster_base.hpp"
14+
15+
16+
namespace pybind11
17+
{
18+
namespace detail
19+
{
20+
// Type caster for casting xarray to ndarray
21+
template <class T, xt::layout_type L>
22+
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
23+
{
24+
};
25+
26+
// Type caster for casting xt::xtensor to ndarray
27+
template <class T, std::size_t N, xt::layout_type L>
28+
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
29+
{
30+
};
31+
32+
// Type caster for casting xt::xstrided_view to ndarray
33+
template <class CT, class S, xt::layout_type L, class FST>
34+
struct type_caster<xt::xstrided_view<CT, S, L, FST>> : xtensor_type_caster_base<xt::xstrided_view<CT, S, L, FST>>
35+
{
36+
};
37+
38+
// Type caster for casting xt::xarray_adaptor to ndarray
39+
template <class EC, xt::layout_type L, class SC, class Tag>
40+
struct type_caster<xt::xarray_adaptor<EC, L, SC, Tag>> : xtensor_type_caster_base<xt::xarray_adaptor<EC, L, SC, Tag>>
41+
{
42+
};
43+
44+
// Type caster for casting xt::xtensor_adaptor to ndarray
45+
template <class EC, std::size_t N, xt::layout_type L, class Tag>
46+
struct type_caster<xt::xtensor_adaptor<EC, N, L, Tag>> : xtensor_type_caster_base<xt::xtensor_adaptor<EC, N, L, Tag>>
47+
{
48+
};
49+
}
50+
}
51+
52+
#endif

include/xtensor-python/pytensor.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "pycontainer.hpp"
2323
#include "pystrides_adaptor.hpp"
24+
#include "pynative_casters.hpp"
2425
#include "xtensor_type_caster_base.hpp"
2526

2627
namespace xt
@@ -99,11 +100,6 @@ namespace pybind11
99100
}
100101
};
101102

102-
// Type caster for casting xt::xtensor to ndarray
103-
template <class T, std::size_t N, xt::layout_type L>
104-
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
105-
{
106-
};
107103
}
108104
}
109105

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace pybind11
2323
{
2424
namespace detail
2525
{
26-
// Casts an xtensor (or xarray) type to numpy array.If given a base,
26+
// Casts a strided expression type to numpy array.If given a base,
2727
// the numpy array references the src data, otherwise it'll make a copy.
2828
// The writeable attributes lets you specify writeable flag for the array.
2929
template <typename Type>
@@ -39,7 +39,7 @@ namespace pybind11
3939
std::vector<std::size_t> python_shape(src.shape().size());
4040
std::copy(src.shape().begin(), src.shape().end(), python_shape.begin());
4141

42-
array a(python_shape, python_strides, src.begin(), base);
42+
array a(python_shape, python_strides, &*(src.begin()), base);
4343

4444
if (!writeable)
4545
{
@@ -49,8 +49,8 @@ namespace pybind11
4949
return a.release();
5050
}
5151

52-
// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
53-
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
52+
// Takes an lvalue ref to some strided expression type and a (python) base object, creating a numpy array that
53+
// reference the expression object's data with `base` as the python-registered base class (if omitted,
5454
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
5555
// non-writeable if the given type is const.
5656
template <typename Type, typename CType>
@@ -59,7 +59,7 @@ namespace pybind11
5959
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
6060
}
6161

62-
// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
62+
// Takes a pointer to a strided expression, builds a capsule around it, then returns a numpy
6363
// array that references the encapsulated data with a python-side reference to the capsule to tie
6464
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
6565
// not the CType of the pointer given is const.
@@ -70,7 +70,7 @@ namespace pybind11
7070
return xtensor_ref_array<Type>(*src, base);
7171
}
7272

73-
// Base class of type_caster for xtensor and xarray
73+
// Base class of type_caster for strided expressions
7474
template <class Type>
7575
struct xtensor_type_caster_base
7676
{

test_python/main.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "xtensor-python/pyarray.hpp"
1616
#include "xtensor-python/pytensor.hpp"
1717
#include "xtensor-python/pyvectorize.hpp"
18+
#include "xtensor/xadapt.hpp"
19+
#include "xtensor/xstrided_view.hpp"
1820

1921
namespace py = pybind11;
2022
using complex_t = std::complex<double>;
@@ -133,6 +135,49 @@ class C
133135
array_type m_array;
134136
};
135137

138+
struct test_native_casters
139+
{
140+
using array_type = xt::xarray<double>;
141+
array_type a = xt::ones<double>({50, 50});
142+
143+
const auto & get_array()
144+
{
145+
return a;
146+
}
147+
148+
auto get_strided_view()
149+
{
150+
return xt::strided_view(a, {xt::range(0, 1), xt::range(0, 3, 2)});
151+
}
152+
153+
auto get_array_adapter()
154+
{
155+
using shape_type = std::vector<size_t>;
156+
shape_type shape = {2, 2};
157+
shape_type stride = {3, 2};
158+
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
159+
}
160+
161+
auto get_tensor_adapter()
162+
{
163+
using shape_type = std::array<size_t, 2>;
164+
shape_type shape = {2, 2};
165+
shape_type stride = {3, 2};
166+
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
167+
}
168+
169+
auto get_owning_array_adapter()
170+
{
171+
size_t size = 100;
172+
int * data = new int[size];
173+
std::fill(data, data + size, 1);
174+
175+
using shape_type = std::vector<size_t>;
176+
shape_type shape = {size};
177+
return xt::adapt(std::move(data), size, xt::acquire_ownership(), shape);
178+
}
179+
};
180+
136181
xt::pyarray<A> dtype_to_python()
137182
{
138183
A a1{123, 321, 'a', {1, 2, 3}};
@@ -257,4 +302,15 @@ PYBIND11_MODULE(xtensor_python_test, m)
257302

258303
m.def("diff_shape_overload", [](xt::pytensor<int, 1> a) { return 1; });
259304
m.def("diff_shape_overload", [](xt::pytensor<int, 2> a) { return 2; });
305+
306+
py::class_<test_native_casters>(m, "test_native_casters")
307+
.def(py::init<>())
308+
.def("get_array", &test_native_casters::get_array, py::return_value_policy::reference_internal) // memory managed by the class instance
309+
.def("get_strided_view", &test_native_casters::get_strided_view, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned view
310+
.def("get_array_adapter", &test_native_casters::get_array_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter
311+
.def("get_tensor_adapter", &test_native_casters::get_tensor_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter
312+
.def("get_owning_array_adapter", &test_native_casters::get_owning_array_adapter) // auto memory management as the adapter owns its memory
313+
.def("view_keep_alive_member_function", [](test_native_casters & self, xt::pyarray<double> & a) // keep_alive<0, 2>() => do not free second parameter before the returned view
314+
{return xt::reshape_view(a, {a.size(), });},
315+
py::keep_alive<0, 2>());
260316
}

test_python/test_pyarray.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,73 @@ def test_diff_shape_overload(self):
166166
# FIXME: the TypeError information is not informative
167167
xt.diff_shape_overload(np.ones((2, 2, 2)))
168168

169+
def test_native_casters(self):
170+
import gc
171+
172+
# check keep alive policy for get_strided_view()
173+
gc.collect()
174+
obj = xt.test_native_casters()
175+
a = obj.get_strided_view()
176+
obj = None
177+
gc.collect()
178+
_ = np.zeros((100, 100))
179+
self.assertEqual(a.sum(), a.size)
180+
181+
# check keep alive policy for get_array_adapter()
182+
gc.collect()
183+
obj = xt.test_native_casters()
184+
a = obj.get_array_adapter()
185+
obj = None
186+
gc.collect()
187+
_ = np.zeros((100, 100))
188+
self.assertEqual(a.sum(), a.size)
189+
190+
# check keep alive policy for get_array_adapter()
191+
gc.collect()
192+
obj = xt.test_native_casters()
193+
a = obj.get_tensor_adapter()
194+
obj = None
195+
gc.collect()
196+
_ = np.zeros((100, 100))
197+
self.assertEqual(a.sum(), a.size)
198+
199+
# check keep alive policy for get_owning_array_adapter()
200+
gc.collect()
201+
obj = xt.test_native_casters()
202+
a = obj.get_owning_array_adapter()
203+
gc.collect()
204+
_ = np.zeros((100, 100))
205+
self.assertEqual(a.sum(), a.size)
206+
207+
# check keep alive policy for view_keep_alive_member_function()
208+
gc.collect()
209+
a = np.ones((100, 100))
210+
b = obj.view_keep_alive_member_function(a)
211+
obj = None
212+
a = None
213+
gc.collect()
214+
_ = np.zeros((100, 100))
215+
self.assertEqual(b.sum(), b.size)
216+
217+
# check shared buffer (insure that no copy is done)
218+
obj = xt.test_native_casters()
219+
arr = obj.get_array()
220+
221+
strided_view = obj.get_strided_view()
222+
strided_view[0, 1] = -1
223+
self.assertEqual(strided_view.shape, (1, 2))
224+
self.assertEqual(arr[0, 2], -1)
225+
226+
adapter = obj.get_array_adapter()
227+
self.assertEqual(adapter.shape, (2, 2))
228+
adapter[1, 1] = -2
229+
self.assertEqual(arr[0, 5], -2)
230+
231+
adapter = obj.get_tensor_adapter()
232+
self.assertEqual(adapter.shape, (2, 2))
233+
adapter[1, 1] = -3
234+
self.assertEqual(arr[0, 5], -3)
235+
169236

170237
class AttributeTest(TestCase):
171238

0 commit comments

Comments
 (0)