Skip to content

ENH: Implement take_along_axis as described in #8708 #8714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

eric-wieser
Copy link
Member

@eric-wieser eric-wieser commented Feb 28, 2017

Edit: Superceded by the simpler #11105

See #8708 and earlier issues linked there for discussion of the need for this function.

Let's keep discussion here to the implementation.

In future, it would be nice to implement this with npyiter in C code for speed, but this is a good starting point, and likely just as fast as what is currently being used in the wild.

]

def take_along_axis(arr, indices, axis):
"""
Take the elements described by `indices` along each 1-D slice of the given
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First sentence should fit on a single line. Maybe

Take elements from slices indexed along the given axis.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is I need to disambiguate this from take, which is

Take elements from an array along an axis.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should go to the mailing list in any case, maybe take is not even the most natural word in the end, which would somewhat remove the problem ;). The take description kind of only works for 1-D, and in 1-D the two do the same thing, so its a bit of a twist :).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have to be my example, but I do think it should be a single line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're not wrong - I'm just asking for help in coming up with an unambiguous description under that constraint :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only other idea I have right now is to call it a vectorized take/pick (which bites with the vindex idea, but maybe that name is not great anyway, I think someone had suggest broadcasted index there too, which may be more logical anyway -- though that does not necessarily mean easier I guess, hehe).


This computes::

out[a..., k..., b...] = arr[a..., indices[a..., k..., b...], b...]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unify use of notation for indices (e.g., here lowercase a, b; below upper case; ideally use the standard "integers", i.e., i--n). If possible, do use axis directly.

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm using the convention that a = range(0, A), ie the capital letters are the shape, and the lowercase ones indices for that dimension

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no way I can use axis directly here, short of some ascii art pointing to the middle index on the right

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use out[i..., j..., k...] = arr[i..., indices[i..., j..., k...], k...], and then Ni, Nk, Nk further down?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the i,j,k, Ni,Nj,Nk, or perhaps i1, i2, i3, N1,N2,N3.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a fixup commit to apply this. I'll squash once everything else is approved

array([30, 60])
>>> ai = np.argmax(a, axis=1); ai
array([1, 0], dtype=int64)
>>> np.take_along_axis(a, ai, axis=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an example where one keeps the dimension?

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I think keeping dimensions is out of scope here, and belongs in #8710.

"keeping the dimension" is something that can be done either before or after take_along_axis.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we could think about making keepdims (well, kind of the inverse) a kwarg here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are lazy about a C version, too bad ;P

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seberg: What would it do though? Dimensions are already kept, in that out.ndim == indices.ndim.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there is no problem with argmax needing an expand dims, since we should just add a keepdims?

Copy link
Member Author

@eric-wieser eric-wieser Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - it's inconvenient right now to have to call np.expand_dims, but that's really a bug of not having keepdims in argmethods. And as for calling squeeze on the return value (vs what I proposed before) - you're likely in for a bad time if you start squeezing axes in the middle of your multidimensional data anyway.

Having assert(ind.ndim == x.ndim) still gives us plenty of freedom to come up with better semantics in future

Copy link
Member Author

@eric-wieser eric-wieser Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another example of how this restriction is sufficient (with keepdims) - indexing the brightest pixel:

assert img.shape == (400, 300, 3)
brightness = np.sum(img, axis=2, keepdims=True)  # again - always keep your dims!
argbrightrow = np.argmax(brightness, axis=1, keepdims=True)
brightest_by_row = np.take_along_axis(img, argbrightrow, axis=1)

Copy link
Member Author

@eric-wieser eric-wieser Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having assert(ind.ndim == x.ndim) still gives us plenty of freedom to come up with better semantics in future

We could remove this and add the default broadcasting behaviour, but this risks confusing users who try to do np.take_along_axis(a, a.argmax(axis=1),axis=1),(no keepdim) which would allow this to sometimes silently do the wrong thing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for now, I guess lets just say we don't put it in, and put it up for discussion on the mailing list. If anyone can present a good reason/usecase especially for the way you first described it, I am good with allowing it. At least the 0-D case (dim is missing), it seems rather intuitive after all....


>>> np.max(a, axis=1)
array([30, 60])
>>> ai = np.argmax(a, axis=1); ai
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not use ; in example code; just make another line...

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty par for the course for numpy docstrings - grep for >>> (\w+)\s*=.*;\s*\1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad practice anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved

else:
axis = normalize_axis_index(axis, arr.ndim)
if not _nx.issubdtype(indices.dtype, _nx.integer):
raise IndexError('arrays used as indices must be of integer type')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a TypeError, I think (at least, [1, 2][1.] raises TypeError).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copying np.array([1, 2])[np.array(1.)] here, which gives IndexError. This is just that error message, but without the bit about booleans.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bools are not considered ints here right?

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, np.issubdtype(np.bool_, np.integer) is false. There's a test for this error in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, should have tried with arrays; not very logical but best to stick with numpy practice here.

)
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
assert_raises, assert_warns
)


class TestTakeAlongAxis(TestCase):
def test_argequivalent(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, also add the keepdims versions.

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue for postponing that to #8710 as well, and then no special handling would be needed.
I have added a test to verify that expand_dims works before and after though, which is sort of the same thing

@eric-wieser
Copy link
Member Author

Does this last commit belong in this PR?

@mhvk
Copy link
Contributor

mhvk commented Feb 28, 2017

I think this is very nice. Generally, I think it is better to keep one PR to one logical commit, so in that sense the np.ma.median one should not be here, but it is nice to see an immediate use here. (In other words, either way is fine to me.)

Agreed though that this should be passed by the mailing list.

@eric-wieser eric-wieser force-pushed the take_along_axis branch 3 times, most recently from 1f93135 to 1296473 Compare March 9, 2017 19:55
@eric-wieser
Copy link
Member Author

eric-wieser commented Mar 9, 2017

put_along_axis is now in too (with the original semantics, not the proposed ones in above comments). One thing I noticed while doing that is that np.put is not a very good dual to np.take, as it has no axis argument (#8765), and has a peculiar "repeat as necessary" behaviour.

I'll send something out to the mailing list once my repeat confirmation email arrives and I actually remember to click on it during the 3-day window.

@homu
Copy link
Contributor

homu commented Mar 23, 2017

☔ The latest upstream changes (presumably #8795) made this pull request unmergeable. Please resolve the merge conflicts.

@homu
Copy link
Contributor

homu commented Mar 27, 2017

☔ The latest upstream changes (presumably #8847) made this pull request unmergeable. Please resolve the merge conflicts.

@homu
Copy link
Contributor

homu commented Apr 21, 2017

☔ The latest upstream changes (presumably #8886) made this pull request unmergeable. Please resolve the merge conflicts.

@charris
Copy link
Member

charris commented Apr 26, 2017

Could the reviewers involved here either sign off on it, merge it, or make a complaint.

@eric-wieser
Copy link
Member Author

@charris: The ball is in my court, I think - there was indecision about how to deal with broadcasting, with the suggestion of me consulting the mailing list - I have not done so.

I think it would be easier to design / put forth a case for this when argmin and argmax acquire keepdims arguments - so perhaps we should punt it to 1.14

@charris
Copy link
Member

charris commented Apr 26, 2017

OK, I'll punt. Thanks for the update.

@charris charris modified the milestones: 1.14.0 release, 1.13.0 release Apr 26, 2017
@homu
Copy link
Contributor

homu commented May 10, 2017

☔ The latest upstream changes (presumably #9050) made this pull request unmergeable. Please resolve the merge conflicts.

@charris
Copy link
Member

charris commented Oct 17, 2017

Still in abeyance. Should I punt this on to 1.15?

@seberg
Copy link
Member

seberg commented Oct 18, 2017

There is no much need for a specific milestone is there? I don't remember this, but for broadcasting, possibly we could do a minimal thing first that can be generalized later if it is too tricky to decide?

@charris
Copy link
Member

charris commented Oct 22, 2017

Removed the milestone.

@eric-wieser
Copy link
Member Author

I've rebased this just to avoid bitrot down the line, and moved the DOC commit to a separate pr (#9946). If nothing else, getting the doc change to existing functions into 1.14 makes it easier to pitch the new feature for 1.15.


Or equivalently (where `...` alone is the builtin `Ellipsis`):

out[i..., ..., k...] = arr[i..., :, k...][indices[i..., ..., k...]]
Copy link
Member Author

@eric-wieser eric-wieser Nov 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the discussion #9946, I suppose this would be better described as

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
    for kk in ndindex(Nj):
        out[ii + s_[...,] + kk] = a[ii + s_[:,] + kk][indices[ii + s_[...,] + kk]]

Edit: updated

eric-wieser added a commit to eric-wieser/numpy that referenced this pull request Nov 22, 2017
@rzu512
Copy link

rzu512 commented Dec 5, 2017

Hi. I am quite new to git. How to see your implementation for take_along axis, after I git clone this repository?

I wanted to do similar thing for a tree reduction, and came up with an implementation for take_along_axis.

The shape of the result is same as shape of the index.

take_along_axis
def take_along_axis(arr, ind, axis):
    """Take elements from an array according to an index along axis.
    Parameters
    ----------
        arr : np.ndarray
        ind : np.ndarray
            Indexing of an array along an axis. This cannot be an int.
            For a 2D array arr and axis=1, this means: from each column of
            arr, select elements along the column by the indices at the
            corresponding column of ind. In other words, the ith column of
            the result is arr[ind[:, i], i].
        axis : int
    Returns
    -------
        result : np.ndarray
    Example
    -------
        >>> import numpy as np
        >>> arr = np.flip(np.arange(8).reshape(2, 4), axis=1)  # test array
        >>> arr
        array([[3, 2, 1, 0],
               [7, 6, 5, 4]])
        >>> ind = np.argsort(arr, axis=1)  # indexing along axis 1
        >>> result = take_along_axis(arr, ind, axis=1)
        >>> result
        array([[0, 1, 2, 3],
               [4, 5, 6, 7]])
        >>> answer = np.sort(arr, axis=1)
        >>> np.all(result == answer)
        True
    """
    # Does not check if axis or the ind are legal
    shape = arr.shape
    before = reduce(mul, shape[:axis], 1)
    at = shape[axis]
    after = arr.size // at // before
    a = arr.reshape(before, at, after)
    idx = [
        np.arange(before).reshape(before, 1, 1),
        ind.reshape(before, -1, after),
        np.arange(after).reshape(1, 1, after),
    ]
    return a[idx].reshape(ind.shape)

@eric-wieser
Copy link
Member Author

You'll need to add my fork of the repository (https://github.com/eric-wieser/numpy.git) as a remote, then you should just be able to checkout the take_along_axis.

Is your goal to compare my implementation with yours?

@rzu512
Copy link

rzu512 commented Dec 5, 2017

Thanks. I found it.

I mostly want to learn by looking at another solution.

Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting eagerly for this to get in! Any reason why there's no activity?

# normalize inputs
arr = asanyarray(arr)
indices = asanyarray(indices)
if axis is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add support for arr.ndim > indices.ndim?
It's a single line here (+ unit test + docs):

indices = indices.reshape((1, ) * (arr.ndim - indices.ndim) + indices.shape)

Use case / example:

x = np.array([[5, 3, 2, 8, 1],
              [0, 7, 1, 3, 2]])
# Completely arbitrary y = f(x0, x1, ..., xn), embarassingly parallel along axis=-1
# Here we only have x0, but we could have more.
y = x.sum(axis=0)
# Sort the x's, moving the ones that cause the smallest y's to the left
take_along_axis(x, np.argsort(y))

Copy link
Member Author

@eric-wieser eric-wieser May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think your use-case is well-motivated. A more explicit way to achieve that would be:

y = x.sum(axis=0, keepdims=True)
take_along_axis(x, np.argsort(y, axis=1), axis=1)

Copy link
Contributor

@crusaderky crusaderky May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point is that with the one-liner addition f(x) can add or remove axes at will (some manual broadcasting required if it replaces or transposes axes, which however happens automatically if you e.g. wrap this in xarray.apply_ufunc).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in the comments above, we decide that perhaps it's best to not allow any case other than indices.ndim == arr.ndim, since there's no obvious right choice.

take_along_axis only really makes sense if you endeavor to keep all your axes aligned. xarray can probably solve that by axis names alone, but in numpy you need to indicate that by axis position. Therefore, you can't afford to let your axes collapse, and have numpy guess which one you lost: in your case, you're advocating for it to guess the left-most one should be reinserted - but this is only the case because you did sum(axis=0).

return tuple(fancy_index)


def take_along_axis(arr, indices, axis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to axis=-1 like 99% of other numpy functions?

return arr[_make_along_axis_idx(arr, indices, axis)]


def put_along_axis(arr, indices, values, axis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to axis=-1 like 99% of other numpy functions?

Copy link
Member Author

@eric-wieser eric-wieser May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish that were true. concatenate defaults to axis=0, and other functions default to axis=None. Since the axis is a key part of the function, it seems best just to require it.

@@ -214,6 +214,20 @@ Chebyshev points of the first kind. A new ``Chebyshev.interpolate`` class
method adds support for interpolation over arbitrary intervals using the scaled
and shifted Chebyshev points of the first kind.

New ``np.take_along_axis`` and ``np.put_along_axis`` functions

This comment was marked as resolved.


arr[i..., :, k...][indices[i..., ..., k...]] = values[i..., ..., k...]

.. versionadded:: 1.13.0

This comment was marked as resolved.


out[i..., ..., k...] = arr[i..., :, k...][indices[i..., ..., k...]]

.. versionadded:: 1.13.0

This comment was marked as resolved.

@eric-wieser eric-wieser force-pushed the take_along_axis branch 2 times, most recently from 679a857 to 162ed85 Compare May 16, 2018 07:35
This is the reduced version that does not allow any insertion of extra dimensions
@eric-wieser
Copy link
Member Author

I've split up the commits into one supporting restricted and obvious broadcasting, and one that adds the less obvious behavior that matches apply_along_axis.

The first commit stands alone at #11105, and I think we should focus on getting minimal functionality in before trying to come up with non-obvious extensions.

@mhvk
Copy link
Contributor

mhvk commented May 29, 2018

superceded by #11105, which was merged

@mhvk mhvk closed this May 29, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants