-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
WIP: top_k draft implementation #26666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Following previous discussion at numpy#15128. I made a small change to the interface in the previous discussion by changing the `mode` keyword into a `largest` bool flag. This follows API such as from [torch.topk](https://pytorch.org/docs/stable/generated/torch.topk.html). Carrying from the previous discussion, a parameter might be useful is `sorted`. This is also implemented in `torch.topk`, and follows from previous work at numpy#19117. Co-authored-by: quarrying
Thanks @JuliaPoo! At first glance this looks quite good - the implementation is easy to understand, and docs, type annotation etc. all LGTM. I experimented a little bit with it, and performance looks as expected: almost all the time is spent in >>> x = np.random.rand(100_000)
>>> %timeit np.top_k(x, 5, largest=False)
759 µs ± 3.95 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
>>> %timeit np.argpartition(x, 5)
758 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
>>> %timeit np.sort(x)[:5]
2.95 ms ± 3.29 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) So far so good. It doesn't seem like there'd be a significant performance or usability benefit, for this implementation at least, of adding more than one function (e.g., |
|
||
A tuple of ``(values, indices)`` is returned, where ``values`` and | ||
``indices`` of the largest/smallest elements of each row of the input | ||
array in the given ``axis``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful to explicitly note the semantics in the presence of NaN values. Is this the same as sort(a)[:k]
/ sort(a)[-k:]
, or it the same as sort(a[~isnan(a)])[:k]
/ sort(a[~isnan(a)])[-k:]
?
Also: does the API make any guarantees about the order of the returned results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With regards to np.nan
, from what I understand, the underlying np.argpartition
is not intentional in how it treats np.nan
. For floats by the nature of how the partial sort is implemented, np.nan
is unintentionally treated like np.inf
since it fails for every comparison with a number. This might change in the future as the underlying implementation changes. Should I add a note that the treatment of np.nan
is not defined?
About the order of the returned results, np.argpartition
by default uses a partial sort which is unstable, so the returned indices is not guaranteed to be the first occurrence of the element. E.g., np.top_k([3,3], 1)
returns (array([3]), array([1]))
. I'll add that as a note.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumPy uses a sort order pushing NaNs to the end consistently. I don't think we should change that.
Now, there is a problem with respect to adding a kwarg to choose a descending sort (which you propose here for top_k
). In that case it might be argued that NaNs should also be sorted to the end!
And if we want that, it would require specific logic to sort in descending order (not just for unstable sorts).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I asked about the order of returned elements, what I had in mind was this:
np.top_k([1, 4, 2, 3], k=2)
As I see it, there are three logically-consistent conventions:
- results are always sorted: return
[3, 4]
- results are always in the order they appear: return
[4, 3]
- order is not guaranteed: return either
[3, 4]
or[4, 3]
It would be helpful to specify in the documentation which of these is the case for NumPy's implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the current implementation returns results that are always sorted. Which seems to me like it's the nicest option for the user. If that falls out naturally, then great. And for other implementations, matching that behavior doesn't seem too costly performance-wise even if their implementation doesn't yield sorted values naturally, because the returned arrays are typically very small compared to the input arrays, so sorting the end result is fast.
Does that sound right to everyone?
@JuliaPoo based on the implementation, do you see a reason that this is hard to guarantee?
The purpose of this PR is to continue several threads of discussion regarding `top_k`. This follows roughly the specifications of `top_k` in data-apis/array-api#722, with slight modifications to the API: ```py def topk( x: array, k: int, /, axis: Optional[int] = None, *, largest: bool = True, ) -> Tuple[array, array]: ... ``` Modifications: - `mode: Literal["largest", "smallest"]` is replaced with `largest: bool` - `axis` is no longer a kw-only arg. This makes `torch.topk` slightly more compatible. The tests implemented here follows the proposed `top_k` implementation at numpy/numpy#26666.
I've drafted a PR to Are there any more concerns regarding this PR? Otherwise, I'll update the docstring to clarify the sort order of |
With #26716 merged, I've updated the docstring for |
d2f3d39
to
85147d6
Compare
Continuing from the discussion here, I've special-cased types in I.e., Other array dtypes that could contain
|
The sorting behavior is looking pretty good. I think the behavior for object arrays not being guaranteed is okay, that is very niche and can be documented as doing whatever |
Another question I'd like to raise is whether complex numbers should be supported at all. It doesn't seem useful, and we have a known issue with the semantics of sorting/comparing complex numbers being ill-defined. To illustrate: >>> # 2+10j is the element with the largest absolute value
>>> x = np.asarray([1.5, 3, 4, 2+10j, np.nan, 9, 9], dtype=np.complex128)
>>> np.top_k(x, 3, largest=False)[0]
array([1.5 +0.j, 2. +10.j, 3. +0.j]) For |
I personally found it a little weird to be sorting complex. Now that you mentioned it, I'm leaning towards |
We have discussed deprecating sorting complex before, and I still wouldn't mind it at all (ISTR an odd |
This is a follow up of the previous discussions at #15128. I've changed the
mode
keyword (as in previous discussion) into alargest
bool flag. This follows API such as that from torch.topk.The API implemented here is:
Carrying forward from the previous discussion, a parameter that might be useful is
sorted
. This is implemented intorch.topk
, and follows from previous work at #19117.