diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index b7886f6..3c3b91b 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -58,7 +58,14 @@ namespace pybind11 } } - value = type::ensure(src); + try + { + value = type::ensure(src); + } + catch (const std::runtime_error&) + { + return false; + } return static_cast(value); } diff --git a/test_python/main.cpp b/test_python/main.cpp index 4b3343a..2f4a0df 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -253,4 +253,7 @@ PYBIND11_MODULE(xtensor_python_test, m) m.def("simple_array", [](xt::pyarray) { return 1; } ); m.def("simple_tensor", [](xt::pytensor) { return 2; } ); + + m.def("diff_shape_overload", [](xt::pytensor a) { return 1; }); + m.def("diff_shape_overload", [](xt::pytensor a) { return 2; }); } diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index e70a2fa..3d8f0af 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -135,13 +135,13 @@ def test_col_row_major(self): with self.assertRaises(RuntimeError): xt.col_major_array(var) - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): xt.row_major_tensor(var.T) - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): xt.row_major_tensor(var[:, ::2, ::2]) - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): # raise for wrong dimension xt.row_major_tensor(var[0, 0, :]) @@ -157,6 +157,15 @@ def test_bad_argument_call(self): with self.assertRaises(TypeError): xt.simple_tensor("foo") + def test_diff_shape_overload(self): + self.assertEqual(1, xt.diff_shape_overload(np.ones(2))) + self.assertEqual(2, xt.diff_shape_overload(np.ones((2, 2)))) + + with self.assertRaises(TypeError): + # FIXME: the TypeError information is not informative + xt.diff_shape_overload(np.ones((2, 2, 2))) + + class AttributeTest(TestCase): def setUp(self):