Skip to content

Commit c801639

Browse files
committed
Adding possibility to 'cast' or copy to xt::xarray etc
1 parent 2a68837 commit c801639

File tree

6 files changed

+157
-17
lines changed

6 files changed

+157
-17
lines changed

.azure-pipelines/unix-build.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ steps:
4242
displayName: Example - readme 1
4343
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/readme_example_1
4444
45+
- script: |
46+
source activate xtensor-python
47+
cmake . -DPYTHON_EXECUTABLE=`which python`
48+
cmake --build .
49+
python example.py
50+
displayName: Example - Copy 'cast'
51+
workingDirectory: $(Build.SourcesDirectory)/docs/source/examples/copy_cast
52+
4553
- script: |
4654
source activate xtensor-python
4755
cmake . -DPYTHON_EXECUTABLE=`which python`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
cmake_minimum_required(VERSION 3.1..3.19)
2+
3+
project(mymodule)
4+
5+
find_package(pybind11 CONFIG REQUIRED)
6+
find_package(xtensor REQUIRED)
7+
find_package(xtensor-python REQUIRED)
8+
find_package(Python REQUIRED COMPONENTS NumPy)
9+
10+
pybind11_add_module(mymodule main.cpp)
11+
target_link_libraries(mymodule PUBLIC pybind11::module xtensor-python Python::NumPy)
12+
13+
target_compile_definitions(mymodule PRIVATE VERSION_INFO=0.1.0)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import mymodule
2+
import numpy as np
3+
4+
c = np.array([[1, 2, 3], [4, 5, 6]])
5+
assert np.isclose(np.sum(np.sin(c)), mymodule.sum_of_sines(c))
6+
assert np.isclose(np.sum(np.cos(c)), mymodule.sum_of_cosines(c))
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <numeric>
2+
#include <xtensor.hpp>
3+
#include <pybind11/pybind11.h>
4+
#define FORCE_IMPORT_ARRAY
5+
#include <xtensor-python/pyarray.hpp>
6+
7+
double sum_of_sines(xt::pyarray<double>& m)
8+
{
9+
auto sines = xt::sin(m); // sines does not actually hold values.
10+
return std::accumulate(sines.begin(), sines.end(), 0.0);
11+
}
12+
13+
double sum_of_cosines(const xt::xarray<double>& m)
14+
{
15+
auto cosines = xt::cos(m); // cosines does not actually hold values.
16+
return std::accumulate(cosines.begin(), cosines.end(), 0.0);
17+
}
18+
19+
PYBIND11_MODULE(mymodule, m)
20+
{
21+
xt::import_numpy();
22+
m.doc() = "Test module for xtensor python bindings";
23+
m.def("sum_of_sines", sum_of_sines, "Sum the sines of the input values");
24+
m.def("sum_of_cosines", sum_of_cosines, "Sum the cosines of the input values");
25+
}

include/xtensor-python/pynative_casters.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "xtensor_type_caster_base.hpp"
1414

15-
1615
namespace pybind11
1716
{
1817
namespace detail

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 105 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,87 @@ namespace pybind11
2323
{
2424
namespace detail
2525
{
26+
template <typename T, xt::layout_type L>
27+
struct xtensor_get_buffer
28+
{
29+
template <typename H>
30+
static auto get(H src)
31+
{
32+
return array_t<T, array::c_style | array::forcecast>::ensure(src);
33+
}
34+
};
35+
36+
template <typename T>
37+
struct xtensor_get_buffer<T, xt::layout_type::column_major>
38+
{
39+
template <typename H>
40+
static auto get(H src)
41+
{
42+
return array_t<T, array::f_style>::ensure(src);
43+
}
44+
};
45+
46+
template <class T>
47+
struct xtensor_check_buffer
48+
{
49+
};
50+
51+
template <class T, xt::layout_type L>
52+
struct xtensor_check_buffer<xt::xarray<T, L>>
53+
{
54+
template <typename H>
55+
static auto get(H src)
56+
{
57+
auto buf = xtensor_get_buffer<T, L>::get(src);
58+
return buf;
59+
}
60+
};
61+
62+
template <class T, std::size_t N, xt::layout_type L>
63+
struct xtensor_check_buffer<xt::xtensor<T, N, L>>
64+
{
65+
template <typename H>
66+
static auto get(H src)
67+
{
68+
auto buf = xtensor_get_buffer<T, L>::get(src);
69+
if (buf.ndim() != N) {
70+
return false;
71+
}
72+
return buf;
73+
}
74+
};
75+
76+
template <class CT, class S, xt::layout_type L, class FST>
77+
struct xtensor_check_buffer<xt::xstrided_view<CT, S, L, FST>>
78+
{
79+
template <typename H>
80+
static auto get(H /*src*/)
81+
{
82+
return false;
83+
}
84+
};
85+
86+
template <class EC, xt::layout_type L, class SC, class Tag>
87+
struct xtensor_check_buffer<xt::xarray_adaptor<EC, L, SC, Tag>>
88+
{
89+
template <typename H>
90+
static auto get(H /*src*/)
91+
{
92+
return false;
93+
}
94+
};
95+
96+
template <class EC, std::size_t N, xt::layout_type L, class Tag>
97+
struct xtensor_check_buffer<xt::xtensor_adaptor<EC, N, L, Tag>>
98+
{
99+
template <typename H>
100+
static auto get(H /*src*/)
101+
{
102+
return false;
103+
}
104+
};
105+
106+
26107
// Casts a strided expression type to numpy array.If given a base,
27108
// the numpy array references the src data, otherwise it'll make a copy.
28109
// The writeable attributes lets you specify writeable flag for the array.
@@ -74,10 +155,6 @@ namespace pybind11
74155
template <class Type>
75156
struct xtensor_type_caster_base
76157
{
77-
bool load(handle /*src*/, bool)
78-
{
79-
return false;
80-
}
81158

82159
private:
83160

@@ -106,6 +183,30 @@ namespace pybind11
106183

107184
public:
108185

186+
PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<typename Type::value_type>::name + _("]"));
187+
188+
bool load(handle src, bool convert)
189+
{
190+
using T = typename Type::value_type;
191+
192+
if (!convert && !array_t<T>::check_(src)) {
193+
return false;
194+
}
195+
196+
auto buf = xtensor_check_buffer<Type>::get(src);
197+
198+
if (!buf) {
199+
return false;
200+
}
201+
202+
std::vector<size_t> shape(buf.ndim());
203+
std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin());
204+
value = Type(shape);
205+
std::copy(buf.data(), buf.data() + buf.size(), value.begin());
206+
207+
return true;
208+
}
209+
109210
// Normal returned non-reference, non-const value:
110211
static handle cast(Type&& src, return_value_policy /* policy */, handle parent)
111212
{
@@ -151,18 +252,6 @@ namespace pybind11
151252
{
152253
return cast_impl(src, policy, parent);
153254
}
154-
155-
#ifdef PYBIND11_DESCR // The macro is removed from pybind11 since 2.3
156-
static PYBIND11_DESCR name()
157-
{
158-
return _("xt::xtensor");
159-
}
160-
#else
161-
static constexpr auto name = _("xt::xtensor");
162-
#endif
163-
164-
template <typename T>
165-
using cast_op_type = cast_op_type<T>;
166255
};
167256
}
168257
}

0 commit comments

Comments
 (0)