Skip to content

Commit 31e71d7

Browse files
authored
Merge pull request numpy#8662 from eric-wieser/ufunc-outer-subclass
ENH: preserve subclasses in ufunc.outer
2 parents ddd59a9 + e043bb9 commit 31e71d7

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

numpy/core/src/umath/ufunc_object.c

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5396,6 +5396,8 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
53965396
PyArrayObject *ap1 = NULL, *ap2 = NULL, *ap_new = NULL;
53975397
PyObject *new_args, *tmp;
53985398
PyObject *shape1, *shape2, *newshape;
5399+
static PyObject *_numpy_matrix;
5400+
53995401

54005402
errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override);
54015403
if (errval) {
@@ -5428,7 +5430,18 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
54285430
if (tmp == NULL) {
54295431
return NULL;
54305432
}
5431-
ap1 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
5433+
5434+
npy_cache_import(
5435+
"numpy",
5436+
"matrix",
5437+
&_numpy_matrix);
5438+
5439+
if (PyObject_IsInstance(tmp, _numpy_matrix)) {
5440+
ap1 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
5441+
}
5442+
else {
5443+
ap1 = (PyArrayObject *) PyArray_FROM_O(tmp);
5444+
}
54325445
Py_DECREF(tmp);
54335446
if (ap1 == NULL) {
54345447
return NULL;
@@ -5437,7 +5450,12 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
54375450
if (tmp == NULL) {
54385451
return NULL;
54395452
}
5440-
ap2 = (PyArrayObject *)PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
5453+
if (PyObject_IsInstance(tmp, _numpy_matrix)) {
5454+
ap2 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
5455+
}
5456+
else {
5457+
ap2 = (PyArrayObject *) PyArray_FROM_O(tmp);
5458+
}
54415459
Py_DECREF(tmp);
54425460
if (ap2 == NULL) {
54435461
Py_DECREF(ap1);

numpy/core/tests/test_umath.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2924,3 +2924,14 @@ def test_signaling_nan_exceptions():
29242924
with assert_no_warnings():
29252925
a = np.ndarray(shape=(), dtype='float32', buffer=b'\x00\xe0\xbf\xff')
29262926
np.isnan(a)
2927+
2928+
@pytest.mark.parametrize("arr", [
2929+
np.arange(2),
2930+
np.matrix([0, 1]),
2931+
np.matrix([[0, 1], [2, 5]]),
2932+
])
2933+
def test_outer_subclass_preserve(arr):
2934+
# for gh-8661
2935+
class foo(np.ndarray): pass
2936+
actual = np.multiply.outer(arr.view(foo), arr.view(foo))
2937+
assert actual.__class__.__name__ == 'foo'

0 commit comments

Comments
 (0)