From 9c2a9954e4501cbd9aac176b8fda5f98020bf705 Mon Sep 17 00:00:00 2001 From: Ben Walsh Date: Sun, 10 Jul 2011 12:52:52 +0100 Subject: [PATCH] Stop _array_find_type looking for a supertype of every list element and bool. --- numpy/core/src/multiarray/common.c | 37 ++++++++++++++++-------------- numpy/core/tests/test_datetime.py | 5 ++++ 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c index 9633141320fd..89d4c5f63ac7 100644 --- a/numpy/core/src/multiarray/common.c +++ b/numpy/core/src/multiarray/common.c @@ -110,12 +110,8 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max) goto finish; } - if (minitype == NULL) { - minitype = PyArray_DescrFromType(PyArray_BOOL); - } - else { - Py_INCREF(minitype); - } + Py_XINCREF(minitype); + if (max < 0) { goto deflt; } @@ -237,8 +233,7 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max) PyErr_Clear(); goto deflt; } - if (l == 0 && minitype->type_num == PyArray_BOOL) { - Py_DECREF(minitype); + if (l == 0 && minitype == NULL) { minitype = PyArray_DescrFromType(NPY_DEFAULT_TYPE); if (minitype == NULL) { return NULL; @@ -253,17 +248,21 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max) } chktype = _array_find_type(ip, minitype, max-1); if (chktype == NULL) { - Py_DECREF(minitype); + Py_XDECREF(minitype); return NULL; } - newtype = PyArray_PromoteTypes(chktype, minitype); - Py_DECREF(minitype); - minitype = newtype; - Py_DECREF(chktype); + if (minitype == NULL) { + minitype = chktype; + } else { + newtype = PyArray_PromoteTypes(chktype, minitype); + Py_DECREF(minitype); + minitype = newtype; + Py_DECREF(chktype); + } Py_DECREF(ip); } chktype = minitype; - Py_INCREF(minitype); + minitype = NULL; goto finish; } @@ -272,9 +271,13 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max) chktype = _use_default_type(op); finish: - outtype = PyArray_PromoteTypes(chktype, minitype); - Py_DECREF(chktype); - Py_DECREF(minitype); + if (minitype == NULL) { + outtype = chktype; + } else { + outtype = PyArray_PromoteTypes(chktype, minitype); + Py_DECREF(chktype); + Py_DECREF(minitype); + } if (outtype == NULL) { return NULL; } diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index 271a5dea9785..ea03ab9b3393 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -169,6 +169,11 @@ def test_datetime_scalar_construction(self): assert_raises(TypeError, np.datetime64, datetime.datetime(1920,4,14,13,20), 'D') + def test_datetime_array_find_type(self): + dt = np.datetime64('1970-01-01', 'M') + arr = np.array([dt]) + assert_equal(arr.dtype, np.dtype('M8[M]')) + def test_timedelta_scalar_construction(self): # Construct with different units assert_equal(np.timedelta64(7, 'D'),