Skip to content

Commit b10b73a

Browse files
AlexandreAbrahamjeremiedbbbetatim
authored
Fix uncomparable values in SimpleImputer tie-breaking strategy (scikit-learn#31820)
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai> Co-authored-by: Tim Head <betatim@gmail.com>
1 parent faf69cb commit b10b73a

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- Fixed a bug in :class:`impute.SimpleImputer` with `strategy="most_frequent"` when
2+
there is a tie in the most frequent value and the input data has mixed types.
3+
By :user:`Alexandre Abraham <AlexandreAbraham>`.

sklearn/impute/_base.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@ def _check_inputs_dtype(X, missing_values):
3838
)
3939

4040

41+
def _safe_min(items):
42+
"""Compute the minimum of a list of potentially non-comparable values.
43+
44+
If values cannot be directly compared due to type incompatibility, the object with
45+
the lowest string representation is returned.
46+
"""
47+
try:
48+
return min(items)
49+
except TypeError as e:
50+
if "'<' not supported between" in str(e):
51+
return min(items, key=lambda x: (str(type(x)), str(x)))
52+
raise # pragma: no cover
53+
54+
4155
def _most_frequent(array, extra_value, n_repeat):
4256
"""Compute the most frequent value in a 1d array extended with
4357
[extra_value] * n_repeat, where extra_value is assumed to be not part
@@ -50,10 +64,12 @@ def _most_frequent(array, extra_value, n_repeat):
5064
counter = Counter(array)
5165
most_frequent_count = counter.most_common(1)[0][1]
5266
# tie breaking similarly to scipy.stats.mode
53-
most_frequent_value = min(
54-
value
55-
for value, count in counter.items()
56-
if count == most_frequent_count
67+
most_frequent_value = _safe_min(
68+
[
69+
value
70+
for value, count in counter.items()
71+
if count == most_frequent_count
72+
]
5773
)
5874
else:
5975
mode = _mode(array)
@@ -72,7 +88,7 @@ def _most_frequent(array, extra_value, n_repeat):
7288
return most_frequent_value
7389
elif most_frequent_count == n_repeat:
7490
# tie breaking similarly to scipy.stats.mode
75-
return min(most_frequent_value, extra_value)
91+
return _safe_min([most_frequent_value, extra_value])
7692

7793

7894
class _BaseImputer(TransformerMixin, BaseEstimator):

sklearn/impute/tests/test_impute.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,26 @@ def test_most_frequent(expected, array, dtype, extra_value, n_repeat):
15291529
)
15301530

15311531

1532+
@pytest.mark.parametrize(
1533+
"expected,array",
1534+
[
1535+
("a", ["a", "b"]),
1536+
(1, [1, 2]),
1537+
(None, [None, "a"]),
1538+
(None, [None, 1]),
1539+
(None, [None, "a", 1]),
1540+
(1, [1, "1"]),
1541+
(1, ["1", 1]),
1542+
],
1543+
)
1544+
def test_most_frequent_tie_object(expected, array):
1545+
"""Check the tie breaking behavior of the most frequent strategy.
1546+
1547+
Non-regression test for issue #31717.
1548+
"""
1549+
assert expected == _most_frequent(np.array(array, dtype=object), None, 0)
1550+
1551+
15321552
@pytest.mark.parametrize(
15331553
"initial_strategy", ["mean", "median", "most_frequent", "constant"]
15341554
)

0 commit comments

Comments
 (0)