-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
ENH: Adding where for argmin #21625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Adding where for argmin #21625
Conversation
I'm not sure about |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, this is a pretty thorough start! We need to pass this through the mailing list as an API decision at some point. I suspect adding where=
should be pretty straight forward since it matches max
(and reduce-like operations in general).
OTOH, I don't have any intuition for what initial=
means. initial=
is normally the starting value! np.max([1, 2], initial=100)
makes sense: it returns 100, since that is larger than all the others.
But using initial
as a "fill" value has a very different meaning. The only way I could make sense of it would be:
np.argmax([1, 2], initial=100)
returning some special value like -1
(indicating that initial
was the largest value. But that feels like it may be too special, at least unless someone comes around with a clear real-world use case. (It is much easier/better to do API decisions with a specific use-case, rather than just filling apparent holes in the API)
There are a couple of other issues that I suspect exist that would need fixing here however:
where=True
should be the default and be optimized out.where=arr[::2]
i.e. non-contiguous arrays must work and be tested! (it does not look like they will?)where
must work if shapes mismatch, as long as it can be broadcast to the input.
numpy/__init__.pyi
Outdated
@overload | ||
def argmax( | ||
self, | ||
axis: None = ..., | ||
out: None = ..., | ||
*, | ||
keepdims: bool = ..., | ||
initial: _ScalarLike_co = ..., | ||
where: _ArrayLikeBool_co = ..., | ||
) -> intp: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for any additional overloads here (or anywhere in this PR), as neither initial
nor where
affect the output type, dtype or shape. In this case you can simply add the new parameters to the existing overloads:
@overload
def argmax(
self,
axis: None = ...,
out: None = ...,
*,
keepdims: bool = ...,
+ initial: _ScalarLike_co = ...,
+ where: _ArrayLikeBool_co = ...,
) -> intp: ...
As neither `initial` nor `where` affect the output type, dtype or shape, they simply should be added to the existing overloads.
@seberg I totally agree with counter intuitive But
And leave |
I've been leisurely working on major issues pointed out by @seberg (non-contiguous case and mismatched shapes). First one is pretty straight forward (I believe), since there were already a piece for forcing alignment wp = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)where, NPY_BOOL, 0, 0); But the second one appeared trickier for me. I decided to use I've tried to recreate such case using import numpy as np
method = 'max'
n = 5
a = np.zeros([2, n], dtype=int)
where = np.ones([2, n], dtype=bool)
value = getattr(np.iinfo(a.dtype), method)
a[:, 0] = value
a[:, n - 1] = value
arg_method = getattr(a, 'arg' + method)
mask_args = dict(initial=0, where=where)
where[0, 0] = False
it = np.nditer(
[a, where],
['multi_index'],
[['readonly'], ['readonly']], casting='no', order='K'
)
with it:
while not it.finished:
it.debug_print()
print(it[0], it[1])
it.iternext() But that didn't help me much... The last straw was that gdb suddenly refused to show lines and step through code... Thread 1 "python" hit Breakpoint 1, 0x00007ffff6d0f0f0 in _PyArray_ArgMinMaxCommon () from /home/ivan/proj/m10an/numpy/numpy/core/_multiarray_umath.cpython-39-x86_64-linux-gnu.so
(gdb) list
1 /usr/local/src/conda/python-3.9.13/Programs/python.c: No such file or directory.
(gdb) s
Single stepping until exit from function _PyArray_ArgMinMaxCommon,
which has no line number information.
PyType_IsSubtype (a=0x7ffff70bb200 <PyArray_Type>, b=0x7ffff70d1d80 <PyGenericArrType_Type>) at /usr/local/src/conda/python-3.9.13/Objects/typeobject.c:1425
1425 /usr/local/src/conda/python-3.9.13/Objects/typeobject.c: No such file or directory. I hope there is a stupid mistake that I ignore or maybe |
it_ops[0] = ap; | ||
it_ops[1] = wp; | ||
iter = NpyIter_MultiNew(2, it_ops, 0, NPY_KEEPORDER, NPY_NO_CASTING, | ||
it_opflags, NULL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will have to pass a couple of flags here, probably: NPY_ITER_EXTERNAL_LOOP
and NPY_ITER_ZEROSIZE_OK
, maybe also NPY_ITER_REFS_OK
. Some of the following code does not necessarily make much sense without the external loop flag.
The crash looks like a PyArray_Check()
on some invalid data, but it may be in the code after the new creation. The main error here may be the missing check for error returns on iter
and some of the following functions (although I am not certain, it may well be that calling some of these is just invalid without the appropriate flags).
I am surprised you are missing debugging symbols on a local build, but try recompiling with CFLAGS=-g
or adding -g
to runtests.py
if you are using that.
Closing this, there is maybe a good start here so can be reopened. But it needs work and has not been active for 2 years. |
Addresses #14371