Skip to content

Commit 8ac5387

Browse files
thomasjpfanjeremiedbb
authored andcommitted
FIX Allow OrdinalEncoder's encoded_missing_value set to the cardinality (#25704)
1 parent a73efcb commit 8ac5387

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

doc/whats_new/v1.2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ Changelog
7272
when the global configuration sets `transform_output="pandas"`.
7373
:pr:`25500` by :user:`Guillaume Lemaitre <glemaitre>`.
7474

75+
:mod:`sklearn.preprocessing`
76+
............................
77+
78+
- |Fix| :class:`preprocessing.OrdinalEncoder` now correctly supports
79+
`encoded_missing_value` or `unknown_value` set to a categories' cardinality
80+
when there is missing values in the training data. :pr:`25704` by `Thomas Fan`_.
81+
7582
:mod:`sklearn.utils`
7683
....................
7784

sklearn/preprocessing/_encoders.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,24 +1257,30 @@ def fit(self, X, y=None):
12571257
# `_fit` will only raise an error when `self.handle_unknown="error"`
12581258
self._fit(X, handle_unknown=self.handle_unknown, force_all_finite="allow-nan")
12591259

1260-
if self.handle_unknown == "use_encoded_value":
1261-
for feature_cats in self.categories_:
1262-
if 0 <= self.unknown_value < len(feature_cats):
1263-
raise ValueError(
1264-
"The used value for unknown_value "
1265-
f"{self.unknown_value} is one of the "
1266-
"values already used for encoding the "
1267-
"seen categories."
1268-
)
1260+
cardinalities = [len(categories) for categories in self.categories_]
12691261

12701262
# stores the missing indices per category
12711263
self._missing_indices = {}
12721264
for cat_idx, categories_for_idx in enumerate(self.categories_):
12731265
for i, cat in enumerate(categories_for_idx):
12741266
if is_scalar_nan(cat):
12751267
self._missing_indices[cat_idx] = i
1268+
1269+
# missing values are not considered part of the cardinality
1270+
# when considering unknown categories or encoded_missing_value
1271+
cardinalities[cat_idx] -= 1
12761272
continue
12771273

1274+
if self.handle_unknown == "use_encoded_value":
1275+
for cardinality in cardinalities:
1276+
if 0 <= self.unknown_value < cardinality:
1277+
raise ValueError(
1278+
"The used value for unknown_value "
1279+
f"{self.unknown_value} is one of the "
1280+
"values already used for encoding the "
1281+
"seen categories."
1282+
)
1283+
12781284
if self._missing_indices:
12791285
if np.dtype(self.dtype).kind != "f" and is_scalar_nan(
12801286
self.encoded_missing_value
@@ -1293,9 +1299,9 @@ def fit(self, X, y=None):
12931299
# known category
12941300
invalid_features = [
12951301
cat_idx
1296-
for cat_idx, categories_for_idx in enumerate(self.categories_)
1302+
for cat_idx, cardinality in enumerate(cardinalities)
12971303
if cat_idx in self._missing_indices
1298-
and 0 <= self.encoded_missing_value < len(categories_for_idx)
1304+
and 0 <= self.encoded_missing_value < cardinality
12991305
]
13001306

13011307
if invalid_features:

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,3 +1977,15 @@ def test_predefined_categories_dtype():
19771977
for n, cat in enumerate(enc.categories_):
19781978
assert cat.dtype == object
19791979
assert_array_equal(categories[n], cat)
1980+
1981+
1982+
def test_ordinal_encoder_missing_unknown_encoding_max():
1983+
"""Check missing value or unknown encoding can equal the cardinality."""
1984+
X = np.array([["dog"], ["cat"], [np.nan]], dtype=object)
1985+
X_trans = OrdinalEncoder(encoded_missing_value=2).fit_transform(X)
1986+
assert_allclose(X_trans, [[1], [0], [2]])
1987+
1988+
enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=2).fit(X)
1989+
X_test = np.array([["snake"]])
1990+
X_trans = enc.transform(X_test)
1991+
assert_allclose(X_trans, [[2]])

0 commit comments

Comments
 (0)