Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions include/xtensor-python/pytensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,26 @@ namespace pybind11
}
};

}
} // namespace detail
}

namespace xt
{
namespace detail {

template <std::size_t N>
struct numpy_strides
{
npy_intp value[N];
};

template <>
struct numpy_strides<0>
{
npy_intp* value = nullptr;
};

} // namespace detail

template <class T, std::size_t N, layout_type L>
struct xiterable_inner_types<pytensor<T, N, L>>
Expand Down Expand Up @@ -433,8 +448,8 @@ namespace xt
template <class T, std::size_t N, layout_type L>
inline void pytensor<T, N, L>::init_tensor(const shape_type& shape, const strides_type& strides)
{
npy_intp python_strides[N];
std::transform(strides.begin(), strides.end(), python_strides,
detail::numpy_strides<N> python_strides;
std::transform(strides.begin(), strides.end(), python_strides.value,
[](auto v) { return sizeof(T) * v; });
int flags = NPY_ARRAY_ALIGNED;
if (!std::is_const<T>::value)
Expand All @@ -445,7 +460,7 @@ namespace xt

auto tmp = pybind11::reinterpret_steal<pybind11::object>(
PyArray_NewFromDescr(&PyArray_Type, (PyArray_Descr*) dtype.release().ptr(), static_cast<int>(shape.size()),
const_cast<npy_intp*>(shape.data()), python_strides,
const_cast<npy_intp*>(shape.data()), python_strides.value,
nullptr, flags, nullptr));

if (!tmp)
Expand Down
9 changes: 9 additions & 0 deletions test/test_pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ namespace xt
EXPECT_THROW(pyt3::from_shape(shp), std::runtime_error);
}

TEST(pytensor, scalar_from_shape)
{
std::array<size_t, 0> shape;
auto a = pytensor<double, 0>::from_shape(shape);
pytensor<double, 0> b(1.2);
EXPECT_TRUE(a.size() == b.size());
EXPECT_TRUE(xt::has_shape(a, b.shape()));
}

TEST(pytensor, strided_constructor)
{
central_major_result<container_type> cmr;
Expand Down
7 changes: 7 additions & 0 deletions test_python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ void col_major_array(xt::pyarray<double, xt::layout_type::column_major>& arg)
}
}

xt::pytensor<int, 0> xscalar(const xt::pytensor<int, 1>& arg)
{
return xt::sum(arg);
}

template <class T>
using ndarray = xt::pyarray<T, xt::layout_type::row_major>;

Expand Down Expand Up @@ -285,6 +290,8 @@ PYBIND11_MODULE(xtensor_python_test, m)
m.def("col_major_array", col_major_array);
m.def("row_major_tensor", row_major_tensor);

m.def("xscalar", xscalar);

py::class_<C>(m, "C")
.def(py::init<>())
.def_property_readonly(
Expand Down
4 changes: 4 additions & 0 deletions test_python/test_pyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def test_col_row_major(self):
xt.col_major_array(varF)
xt.col_major_array(varF[:, :, 0]) # still col major!

def test_xscalar(self):
var = np.arange(50, dtype=int)
self.assertTrue(np.sum(var) == xt.xscalar(var))

def test_bad_argument_call(self):
with self.assertRaises(TypeError):
xt.simple_array("foo")
Expand Down