-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
MAINT: Simplify block implementation #9667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+110
−148
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
a5c6f0d
Simplify block implementation
j-towns 2dcc9aa
np.block style improvements
j-towns e787a9f
Reflect asanyarray behaviour in block
j-towns 95adb77
Add empty list comment to block depth check
j-towns 07a3f43
Use strict type checking (not isinstance)
j-towns 19fc68c
Re-add tuple type-check comment
j-towns 997ac2c
Re-add `atleast_nd` function.
j-towns 7eb1044
Add detailed comment to _block_check_depths_match
j-towns ff7f726
Extend comments _block_check_depths_match
j-towns 6ecd2b4
Try not recomputing list_ndim
j-towns 1211b70
Slight simplification to logic
j-towns ffa6cf6
Add two tests for different arr_ndims
j-towns 3ed6936
Simplify further - matching docstring logic
j-towns 5f9f1fa
Rename list_ndim to max_depth
j-towns 8a83a5f
rm extra line from near top of shape_base
j-towns 5a0557a
Pre-calculate max array ndim
j-towns c2b5be5
Further slight simplifications
j-towns bd6729d
Update block docstrings
j-towns ad278f3
Fix python 3.4 sequence error
j-towns 2c1734b
Correct empty list ndim
j-towns eaddf39
Avoid using zip(*...) syntax
j-towns a5cbc93
Use builtin next method
j-towns a691f2d
Rm unnecessary enumerate
j-towns File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -365,78 +365,93 @@ def stack(arrays, axis=0, out=None): | |
return _nx.concatenate(expanded_arrays, axis=axis, out=out) | ||
|
||
|
||
class _Recurser(object): | ||
def _block_check_depths_match(arrays, parent_index=[]): | ||
""" | ||
Utility class for recursing over nested iterables | ||
Recursive function checking that the depths of nested lists in `arrays` | ||
all match. Mismatch raises a ValueError as described in the block | ||
docstring below. | ||
|
||
The entire index (rather than just the depth) needs to be calculated | ||
for each innermost list, in case an error needs to be raised, so that | ||
the index of the offending list can be printed as part of the error. | ||
|
||
The parameter `parent_index` is the full index of `arrays` within the | ||
nested lists passed to _block_check_depths_match at the top of the | ||
recursion. | ||
The return value is a pair. The first item returned is the full index | ||
of an element (specifically the first element) from the bottom of the | ||
nesting in `arrays`. An empty list at the bottom of the nesting is | ||
represented by a `None` index. | ||
The second item is the maximum of the ndims of the arrays nested in | ||
`arrays`. | ||
""" | ||
def __init__(self, recurse_if): | ||
self.recurse_if = recurse_if | ||
|
||
def map_reduce(self, x, f_map=lambda x, **kwargs: x, | ||
f_reduce=lambda x, **kwargs: x, | ||
f_kwargs=lambda **kwargs: kwargs, | ||
**kwargs): | ||
""" | ||
Iterate over the nested list, applying: | ||
* ``f_map`` (T -> U) to items | ||
* ``f_reduce`` (Iterable[U] -> U) to mapped items | ||
|
||
For instance, ``map_reduce([[1, 2], 3, 4])`` is:: | ||
|
||
f_reduce([ | ||
f_reduce([ | ||
f_map(1), | ||
f_map(2) | ||
]), | ||
f_map(3), | ||
f_map(4) | ||
]]) | ||
|
||
|
||
State can be passed down through the calls with `f_kwargs`, | ||
to iterables of mapped items. When kwargs are passed, as in | ||
``map_reduce([[1, 2], 3, 4], **kw)``, this becomes:: | ||
|
||
kw1 = f_kwargs(**kw) | ||
kw2 = f_kwargs(**kw1) | ||
f_reduce([ | ||
f_reduce([ | ||
f_map(1), **kw2) | ||
f_map(2, **kw2) | ||
], **kw1), | ||
f_map(3, **kw1), | ||
f_map(4, **kw1) | ||
]], **kw) | ||
""" | ||
def f(x, **kwargs): | ||
if not self.recurse_if(x): | ||
return f_map(x, **kwargs) | ||
else: | ||
next_kwargs = f_kwargs(**kwargs) | ||
return f_reduce(( | ||
f(xi, **next_kwargs) | ||
for xi in x | ||
), **kwargs) | ||
return f(x, **kwargs) | ||
|
||
def walk(self, x, index=()): | ||
""" | ||
Iterate over x, yielding (index, value, entering), where | ||
|
||
* ``index``: a tuple of indices up to this point | ||
* ``value``: equal to ``x[index[0]][...][index[-1]]``. On the first iteration, is | ||
``x`` itself | ||
* ``entering``: bool. The result of ``recurse_if(value)`` | ||
""" | ||
do_recurse = self.recurse_if(x) | ||
yield index, x, do_recurse | ||
|
||
if not do_recurse: | ||
return | ||
for i, xi in enumerate(x): | ||
# yield from ... | ||
for v in self.walk(xi, index + (i,)): | ||
yield v | ||
def format_index(index): | ||
idx_str = ''.join('[{}]'.format(i) for i in index if i is not None) | ||
return 'arrays' + idx_str | ||
if type(arrays) is tuple: | ||
# not strictly necessary, but saves us from: | ||
# - more than one way to do things - no point treating tuples like | ||
# lists | ||
# - horribly confusing behaviour that results when tuples are | ||
# treated like ndarray | ||
raise TypeError( | ||
'{} is a tuple. ' | ||
'Only lists can be used to arrange blocks, and np.block does ' | ||
'not allow implicit conversion from tuple to ndarray.'.format( | ||
format_index(parent_index) | ||
) | ||
) | ||
elif type(arrays) is list and len(arrays) > 0: | ||
idxs_ndims = (_block_check_depths_match(arr, parent_index + [i]) | ||
for i, arr in enumerate(arrays)) | ||
|
||
first_index, max_arr_ndim = next(idxs_ndims) | ||
for index, ndim in idxs_ndims: | ||
if ndim > max_arr_ndim: | ||
max_arr_ndim = ndim | ||
if len(index) != len(first_index): | ||
raise ValueError( | ||
"List depths are mismatched. First element was at depth " | ||
"{}, but there is an element at depth {} ({})".format( | ||
len(first_index), | ||
len(index), | ||
format_index(index) | ||
) | ||
) | ||
return first_index, max_arr_ndim | ||
elif type(arrays) is list and len(arrays) == 0: | ||
# We've 'bottomed out' on an empty list | ||
return parent_index + [None], 0 | ||
else: | ||
# We've 'bottomed out' - arrays is either a scalar or an array | ||
return parent_index, _nx.ndim(arrays) | ||
|
||
|
||
def _block(arrays, max_depth, result_ndim): | ||
""" | ||
Internal implementation of block. `arrays` is the argument passed to | ||
block. `max_depth` is the depth of nested lists within `arrays` and | ||
`result_ndim` is the greatest of the dimensions of the arrays in | ||
`arrays` and the depth of the lists in `arrays` (see block docstring | ||
for details). | ||
""" | ||
def atleast_nd(a, ndim): | ||
# Ensures `a` has at least `ndim` dimensions by prepending | ||
# ones to `a.shape` as necessary | ||
return array(a, ndmin=ndim, copy=False, subok=True) | ||
|
||
def block_recursion(arrays, depth=0): | ||
if depth < max_depth: | ||
if len(arrays) == 0: | ||
raise ValueError('Lists cannot be empty') | ||
arrs = [block_recursion(arr, depth+1) for arr in arrays] | ||
return _nx.concatenate(arrs, axis=-(max_depth-depth)) | ||
else: | ||
# We've 'bottomed out' - arrays is either a scalar or an array | ||
# type(arrays) is not list | ||
return atleast_nd(arrays, result_ndim) | ||
|
||
return block_recursion(arrays) | ||
|
||
|
||
def block(arrays): | ||
|
@@ -587,81 +602,6 @@ def block(arrays): | |
|
||
|
||
""" | ||
def atleast_nd(x, ndim): | ||
x = asanyarray(x) | ||
diff = max(ndim - x.ndim, 0) | ||
return x[(None,)*diff + (Ellipsis,)] | ||
|
||
def format_index(index): | ||
return 'arrays' + ''.join('[{}]'.format(i) for i in index) | ||
|
||
rec = _Recurser(recurse_if=lambda x: type(x) is list) | ||
|
||
# ensure that the lists are all matched in depth | ||
list_ndim = None | ||
any_empty = False | ||
for index, value, entering in rec.walk(arrays): | ||
if type(value) is tuple: | ||
# not strictly necessary, but saves us from: | ||
# - more than one way to do things - no point treating tuples like | ||
# lists | ||
# - horribly confusing behaviour that results when tuples are | ||
# treated like ndarray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for removing this comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I actually hadn't meant to delete that, will re-include it. |
||
raise TypeError( | ||
'{} is a tuple. ' | ||
'Only lists can be used to arrange blocks, and np.block does ' | ||
'not allow implicit conversion from tuple to ndarray.'.format( | ||
format_index(index) | ||
) | ||
) | ||
if not entering: | ||
curr_depth = len(index) | ||
elif len(value) == 0: | ||
curr_depth = len(index) + 1 | ||
any_empty = True | ||
else: | ||
continue | ||
|
||
if list_ndim is not None and list_ndim != curr_depth: | ||
raise ValueError( | ||
"List depths are mismatched. First element was at depth {}, " | ||
"but there is an element at depth {} ({})".format( | ||
list_ndim, | ||
curr_depth, | ||
format_index(index) | ||
) | ||
) | ||
list_ndim = curr_depth | ||
|
||
# do this here so we catch depth mismatches first | ||
if any_empty: | ||
raise ValueError('Lists cannot be empty') | ||
|
||
# convert all the arrays to ndarrays | ||
arrays = rec.map_reduce(arrays, | ||
f_map=asanyarray, | ||
f_reduce=list | ||
) | ||
|
||
# determine the maximum dimension of the elements | ||
elem_ndim = rec.map_reduce(arrays, | ||
f_map=lambda xi: xi.ndim, | ||
f_reduce=max | ||
) | ||
ndim = max(list_ndim, elem_ndim) | ||
|
||
# first axis to concatenate along | ||
first_axis = ndim - list_ndim | ||
|
||
# Make all the elements the same dimension | ||
arrays = rec.map_reduce(arrays, | ||
f_map=lambda xi: atleast_nd(xi, ndim), | ||
f_reduce=list | ||
) | ||
|
||
# concatenate innermost lists on the right, outermost on the left | ||
return rec.map_reduce(arrays, | ||
f_reduce=lambda xs, axis: _nx.concatenate(list(xs), axis=axis), | ||
f_kwargs=lambda axis: dict(axis=axis+1), | ||
axis=first_axis | ||
) | ||
bottom_index, arr_ndim = _block_check_depths_match(arrays) | ||
list_ndim = len(bottom_index) | ||
return _block(arrays, list_ndim, max(arr_ndim, list_ndim)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't
-(max_depth-depth)
just be(depth-max_depth)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also slight personal preference to having spaces around binary operators for readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would, yeah. The reason I wrote it that way round was because I wanted to match the docstring as closely as possible (my original motivation for this pr was to make it a bit clearer to people reading the code how block actually worked). The docstring says:
To me the correspondence between the docstring and the code is ever-so-slightly clearer with the expression the way round that it currently is, and I think the effect on performance is probably negligable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I often instinctively shorten things when they're being fed to a keyword argument, because you don't normally put spaces around the
=
sign. Don'tand
look a bit weird?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The style point is minor. Though there are many tools (e.g.
flake8
) that consider not having spaces around binary operators an error regardless of context. Normally this is argued from the standpoint of readability (including in PEP8 when it was introduced). Hence why I mentioned it. Adding parentheses around the keyword argument's value help guide the eye when following these style tools' recommendations. Not strongly attached to this styling nit though.