diff --git a/include/xtensor-python/pyarray.hpp b/include/xtensor-python/pyarray.hpp index 08733990..79de0b5b 100644 --- a/include/xtensor-python/pyarray.hpp +++ b/include/xtensor-python/pyarray.hpp @@ -158,6 +158,7 @@ namespace xt using inner_shape_type = typename base_type::inner_shape_type; using inner_strides_type = typename base_type::inner_strides_type; using inner_backstrides_type = typename base_type::inner_backstrides_type; + constexpr static std::size_t rank = SIZE_MAX; pyarray(); pyarray(const value_type& t); @@ -514,7 +515,7 @@ namespace xt { return; } - + m_shape = inner_shape_type(reinterpret_cast(PyArray_SHAPE(this->python_array())), static_cast(PyArray_NDIM(this->python_array()))); m_strides = inner_strides_type(reinterpret_cast(PyArray_STRIDES(this->python_array())), diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index 906bfff2..1dc9e313 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -168,6 +168,7 @@ namespace xt using inner_shape_type = typename base_type::inner_shape_type; using inner_strides_type = typename base_type::inner_strides_type; using inner_backstrides_type = typename base_type::inner_backstrides_type; + constexpr static std::size_t rank = N; pytensor(); pytensor(nested_initializer_list_t t); @@ -471,7 +472,7 @@ namespace xt { return; } - + if (PyArray_NDIM(this->python_array()) != N) { throw std::runtime_error("NumPy: ndarray has incorrect number of dimensions"); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0517e765..7174802a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -86,6 +86,7 @@ set(XTENSOR_PYTHON_TESTS test_pyarray.cpp test_pytensor.cpp test_pyvectorize.cpp + test_sfinae.cpp ) add_executable(test_xtensor_python ${XTENSOR_PYTHON_TESTS} ${XTENSOR_PYTHON_HEADERS}) diff --git a/test/test_sfinae.cpp b/test/test_sfinae.cpp new file mode 100644 index 00000000..a6144857 --- /dev/null +++ b/test/test_sfinae.cpp @@ -0,0 +1,53 @@ +/*************************************************************************** +* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay * +* Copyright (c) QuantStack * +* * +* Distributed under the terms of the BSD 3-Clause License. * +* * +* The full license is in the file LICENSE, distributed with this software. * +****************************************************************************/ + +#include + +#include "gtest/gtest.h" +#include "xtensor-python/pytensor.hpp" +#include "xtensor-python/pyarray.hpp" +#include "xtensor/xarray.hpp" +#include "xtensor/xtensor.hpp" + +namespace xt +{ + template ::value, int> = 0> + inline bool sfinae_has_fixed_rank(E&&) + { + return false; + } + + template ::value, int> = 0> + inline bool sfinae_has_fixed_rank(E&&) + { + return true; + } + + TEST(sfinae, fixed_rank) + { + xt::pyarray a = {{9, 9, 9}, {9, 9, 9}}; + xt::pytensor b = {9, 9}; + xt::pytensor c = {{9, 9}, {9, 9}}; + + EXPECT_TRUE(sfinae_has_fixed_rank(a) == false); + EXPECT_TRUE(sfinae_has_fixed_rank(b) == true); + EXPECT_TRUE(sfinae_has_fixed_rank(c) == true); + } + + TEST(sfinae, get_rank) + { + xt::pytensor A = xt::zeros({2}); + xt::pytensor B = xt::zeros({2, 2}); + xt::pyarray C = xt::zeros({2, 2}); + + EXPECT_TRUE(xt::get_rank::value == 1ul); + EXPECT_TRUE(xt::get_rank::value == 2ul); + EXPECT_TRUE(xt::get_rank::value == SIZE_MAX); + } +}