Skip to content

Commit 06d86c9

Browse files
authored
Merge pull request #267 from tdegeus/qad
Adding possibility to 'cast' or copy to `xt::xarray` etc
2 parents 43b244e + af91def commit 06d86c9

File tree

9 files changed

+276
-18
lines changed

9 files changed

+276
-18
lines changed

.azure-pipelines/unix-build.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ steps:
4545
displayName: Example - readme 1
4646
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/readme_example_1
4747
48+
- script: |
49+
source activate xtensor-python
50+
cmake -Bbuild -DPython_EXECUTABLE=`which python`
51+
cd build
52+
cmake --build .
53+
cp ../example.py .
54+
python example.py
55+
cd ..
56+
displayName: Example - Copy 'cast'
57+
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/copy_cast
58+
4859
- script: |
4960
source activate xtensor-python
5061
cmake -Bbuild -DPython_EXECUTABLE=`which python`

docs/source/examples.rst

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,54 @@ Then we can test the module:
143143
Since we did not install the module,
144144
we should compile and run the example from the same folder.
145145
To install, please consult
146-
`this *pybind11* / *CMake* example <https://github.com/pybind/cmake_example>`_.
146+
`this pybind11 / CMake example <https://github.com/pybind/cmake_example>`_.
147147
**Tip**: take care to modify that example with the correct *CMake* case ``Python_EXECUTABLE``.
148+
149+
Fall-back cast
150+
==============
151+
152+
The previous example showed you how to design your module to be flexible in accepting data.
153+
From C++ we used ``xt::xarray<double>``,
154+
whereas for the Python API we used ``xt::pyarray<double>`` to operate directly on the memory
155+
of a NumPy array from Python (without copying the data).
156+
157+
Sometimes, you might not have the flexibility to design your module's methods
158+
with template parameters.
159+
This might occur when you want to ``override`` functions
160+
(though it is recommended to use CRTP to still use templates).
161+
In this case we can still bind the module in Python using *xtensor-python*,
162+
however, we have to copy the data from a (NumPy) array.
163+
This means that although the following signatures are quite different when used from C++,
164+
as follows:
165+
166+
1. *Constant reference*: read from the data, without copying it.
167+
168+
.. code-block:: cpp
169+
170+
void foo(const xt::xarray<double>& a);
171+
172+
2. *Reference*: read from and/or write to the data, without copying it.
173+
174+
.. code-block:: cpp
175+
176+
void foo(xt::xarray<double>& a);
177+
178+
3. *Copy*: copy the data.
179+
180+
.. code-block:: cpp
181+
182+
void foo(xt::xarray<double> a);
183+
184+
The Python will all cases result in a copy to a temporary variable
185+
(though the last signature will lead to a copy to a temporary variable, and another copy to ``a``).
186+
On the one hand, this is more costly than when using ``xt::pyarray`` and ``xt::pyxtensor``,
187+
on the other hand, it means that all changes you make to a reference, are made to the temporary
188+
copy, and are thus lost.
189+
190+
Still, it might be a convenient way to create Python bindings, using a minimal effort.
191+
Consider this example:
192+
193+
:download:`main.cpp <examples/copy_cast/main.cpp>`
194+
195+
.. literalinclude:: examples/copy_cast/main.cpp
196+
:language: cpp
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
cmake_minimum_required(VERSION 3.1..3.19)
2+
3+
project(mymodule)
4+
5+
find_package(pybind11 CONFIG REQUIRED)
6+
find_package(xtensor REQUIRED)
7+
find_package(xtensor-python REQUIRED)
8+
find_package(Python REQUIRED COMPONENTS NumPy)
9+
10+
pybind11_add_module(mymodule main.cpp)
11+
target_link_libraries(mymodule PUBLIC pybind11::module xtensor-python Python::NumPy)
12+
13+
target_compile_definitions(mymodule PRIVATE VERSION_INFO=0.1.0)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import mymodule
2+
import numpy as np
3+
4+
c = np.array([[1, 2, 3], [4, 5, 6]])
5+
assert np.isclose(np.sum(np.sin(c)), mymodule.sum_of_sines(c))
6+
assert np.isclose(np.sum(np.cos(c)), mymodule.sum_of_cosines(c))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <numeric>
2+
#include <xtensor.hpp>
3+
#include <pybind11/pybind11.h>
4+
#define FORCE_IMPORT_ARRAY
5+
#include <xtensor-python/pyarray.hpp>
6+
7+
template <class T>
8+
double sum_of_sines(T& m)
9+
{
10+
auto sines = xt::sin(m); // sines does not actually hold values.
11+
return std::accumulate(sines.begin(), sines.end(), 0.0);
12+
}
13+
14+
// In the Python API this a reference to a temporary variable
15+
double sum_of_cosines(const xt::xarray<double>& m)
16+
{
17+
auto cosines = xt::cos(m); // cosines does not actually hold values.
18+
return std::accumulate(cosines.begin(), cosines.end(), 0.0);
19+
}
20+
21+
PYBIND11_MODULE(mymodule, m)
22+
{
23+
xt::import_numpy();
24+
m.doc() = "Test module for xtensor python bindings";
25+
m.def("sum_of_sines", sum_of_sines<xt::pyarray<double>>, "Sum the sines of the input values");
26+
m.def("sum_of_cosines", sum_of_cosines, "Sum the cosines of the input values");
27+
}

include/xtensor-python/pynative_casters.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "xtensor_type_caster_base.hpp"
1414

15-
1615
namespace pybind11
1716
{
1817
namespace detail

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,97 @@ namespace pybind11
2323
{
2424
namespace detail
2525
{
26+
template <typename T, xt::layout_type L>
27+
struct pybind_array_getter_impl
28+
{
29+
static auto run(handle src)
30+
{
31+
return array_t<T, array::c_style | array::forcecast>::ensure(src);
32+
}
33+
};
34+
35+
template <typename T>
36+
struct pybind_array_getter_impl<T, xt::layout_type::column_major>
37+
{
38+
static auto run(handle src)
39+
{
40+
return array_t<T, array::f_style | array::forcecast>::ensure(src);
41+
}
42+
};
43+
44+
template <class T>
45+
struct pybind_array_getter
46+
{
47+
};
48+
49+
template <class T, xt::layout_type L>
50+
struct pybind_array_getter<xt::xarray<T, L>>
51+
{
52+
static auto run(handle src)
53+
{
54+
return pybind_array_getter_impl<T, L>::run(src);
55+
}
56+
};
57+
58+
template <class T, std::size_t N, xt::layout_type L>
59+
struct pybind_array_getter<xt::xtensor<T, N, L>>
60+
{
61+
static auto run(handle src)
62+
{
63+
return pybind_array_getter_impl<T, L>::run(src);
64+
}
65+
};
66+
67+
template <class CT, class S, xt::layout_type L, class FST>
68+
struct pybind_array_getter<xt::xstrided_view<CT, S, L, FST>>
69+
{
70+
static auto run(handle /*src*/)
71+
{
72+
return false;
73+
}
74+
};
75+
76+
template <class EC, xt::layout_type L, class SC, class Tag>
77+
struct pybind_array_getter<xt::xarray_adaptor<EC, L, SC, Tag>>
78+
{
79+
static auto run(handle src)
80+
{
81+
auto buf = pybind_array_getter_impl<EC, L>::run(src);
82+
return buf;
83+
}
84+
};
85+
86+
template <class EC, std::size_t N, xt::layout_type L, class Tag>
87+
struct pybind_array_getter<xt::xtensor_adaptor<EC, N, L, Tag>>
88+
{
89+
static auto run(handle /*src*/)
90+
{
91+
return false;
92+
}
93+
};
94+
95+
96+
template <class T>
97+
struct pybind_array_dim_checker
98+
{
99+
template <class B>
100+
static bool run(const B& buf)
101+
{
102+
return true;
103+
}
104+
};
105+
106+
template <class T, std::size_t N, xt::layout_type L>
107+
struct pybind_array_dim_checker<xt::xtensor<T, N, L>>
108+
{
109+
template <class B>
110+
static bool run(const B& buf)
111+
{
112+
return buf.ndim() == N;
113+
}
114+
};
115+
116+
26117
// Casts a strided expression type to numpy array.If given a base,
27118
// the numpy array references the src data, otherwise it'll make a copy.
28119
// The writeable attributes lets you specify writeable flag for the array.
@@ -74,10 +165,6 @@ namespace pybind11
74165
template <class Type>
75166
struct xtensor_type_caster_base
76167
{
77-
bool load(handle /*src*/, bool)
78-
{
79-
return false;
80-
}
81168

82169
private:
83170

@@ -106,6 +193,36 @@ namespace pybind11
106193

107194
public:
108195

196+
PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<typename Type::value_type>::name + _("]"));
197+
198+
bool load(handle src, bool convert)
199+
{
200+
using T = typename Type::value_type;
201+
202+
if (!convert && !array_t<T>::check_(src))
203+
{
204+
return false;
205+
}
206+
207+
auto buf = pybind_array_getter<Type>::run(src);
208+
209+
if (!buf)
210+
{
211+
return false;
212+
}
213+
if (!pybind_array_dim_checker<Type>::run(buf))
214+
{
215+
return false;
216+
}
217+
218+
std::vector<size_t> shape(buf.ndim());
219+
std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin());
220+
value = Type::from_shape(shape);
221+
std::copy(buf.data(), buf.data() + buf.size(), value.data());
222+
223+
return true;
224+
}
225+
109226
// Normal returned non-reference, non-const value:
110227
static handle cast(Type&& src, return_value_policy /* policy */, handle parent)
111228
{
@@ -151,18 +268,6 @@ namespace pybind11
151268
{
152269
return cast_impl(src, policy, parent);
153270
}
154-
155-
#ifdef PYBIND11_DESCR // The macro is removed from pybind11 since 2.3
156-
static PYBIND11_DESCR name()
157-
{
158-
return _("xt::xtensor");
159-
}
160-
#else
161-
static constexpr auto name = _("xt::xtensor");
162-
#endif
163-
164-
template <typename T>
165-
using cast_op_type = cast_op_type<T>;
166271
};
167272
}
168273
}

test_python/main.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,33 @@ xt::pyarray<double> example2(xt::pyarray<double>& m)
3333
return m + 2;
3434
}
3535

36+
xt::xarray<int> example3_xarray(const xt::xarray<int>& m)
37+
{
38+
return xt::transpose(m) + 2;
39+
}
40+
41+
xt::xarray<int, xt::layout_type::column_major> example3_xarray_colmajor(
42+
const xt::xarray<int, xt::layout_type::column_major>& m)
43+
{
44+
return xt::transpose(m) + 2;
45+
}
46+
47+
xt::xtensor<int, 3> example3_xtensor3(const xt::xtensor<int, 3>& m)
48+
{
49+
return xt::transpose(m) + 2;
50+
}
51+
52+
xt::xtensor<int, 2> example3_xtensor2(const xt::xtensor<int, 2>& m)
53+
{
54+
return xt::transpose(m) + 2;
55+
}
56+
57+
xt::xtensor<int, 2, xt::layout_type::column_major> example3_xtensor2_colmajor(
58+
const xt::xtensor<int, 2, xt::layout_type::column_major>& m)
59+
{
60+
return xt::transpose(m) + 2;
61+
}
62+
3663
// Readme Examples
3764

3865
double readme_example1(xt::pyarray<double>& m)
@@ -249,6 +276,11 @@ PYBIND11_MODULE(xtensor_python_test, m)
249276

250277
m.def("example1", example1);
251278
m.def("example2", example2);
279+
m.def("example3_xarray", example3_xarray);
280+
m.def("example3_xarray_colmajor", example3_xarray_colmajor);
281+
m.def("example3_xtensor3", example3_xtensor3);
282+
m.def("example3_xtensor2", example3_xtensor2);
283+
m.def("example3_xtensor2_colmajor", example3_xtensor2_colmajor);
252284

253285
m.def("complex_overload", no_complex_overload);
254286
m.def("complex_overload", complex_overload);

test_python/test_pyarray.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ def test_example2(self):
3636
y = xt.example2(x)
3737
np.testing.assert_allclose(y, res, 1e-12)
3838

39+
def test_example3(self):
40+
x = np.arange(2 * 3).reshape(2, 3)
41+
xc = np.asfortranarray(x)
42+
y = np.arange(2 * 3 * 4).reshape(2, 3, 4)
43+
v = y[1:, 1:, 0]
44+
z = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
45+
np.testing.assert_array_equal(xt.example3_xarray(x), x.T + 2)
46+
np.testing.assert_array_equal(xt.example3_xarray_colmajor(xc), xc.T + 2)
47+
np.testing.assert_array_equal(xt.example3_xtensor3(y), y.T + 2)
48+
np.testing.assert_array_equal(xt.example3_xtensor2(x), x.T + 2)
49+
np.testing.assert_array_equal(xt.example3_xtensor2(y[1:, 1:, 0]), v.T + 2)
50+
np.testing.assert_array_equal(xt.example3_xtensor2_colmajor(xc), xc.T + 2)
51+
52+
with self.assertRaises(TypeError):
53+
xt.example3_xtensor3(x)
54+
3955
def test_vectorize(self):
4056
x1 = np.array([[0, 1], [2, 3]])
4157
x2 = np.array([0, 1])

0 commit comments

Comments
 (0)