-
Notifications
You must be signed in to change notification settings - Fork 64
Adding possibility to 'cast' or copy to xt::xarray
etc
#267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c801639
e21ecb0
428ed03
5c55426
acc30ed
af91def
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
cmake_minimum_required(VERSION 3.1..3.19) | ||
|
||
project(mymodule) | ||
|
||
find_package(pybind11 CONFIG REQUIRED) | ||
find_package(xtensor REQUIRED) | ||
find_package(xtensor-python REQUIRED) | ||
find_package(Python REQUIRED COMPONENTS NumPy) | ||
|
||
pybind11_add_module(mymodule main.cpp) | ||
target_link_libraries(mymodule PUBLIC pybind11::module xtensor-python Python::NumPy) | ||
|
||
target_compile_definitions(mymodule PRIVATE VERSION_INFO=0.1.0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import mymodule | ||
import numpy as np | ||
|
||
c = np.array([[1, 2, 3], [4, 5, 6]]) | ||
assert np.isclose(np.sum(np.sin(c)), mymodule.sum_of_sines(c)) | ||
assert np.isclose(np.sum(np.cos(c)), mymodule.sum_of_cosines(c)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#include <numeric> | ||
#include <xtensor.hpp> | ||
#include <pybind11/pybind11.h> | ||
#define FORCE_IMPORT_ARRAY | ||
#include <xtensor-python/pyarray.hpp> | ||
|
||
template <class T> | ||
double sum_of_sines(T& m) | ||
{ | ||
auto sines = xt::sin(m); // sines does not actually hold values. | ||
return std::accumulate(sines.begin(), sines.end(), 0.0); | ||
} | ||
|
||
// In the Python API this a reference to a temporary variable | ||
double sum_of_cosines(const xt::xarray<double>& m) | ||
{ | ||
auto cosines = xt::cos(m); // cosines does not actually hold values. | ||
return std::accumulate(cosines.begin(), cosines.end(), 0.0); | ||
} | ||
|
||
PYBIND11_MODULE(mymodule, m) | ||
{ | ||
xt::import_numpy(); | ||
m.doc() = "Test module for xtensor python bindings"; | ||
m.def("sum_of_sines", sum_of_sines<xt::pyarray<double>>, "Sum the sines of the input values"); | ||
m.def("sum_of_cosines", sum_of_cosines, "Sum the cosines of the input values"); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,6 @@ | |
|
||
#include "xtensor_type_caster_base.hpp" | ||
|
||
|
||
namespace pybind11 | ||
{ | ||
namespace detail | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,97 @@ namespace pybind11 | |
{ | ||
namespace detail | ||
{ | ||
template <typename T, xt::layout_type L> | ||
struct pybind_array_getter_impl | ||
{ | ||
static auto run(handle src) | ||
{ | ||
return array_t<T, array::c_style | array::forcecast>::ensure(src); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct pybind_array_getter_impl<T, xt::layout_type::column_major> | ||
{ | ||
static auto run(handle src) | ||
{ | ||
return array_t<T, array::f_style | array::forcecast>::ensure(src); | ||
} | ||
}; | ||
|
||
template <class T> | ||
struct pybind_array_getter | ||
{ | ||
}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can omit the definition of the generic case. A simple declaratoin should be enough. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought so too, but my compiler complained: xtensor_type_caster_base.hpp:50:16: error: no template named 'xtensor_check_buffer'; did you mean 'xtensor_get_buffer'?
struct xtensor_check_buffer<xt::xarray<T, L>>
^~~~~~~~~~~~~~~~~~~~
xtensor_get_buffer |
||
|
||
template <class T, xt::layout_type L> | ||
struct pybind_array_getter<xt::xarray<T, L>> | ||
{ | ||
static auto run(handle src) | ||
{ | ||
return pybind_array_getter_impl<T, L>::run(src); | ||
} | ||
}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpicking: in line with the previous renaming proposal: |
||
|
||
template <class T, std::size_t N, xt::layout_type L> | ||
struct pybind_array_getter<xt::xtensor<T, N, L>> | ||
{ | ||
static auto run(handle src) | ||
{ | ||
return pybind_array_getter_impl<T, L>::run(src); | ||
} | ||
}; | ||
|
||
template <class CT, class S, xt::layout_type L, class FST> | ||
struct pybind_array_getter<xt::xstrided_view<CT, S, L, FST>> | ||
{ | ||
static auto run(handle /*src*/) | ||
{ | ||
return false; | ||
} | ||
}; | ||
|
||
template <class EC, xt::layout_type L, class SC, class Tag> | ||
struct pybind_array_getter<xt::xarray_adaptor<EC, L, SC, Tag>> | ||
{ | ||
static auto run(handle src) | ||
{ | ||
auto buf = pybind_array_getter_impl<EC, L>::run(src); | ||
return buf; | ||
} | ||
}; | ||
|
||
template <class EC, std::size_t N, xt::layout_type L, class Tag> | ||
struct pybind_array_getter<xt::xtensor_adaptor<EC, N, L, Tag>> | ||
{ | ||
static auto run(handle /*src*/) | ||
{ | ||
return false; | ||
} | ||
}; | ||
|
||
|
||
template <class T> | ||
struct pybind_array_dim_checker | ||
{ | ||
template <class B> | ||
static bool run(const B& buf) | ||
{ | ||
return true; | ||
} | ||
}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpicking: |
||
|
||
template <class T, std::size_t N, xt::layout_type L> | ||
struct pybind_array_dim_checker<xt::xtensor<T, N, L>> | ||
{ | ||
template <class B> | ||
static bool run(const B& buf) | ||
{ | ||
return buf.ndim() == N; | ||
} | ||
}; | ||
|
||
|
||
// Casts a strided expression type to numpy array.If given a base, | ||
// the numpy array references the src data, otherwise it'll make a copy. | ||
// The writeable attributes lets you specify writeable flag for the array. | ||
|
@@ -74,10 +165,6 @@ namespace pybind11 | |
template <class Type> | ||
struct xtensor_type_caster_base | ||
{ | ||
bool load(handle /*src*/, bool) | ||
{ | ||
return false; | ||
} | ||
|
||
private: | ||
|
||
|
@@ -106,6 +193,36 @@ namespace pybind11 | |
|
||
public: | ||
|
||
PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<typename Type::value_type>::name + _("]")); | ||
|
||
bool load(handle src, bool convert) | ||
{ | ||
using T = typename Type::value_type; | ||
|
||
if (!convert && !array_t<T>::check_(src)) | ||
{ | ||
return false; | ||
} | ||
|
||
auto buf = pybind_array_getter<Type>::run(src); | ||
|
||
if (!buf) | ||
{ | ||
return false; | ||
} | ||
if (!pybind_array_dim_checker<Type>::run(buf)) | ||
{ | ||
return false; | ||
} | ||
|
||
std::vector<size_t> shape(buf.ndim()); | ||
std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin()); | ||
value = Type::from_shape(shape); | ||
std::copy(buf.data(), buf.data() + buf.size(), value.data()); | ||
|
||
return true; | ||
} | ||
|
||
// Normal returned non-reference, non-const value: | ||
static handle cast(Type&& src, return_value_policy /* policy */, handle parent) | ||
{ | ||
|
@@ -151,18 +268,6 @@ namespace pybind11 | |
{ | ||
return cast_impl(src, policy, parent); | ||
} | ||
|
||
#ifdef PYBIND11_DESCR // The macro is removed from pybind11 since 2.3 | ||
static PYBIND11_DESCR name() | ||
{ | ||
return _("xt::xtensor"); | ||
} | ||
#else | ||
static constexpr auto name = _("xt::xtensor"); | ||
#endif | ||
|
||
template <typename T> | ||
using cast_op_type = cast_op_type<T>; | ||
}; | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpicking: what about the following renaming?
xtensor_get_buffer
=>pybind_array_getter_impl
get
=>run