diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 026ad603a500..8a047fddab1c 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -365,78 +365,93 @@ def stack(arrays, axis=0, out=None): return _nx.concatenate(expanded_arrays, axis=axis, out=out) -class _Recurser(object): +def _block_check_depths_match(arrays, parent_index=[]): """ - Utility class for recursing over nested iterables + Recursive function checking that the depths of nested lists in `arrays` + all match. Mismatch raises a ValueError as described in the block + docstring below. + + The entire index (rather than just the depth) needs to be calculated + for each innermost list, in case an error needs to be raised, so that + the index of the offending list can be printed as part of the error. + + The parameter `parent_index` is the full index of `arrays` within the + nested lists passed to _block_check_depths_match at the top of the + recursion. + The return value is a pair. The first item returned is the full index + of an element (specifically the first element) from the bottom of the + nesting in `arrays`. An empty list at the bottom of the nesting is + represented by a `None` index. + The second item is the maximum of the ndims of the arrays nested in + `arrays`. """ - def __init__(self, recurse_if): - self.recurse_if = recurse_if - - def map_reduce(self, x, f_map=lambda x, **kwargs: x, - f_reduce=lambda x, **kwargs: x, - f_kwargs=lambda **kwargs: kwargs, - **kwargs): - """ - Iterate over the nested list, applying: - * ``f_map`` (T -> U) to items - * ``f_reduce`` (Iterable[U] -> U) to mapped items - - For instance, ``map_reduce([[1, 2], 3, 4])`` is:: - - f_reduce([ - f_reduce([ - f_map(1), - f_map(2) - ]), - f_map(3), - f_map(4) - ]]) - - - State can be passed down through the calls with `f_kwargs`, - to iterables of mapped items. When kwargs are passed, as in - ``map_reduce([[1, 2], 3, 4], **kw)``, this becomes:: - - kw1 = f_kwargs(**kw) - kw2 = f_kwargs(**kw1) - f_reduce([ - f_reduce([ - f_map(1), **kw2) - f_map(2, **kw2) - ], **kw1), - f_map(3, **kw1), - f_map(4, **kw1) - ]], **kw) - """ - def f(x, **kwargs): - if not self.recurse_if(x): - return f_map(x, **kwargs) - else: - next_kwargs = f_kwargs(**kwargs) - return f_reduce(( - f(xi, **next_kwargs) - for xi in x - ), **kwargs) - return f(x, **kwargs) - - def walk(self, x, index=()): - """ - Iterate over x, yielding (index, value, entering), where - - * ``index``: a tuple of indices up to this point - * ``value``: equal to ``x[index[0]][...][index[-1]]``. On the first iteration, is - ``x`` itself - * ``entering``: bool. The result of ``recurse_if(value)`` - """ - do_recurse = self.recurse_if(x) - yield index, x, do_recurse - - if not do_recurse: - return - for i, xi in enumerate(x): - # yield from ... - for v in self.walk(xi, index + (i,)): - yield v + def format_index(index): + idx_str = ''.join('[{}]'.format(i) for i in index if i is not None) + return 'arrays' + idx_str + if type(arrays) is tuple: + # not strictly necessary, but saves us from: + # - more than one way to do things - no point treating tuples like + # lists + # - horribly confusing behaviour that results when tuples are + # treated like ndarray + raise TypeError( + '{} is a tuple. ' + 'Only lists can be used to arrange blocks, and np.block does ' + 'not allow implicit conversion from tuple to ndarray.'.format( + format_index(parent_index) + ) + ) + elif type(arrays) is list and len(arrays) > 0: + idxs_ndims = (_block_check_depths_match(arr, parent_index + [i]) + for i, arr in enumerate(arrays)) + + first_index, max_arr_ndim = next(idxs_ndims) + for index, ndim in idxs_ndims: + if ndim > max_arr_ndim: + max_arr_ndim = ndim + if len(index) != len(first_index): + raise ValueError( + "List depths are mismatched. First element was at depth " + "{}, but there is an element at depth {} ({})".format( + len(first_index), + len(index), + format_index(index) + ) + ) + return first_index, max_arr_ndim + elif type(arrays) is list and len(arrays) == 0: + # We've 'bottomed out' on an empty list + return parent_index + [None], 0 + else: + # We've 'bottomed out' - arrays is either a scalar or an array + return parent_index, _nx.ndim(arrays) + + +def _block(arrays, max_depth, result_ndim): + """ + Internal implementation of block. `arrays` is the argument passed to + block. `max_depth` is the depth of nested lists within `arrays` and + `result_ndim` is the greatest of the dimensions of the arrays in + `arrays` and the depth of the lists in `arrays` (see block docstring + for details). + """ + def atleast_nd(a, ndim): + # Ensures `a` has at least `ndim` dimensions by prepending + # ones to `a.shape` as necessary + return array(a, ndmin=ndim, copy=False, subok=True) + + def block_recursion(arrays, depth=0): + if depth < max_depth: + if len(arrays) == 0: + raise ValueError('Lists cannot be empty') + arrs = [block_recursion(arr, depth+1) for arr in arrays] + return _nx.concatenate(arrs, axis=-(max_depth-depth)) + else: + # We've 'bottomed out' - arrays is either a scalar or an array + # type(arrays) is not list + return atleast_nd(arrays, result_ndim) + + return block_recursion(arrays) def block(arrays): @@ -587,81 +602,6 @@ def block(arrays): """ - def atleast_nd(x, ndim): - x = asanyarray(x) - diff = max(ndim - x.ndim, 0) - return x[(None,)*diff + (Ellipsis,)] - - def format_index(index): - return 'arrays' + ''.join('[{}]'.format(i) for i in index) - - rec = _Recurser(recurse_if=lambda x: type(x) is list) - - # ensure that the lists are all matched in depth - list_ndim = None - any_empty = False - for index, value, entering in rec.walk(arrays): - if type(value) is tuple: - # not strictly necessary, but saves us from: - # - more than one way to do things - no point treating tuples like - # lists - # - horribly confusing behaviour that results when tuples are - # treated like ndarray - raise TypeError( - '{} is a tuple. ' - 'Only lists can be used to arrange blocks, and np.block does ' - 'not allow implicit conversion from tuple to ndarray.'.format( - format_index(index) - ) - ) - if not entering: - curr_depth = len(index) - elif len(value) == 0: - curr_depth = len(index) + 1 - any_empty = True - else: - continue - - if list_ndim is not None and list_ndim != curr_depth: - raise ValueError( - "List depths are mismatched. First element was at depth {}, " - "but there is an element at depth {} ({})".format( - list_ndim, - curr_depth, - format_index(index) - ) - ) - list_ndim = curr_depth - - # do this here so we catch depth mismatches first - if any_empty: - raise ValueError('Lists cannot be empty') - - # convert all the arrays to ndarrays - arrays = rec.map_reduce(arrays, - f_map=asanyarray, - f_reduce=list - ) - - # determine the maximum dimension of the elements - elem_ndim = rec.map_reduce(arrays, - f_map=lambda xi: xi.ndim, - f_reduce=max - ) - ndim = max(list_ndim, elem_ndim) - - # first axis to concatenate along - first_axis = ndim - list_ndim - - # Make all the elements the same dimension - arrays = rec.map_reduce(arrays, - f_map=lambda xi: atleast_nd(xi, ndim), - f_reduce=list - ) - - # concatenate innermost lists on the right, outermost on the left - return rec.map_reduce(arrays, - f_reduce=lambda xs, axis: _nx.concatenate(list(xs), axis=axis), - f_kwargs=lambda axis: dict(axis=axis+1), - axis=first_axis - ) + bottom_index, arr_ndim = _block_check_depths_match(arrays) + list_ndim = len(bottom_index) + return _block(arrays, list_ndim, max(arr_ndim, list_ndim)) diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index 5c1e569b7d9a..deb2a407d78c 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -560,6 +560,28 @@ def test_tuple(self): assert_raises_regex(TypeError, 'tuple', np.block, ([1, 2], [3, 4])) assert_raises_regex(TypeError, 'tuple', np.block, [(1, 2), (3, 4)]) + def test_different_ndims(self): + a = 1. + b = 2 * np.ones((1, 2)) + c = 3 * np.ones((1, 1, 3)) + + result = np.block([a, b, c]) + expected = np.array([[[1., 2., 2., 3., 3., 3.]]]) + + assert_equal(result, expected) + + def test_different_ndims_depths(self): + a = 1. + b = 2 * np.ones((1, 2)) + c = 3 * np.ones((1, 2, 3)) + + result = np.block([[a, b], [c]]) + expected = np.array([[[1., 2., 2.], + [3., 3., 3.], + [3., 3., 3.]]]) + + assert_equal(result, expected) + if __name__ == "__main__": run_module_suite()