Skip to content

Commit 1be71a7

Browse files
committed
Adding rank member
1 parent d6f87cf commit 1be71a7

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ namespace xt
158158
using inner_shape_type = typename base_type::inner_shape_type;
159159
using inner_strides_type = typename base_type::inner_strides_type;
160160
using inner_backstrides_type = typename base_type::inner_backstrides_type;
161+
constexpr static std::size_t rank = SIZE_MAX;
161162

162163
pyarray();
163164
pyarray(const value_type& t);
@@ -514,7 +515,7 @@ namespace xt
514515
{
515516
return;
516517
}
517-
518+
518519
m_shape = inner_shape_type(reinterpret_cast<size_type*>(PyArray_SHAPE(this->python_array())),
519520
static_cast<size_type>(PyArray_NDIM(this->python_array())));
520521
m_strides = inner_strides_type(reinterpret_cast<difference_type*>(PyArray_STRIDES(this->python_array())),

include/xtensor-python/pytensor.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ namespace xt
168168
using inner_shape_type = typename base_type::inner_shape_type;
169169
using inner_strides_type = typename base_type::inner_strides_type;
170170
using inner_backstrides_type = typename base_type::inner_backstrides_type;
171+
constexpr static std::size_t rank = N;
171172

172173
pytensor();
173174
pytensor(nested_initializer_list_t<T, N> t);
@@ -471,7 +472,7 @@ namespace xt
471472
{
472473
return;
473474
}
474-
475+
475476
if (PyArray_NDIM(this->python_array()) != N)
476477
{
477478
throw std::runtime_error("NumPy: ndarray has incorrect number of dimensions");

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ set(XTENSOR_PYTHON_TESTS
8686
test_pyarray.cpp
8787
test_pytensor.cpp
8888
test_pyvectorize.cpp
89+
test_sfinae.cpp
8990
)
9091

9192
add_executable(test_xtensor_python ${XTENSOR_PYTHON_TESTS} ${XTENSOR_PYTHON_HEADERS})

test/test_sfinae.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/***************************************************************************
2+
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#include <limits>
11+
12+
#include "gtest/gtest.h"
13+
#include "xtensor-python/pytensor.hpp"
14+
#include "xtensor-python/pyarray.hpp"
15+
#include "xtensor/xarray.hpp"
16+
#include "xtensor/xtensor.hpp"
17+
18+
namespace xt
19+
{
20+
template <class E, std::enable_if_t<!xt::has_fixed_rank_t<E>::value, int> = 0>
21+
inline bool sfinae_has_fixed_rank(E&&)
22+
{
23+
return false;
24+
}
25+
26+
template <class E, std::enable_if_t<xt::has_fixed_rank_t<E>::value, int> = 0>
27+
inline bool sfinae_has_fixed_rank(E&&)
28+
{
29+
return true;
30+
}
31+
32+
TEST(sfinae, fixed_rank)
33+
{
34+
xt::pyarray<size_t> a = {{9, 9, 9}, {9, 9, 9}};
35+
xt::pytensor<size_t, 1> b = {9, 9};
36+
xt::pytensor<size_t, 2> c = {{9, 9}, {9, 9}};
37+
38+
EXPECT_TRUE(sfinae_has_fixed_rank(a) == false);
39+
EXPECT_TRUE(sfinae_has_fixed_rank(b) == true);
40+
EXPECT_TRUE(sfinae_has_fixed_rank(c) == true);
41+
}
42+
43+
TEST(sfinae, get_rank)
44+
{
45+
xt::pytensor<double, 1> A = xt::zeros<double>({2});
46+
xt::pytensor<double, 2> B = xt::zeros<double>({2, 2});
47+
xt::pyarray<double> C = xt::zeros<double>({2, 2});
48+
49+
EXPECT_TRUE(xt::get_rank<decltype(A)>::value == 1ul);
50+
EXPECT_TRUE(xt::get_rank<decltype(B)>::value == 2ul);
51+
EXPECT_TRUE(xt::get_rank<decltype(C)>::value == SIZE_MAX);
52+
}
53+
}

0 commit comments

Comments
 (0)