We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5e25db7 commit 60e4acdCopy full SHA for 60e4acd
sklearn/utils/validation.py
@@ -95,15 +95,15 @@ def _assert_all_finite(
95
):
96
"""Like assert_all_finite, but only for ndarray."""
97
98
- xp, _ = get_namespace(X)
+ xp, is_array_api = get_namespace(X)
99
100
if _get_config()["assume_finite"]:
101
return
102
103
X = xp.asarray(X)
104
105
# for object dtype data, we only check for NaNs (GH-13254)
106
- if X.dtype == np.dtype("object") and not allow_nan:
+ if not is_array_api and X.dtype == np.dtype("object") and not allow_nan:
107
if _object_dtype_isnan(X).any():
108
raise ValueError("Input contains NaN")
109
0 commit comments