Skip to content

Commit 94ae1c5

Browse files
authored
Merge pull request #13610 from eric-wieser/argwhere
ENH: Always produce a consistent shape in the result of `argwhere`
2 parents e4e12cb + b6a3ee3 commit 94ae1c5

File tree

4 files changed

+48
-9
lines changed

4 files changed

+48
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
``argwhere`` now produces a consistent result on 0d arrays
2+
----------------------------------------------------------
3+
On N-d arrays, `numpy.argwhere` now always produces an array of shape
4+
``(n_non_zero, arr.ndim)``, even when ``arr.ndim == 0``. Previously, the
5+
last axis would have a dimension of 1 in this case.

numpy/core/numeric.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from . import overrides
2828
from . import umath
29+
from . import shape_base
2930
from .overrides import set_module
3031
from .umath import (multiply, invert, sin, PINF, NAN)
3132
from . import numerictypes
@@ -545,16 +546,19 @@ def argwhere(a):
545546
546547
Returns
547548
-------
548-
index_array : ndarray
549+
index_array : (N, a.ndim) ndarray
549550
Indices of elements that are non-zero. Indices are grouped by element.
551+
This array will have shape ``(N, a.ndim)`` where ``N`` is the number of
552+
non-zero items.
550553
551554
See Also
552555
--------
553556
where, nonzero
554557
555558
Notes
556559
-----
557-
``np.argwhere(a)`` is the same as ``np.transpose(np.nonzero(a))``.
560+
``np.argwhere(a)`` is almost the same as ``np.transpose(np.nonzero(a))``,
561+
but produces a result of the correct shape for a 0D array.
558562
559563
The output of ``argwhere`` is not suitable for indexing arrays.
560564
For this purpose use ``nonzero(a)`` instead.
@@ -572,6 +576,11 @@ def argwhere(a):
572576
[1, 2]])
573577
574578
"""
579+
# nonzero does not behave well on 0d, so promote to 1d
580+
if np.ndim(a) == 0:
581+
a = shape_base.atleast_1d(a)
582+
# then remove the added dimension
583+
return argwhere(a)[:,:0]
575584
return transpose(nonzero(a))
576585

577586

numpy/core/shape_base.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from . import numeric as _nx
1111
from . import overrides
12-
from .numeric import array, asanyarray, newaxis
12+
from ._asarray import array, asanyarray
1313
from .multiarray import normalize_axis_index
14+
from . import fromnumeric as _from_nx
1415

1516

1617
array_function_dispatch = functools.partial(
@@ -123,7 +124,7 @@ def atleast_2d(*arys):
123124
if ary.ndim == 0:
124125
result = ary.reshape(1, 1)
125126
elif ary.ndim == 1:
126-
result = ary[newaxis, :]
127+
result = ary[_nx.newaxis, :]
127128
else:
128129
result = ary
129130
res.append(result)
@@ -193,9 +194,9 @@ def atleast_3d(*arys):
193194
if ary.ndim == 0:
194195
result = ary.reshape(1, 1, 1)
195196
elif ary.ndim == 1:
196-
result = ary[newaxis, :, newaxis]
197+
result = ary[_nx.newaxis, :, _nx.newaxis]
197198
elif ary.ndim == 2:
198-
result = ary[:, :, newaxis]
199+
result = ary[:, :, _nx.newaxis]
199200
else:
200201
result = ary
201202
res.append(result)
@@ -435,9 +436,9 @@ def stack(arrays, axis=0, out=None):
435436
# Internal functions to eliminate the overhead of repeated dispatch in one of
436437
# the two possible paths inside np.block.
437438
# Use getattr to protect against __array_function__ being disabled.
438-
_size = getattr(_nx.size, '__wrapped__', _nx.size)
439-
_ndim = getattr(_nx.ndim, '__wrapped__', _nx.ndim)
440-
_concatenate = getattr(_nx.concatenate, '__wrapped__', _nx.concatenate)
439+
_size = getattr(_from_nx.size, '__wrapped__', _from_nx.size)
440+
_ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim)
441+
_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate)
441442

442443

443444
def _block_format_index(index):

numpy/core/tests/test_numeric.py

+24
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,30 @@ def test_no_overwrite(self):
25832583

25842584

25852585
class TestArgwhere(object):
2586+
2587+
@pytest.mark.parametrize('nd', [0, 1, 2])
2588+
def test_nd(self, nd):
2589+
# get an nd array with multiple elements in every dimension
2590+
x = np.empty((2,)*nd, bool)
2591+
2592+
# none
2593+
x[...] = False
2594+
assert_equal(np.argwhere(x).shape, (0, nd))
2595+
2596+
# only one
2597+
x[...] = False
2598+
x.flat[0] = True
2599+
assert_equal(np.argwhere(x).shape, (1, nd))
2600+
2601+
# all but one
2602+
x[...] = True
2603+
x.flat[0] = False
2604+
assert_equal(np.argwhere(x).shape, (x.size - 1, nd))
2605+
2606+
# all
2607+
x[...] = True
2608+
assert_equal(np.argwhere(x).shape, (x.size, nd))
2609+
25862610
def test_2D(self):
25872611
x = np.arange(6).reshape((2, 3))
25882612
assert_array_equal(np.argwhere(x > 1),

0 commit comments

Comments
 (0)