From 2af1f403a3db0ee0bc40cf8d7d4b637239d4ea79 Mon Sep 17 00:00:00 2001 From: zhujun98 Date: Tue, 6 Aug 2019 20:47:47 +0200 Subject: [PATCH] Fix pytensor overload resolution in pybind11. --- include/xtensor-python/pytensor.hpp | 9 ++++++++- test_python/main.cpp | 3 +++ test_python/test_pyarray.py | 15 ++++++++++++--- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index b7886f6..3c3b91b 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -58,7 +58,14 @@ namespace pybind11 } } - value = type::ensure(src); + try + { + value = type::ensure(src); + } + catch (const std::runtime_error&) + { + return false; + } return static_cast(value); } diff --git a/test_python/main.cpp b/test_python/main.cpp index 4b3343a..2f4a0df 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -253,4 +253,7 @@ PYBIND11_MODULE(xtensor_python_test, m) m.def("simple_array", [](xt::pyarray) { return 1; } ); m.def("simple_tensor", [](xt::pytensor) { return 2; } ); + + m.def("diff_shape_overload", [](xt::pytensor a) { return 1; }); + m.def("diff_shape_overload", [](xt::pytensor a) { return 2; }); } diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index e70a2fa..3d8f0af 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -135,13 +135,13 @@ def test_col_row_major(self): with self.assertRaises(RuntimeError): xt.col_major_array(var) - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): xt.row_major_tensor(var.T) - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): xt.row_major_tensor(var[:, ::2, ::2]) - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): # raise for wrong dimension xt.row_major_tensor(var[0, 0, :]) @@ -157,6 +157,15 @@ def test_bad_argument_call(self): with self.assertRaises(TypeError): xt.simple_tensor("foo") + def test_diff_shape_overload(self): + self.assertEqual(1, xt.diff_shape_overload(np.ones(2))) + self.assertEqual(2, xt.diff_shape_overload(np.ones((2, 2)))) + + with self.assertRaises(TypeError): + # FIXME: the TypeError information is not informative + xt.diff_shape_overload(np.ones((2, 2, 2))) + + class AttributeTest(TestCase): def setUp(self):