Skip to content

Commit 0eeec7e

Browse files
glemaitreogriselMicky774thomasjpfan
committed
MAINT fix the way to call stats.mode (#23633)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent c722368 commit 0eeec7e

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

sklearn/impute/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import numpy as np
1010
import numpy.ma as ma
1111
from scipy import sparse as sp
12-
from scipy import stats
1312

1413
from ..base import BaseEstimator, TransformerMixin
14+
from ..utils.fixes import _mode
1515
from ..utils.sparsefuncs import _get_median
1616
from ..utils.validation import check_is_fitted
1717
from ..utils.validation import FLOAT_DTYPES
@@ -51,7 +51,7 @@ def _most_frequent(array, extra_value, n_repeat):
5151
if count == most_frequent_count
5252
)
5353
else:
54-
mode = stats.mode(array)
54+
mode = _mode(array)
5555
most_frequent_value = mode[0][0]
5656
most_frequent_count = mode[1][0]
5757
else:

sklearn/neighbors/_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# License: BSD 3 clause (C) INRIA, University of Amsterdam
1010

1111
import numpy as np
12-
from scipy import stats
12+
from ..utils.fixes import _mode
1313
from ..utils.extmath import weighted_mode
1414
from ..utils.validation import _is_arraylike, _num_samples
1515

@@ -241,7 +241,7 @@ def predict(self, X):
241241
y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype)
242242
for k, classes_k in enumerate(classes_):
243243
if weights is None:
244-
mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
244+
mode, _ = _mode(_y[neigh_ind, k], axis=1)
245245
else:
246246
mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1)
247247

sklearn/utils/fixes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,10 @@ def threadpool_info():
163163

164164

165165
threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__
166+
167+
168+
# TODO: Remove when SciPy 1.9 is the minimum supported version
169+
def _mode(a, axis=0):
170+
if sp_version >= parse_version("1.9.0"):
171+
return scipy.stats.mode(a, axis=axis, keepdims=True)
172+
return scipy.stats.mode(a, axis=axis)

sklearn/utils/tests/test_extmath.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
from scipy import sparse
99
from scipy import linalg
10-
from scipy import stats
1110
from scipy.sparse.linalg import eigsh
1211
from scipy.special import expit
1312

@@ -20,6 +19,7 @@
2019
from sklearn.utils._testing import assert_array_equal
2120
from sklearn.utils._testing import assert_array_almost_equal
2221
from sklearn.utils._testing import skip_if_32bit
22+
from sklearn.utils.fixes import _mode
2323

2424
from sklearn.utils.extmath import density, _safe_accumulator_op
2525
from sklearn.utils.extmath import randomized_svd, _randomized_eigsh
@@ -57,7 +57,7 @@ def test_uniform_weights():
5757
weights = np.ones(x.shape)
5858

5959
for axis in (None, 0, 1):
60-
mode, score = stats.mode(x, axis)
60+
mode, score = _mode(x, axis)
6161
mode2, score2 = weighted_mode(x, weights, axis=axis)
6262

6363
assert_array_equal(mode, mode2)

0 commit comments

Comments
 (0)