diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 07b7df7266d7..9814a9f6540e 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -337,6 +337,16 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis) if (axis < 0) { axis += ndim; } + + if (ndim == 1 & axis != 0) { + char msg[] = "axis != 0 for ndim == 1; this will raise an error in " + "future versions of numpy"; + if (DEPRECATE(msg) < 0) { + return NULL; + } + axis = 0; + } + if (axis < 0 || axis >= ndim) { PyErr_Format(PyExc_IndexError, "axis %d out of bounds [0, %d)", orig_axis, ndim); diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index 2017ca7a36af..b3f781980231 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -1,7 +1,7 @@ import warnings import numpy as np -from numpy.testing import (TestCase, assert_, assert_raises, assert_equal, - assert_array_equal, run_module_suite) +from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal, + assert_equal, run_module_suite) from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d, vstack, hstack, newaxis, concatenate) @@ -40,6 +40,7 @@ def test_r1array(self): assert_(atleast_1d(3.0).shape == (1,)) assert_(atleast_1d([[2,3],[4,5]]).shape == (2,2)) + class TestAtleast2d(TestCase): def test_0D_array(self): a = array(1); b = array(2); @@ -100,6 +101,7 @@ def test_3D_array(self): desired = [a,b] assert_array_equal(res,desired) + class TestHstack(TestCase): def test_0D_array(self): a = array(1); b = array(2); @@ -119,6 +121,7 @@ def test_2D_array(self): desired = array([[1,1],[2,2]]) assert_array_equal(res,desired) + class TestVstack(TestCase): def test_0D_array(self): a = array(1); b = array(2); @@ -159,5 +162,71 @@ def test_concatenate_axis_None(): '0', '1', '2', 'x']) assert_array_equal(r,d) + +def test_concatenate(): + # Test concatenate function + # No arrays raise ValueError + assert_raises(ValueError, concatenate, ()) + # Scalars cannot be concatenated + assert_raises(ValueError, concatenate, (0,)) + assert_raises(ValueError, concatenate, (array(0),)) + # One sequence returns unmodified (but as array) + r4 = list(range(4)) + assert_array_equal(concatenate((r4,)), r4) + # Any sequence + assert_array_equal(concatenate((tuple(r4),)), r4) + assert_array_equal(concatenate((array(r4),)), r4) + # 1D default concatenation + r3 = list(range(3)) + assert_array_equal(concatenate((r4, r3)), r4 + r3) + # Mixed sequence types + assert_array_equal(concatenate((tuple(r4), r3)), r4 + r3) + assert_array_equal(concatenate((array(r4), r3)), r4 + r3) + # Explicit axis specification + assert_array_equal(concatenate((r4, r3), 0), r4 + r3) + # Including negative + assert_array_equal(concatenate((r4, r3), -1), r4 + r3) + # 2D + a23 = array([[10, 11, 12], [13, 14, 15]]) + a13 = array([[0, 1, 2]]) + res = array([[10, 11, 12], [13, 14, 15], [0, 1, 2]]) + assert_array_equal(concatenate((a23, a13)), res) + assert_array_equal(concatenate((a23, a13), 0), res) + assert_array_equal(concatenate((a23.T, a13.T), 1), res.T) + assert_array_equal(concatenate((a23.T, a13.T), -1), res.T) + # Arrays much match shape + assert_raises(ValueError, concatenate, (a23.T, a13.T), 0) + # 3D + res = arange(2 * 3 * 7).reshape((2, 3, 7)) + a0 = res[..., :4] + a1 = res[..., 4:6] + a2 = res[..., 6:] + assert_array_equal(concatenate((a0, a1, a2), 2), res) + assert_array_equal(concatenate((a0, a1, a2), -1), res) + assert_array_equal(concatenate((a0.T, a1.T, a2.T), 0), res.T) + + +def test_concatenate_sloppy0(): + # Versions of numpy < 1.7.0 ignored axis argument value for 1D arrays. We + # allow this for now, but in due course we will raise an error + r4 = list(range(4)) + r3 = list(range(3)) + assert_array_equal(concatenate((r4, r3), 0), r4 + r3) + warnings.simplefilter('ignore', DeprecationWarning) + try: + assert_array_equal(concatenate((r4, r3), -10), r4 + r3) + assert_array_equal(concatenate((r4, r3), 10), r4 + r3) + finally: + warnings.filters.pop(0) + # Confurm DepractionWarning raised + warnings.simplefilter('always', DeprecationWarning) + warnings.simplefilter('error', DeprecationWarning) + try: + assert_raises(DeprecationWarning, concatenate, (r4, r3), 10) + finally: + warnings.filters.pop(0) + warnings.filters.pop(0) + + if __name__ == "__main__": run_module_suite()