-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
ENH: Implement take_along_axis as described in #8708 #8714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ece6ee1
to
626ec17
Compare
numpy/lib/shape_base.py
Outdated
] | ||
|
||
def take_along_axis(arr, indices, axis): | ||
""" | ||
Take the elements described by `indices` along each 1-D slice of the given |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First sentence should fit on a single line. Maybe
Take elements from slices indexed along the given axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is I need to disambiguate this from take
, which is
Take elements from an array along an axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should go to the mailing list in any case, maybe take is not even the most natural word in the end, which would somewhat remove the problem ;). The take description kind of only works for 1-D, and in 1-D the two do the same thing, so its a bit of a twist :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't have to be my example, but I do think it should be a single line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're not wrong - I'm just asking for help in coming up with an unambiguous description under that constraint :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only other idea I have right now is to call it a vectorized take/pick (which bites with the vindex idea, but maybe that name is not great anyway, I think someone had suggest broadcasted index there too, which may be more logical anyway -- though that does not necessarily mean easier I guess, hehe).
numpy/lib/shape_base.py
Outdated
|
||
This computes:: | ||
|
||
out[a..., k..., b...] = arr[a..., indices[a..., k..., b...], b...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unify use of notation for indices (e.g., here lowercase a, b; below upper case; ideally use the standard "integers", i.e., i--n). If possible, do use axis
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'm using the convention that a = range(0, A)
, ie the capital letters are the shape, and the lowercase ones indices for that dimension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no way I can use axis
directly here, short of some ascii art pointing to the middle index on the right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could use out[i..., j..., k...] = arr[i..., indices[i..., j..., k...], k...]
, and then Ni
, Nk
, Nk
further down?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the i,j,k
, Ni,Nj,Nk
, or perhaps i1, i2, i3
, N1,N2,N3
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a fixup commit to apply this. I'll squash once everything else is approved
numpy/lib/shape_base.py
Outdated
array([30, 60]) | ||
>>> ai = np.argmax(a, axis=1); ai | ||
array([1, 0], dtype=int64) | ||
>>> np.take_along_axis(a, ai, axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add an example where one keeps the dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, I think keeping dimensions is out of scope here, and belongs in #8710.
"keeping the dimension" is something that can be done either before or after take_along_axis
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle we could think about making keepdims (well, kind of the inverse) a kwarg here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you are lazy about a C version, too bad ;P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seberg: What would it do though? Dimensions are already kept, in that out.ndim == indices.ndim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But there is no problem with argmax needing an expand dims, since we should just add a keepdims?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - it's inconvenient right now to have to call np.expand_dims
, but that's really a bug of not having keepdims
in argmethods. And as for calling squeeze on the return value (vs what I proposed before) - you're likely in for a bad time if you start squeezing axes in the middle of your multidimensional data anyway.
Having assert(ind.ndim == x.ndim)
still gives us plenty of freedom to come up with better semantics in future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another example of how this restriction is sufficient (with keepdims
) - indexing the brightest pixel:
assert img.shape == (400, 300, 3)
brightness = np.sum(img, axis=2, keepdims=True) # again - always keep your dims!
argbrightrow = np.argmax(brightness, axis=1, keepdims=True)
brightest_by_row = np.take_along_axis(img, argbrightrow, axis=1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having assert(ind.ndim == x.ndim) still gives us plenty of freedom to come up with better semantics in future
We could remove this and add the default broadcasting behaviour, but this risks confusing users who try to do np.take_along_axis(a, a.argmax(axis=1),axis=1)
,(no keepdim
) which would allow this to sometimes silently do the wrong thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for now, I guess lets just say we don't put it in, and put it up for discussion on the mailing list. If anyone can present a good reason/usecase especially for the way you first described it, I am good with allowing it. At least the 0-D case (dim is missing), it seems rather intuitive after all....
numpy/lib/shape_base.py
Outdated
|
||
>>> np.max(a, axis=1) | ||
array([30, 60]) | ||
>>> ai = np.argmax(a, axis=1); ai |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not use ;
in example code; just make another line...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty par for the course for numpy docstrings - grep for >>> (\w+)\s*=.*;\s*\1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad practice anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved
else: | ||
axis = normalize_axis_index(axis, arr.ndim) | ||
if not _nx.issubdtype(indices.dtype, _nx.integer): | ||
raise IndexError('arrays used as indices must be of integer type') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a TypeError
, I think (at least, [1, 2][1.]
raises TypeError
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just copying np.array([1, 2])[np.array(1.)]
here, which gives IndexError
. This is just that error message, but without the bit about booleans.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bools are not considered ints here right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, np.issubdtype(np.bool_, np.integer)
is false. There's a test for this error in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, should have tried with arrays; not very logical but best to stick with numpy practice here.
numpy/lib/tests/test_shape_base.py
Outdated
) | ||
from numpy.testing import ( | ||
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal, | ||
assert_raises, assert_warns | ||
) | ||
|
||
|
||
class TestTakeAlongAxis(TestCase): | ||
def test_argequivalent(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be sure, also add the keepdims
versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd argue for postponing that to #8710 as well, and then no special handling would be needed.
I have added a test to verify that expand_dims
works before and after though, which is sort of the same thing
Does this last commit belong in this PR? |
I think this is very nice. Generally, I think it is better to keep one PR to one logical commit, so in that sense the Agreed though that this should be passed by the mailing list. |
1f93135
to
1296473
Compare
I'll send something out to the mailing list once my repeat confirmation email arrives and I actually remember to click on it during the 3-day window. |
☔ The latest upstream changes (presumably #8795) made this pull request unmergeable. Please resolve the merge conflicts. |
1296473
to
9242bce
Compare
☔ The latest upstream changes (presumably #8847) made this pull request unmergeable. Please resolve the merge conflicts. |
9242bce
to
a2e68cc
Compare
☔ The latest upstream changes (presumably #8886) made this pull request unmergeable. Please resolve the merge conflicts. |
Could the reviewers involved here either sign off on it, merge it, or make a complaint. |
@charris: The ball is in my court, I think - there was indecision about how to deal with broadcasting, with the suggestion of me consulting the mailing list - I have not done so. I think it would be easier to design / put forth a case for this when |
OK, I'll punt. Thanks for the update. |
☔ The latest upstream changes (presumably #9050) made this pull request unmergeable. Please resolve the merge conflicts. |
Still in abeyance. Should I punt this on to 1.15? |
There is no much need for a specific milestone is there? I don't remember this, but for broadcasting, possibly we could do a minimal thing first that can be generalized later if it is too tricky to decide? |
Removed the milestone. |
3f44cd1
to
8b7c244
Compare
I've rebased this just to avoid bitrot down the line, and moved the DOC commit to a separate pr (#9946). If nothing else, getting the doc change to existing functions into 1.14 makes it easier to pitch the new feature for 1.15. |
numpy/lib/shape_base.py
Outdated
|
||
Or equivalently (where `...` alone is the builtin `Ellipsis`): | ||
|
||
out[i..., ..., k...] = arr[i..., :, k...][indices[i..., ..., k...]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the discussion #9946, I suppose this would be better described as
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
for kk in ndindex(Nj):
out[ii + s_[...,] + kk] = a[ii + s_[:,] + kk][indices[ii + s_[...,] + kk]]
Edit: updated
Extracted from numpygh-8714 [ci-skip]
Hi. I am quite new to I wanted to do similar thing for a tree reduction, and came up with an implementation for The shape of the result is same as shape of the index. take_along_axisdef take_along_axis(arr, ind, axis):
"""Take elements from an array according to an index along axis.
Parameters
----------
arr : np.ndarray
ind : np.ndarray
Indexing of an array along an axis. This cannot be an int.
For a 2D array arr and axis=1, this means: from each column of
arr, select elements along the column by the indices at the
corresponding column of ind. In other words, the ith column of
the result is arr[ind[:, i], i].
axis : int
Returns
-------
result : np.ndarray
Example
-------
>>> import numpy as np
>>> arr = np.flip(np.arange(8).reshape(2, 4), axis=1) # test array
>>> arr
array([[3, 2, 1, 0],
[7, 6, 5, 4]])
>>> ind = np.argsort(arr, axis=1) # indexing along axis 1
>>> result = take_along_axis(arr, ind, axis=1)
>>> result
array([[0, 1, 2, 3],
[4, 5, 6, 7]])
>>> answer = np.sort(arr, axis=1)
>>> np.all(result == answer)
True
"""
# Does not check if axis or the ind are legal
shape = arr.shape
before = reduce(mul, shape[:axis], 1)
at = shape[axis]
after = arr.size // at // before
a = arr.reshape(before, at, after)
idx = [
np.arange(before).reshape(before, 1, 1),
ind.reshape(before, -1, after),
np.arange(after).reshape(1, 1, after),
]
return a[idx].reshape(ind.shape) |
You'll need to add my fork of the repository (https://github.com/eric-wieser/numpy.git) as a remote, then you should just be able to checkout the Is your goal to compare my implementation with yours? |
Thanks. I found it. I mostly want to learn by looking at another solution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Waiting eagerly for this to get in! Any reason why there's no activity?
# normalize inputs | ||
arr = asanyarray(arr) | ||
indices = asanyarray(indices) | ||
if axis is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add support for arr.ndim > indices.ndim?
It's a single line here (+ unit test + docs):
indices = indices.reshape((1, ) * (arr.ndim - indices.ndim) + indices.shape)
Use case / example:
x = np.array([[5, 3, 2, 8, 1],
[0, 7, 1, 3, 2]])
# Completely arbitrary y = f(x0, x1, ..., xn), embarassingly parallel along axis=-1
# Here we only have x0, but we could have more.
y = x.sum(axis=0)
# Sort the x's, moving the ones that cause the smallest y's to the left
take_along_axis(x, np.argsort(y))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think your use-case is well-motivated. A more explicit way to achieve that would be:
y = x.sum(axis=0, keepdims=True)
take_along_axis(x, np.argsort(y, axis=1), axis=1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The whole point is that with the one-liner addition f(x) can add or remove axes at will (some manual broadcasting required if it replaces or transposes axes, which however happens automatically if you e.g. wrap this in xarray.apply_ufunc
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that in the comments above, we decide that perhaps it's best to not allow any case other than indices.ndim == arr.ndim
, since there's no obvious right choice.
take_along_axis
only really makes sense if you endeavor to keep all your axes aligned. xarray
can probably solve that by axis names alone, but in numpy you need to indicate that by axis position. Therefore, you can't afford to let your axes collapse, and have numpy guess which one you lost: in your case, you're advocating for it to guess the left-most one should be reinserted - but this is only the case because you did sum(axis=0)
.
return tuple(fancy_index) | ||
|
||
|
||
def take_along_axis(arr, indices, axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to axis=-1 like 99% of other numpy functions?
return arr[_make_along_axis_idx(arr, indices, axis)] | ||
|
||
|
||
def put_along_axis(arr, indices, values, axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to axis=-1 like 99% of other numpy functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish that were true. concatenate
defaults to axis=0
, and other functions default to axis=None
. Since the axis is a key part of the function, it seems best just to require it.
doc/release/1.14.0-notes.rst
Outdated
@@ -214,6 +214,20 @@ Chebyshev points of the first kind. A new ``Chebyshev.interpolate`` class | |||
method adds support for interpolation over arbitrary intervals using the scaled | |||
and shifted Chebyshev points of the first kind. | |||
|
|||
New ``np.take_along_axis`` and ``np.put_along_axis`` functions |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
numpy/lib/shape_base.py
Outdated
|
||
arr[i..., :, k...][indices[i..., ..., k...]] = values[i..., ..., k...] | ||
|
||
.. versionadded:: 1.13.0 |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
numpy/lib/shape_base.py
Outdated
|
||
out[i..., ..., k...] = arr[i..., :, k...][indices[i..., ..., k...]] | ||
|
||
.. versionadded:: 1.13.0 |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
679a857
to
162ed85
Compare
This is the reduced version that does not allow any insertion of extra dimensions
…h apply_along_axis
a5cc638
to
4eec0ce
Compare
I've split up the commits into one supporting restricted and obvious broadcasting, and one that adds the less obvious behavior that matches The first commit stands alone at #11105, and I think we should focus on getting minimal functionality in before trying to come up with non-obvious extensions. |
superceded by #11105, which was merged |
Edit: Superceded by the simpler #11105
See #8708 and earlier issues linked there for discussion of the need for this function.
Let's keep discussion here to the implementation.
In future, it would be nice to implement this with npyiter in C code for speed, but this is a good starting point, and likely just as fast as what is currently being used in the wild.