diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 86f2d29cf4ecf..dc151871874d4 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -729,14 +729,15 @@ separate categories:: See :ref:`dict_feature_extraction` for categorical features that are represented as a dict, not as scalars. -.. _one_hot_encoder_infrequent_categories: +.. _encoder_infrequent_categories: Infrequent categories --------------------- -:class:`OneHotEncoder` supports aggregating infrequent categories into a single -output for each feature. The parameters to enable the gathering of infrequent -categories are `min_frequency` and `max_categories`. +:class:`OneHotEncoder` and :class:`OrdinalEncoder` support aggregating +infrequent categories into a single output for each feature. The parameters to +enable the gathering of infrequent categories are `min_frequency` and +`max_categories`. 1. `min_frequency` is either an integer greater or equal to 1, or a float in the interval `(0.0, 1.0)`. If `min_frequency` is an integer, categories with @@ -750,11 +751,47 @@ categories are `min_frequency` and `max_categories`. input feature. `max_categories` includes the feature that combines infrequent categories. -In the following example, the categories, `'dog', 'snake'` are considered -infrequent:: +In the following example with :class:`OrdinalEncoder`, the categories `'dog' and +'snake'` are considered infrequent:: >>> X = np.array([['dog'] * 5 + ['cat'] * 20 + ['rabbit'] * 10 + ... ['snake'] * 3], dtype=object).T + >>> enc = preprocessing.OrdinalEncoder(min_frequency=6).fit(X) + >>> enc.infrequent_categories_ + [array(['dog', 'snake'], dtype=object)] + >>> enc.transform(np.array([['dog'], ['cat'], ['rabbit'], ['snake']])) + array([[2.], + [0.], + [1.], + [2.]]) + +:class:`OrdinalEncoder`'s `max_categories` do **not** take into account missing +or unknown categories. Setting `unknown_value` or `encoded_missing_value` to an +integer will increase the number of unique integer codes by one each. This can +result in up to `max_categories + 2` integer codes. In the following example, +"a" and "d" are considered infrequent and grouped together into a single +category, "b" and "c" are their own categories, unknown values are encoded as 3 +and missing values are encoded as 4. + + >>> X_train = np.array( + ... [["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 3 + [np.nan]], + ... dtype=object).T + >>> enc = preprocessing.OrdinalEncoder( + ... handle_unknown="use_encoded_value", unknown_value=3, + ... max_categories=3, encoded_missing_value=4) + >>> _ = enc.fit(X_train) + >>> X_test = np.array([["a"], ["b"], ["c"], ["d"], ["e"], [np.nan]], dtype=object) + >>> enc.transform(X_test) + array([[2.], + [0.], + [1.], + [2.], + [3.], + [4.]]) + +Similarity, :class:`OneHotEncoder` can be configured to group together infrequent +categories:: + >>> enc = preprocessing.OneHotEncoder(min_frequency=6, sparse_output=False).fit(X) >>> enc.infrequent_categories_ [array(['dog', 'snake'], dtype=object)] diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index c50672a712f93..0830468d35835 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -381,6 +381,11 @@ Changelog :pr:`24935` by :user:`Seladus `, :user:`Guillaume Lemaitre `, and :user:`Dea María Léon `, :pr:`25257` by :user:`Gleb Levitski `. +- |Feature| :class:`preprocessing.OrdinalEncoder` now supports grouping + infrequent categories into a single feature. Grouping infrequent categories + is enabled by specifying how to select infrequent categories with + `min_frequency` or `max_categories`. :pr:`25677` by `Thomas Fan`_. + - |Fix| :class:`AdditiveChi2Sampler` is now stateless. The `sample_interval_` attribute is deprecated and will be removed in 1.5. :pr:`25190` by :user:`Vincent Maladière `. diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index f985a4a4e18b3..1962f571cfbaa 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -68,8 +68,14 @@ def _check_X(self, X, force_all_finite=True): return X_columns, n_samples, n_features def _fit( - self, X, handle_unknown="error", force_all_finite=True, return_counts=False + self, + X, + handle_unknown="error", + force_all_finite=True, + return_counts=False, + return_and_ignore_missing_for_infrequent=False, ): + self._check_infrequent_enabled() self._check_n_features(X, reset=True) self._check_feature_names(X, reset=True) X_list, n_samples, n_features = self._check_X( @@ -86,13 +92,14 @@ def _fit( self.categories_ = [] category_counts = [] + compute_counts = return_counts or self._infrequent_enabled for i in range(n_features): Xi = X_list[i] if self.categories == "auto": - result = _unique(Xi, return_counts=return_counts) - if return_counts: + result = _unique(Xi, return_counts=compute_counts) + if compute_counts: cats, counts = result category_counts.append(counts) else: @@ -139,7 +146,7 @@ def _fit( " during fit".format(diff, i) ) raise ValueError(msg) - if return_counts: + if compute_counts: category_counts.append(_get_counts(Xi, cats)) self.categories_.append(cats) @@ -147,10 +154,31 @@ def _fit( output = {"n_samples": n_samples} if return_counts: output["category_counts"] = category_counts + + missing_indices = {} + if return_and_ignore_missing_for_infrequent: + for feature_idx, categories_for_idx in enumerate(self.categories_): + for category_idx, category in enumerate(categories_for_idx): + if is_scalar_nan(category): + missing_indices[feature_idx] = category_idx + break + output["missing_indices"] = missing_indices + + if self._infrequent_enabled: + self._fit_infrequent_category_mapping( + n_samples, + category_counts, + missing_indices, + ) return output def _transform( - self, X, handle_unknown="error", force_all_finite=True, warn_on_unknown=False + self, + X, + handle_unknown="error", + force_all_finite=True, + warn_on_unknown=False, + ignore_category_indices=None, ): self._check_feature_names(X, reset=False) self._check_n_features(X, reset=False) @@ -207,8 +235,209 @@ def _transform( UserWarning, ) + self._map_infrequent_categories(X_int, X_mask, ignore_category_indices) return X_int, X_mask + @property + def infrequent_categories_(self): + """Infrequent categories for each feature.""" + # raises an AttributeError if `_infrequent_indices` is not defined + infrequent_indices = self._infrequent_indices + return [ + None if indices is None else category[indices] + for category, indices in zip(self.categories_, infrequent_indices) + ] + + def _check_infrequent_enabled(self): + """ + This functions checks whether _infrequent_enabled is True or False. + This has to be called after parameter validation in the fit function. + """ + max_categories = getattr(self, "max_categories", None) + min_frequency = getattr(self, "min_frequency", None) + self._infrequent_enabled = ( + max_categories is not None and max_categories >= 1 + ) or min_frequency is not None + + def _identify_infrequent(self, category_count, n_samples, col_idx): + """Compute the infrequent indices. + + Parameters + ---------- + category_count : ndarray of shape (n_cardinality,) + Category counts. + + n_samples : int + Number of samples. + + col_idx : int + Index of the current category. Only used for the error message. + + Returns + ------- + output : ndarray of shape (n_infrequent_categories,) or None + If there are infrequent categories, indices of infrequent + categories. Otherwise None. + """ + if isinstance(self.min_frequency, numbers.Integral): + infrequent_mask = category_count < self.min_frequency + elif isinstance(self.min_frequency, numbers.Real): + min_frequency_abs = n_samples * self.min_frequency + infrequent_mask = category_count < min_frequency_abs + else: + infrequent_mask = np.zeros(category_count.shape[0], dtype=bool) + + n_current_features = category_count.size - infrequent_mask.sum() + 1 + if self.max_categories is not None and self.max_categories < n_current_features: + # max_categories includes the one infrequent category + frequent_category_count = self.max_categories - 1 + if frequent_category_count == 0: + # All categories are infrequent + infrequent_mask[:] = True + else: + # stable sort to preserve original count order + smallest_levels = np.argsort(category_count, kind="mergesort")[ + :-frequent_category_count + ] + infrequent_mask[smallest_levels] = True + + output = np.flatnonzero(infrequent_mask) + return output if output.size > 0 else None + + def _fit_infrequent_category_mapping( + self, n_samples, category_counts, missing_indices + ): + """Fit infrequent categories. + + Defines the private attribute: `_default_to_infrequent_mappings`. For + feature `i`, `_default_to_infrequent_mappings[i]` defines the mapping + from the integer encoding returned by `super().transform()` into + infrequent categories. If `_default_to_infrequent_mappings[i]` is None, + there were no infrequent categories in the training set. + + For example if categories 0, 2 and 4 were frequent, while categories + 1, 3, 5 were infrequent for feature 7, then these categories are mapped + to a single output: + `_default_to_infrequent_mappings[7] = array([0, 3, 1, 3, 2, 3])` + + Defines private attribute: `_infrequent_indices`. `_infrequent_indices[i]` + is an array of indices such that + `categories_[i][_infrequent_indices[i]]` are all the infrequent category + labels. If the feature `i` has no infrequent categories + `_infrequent_indices[i]` is None. + + .. versionadded:: 1.1 + + Parameters + ---------- + n_samples : int + Number of samples in training set. + category_counts: list of ndarray + `category_counts[i]` is the category counts corresponding to + `self.categories_[i]`. + missing_indices : dict + Dict mapping from feature_idx to category index with a missing value. + """ + # Remove missing value from counts, so it is not considered as infrequent + if missing_indices: + category_counts_ = [] + for feature_idx, count in enumerate(category_counts): + if feature_idx in missing_indices: + category_counts_.append( + np.delete(count, missing_indices[feature_idx]) + ) + else: + category_counts_.append(count) + else: + category_counts_ = category_counts + + self._infrequent_indices = [ + self._identify_infrequent(category_count, n_samples, col_idx) + for col_idx, category_count in enumerate(category_counts_) + ] + + # compute mapping from default mapping to infrequent mapping + self._default_to_infrequent_mappings = [] + + for feature_idx, infreq_idx in enumerate(self._infrequent_indices): + cats = self.categories_[feature_idx] + # no infrequent categories + if infreq_idx is None: + self._default_to_infrequent_mappings.append(None) + continue + + n_cats = len(cats) + if feature_idx in missing_indices: + # Missing index was removed from ths category when computing + # infrequent indices, thus we need to decrease the number of + # total categories when considering the infrequent mapping. + n_cats -= 1 + + # infrequent indices exist + mapping = np.empty(n_cats, dtype=np.int64) + n_infrequent_cats = infreq_idx.size + + # infrequent categories are mapped to the last element. + n_frequent_cats = n_cats - n_infrequent_cats + mapping[infreq_idx] = n_frequent_cats + + frequent_indices = np.setdiff1d(np.arange(n_cats), infreq_idx) + mapping[frequent_indices] = np.arange(n_frequent_cats) + + self._default_to_infrequent_mappings.append(mapping) + + def _map_infrequent_categories(self, X_int, X_mask, ignore_category_indices): + """Map infrequent categories to integer representing the infrequent category. + + This modifies X_int in-place. Values that were invalid based on `X_mask` + are mapped to the infrequent category if there was an infrequent + category for that feature. + + Parameters + ---------- + X_int: ndarray of shape (n_samples, n_features) + Integer encoded categories. + + X_mask: ndarray of shape (n_samples, n_features) + Bool mask for valid values in `X_int`. + + ignore_category_indices : dict + Dictionary mapping from feature_idx to category index to ignore. + Ignored indexes will not be grouped and the original ordinal encoding + will remain. + """ + if not self._infrequent_enabled: + return + + ignore_category_indices = ignore_category_indices or {} + + for col_idx in range(X_int.shape[1]): + infrequent_idx = self._infrequent_indices[col_idx] + if infrequent_idx is None: + continue + + X_int[~X_mask[:, col_idx], col_idx] = infrequent_idx[0] + if self.handle_unknown == "infrequent_if_exist": + # All the unknown values are now mapped to the + # infrequent_idx[0], which makes the unknown values valid + # This is needed in `transform` when the encoding is formed + # using `X_mask`. + X_mask[:, col_idx] = True + + # Remaps encoding in `X_int` where the infrequent categories are + # grouped together. + for i, mapping in enumerate(self._default_to_infrequent_mappings): + if mapping is None: + continue + + if i in ignore_category_indices: + # Update rows that are **not** ignored + rows_to_update = X_int[:, i] != ignore_category_indices[i] + else: + rows_to_update = slice(None) + + X_int[rows_to_update, i] = np.take(mapping, X_int[rows_to_update, i]) + def _more_tags(self): return {"X_types": ["categorical"]} @@ -319,7 +548,7 @@ class OneHotEncoder(_BaseEncoder): :meth:`inverse_transform` will handle an unknown category as with `handle_unknown='ignore'`. Infrequent categories exist based on `min_frequency` and `max_categories`. Read more in the - :ref:`User Guide `. + :ref:`User Guide `. .. versionchanged:: 1.1 `'infrequent_if_exist'` was added to automatically handle unknown @@ -336,7 +565,7 @@ class OneHotEncoder(_BaseEncoder): `min_frequency * n_samples` will be considered infrequent. .. versionadded:: 1.1 - Read more in the :ref:`User Guide `. + Read more in the :ref:`User Guide `. max_categories : int, default=None Specifies an upper limit to the number of output features for each input @@ -346,7 +575,7 @@ class OneHotEncoder(_BaseEncoder): there is no limit to the number of output features. .. versionadded:: 1.1 - Read more in the :ref:`User Guide `. + Read more in the :ref:`User Guide `. feature_name_combiner : "concat" or callable, default="concat" Callable with signature `def callable(input_feature, category)` that returns a @@ -527,25 +756,6 @@ def __init__( self.max_categories = max_categories self.feature_name_combiner = feature_name_combiner - @property - def infrequent_categories_(self): - """Infrequent categories for each feature.""" - # raises an AttributeError if `_infrequent_indices` is not defined - infrequent_indices = self._infrequent_indices - return [ - None if indices is None else category[indices] - for category, indices in zip(self.categories_, infrequent_indices) - ] - - def _check_infrequent_enabled(self): - """ - This functions checks whether _infrequent_enabled is True or False. - This has to be called after parameter validation in the fit function. - """ - self._infrequent_enabled = ( - self.max_categories is not None and self.max_categories >= 1 - ) or self.min_frequency is not None - def _map_drop_idx_to_infrequent(self, feature_idx, drop_idx): """Convert `drop_idx` into the index for infrequent categories. @@ -688,141 +898,6 @@ def _set_drop_idx(self): self.drop_idx_ = np.asarray(drop_idx_, dtype=object) - def _identify_infrequent(self, category_count, n_samples, col_idx): - """Compute the infrequent indices. - - Parameters - ---------- - category_count : ndarray of shape (n_cardinality,) - Category counts. - - n_samples : int - Number of samples. - - col_idx : int - Index of the current category. Only used for the error message. - - Returns - ------- - output : ndarray of shape (n_infrequent_categories,) or None - If there are infrequent categories, indices of infrequent - categories. Otherwise None. - """ - if isinstance(self.min_frequency, numbers.Integral): - infrequent_mask = category_count < self.min_frequency - elif isinstance(self.min_frequency, numbers.Real): - min_frequency_abs = n_samples * self.min_frequency - infrequent_mask = category_count < min_frequency_abs - else: - infrequent_mask = np.zeros(category_count.shape[0], dtype=bool) - - n_current_features = category_count.size - infrequent_mask.sum() + 1 - if self.max_categories is not None and self.max_categories < n_current_features: - # stable sort to preserve original count order - smallest_levels = np.argsort(category_count, kind="mergesort")[ - : -self.max_categories + 1 - ] - infrequent_mask[smallest_levels] = True - - output = np.flatnonzero(infrequent_mask) - return output if output.size > 0 else None - - def _fit_infrequent_category_mapping(self, n_samples, category_counts): - """Fit infrequent categories. - - Defines the private attribute: `_default_to_infrequent_mappings`. For - feature `i`, `_default_to_infrequent_mappings[i]` defines the mapping - from the integer encoding returned by `super().transform()` into - infrequent categories. If `_default_to_infrequent_mappings[i]` is None, - there were no infrequent categories in the training set. - - For example if categories 0, 2 and 4 were frequent, while categories - 1, 3, 5 were infrequent for feature 7, then these categories are mapped - to a single output: - `_default_to_infrequent_mappings[7] = array([0, 3, 1, 3, 2, 3])` - - Defines private attribute: `_infrequent_indices`. `_infrequent_indices[i]` - is an array of indices such that - `categories_[i][_infrequent_indices[i]]` are all the infrequent category - labels. If the feature `i` has no infrequent categories - `_infrequent_indices[i]` is None. - - .. versionadded:: 1.1 - - Parameters - ---------- - n_samples : int - Number of samples in training set. - category_counts: list of ndarray - `category_counts[i]` is the category counts corresponding to - `self.categories_[i]`. - """ - self._infrequent_indices = [ - self._identify_infrequent(category_count, n_samples, col_idx) - for col_idx, category_count in enumerate(category_counts) - ] - - # compute mapping from default mapping to infrequent mapping - self._default_to_infrequent_mappings = [] - - for cats, infreq_idx in zip(self.categories_, self._infrequent_indices): - # no infrequent categories - if infreq_idx is None: - self._default_to_infrequent_mappings.append(None) - continue - - n_cats = len(cats) - # infrequent indices exist - mapping = np.empty(n_cats, dtype=np.int64) - n_infrequent_cats = infreq_idx.size - - # infrequent categories are mapped to the last element. - n_frequent_cats = n_cats - n_infrequent_cats - mapping[infreq_idx] = n_frequent_cats - - frequent_indices = np.setdiff1d(np.arange(n_cats), infreq_idx) - mapping[frequent_indices] = np.arange(n_frequent_cats) - - self._default_to_infrequent_mappings.append(mapping) - - def _map_infrequent_categories(self, X_int, X_mask): - """Map infrequent categories to integer representing the infrequent category. - - This modifies X_int in-place. Values that were invalid based on `X_mask` - are mapped to the infrequent category if there was an infrequent - category for that feature. - - Parameters - ---------- - X_int: ndarray of shape (n_samples, n_features) - Integer encoded categories. - - X_mask: ndarray of shape (n_samples, n_features) - Bool mask for valid values in `X_int`. - """ - if not self._infrequent_enabled: - return - - for col_idx in range(X_int.shape[1]): - infrequent_idx = self._infrequent_indices[col_idx] - if infrequent_idx is None: - continue - - X_int[~X_mask[:, col_idx], col_idx] = infrequent_idx[0] - if self.handle_unknown == "infrequent_if_exist": - # All the unknown values are now mapped to the - # infrequent_idx[0], which makes the unknown values valid - # This is needed in `transform` when the encoding is formed - # using `X_mask`. - X_mask[:, col_idx] = True - - # Remaps encoding in `X_int` where the infrequent categories are - # grouped together. - for i, mapping in enumerate(self._default_to_infrequent_mappings): - if mapping is None: - continue - X_int[:, i] = np.take(mapping, X_int[:, i]) - def _compute_transformed_categories(self, i, remove_dropped=True): """Compute the transformed categories used for column `i`. @@ -905,18 +980,11 @@ def fit(self, X, y=None): ) self.sparse_output = self.sparse - self._check_infrequent_enabled() - - fit_results = self._fit( + self._fit( X, handle_unknown=self.handle_unknown, force_all_finite="allow-nan", - return_counts=self._infrequent_enabled, ) - if self._infrequent_enabled: - self._fit_infrequent_category_mapping( - fit_results["n_samples"], fit_results["category_counts"] - ) self._set_drop_idx() self._n_features_outs = self._compute_n_features_outs() return self @@ -952,7 +1020,6 @@ def transform(self, X): force_all_finite="allow-nan", warn_on_unknown=warn_on_unknown, ) - self._map_infrequent_categories(X_int, X_mask) n_samples, n_features = X_int.shape @@ -1210,6 +1277,34 @@ class OrdinalEncoder(OneToOneFeatureMixin, _BaseEncoder): .. versionadded:: 1.1 + min_frequency : int or float, default=None + Specifies the minimum frequency below which a category will be + considered infrequent. + + - If `int`, categories with a smaller cardinality will be considered + infrequent. + + - If `float`, categories with a smaller cardinality than + `min_frequency * n_samples` will be considered infrequent. + + .. versionadded:: 1.3 + Read more in the :ref:`User Guide `. + + max_categories : int, default=None + Specifies an upper limit to the number of output categories for each input + feature when considering infrequent categories. If there are infrequent + categories, `max_categories` includes the category representing the + infrequent categories along with the frequent categories. If `None`, + there is no limit to the number of output features. + + `max_categories` do **not** take into account missing or unknown + categories. Setting `unknown_value` or `encoded_missing_value` to an + integer will increase the number of unique integer codes by one each. + This can result in up to `max_categories + 2` integer codes. + + .. versionadded:: 1.3 + Read more in the :ref:`User Guide `. + Attributes ---------- categories_ : list of arrays @@ -1228,6 +1323,15 @@ class OrdinalEncoder(OneToOneFeatureMixin, _BaseEncoder): .. versionadded:: 1.0 + infrequent_categories_ : list of ndarray + Defined only if infrequent categories are enabled by setting + `min_frequency` or `max_categories` to a non-default value. + `infrequent_categories_[i]` are the infrequent categories for feature + `i`. If the feature `i` has no infrequent categories + `infrequent_categories_[i]` is None. + + .. versionadded:: 1.3 + See Also -------- OneHotEncoder : Performs a one-hot encoding of categorical features. This encoding @@ -1282,6 +1386,27 @@ class OrdinalEncoder(OneToOneFeatureMixin, _BaseEncoder): array([[ 1., 0.], [ 0., 1.], [ 0., -1.]]) + + Infrequent categories are enabled by setting `max_categories` or `min_frequency`. + In the following example, "a" and "d" are considered infrequent and grouped + together into a single category, "b" and "c" are their own categories, unknown + values are encoded as 3 and missing values are encoded as 4. + + >>> X_train = np.array( + ... [["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 3 + [np.nan]], + ... dtype=object).T + >>> enc = OrdinalEncoder( + ... handle_unknown="use_encoded_value", unknown_value=3, + ... max_categories=3, encoded_missing_value=4) + >>> _ = enc.fit(X_train) + >>> X_test = np.array([["a"], ["b"], ["c"], ["d"], ["e"], [np.nan]], dtype=object) + >>> enc.transform(X_test) + array([[2.], + [0.], + [1.], + [2.], + [3.], + [4.]]) """ _parameter_constraints: dict = { @@ -1290,6 +1415,12 @@ class OrdinalEncoder(OneToOneFeatureMixin, _BaseEncoder): "encoded_missing_value": [Integral, type(np.nan)], "handle_unknown": [StrOptions({"error", "use_encoded_value"})], "unknown_value": [Integral, type(np.nan), None], + "max_categories": [Interval(Integral, 1, None, closed="left"), None], + "min_frequency": [ + Interval(Integral, 1, None, closed="left"), + Interval(RealNotInt, 0, 1, closed="neither"), + None, + ], } def __init__( @@ -1300,12 +1431,16 @@ def __init__( handle_unknown="error", unknown_value=None, encoded_missing_value=np.nan, + min_frequency=None, + max_categories=None, ): self.categories = categories self.dtype = dtype self.handle_unknown = handle_unknown self.unknown_value = unknown_value self.encoded_missing_value = encoded_missing_value + self.min_frequency = min_frequency + self.max_categories = max_categories def fit(self, X, y=None): """ @@ -1350,9 +1485,21 @@ def fit(self, X, y=None): ) # `_fit` will only raise an error when `self.handle_unknown="error"` - self._fit(X, handle_unknown=self.handle_unknown, force_all_finite="allow-nan") + fit_results = self._fit( + X, + handle_unknown=self.handle_unknown, + force_all_finite="allow-nan", + return_and_ignore_missing_for_infrequent=True, + ) + self._missing_indices = fit_results["missing_indices"] cardinalities = [len(categories) for categories in self.categories_] + if self._infrequent_enabled: + # Cardinality decreases because the infrequent categories are grouped + # together + for feature_idx, infrequent in enumerate(self.infrequent_categories_): + if infrequent is not None: + cardinalities[feature_idx] -= len(infrequent) # stores the missing indices per category self._missing_indices = {} @@ -1426,7 +1573,10 @@ def transform(self, X): Transformed input. """ X_int, X_mask = self._transform( - X, handle_unknown=self.handle_unknown, force_all_finite="allow-nan" + X, + handle_unknown=self.handle_unknown, + force_all_finite="allow-nan", + ignore_category_indices=self._missing_indices, ) X_trans = X_int.astype(self.dtype, copy=False) @@ -1471,6 +1621,9 @@ def inverse_transform(self, X): X_tr = np.empty((n_samples, n_features), dtype=dt) found_unknown = {} + infrequent_masks = {} + + infrequent_indices = getattr(self, "_infrequent_indices", None) for i in range(n_features): labels = X[:, i] @@ -1480,22 +1633,44 @@ def inverse_transform(self, X): X_i_mask = _get_mask(labels, self.encoded_missing_value) labels[X_i_mask] = self._missing_indices[i] + rows_to_update = slice(None) + categories = self.categories_[i] + + if infrequent_indices is not None and infrequent_indices[i] is not None: + # Compute mask for frequent categories + infrequent_encoding_value = len(categories) - len(infrequent_indices[i]) + infrequent_masks[i] = labels == infrequent_encoding_value + rows_to_update = ~infrequent_masks[i] + + # Remap categories to be only frequent categories. The infrequent + # categories will be mapped to "infrequent_sklearn" later + frequent_categories_mask = np.ones_like(categories, dtype=bool) + frequent_categories_mask[infrequent_indices[i]] = False + categories = categories[frequent_categories_mask] + if self.handle_unknown == "use_encoded_value": unknown_labels = _get_mask(labels, self.unknown_value) + found_unknown[i] = unknown_labels known_labels = ~unknown_labels - X_tr[known_labels, i] = self.categories_[i][ - labels[known_labels].astype("int64", copy=False) - ] - found_unknown[i] = unknown_labels - else: - X_tr[:, i] = self.categories_[i][labels.astype("int64", copy=False)] + if isinstance(rows_to_update, np.ndarray): + rows_to_update &= known_labels + else: + rows_to_update = known_labels - # insert None values for unknown values - if found_unknown: + labels_int = labels[rows_to_update].astype("int64", copy=False) + X_tr[rows_to_update, i] = categories[labels_int] + + if found_unknown or infrequent_masks: X_tr = X_tr.astype(object, copy=False) + # insert None values for unknown values + if found_unknown: for idx, mask in found_unknown.items(): X_tr[mask, idx] = None + if infrequent_masks: + for idx, mask in infrequent_masks.items(): + X_tr[mask, idx] = "infrequent_sklearn" + return X_tr diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index a4fea0ee92dbc..ffd5eda5195d0 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -2051,3 +2051,256 @@ def test_drop_idx_infrequent_categories(): ["x0_b", "x0_c", "x0_d", "x0_e", "x0_infrequent_sklearn"], ) assert ohe.drop_idx_ is None + + +@pytest.mark.parametrize( + "kwargs", + [ + {"max_categories": 3}, + {"min_frequency": 6}, + {"min_frequency": 9}, + {"min_frequency": 0.24}, + {"min_frequency": 0.16}, + {"max_categories": 3, "min_frequency": 8}, + {"max_categories": 4, "min_frequency": 6}, + ], +) +def test_ordinal_encoder_infrequent_three_levels(kwargs): + """Test parameters for grouping 'a', and 'd' into the infrequent category.""" + + X_train = np.array([["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 3]).T + ordinal = OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=-1, **kwargs + ).fit(X_train) + assert_array_equal(ordinal.categories_, [["a", "b", "c", "d"]]) + assert_array_equal(ordinal.infrequent_categories_, [["a", "d"]]) + + X_test = [["a"], ["b"], ["c"], ["d"], ["z"]] + expected_trans = [[2], [0], [1], [2], [-1]] + + X_trans = ordinal.transform(X_test) + assert_allclose(X_trans, expected_trans) + + X_inverse = ordinal.inverse_transform(X_trans) + expected_inverse = [ + ["infrequent_sklearn"], + ["b"], + ["c"], + ["infrequent_sklearn"], + [None], + ] + assert_array_equal(X_inverse, expected_inverse) + + +def test_ordinal_encoder_infrequent_three_levels_user_cats(): + """Test that the order of the categories provided by a user is respected. + + In this case 'c' is encoded as the first category and 'b' is encoded + as the second one. + """ + + X_train = np.array( + [["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 3], dtype=object + ).T + ordinal = OrdinalEncoder( + categories=[["c", "d", "b", "a"]], + max_categories=3, + handle_unknown="use_encoded_value", + unknown_value=-1, + ).fit(X_train) + assert_array_equal(ordinal.categories_, [["c", "d", "b", "a"]]) + assert_array_equal(ordinal.infrequent_categories_, [["d", "a"]]) + + X_test = [["a"], ["b"], ["c"], ["d"], ["z"]] + expected_trans = [[2], [1], [0], [2], [-1]] + + X_trans = ordinal.transform(X_test) + assert_allclose(X_trans, expected_trans) + + X_inverse = ordinal.inverse_transform(X_trans) + expected_inverse = [ + ["infrequent_sklearn"], + ["b"], + ["c"], + ["infrequent_sklearn"], + [None], + ] + assert_array_equal(X_inverse, expected_inverse) + + +def test_ordinal_encoder_infrequent_mixed(): + """Test when feature 0 has infrequent categories and feature 1 does not.""" + + X = np.column_stack(([0, 1, 3, 3, 3, 3, 2, 0, 3], [0, 0, 0, 0, 1, 1, 1, 1, 1])) + + ordinal = OrdinalEncoder(max_categories=3).fit(X) + + assert_array_equal(ordinal.infrequent_categories_[0], [1, 2]) + assert ordinal.infrequent_categories_[1] is None + + X_test = [[3, 0], [1, 1]] + expected_trans = [[1, 0], [2, 1]] + + X_trans = ordinal.transform(X_test) + assert_allclose(X_trans, expected_trans) + + X_inverse = ordinal.inverse_transform(X_trans) + expected_inverse = np.array([[3, 0], ["infrequent_sklearn", 1]], dtype=object) + assert_array_equal(X_inverse, expected_inverse) + + +def test_ordinal_encoder_infrequent_multiple_categories_dtypes(): + """Test infrequent categories with a pandas DataFrame with multiple dtypes.""" + + pd = pytest.importorskip("pandas") + categorical_dtype = pd.CategoricalDtype(["bird", "cat", "dog", "snake"]) + X = pd.DataFrame( + { + "str": ["a", "f", "c", "f", "f", "a", "c", "b", "b"], + "int": [5, 3, 0, 10, 10, 12, 0, 3, 5], + "categorical": pd.Series( + ["dog"] * 4 + ["cat"] * 3 + ["snake"] + ["bird"], + dtype=categorical_dtype, + ), + }, + columns=["str", "int", "categorical"], + ) + + ordinal = OrdinalEncoder(max_categories=3).fit(X) + # X[:, 0] 'a', 'b', 'c' have the same frequency. 'a' and 'b' will be + # considered infrequent because they appear first when sorted + + # X[:, 1] 0, 3, 5, 10 has frequency 2 and 12 has frequency 1. + # 0, 3, 12 will be considered infrequent because they appear first when + # sorted. + + # X[:, 2] "snake" and "bird" or infrequent + + assert_array_equal(ordinal.infrequent_categories_[0], ["a", "b"]) + assert_array_equal(ordinal.infrequent_categories_[1], [0, 3, 12]) + assert_array_equal(ordinal.infrequent_categories_[2], ["bird", "snake"]) + + X_test = pd.DataFrame( + { + "str": ["a", "b", "f", "c"], + "int": [12, 0, 10, 5], + "categorical": pd.Series( + ["cat"] + ["snake"] + ["bird"] + ["dog"], + dtype=categorical_dtype, + ), + }, + columns=["str", "int", "categorical"], + ) + expected_trans = [[2, 2, 0], [2, 2, 2], [1, 1, 2], [0, 0, 1]] + + X_trans = ordinal.transform(X_test) + assert_allclose(X_trans, expected_trans) + + +def test_ordinal_encoder_infrequent_custom_mapping(): + """Check behavior of unknown_value and encoded_missing_value with infrequent.""" + X_train = np.array( + [["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 3 + [np.nan]], dtype=object + ).T + + ordinal = OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=2, + max_categories=2, + encoded_missing_value=3, + ).fit(X_train) + assert_array_equal(ordinal.infrequent_categories_, [["a", "c", "d"]]) + + X_test = np.array([["a"], ["b"], ["c"], ["d"], ["e"], [np.nan]], dtype=object) + expected_trans = [[1], [0], [1], [1], [2], [3]] + + X_trans = ordinal.transform(X_test) + assert_allclose(X_trans, expected_trans) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"max_categories": 6}, + {"min_frequency": 2}, + ], +) +def test_ordinal_encoder_all_frequent(kwargs): + """All categories are considered frequent have same encoding as default encoder.""" + X_train = np.array( + [["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 3], dtype=object + ).T + + adjusted_encoder = OrdinalEncoder( + **kwargs, handle_unknown="use_encoded_value", unknown_value=-1 + ).fit(X_train) + default_encoder = OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=-1 + ).fit(X_train) + + X_test = [["a"], ["b"], ["c"], ["d"], ["e"]] + + assert_allclose( + adjusted_encoder.transform(X_test), default_encoder.transform(X_test) + ) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"max_categories": 1}, + {"min_frequency": 100}, + ], +) +def test_ordinal_encoder_all_infrequent(kwargs): + """When all categories are infrequent, they are all encoded as zero.""" + X_train = np.array( + [["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 3], dtype=object + ).T + encoder = OrdinalEncoder( + **kwargs, handle_unknown="use_encoded_value", unknown_value=-1 + ).fit(X_train) + + X_test = [["a"], ["b"], ["c"], ["d"], ["e"]] + assert_allclose(encoder.transform(X_test), [[0], [0], [0], [0], [-1]]) + + +def test_ordinal_encoder_missing_appears_frequent(): + """Check behavior when missing value appears frequently.""" + X = np.array( + [[np.nan] * 20 + ["dog"] * 10 + ["cat"] * 5 + ["snake"] + ["deer"]], + dtype=object, + ).T + ordinal = OrdinalEncoder(max_categories=3).fit(X) + + X_test = np.array([["snake", "cat", "dog", np.nan]], dtype=object).T + X_trans = ordinal.transform(X_test) + assert_allclose(X_trans, [[2], [0], [1], [np.nan]]) + + +def test_ordinal_encoder_missing_appears_infrequent(): + """Check behavior when missing value appears infrequently.""" + + # feature 0 has infrequent categories + # feature 1 has no infrequent categories + X = np.array( + [ + [np.nan] + ["dog"] * 10 + ["cat"] * 5 + ["snake"] + ["deer"], + ["red"] * 9 + ["green"] * 9, + ], + dtype=object, + ).T + ordinal = OrdinalEncoder(min_frequency=4).fit(X) + + X_test = np.array( + [ + ["snake", "red"], + ["deer", "green"], + [np.nan, "green"], + ["dog", "green"], + ["cat", "red"], + ], + dtype=object, + ) + X_trans = ordinal.transform(X_test) + assert_allclose(X_trans, [[2, 1], [2, 0], [np.nan, 0], [1, 0], [0, 1]])