Skip to content

Commit 025d1e5

Browse files
committed
Fix pytensor overload resolution in pybind11.
1 parent 0a8a00b commit 025d1e5

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

include/xtensor-python/pytensor.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,14 @@ namespace pybind11
5858
}
5959
}
6060

61-
value = type::ensure(src);
61+
try
62+
{
63+
value = type::ensure(src);
64+
}
65+
catch (const std::runtime_error&)
66+
{
67+
return false;
68+
}
6269
return static_cast<bool>(value);
6370
}
6471

test_python/main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,4 +253,7 @@ PYBIND11_MODULE(xtensor_python_test, m)
253253

254254
m.def("simple_array", [](xt::pyarray<int>) { return 1; } );
255255
m.def("simple_tensor", [](xt::pytensor<int, 1>) { return 2; } );
256+
257+
m.def("diff_shape_overload", [](xt::pytensor<int, 1> a) { return 1; });
258+
m.def("diff_shape_overload", [](xt::pytensor<int, 2> a) { return 2; });
256259
}

test_python/test_pyarray.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ def test_col_row_major(self):
135135
with self.assertRaises(RuntimeError):
136136
xt.col_major_array(var)
137137

138-
with self.assertRaises(RuntimeError):
138+
with self.assertRaises(TypeError):
139139
xt.row_major_tensor(var.T)
140140

141-
with self.assertRaises(RuntimeError):
141+
with self.assertRaises(TypeError):
142142
xt.row_major_tensor(var[:, ::2, ::2])
143143

144-
with self.assertRaises(RuntimeError):
144+
with self.assertRaises(TypeError):
145145
# raise for wrong dimension
146146
xt.row_major_tensor(var[0, 0, :])
147147

@@ -157,6 +157,15 @@ def test_bad_argument_call(self):
157157
with self.assertRaises(TypeError):
158158
xt.simple_tensor("foo")
159159

160+
def test_diff_shape_overload(self):
161+
self.assertEqual(1, xt.diff_shape_overload(np.ones(2)))
162+
self.assertEqual(2, xt.diff_shape_overload(np.ones((2, 2))))
163+
164+
with self.assertRaises(TypeError):
165+
# FIXME: the TypeError information is not informative
166+
xt.diff_shape_overload(np.ones((2, 2, 2)))
167+
168+
160169
class AttributeTest(TestCase):
161170

162171
def setUp(self):

0 commit comments

Comments
 (0)