Skip to content

Commit ef4885f

Browse files
authored
MNT np.nan_to_num -> xpx.nan_to_num (scikit-learn#32033)
1 parent 00acd12 commit ef4885f

File tree

17 files changed

+279
-71
lines changed

17 files changed

+279
-71
lines changed

maint_tools/vendor_array_api_extra.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -o nounset
66
set -o errexit
77

88
URL="https://github.com/data-apis/array-api-extra.git"
9-
VERSION="v0.8.0"
9+
VERSION="v0.8.2"
1010

1111
ROOT_DIR=sklearn/externals/array_api_extra
1212

sklearn/externals/array_api_extra/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, one_hot, pad
3+
from ._delegation import isclose, nan_to_num, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -17,7 +17,7 @@
1717
)
1818
from ._lib._lazy import lazy_apply
1919

20-
__version__ = "0.8.0"
20+
__version__ = "0.8.2"
2121

2222
# pylint: disable=duplicate-code
2323
__all__ = [
@@ -33,6 +33,7 @@
3333
"isclose",
3434
"kron",
3535
"lazy_apply",
36+
"nan_to_num",
3637
"nunique",
3738
"one_hot",
3839
"pad",

sklearn/externals/array_api_extra/_delegation.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["isclose", "one_hot", "pad"]
21+
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
2222

2323

2424
def isclose(
@@ -113,6 +113,85 @@ def isclose(
113113
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
114114

115115

116+
def nan_to_num(
117+
x: Array | float | complex,
118+
/,
119+
*,
120+
fill_value: int | float = 0.0,
121+
xp: ModuleType | None = None,
122+
) -> Array:
123+
"""
124+
Replace NaN with zero and infinity with large finite numbers (default behaviour).
125+
126+
If `x` is inexact, NaN is replaced by zero or by the user defined value in the
127+
`fill_value` keyword, infinity is replaced by the largest finite floating
128+
point value representable by ``x.dtype``, and -infinity is replaced by the
129+
most negative finite floating point value representable by ``x.dtype``.
130+
131+
For complex dtypes, the above is applied to each of the real and
132+
imaginary components of `x` separately.
133+
134+
Parameters
135+
----------
136+
x : array | float | complex
137+
Input data.
138+
fill_value : int | float, optional
139+
Value to be used to fill NaN values. If no value is passed
140+
then NaN values will be replaced with 0.0.
141+
xp : array_namespace, optional
142+
The standard-compatible namespace for `x`. Default: infer.
143+
144+
Returns
145+
-------
146+
array
147+
`x`, with the non-finite values replaced.
148+
149+
See Also
150+
--------
151+
array_api.isnan : Shows which elements are Not a Number (NaN).
152+
153+
Examples
154+
--------
155+
>>> import array_api_extra as xpx
156+
>>> import array_api_strict as xp
157+
>>> xpx.nan_to_num(xp.inf)
158+
1.7976931348623157e+308
159+
>>> xpx.nan_to_num(-xp.inf)
160+
-1.7976931348623157e+308
161+
>>> xpx.nan_to_num(xp.nan)
162+
0.0
163+
>>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
164+
>>> xpx.nan_to_num(x)
165+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
166+
-1.28000000e+002, 1.28000000e+002])
167+
>>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
168+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
169+
-1.28000000e+002, 1.28000000e+002])
170+
>>> xpx.nan_to_num(y)
171+
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
172+
0.00000000e+000 +0.00000000e+000j,
173+
0.00000000e+000 +1.79769313e+308j])
174+
"""
175+
if isinstance(fill_value, complex):
176+
msg = "Complex fill values are not supported."
177+
raise TypeError(msg)
178+
179+
xp = array_namespace(x) if xp is None else xp
180+
181+
# for scalars we want to output an array
182+
y = xp.asarray(x)
183+
184+
if (
185+
is_cupy_namespace(xp)
186+
or is_jax_namespace(xp)
187+
or is_numpy_namespace(xp)
188+
or is_torch_namespace(xp)
189+
):
190+
return xp.nan_to_num(y, nan=fill_value)
191+
192+
return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)
193+
194+
116195
def one_hot(
117196
x: Array,
118197
/,

sklearn/externals/array_api_extra/_lib/_at.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class _AtOp(Enum):
3737
MAX = "max"
3838

3939
# @override from Python 3.12
40-
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride]
40+
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride]
4141
"""
4242
Return string representation (useful for pytest logs).
4343

sklearn/externals/array_api_extra/_lib/_backends.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
from __future__ import annotations
44

55
from enum import Enum
6+
from typing import Any
67

7-
__all__ = ["Backend"]
8+
import numpy as np
9+
import pytest
10+
11+
__all__ = ["NUMPY_VERSION", "Backend"]
12+
13+
NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[:3]) # pyright: ignore[reportUnknownArgumentType]
814

915

1016
class Backend(Enum): # numpydoc ignore=PR02
@@ -30,12 +36,6 @@ class Backend(Enum): # numpydoc ignore=PR02
3036
JAX = "jax.numpy"
3137
JAX_GPU = "jax.numpy:gpu"
3238

33-
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
34-
"""Pretty-print parameterized test names."""
35-
return (
36-
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
37-
)
38-
3939
@property
4040
def modname(self) -> str: # numpydoc ignore=RT01
4141
"""Module name to be imported."""
@@ -44,3 +44,29 @@ def modname(self) -> str: # numpydoc ignore=RT01
4444
def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
4545
"""Check if this backend uses the same module as others."""
4646
return any(self.modname == other.modname for other in others)
47+
48+
def pytest_param(self) -> Any:
49+
"""
50+
Backend as a pytest parameter
51+
52+
Returns
53+
-------
54+
pytest.mark.ParameterSet
55+
"""
56+
id_ = (
57+
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
58+
)
59+
60+
marks = []
61+
if self.like(Backend.ARRAY_API_STRICT):
62+
marks.append(
63+
pytest.mark.skipif(
64+
NUMPY_VERSION < (1, 26),
65+
reason="array_api_strict is untested on NumPy <1.26",
66+
)
67+
)
68+
if self.like(Backend.DASK, Backend.JAX):
69+
# Monkey-patched by lazy_xp_function
70+
marks.append(pytest.mark.thread_unsafe)
71+
72+
return pytest.param(self, id=id_, marks=marks) # pyright: ignore[reportUnknownArgumentType]

sklearn/externals/array_api_extra/_lib/_funcs.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
@overload
37-
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
37+
def apply_where( # numpydoc ignore=GL08
3838
cond: Array,
3939
args: Array | tuple[Array, ...],
4040
f1: Callable[..., Array],
@@ -46,7 +46,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G
4646

4747

4848
@overload
49-
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
49+
def apply_where( # numpydoc ignore=GL08
5050
cond: Array,
5151
args: Array | tuple[Array, ...],
5252
f1: Callable[..., Array],
@@ -57,7 +57,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G
5757
) -> Array: ...
5858

5959

60-
def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
60+
def apply_where( # numpydoc ignore=PR01,PR02
6161
cond: Array,
6262
args: Array | tuple[Array, ...],
6363
f1: Callable[..., Array],
@@ -143,7 +143,7 @@ def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
143143
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
144144

145145

146-
def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
146+
def _apply_where( # numpydoc ignore=PR01,RT01
147147
cond: Array,
148148
f1: Callable[..., Array],
149149
f2: Callable[..., Array] | None,
@@ -268,7 +268,7 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
268268
for axis in range(-ndim, 0):
269269
sizes = {shape[axis] for shape in shapes if axis >= -len(shape)}
270270
# Dask uses NaN for unknown shape, which predates the Array API spec for None
271-
none_size = None in sizes or math.nan in sizes
271+
none_size = None in sizes or math.nan in sizes # noqa: PLW0177
272272
sizes -= {1, None, math.nan}
273273
if len(sizes) > 1:
274274
msg = (
@@ -738,6 +738,47 @@ def kron(
738738
return xp.reshape(result, res_shape)
739739

740740

741+
def nan_to_num( # numpydoc ignore=PR01,RT01
742+
x: Array,
743+
/,
744+
fill_value: int | float = 0.0,
745+
*,
746+
xp: ModuleType,
747+
) -> Array:
748+
"""See docstring in `array_api_extra._delegation.py`."""
749+
750+
def perform_replacements( # numpydoc ignore=PR01,RT01
751+
x: Array,
752+
fill_value: int | float,
753+
xp: ModuleType,
754+
) -> Array:
755+
"""Internal function to perform the replacements."""
756+
x = xp.where(xp.isnan(x), fill_value, x)
757+
758+
# convert infinities to finite values
759+
finfo = xp.finfo(x.dtype)
760+
idx_posinf = xp.isinf(x) & ~xp.signbit(x)
761+
idx_neginf = xp.isinf(x) & xp.signbit(x)
762+
x = xp.where(idx_posinf, finfo.max, x)
763+
return xp.where(idx_neginf, finfo.min, x)
764+
765+
if xp.isdtype(x.dtype, "complex floating"):
766+
return perform_replacements(
767+
xp.real(x),
768+
fill_value,
769+
xp,
770+
) + 1j * perform_replacements(
771+
xp.imag(x),
772+
fill_value,
773+
xp,
774+
)
775+
776+
if xp.isdtype(x.dtype, "numeric"):
777+
return perform_replacements(x, fill_value, xp)
778+
779+
return x
780+
781+
741782
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
742783
"""
743784
Count the number of unique elements in an array.
@@ -813,8 +854,7 @@ def pad(
813854
else:
814855
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))
815856

816-
# https://github.com/python/typeshed/issues/13376
817-
slices: list[slice] = [] # type: ignore[explicit-any]
857+
slices: list[slice] = []
818858
newshape: list[int] = []
819859
for ax, w_tpl in enumerate(pad_width_seq):
820860
if len(w_tpl) != 2:
@@ -826,6 +866,7 @@ def pad(
826866
if w_tpl[0] == 0 and w_tpl[1] == 0:
827867
sl = slice(None, None, None)
828868
else:
869+
stop: int | None
829870
start, stop = w_tpl
830871
stop = None if stop == 0 else -stop
831872

sklearn/externals/array_api_extra/_lib/_lazy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
from numpy.typing import ArrayLike
2424

25-
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[explicit-any]
25+
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic
2626
else:
2727
# Sphinx hack
2828
NumPyObject = Any
@@ -31,7 +31,7 @@
3131

3232

3333
@overload
34-
def lazy_apply( # type: ignore[decorated-any, valid-type]
34+
def lazy_apply( # type: ignore[valid-type]
3535
func: Callable[P, Array | ArrayLike],
3636
*args: Array | complex | None,
3737
shape: tuple[int | None, ...] | None = None,
@@ -43,7 +43,7 @@ def lazy_apply( # type: ignore[decorated-any, valid-type]
4343

4444

4545
@overload
46-
def lazy_apply( # type: ignore[decorated-any, valid-type]
46+
def lazy_apply( # type: ignore[valid-type]
4747
func: Callable[P, Sequence[Array | ArrayLike]],
4848
*args: Array | complex | None,
4949
shape: Sequence[tuple[int | None, ...]],
@@ -313,7 +313,7 @@ def _is_jax_jit_enabled(xp: ModuleType) -> bool: # numpydoc ignore=PR01,RT01
313313
return True
314314

315315

316-
def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
316+
def _lazy_apply_wrapper( # numpydoc ignore=PR01,RT01
317317
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
318318
as_numpy: bool,
319319
multi_output: bool,
@@ -331,7 +331,7 @@ def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,R
331331

332332
# On Dask, @wraps causes the graph key to contain the wrapped function's name
333333
@wraps(func)
334-
def wrapper( # type: ignore[decorated-any,explicit-any]
334+
def wrapper(
335335
*args: Array | complex | None, **kwargs: Any
336336
) -> tuple[Array, ...]: # numpydoc ignore=GL08
337337
args_list = []
@@ -343,7 +343,7 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
343343
if as_numpy:
344344
import numpy as np
345345

346-
arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # noqa: PLW2901
346+
arg = cast(Array, np.asarray(arg)) # pyright: ignore[reportInvalidCast] # noqa: PLW2901
347347
args_list.append(arg)
348348
assert device is not None
349349

sklearn/externals/array_api_extra/_lib/_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _is_materializable(x: Array) -> bool:
110110
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
111111

112112

113-
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
113+
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
114114
"""
115115
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
116116
"""

sklearn/externals/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def is_torch_array(x: object, /) -> TypeGuard[Array]: ...
3636
def is_lazy_array(x: object, /) -> TypeGuard[Array]: ...
3737
def is_writeable_array(x: object, /) -> TypeGuard[Array]: ...
3838
def size(x: Array, /) -> int | None: ...
39-
def to_device( # type: ignore[explicit-any]
39+
def to_device(
4040
x: Array,
4141
device: Device, # pylint: disable=redefined-outer-name
4242
/,

0 commit comments

Comments
 (0)