Skip to content

WIP: top_k draft implementation #26666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

JuliaPoo
Copy link
Contributor

@JuliaPoo JuliaPoo commented Jun 12, 2024

This is a follow up of the previous discussions at #15128. I've changed the mode keyword (as in previous discussion) into a largest bool flag. This follows API such as that from torch.topk.

The API implemented here is:

def top_k(
    a: ArrayLike,
    k: int,
    /,
    *,
    axis: None | int = ...,
    largest: bool = ...,
) -> tuple[NDArray[Any], NDArray[intp]]: ...

Carrying forward from the previous discussion, a parameter that might be useful is sorted. This is implemented in torch.topk, and follows from previous work at #19117.

Following previous discussion at numpy#15128. I made a small change
to the interface in the previous discussion by changing the
`mode` keyword into a `largest` bool flag. This follows API
such as from
[torch.topk](https://pytorch.org/docs/stable/generated/torch.topk.html).

Carrying from the previous discussion, a parameter might be useful is
`sorted`. This is also implemented in `torch.topk`, and follows from
previous work at numpy#19117.

Co-authored-by: quarrying
@rgommers rgommers self-requested a review June 13, 2024 16:03
@rgommers
Copy link
Member

Thanks @JuliaPoo! At first glance this looks quite good - the implementation is easy to understand, and docs, type annotation etc. all LGTM.

I experimented a little bit with it, and performance looks as expected: almost all the time is spent in np.argpartition and np.top_k is a lot faster than a full sort:

>>> x = np.random.rand(100_000)
>>> %timeit np.top_k(x, 5, largest=False)
759 µs ± 3.95 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
>>> %timeit np.argpartition(x, 5)
758 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
>>> %timeit np.sort(x)[:5]
2.95 ms ± 3.29 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

So far so good. It doesn't seem like there'd be a significant performance or usability benefit, for this implementation at least, of adding more than one function (e.g., top_k_values).


A tuple of ``(values, indices)`` is returned, where ``values`` and
``indices`` of the largest/smallest elements of each row of the input
array in the given ``axis``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to explicitly note the semantics in the presence of NaN values. Is this the same as sort(a)[:k] / sort(a)[-k:], or it the same as sort(a[~isnan(a)])[:k] / sort(a[~isnan(a)])[-k:]?

Also: does the API make any guarantees about the order of the returned results?

Copy link
Contributor Author

@JuliaPoo JuliaPoo Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With regards to np.nan, from what I understand, the underlying np.argpartition is not intentional in how it treats np.nan. For floats by the nature of how the partial sort is implemented, np.nan is unintentionally treated like np.inf since it fails for every comparison with a number. This might change in the future as the underlying implementation changes. Should I add a note that the treatment of np.nan is not defined?

About the order of the returned results, np.argpartition by default uses a partial sort which is unstable, so the returned indices is not guaranteed to be the first occurrence of the element. E.g., np.top_k([3,3], 1) returns (array([3]), array([1])). I'll add that as a note.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumPy uses a sort order pushing NaNs to the end consistently. I don't think we should change that.

Now, there is a problem with respect to adding a kwarg to choose a descending sort (which you propose here for top_k). In that case it might be argued that NaNs should also be sorted to the end!
And if we want that, it would require specific logic to sort in descending order (not just for unstable sorts).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I asked about the order of returned elements, what I had in mind was this:

np.top_k([1, 4, 2, 3], k=2)

As I see it, there are three logically-consistent conventions:

  • results are always sorted: return [3, 4]
  • results are always in the order they appear: return [4, 3]
  • order is not guaranteed: return either [3, 4] or [4, 3]

It would be helpful to specify in the documentation which of these is the case for NumPy's implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the current implementation returns results that are always sorted. Which seems to me like it's the nicest option for the user. If that falls out naturally, then great. And for other implementations, matching that behavior doesn't seem too costly performance-wise even if their implementation doesn't yield sorted values naturally, because the returned arrays are typically very small compared to the input arrays, so sorting the end result is fast.

Does that sound right to everyone?

@JuliaPoo based on the implementation, do you see a reason that this is hard to guarantee?

JuliaPoo added a commit to JuliaPoo/array-api-tests that referenced this pull request Jun 24, 2024
The purpose of this PR is to continue several threads of discussion
regarding `top_k`.

This follows roughly the specifications of `top_k` in
data-apis/array-api#722, with slight modifications to the API:

```py
def topk(
    x: array,
    k: int,
    /,
    axis: Optional[int] = None,
    *,
    largest: bool = True,
) -> Tuple[array, array]:
    ...
```

Modifications:
- `mode: Literal["largest", "smallest"]` is replaced with
`largest: bool`
- `axis` is no longer a kw-only arg. This makes `torch.topk`
slightly more compatible.

The tests implemented here follows the proposed `top_k`
implementation at numpy/numpy#26666.
@JuliaPoo
Copy link
Contributor Author

JuliaPoo commented Jul 1, 2024

I've drafted a PR to array-api-compat: data-apis/array-api-compat#158 that informs us about how compatible numpy's top_k implementation is to other existing implementations. Notably, it is directly compatible with torch.topk.

Are there any more concerns regarding this PR? Otherwise, I'll update the docstring to clarify the sort order of np.nan similar to #26716 and fix the merge conflict.

@JuliaPoo
Copy link
Contributor Author

JuliaPoo commented Jul 2, 2024

With #26716 merged, I've updated the docstring for top_k about the sort order of complex and nans, and added a release note.

@JuliaPoo JuliaPoo force-pushed the issue-15128-topk-feat branch from d2f3d39 to 85147d6 Compare July 2, 2024 07:36
@JuliaPoo
Copy link
Contributor Author

Continuing from the discussion here, I've special-cased types in np.typecodes["AllFloat"] (float16, float32, float64, float64, complex64, complex128, complex128) to push np.nans to the back.

I.e., np.top_k(np.array([1,2,3,np.nan], dtype="f"), 3) will return values [1,2,3] instead of previously [2,3,np.nan].

Other array dtypes that could contain np.nans such as np.array([1,np.nan], dtype=object) will simply use the regular sort order (push nans to the front). The rational behind is that

  • it's too difficult to support the nan behaviour for object arrays
  • it's a case that is probably quite rare (ufuncs like np.isnan doesn't support object arrays anyways)
  • for arbitruary dtypes such as that containing Decimals or a structured array, the onus should be on the user to handle np.nans

@rgommers
Copy link
Member

The sorting behavior is looking pretty good. I think the behavior for object arrays not being guaranteed is okay, that is very niche and can be documented as doing whatever np.sort does (same for custom dtypes). This does the right thing for integers and floating-point dtypes, which is a win.

@rgommers
Copy link
Member

Another question I'd like to raise is whether complex numbers should be supported at all. It doesn't seem useful, and we have a known issue with the semantics of sorting/comparing complex numbers being ill-defined. To illustrate:

>>> # 2+10j is the element with the largest absolute value
>>> x = np.asarray([1.5, 3, 4, 2+10j, np.nan, 9, 9], dtype=np.complex128)
>>> np.top_k(x, 3, largest=False)[0]
array([1.5 +0.j, 2. +10.j, 3.  +0.j])

For np.sort & co we've tended to not both trying to do anything about this, but I question whether we should be introducing support in a new function like this. If a user wants to use top_k on an array with a complex dtype, it seems better to have them use either top_k(x.real) or top_k(x.abs()) and use the returned indices to get the k complex values they want.

@JuliaPoo
Copy link
Contributor Author

I personally found it a little weird to be sorting complex. Now that you mentioned it, I'm leaning towards np.top_k defaulting to np.sort ordering for complex.

@seberg
Copy link
Member

seberg commented Jul 19, 2024

We have discussed deprecating sorting complex before, and I still wouldn't mind it at all (ISTR an odd sort_complex function even for the odd-case where it is used).
Sort order is also used by minimum/maximum, so it seems fine to just do it and it would be fixed when the other places are... But I don't care if you want to explicitly disallow it either.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants