Skip to content

ENH: Adding where for argmin #21625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed

Conversation

m10an
Copy link
Contributor

@m10an m10an commented May 28, 2022

Addresses #14371

@m10an
Copy link
Contributor Author

m10an commented May 28, 2022

I'm not sure about np.nanargmin, np.nanargmax and MaskedArray's argmin, argmax.
Are initial and where needed in those functions?

@seberg seberg added the 62 - Python API Changes or additions to the Python API. Mailing list should usually be notified. label May 28, 2022
Copy link
Member

@seberg seberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this is a pretty thorough start! We need to pass this through the mailing list as an API decision at some point. I suspect adding where= should be pretty straight forward since it matches max (and reduce-like operations in general).

OTOH, I don't have any intuition for what initial= means. initial= is normally the starting value! np.max([1, 2], initial=100) makes sense: it returns 100, since that is larger than all the others.
But using initial as a "fill" value has a very different meaning. The only way I could make sense of it would be:

np.argmax([1, 2], initial=100)

returning some special value like -1 (indicating that initial was the largest value. But that feels like it may be too special, at least unless someone comes around with a clear real-world use case. (It is much easier/better to do API decisions with a specific use-case, rather than just filling apparent holes in the API)

There are a couple of other issues that I suspect exist that would need fixing here however:

  • where=True should be the default and be optimized out.
  • where=arr[::2] i.e. non-contiguous arrays must work and be tested! (it does not look like they will?)
  • where must work if shapes mismatch, as long as it can be broadcast to the input.

Comment on lines 1089 to 1098
@overload
def argmax(
self,
axis: None = ...,
out: None = ...,
*,
keepdims: bool = ...,
initial: _ScalarLike_co = ...,
where: _ArrayLikeBool_co = ...,
) -> intp: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for any additional overloads here (or anywhere in this PR), as neither initial nor where affect the output type, dtype or shape. In this case you can simply add the new parameters to the existing overloads:

    @overload
    def argmax(
        self,
        axis: None = ...,
        out: None = ...,
        *,
        keepdims: bool = ...,
+       initial: _ScalarLike_co = ...,
+       where: _ArrayLikeBool_co = ...,
    ) -> intp: ...

@BvB93 BvB93 linked an issue May 30, 2022 that may be closed by this pull request
m10an added 4 commits May 31, 2022 21:37
As neither `initial` nor `where` affect  the output type,
dtype or shape, they simply should be added to the existing overloads.
@m10an
Copy link
Contributor Author

m10an commented May 31, 2022

@seberg I totally agree with counter intuitive initial= argument. I just followed example, and realised only while implementing it :)

But where= I would use as simple masking, and in case of zero mask raise

ValueError: attempt to get argmax of an empty sequence

And leave initial= as parameter of reduce-functions family

@m10an m10an requested a review from BvB93 May 31, 2022 19:31
@m10an
Copy link
Contributor Author

m10an commented Jul 7, 2022

I've been leisurely working on major issues pointed out by @seberg (non-contiguous case and mismatched shapes).

First one is pretty straight forward (I believe), since there were already a piece for forcing alignment

wp = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)where, NPY_BOOL, 0, 0);

But the second one appeared trickier for me. I decided to use NpyIter and to not reinvent wheel I thought.
My first test went smoothly (test_masked), but the second one (test_masked_2d) exits with segmentation fault (core dumped) because of calling of NpyIter_MultiNew here.

I've tried to recreate such case using np.nditer:

import numpy as np

method = 'max'
n = 5
a = np.zeros([2, n], dtype=int)
where = np.ones([2, n], dtype=bool)
value = getattr(np.iinfo(a.dtype), method)

a[:, 0] = value
a[:, n - 1] = value

arg_method = getattr(a, 'arg' + method)
mask_args = dict(initial=0, where=where)
where[0, 0] = False

it = np.nditer(
    [a, where], 
    ['multi_index'],
    [['readonly'], ['readonly']], casting='no', order='K'
)
with it:
    while not it.finished:
        it.debug_print()
        print(it[0], it[1])
        it.iternext()

But that didn't help me much...

The last straw was that gdb suddenly refused to show lines and step through code...

Thread 1 "python" hit Breakpoint 1, 0x00007ffff6d0f0f0 in _PyArray_ArgMinMaxCommon () from /home/ivan/proj/m10an/numpy/numpy/core/_multiarray_umath.cpython-39-x86_64-linux-gnu.so
(gdb) list
1       /usr/local/src/conda/python-3.9.13/Programs/python.c: No such file or directory.
(gdb) s
Single stepping until exit from function _PyArray_ArgMinMaxCommon,
which has no line number information.
PyType_IsSubtype (a=0x7ffff70bb200 <PyArray_Type>, b=0x7ffff70d1d80 <PyGenericArrType_Type>) at /usr/local/src/conda/python-3.9.13/Objects/typeobject.c:1425
1425    /usr/local/src/conda/python-3.9.13/Objects/typeobject.c: No such file or directory.

I hope there is a stupid mistake that I ignore or maybe NpyIter is not best here.
I would appreciate your thought on this.

@m10an m10an requested a review from seberg July 7, 2022 06:34
it_ops[0] = ap;
it_ops[1] = wp;
iter = NpyIter_MultiNew(2, it_ops, 0, NPY_KEEPORDER, NPY_NO_CASTING,
it_opflags, NULL);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will have to pass a couple of flags here, probably: NPY_ITER_EXTERNAL_LOOP and NPY_ITER_ZEROSIZE_OK, maybe also NPY_ITER_REFS_OK. Some of the following code does not necessarily make much sense without the external loop flag.

The crash looks like a PyArray_Check() on some invalid data, but it may be in the code after the new creation. The main error here may be the missing check for error returns on iter and some of the following functions (although I am not certain, it may well be that calling some of these is just invalid without the appropriate flags).

I am surprised you are missing debugging symbols on a local build, but try recompiling with CFLAGS=-g or adding -g to runtests.py if you are using that.

@seberg
Copy link
Member

seberg commented Jan 13, 2024

Closing this, there is maybe a good start here so can be reopened. But it needs work and has not been active for 2 years.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
01 - Enhancement 55 - Needs work 62 - Python API Changes or additions to the Python API. Mailing list should usually be notified.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Adding where for argmin
3 participants