diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 87814e102ad98..6740041b072c7 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -88,6 +88,13 @@ Changelog whether to raise an exception if a subset of the scorers in multimetric scoring fails or to return an error code. :pr:`28992` by :user:`Stefanie Senger `. +:mod:`sklearn.preprocessing` +............................ + +- |Feature| :class:`preprocessing.Target_encoder` now supports the `sample_weight` + parameter in the `fit` and `fit_transform` methods. :pr:`29110` + by :user: `Duarte São José ` and `Miguel Parece `. + Thanks to everyone who has contributed to the maintenance and improvement of the project since version 1.5, including: diff --git a/sklearn/preprocessing/_target_encoder.py b/sklearn/preprocessing/_target_encoder.py index b3b7c3d5e7bd9..055b427d98494 100644 --- a/sklearn/preprocessing/_target_encoder.py +++ b/sklearn/preprocessing/_target_encoder.py @@ -7,6 +7,7 @@ from ..utils.multiclass import type_of_target from ..utils.validation import ( _check_feature_names_in, + _check_sample_weight, _check_y, check_consistent_length, check_is_fitted, @@ -209,7 +210,7 @@ def __init__( self.random_state = random_state @_fit_context(prefer_skip_nested_validation=True) - def fit(self, X, y): + def fit(self, X, y, sample_weight=None): """Fit the :class:`TargetEncoder` to X and y. Parameters @@ -220,16 +221,19 @@ def fit(self, X, y): y : array-like of shape (n_samples,) The target data used to encode the categories. + sample_weight : ndarray of shape (n_samples,) + Contains weight values to be associated with each sample. + Returns ------- self : object Fitted encoder. """ - self._fit_encodings_all(X, y) + self._fit_encodings_all(X, y, sample_weight) return self @_fit_context(prefer_skip_nested_validation=True) - def fit_transform(self, X, y): + def fit_transform(self, X, y, sample_weight=None): """Fit :class:`TargetEncoder` and transform X with the target encoding. .. note:: @@ -245,6 +249,9 @@ def fit_transform(self, X, y): y : array-like of shape (n_samples,) The target data used to encode the categories. + sample_weight : ndarray of shape (n_samples,) + Contains weight values to be associated with each sample. + Returns ------- X_trans : ndarray of shape (n_samples, n_features) or \ @@ -253,11 +260,14 @@ def fit_transform(self, X, y): """ from ..model_selection import KFold, StratifiedKFold # avoid circular import - X_ordinal, X_known_mask, y_encoded, n_categories = self._fit_encodings_all(X, y) + X_ordinal, X_known_mask, y_encoded, n_categories = self._fit_encodings_all( + X, y, sample_weight + ) # The cv splitter is voluntarily restricted to *KFold to enforce non # overlapping validation folds, otherwise the fit_transform output will # not be well-specified. + if self.target_type_ == "continuous": cv = KFold(self.cv, shuffle=self.shuffle, random_state=self.random_state) else: @@ -273,17 +283,18 @@ def fit_transform(self, X, y): ) else: X_out = np.empty_like(X_ordinal, dtype=np.float64) - + sample_weight = _check_sample_weight(sample_weight, X) for train_idx, test_idx in cv.split(X, y): X_train, y_train = X_ordinal[train_idx, :], y_encoded[train_idx] - y_train_mean = np.mean(y_train, axis=0) - + sample_weight_train = sample_weight[train_idx] + y_train_mean = np.average(y_train, weights=sample_weight_train, axis=0) if self.target_type_ == "multiclass": encodings = self._fit_encoding_multiclass( X_train, y_train, n_categories, y_train_mean, + sample_weight_train, ) else: encodings = self._fit_encoding_binary_or_continuous( @@ -291,6 +302,7 @@ def fit_transform(self, X, y): y_train, n_categories, y_train_mean, + sample_weight_train, ) self._transform_X_ordinal( X_out, @@ -344,7 +356,7 @@ def transform(self, X): ) return X_out - def _fit_encodings_all(self, X, y): + def _fit_encodings_all(self, X, y, sample_weight=None): """Fit a target encoding with all the data.""" # avoid circular import from ..preprocessing import ( @@ -353,6 +365,7 @@ def _fit_encodings_all(self, X, y): ) check_consistent_length(X, y) + sample_weight = _check_sample_weight(sample_weight, X) self._fit(X, handle_unknown="ignore", force_all_finite="allow-nan") if self.target_type == "auto": @@ -380,7 +393,7 @@ def _fit_encodings_all(self, X, y): else: # continuous y = _check_y(y, y_numeric=True, estimator=self) - self.target_mean_ = np.mean(y, axis=0) + self.target_mean_ = np.average(y, weights=sample_weight, axis=0) X_ordinal, X_known_mask = self._transform( X, handle_unknown="ignore", force_all_finite="allow-nan" @@ -396,6 +409,7 @@ def _fit_encodings_all(self, X, y): y, n_categories, self.target_mean_, + sample_weight, ) else: encodings = self._fit_encoding_binary_or_continuous( @@ -403,35 +417,44 @@ def _fit_encodings_all(self, X, y): y, n_categories, self.target_mean_, + sample_weight, ) self.encodings_ = encodings return X_ordinal, X_known_mask, y, n_categories def _fit_encoding_binary_or_continuous( - self, X_ordinal, y, n_categories, target_mean + self, X_ordinal, y, n_categories, target_mean, sample_weight ): """Learn target encodings.""" if self.smooth == "auto": - y_variance = np.var(y) + y_variance = np.sum(sample_weight * (y - target_mean) ** 2) / ( + np.sum(sample_weight) + ) + encodings = _fit_encoding_fast_auto_smooth( X_ordinal, y, + sample_weight, n_categories, target_mean, y_variance, ) else: + encodings = _fit_encoding_fast( X_ordinal, y, + sample_weight, n_categories, self.smooth, target_mean, ) return encodings - def _fit_encoding_multiclass(self, X_ordinal, y, n_categories, target_mean): + def _fit_encoding_multiclass( + self, X_ordinal, y, n_categories, target_mean, sample_weight + ): """Learn multiclass encodings. Learn encodings for each class (c) then reorder encodings such that @@ -452,6 +475,7 @@ def _fit_encoding_multiclass(self, X_ordinal, y, n_categories, target_mean): y_class, n_categories, target_mean[i], + sample_weight, ) encodings.extend(encoding) diff --git a/sklearn/preprocessing/_target_encoder_fast.pyx b/sklearn/preprocessing/_target_encoder_fast.pyx index dca5f78e8d60f..031669071eef5 100644 --- a/sklearn/preprocessing/_target_encoder_fast.pyx +++ b/sklearn/preprocessing/_target_encoder_fast.pyx @@ -15,11 +15,17 @@ ctypedef fused Y_DTYPE: int32_t float64_t float32_t +ctypedef fused W_DTYPE: + int64_t + int32_t + float64_t + float32_t def _fit_encoding_fast( INT_DTYPE[:, ::1] X_int, const Y_DTYPE[:] y, + const W_DTYPE[:] sample_weight, int64_t[::1] n_categories, double smooth, double y_mean, @@ -65,8 +71,8 @@ def _fit_encoding_fast( # -1 are unknown categories, which are not counted if X_int_tmp == -1: continue - sums[X_int_tmp] += y[sample_idx] - counts[X_int_tmp] += 1.0 + sums[X_int_tmp] += y[sample_idx] * sample_weight[sample_idx] + counts[X_int_tmp] += sample_weight[sample_idx] for cat_idx in range(n_cats): if counts[cat_idx] == 0: @@ -80,6 +86,7 @@ def _fit_encoding_fast( def _fit_encoding_fast_auto_smooth( INT_DTYPE[:, ::1] X_int, const Y_DTYPE[:] y, + const W_DTYPE[:] sample_weight, int64_t[::1] n_categories, double y_mean, double y_variance, @@ -99,7 +106,7 @@ def _fit_encoding_fast_auto_smooth( int n_features = X_int.shape[1] int64_t max_n_cats = np.max(n_categories) double[::1] means = np.empty(max_n_cats, dtype=np.float64) - int64_t[::1] counts = np.empty(max_n_cats, dtype=np.int64) + double[::1] weighted_counts = np.empty(max_n_cats, dtype=np.float64) double[::1] sum_of_squared_diffs = np.empty(max_n_cats, dtype=np.float64) double lambda_ list encodings = [] @@ -124,21 +131,21 @@ def _fit_encoding_fast_auto_smooth( for cat_idx in range(n_cats): means[cat_idx] = 0.0 - counts[cat_idx] = 0 + weighted_counts[cat_idx] = 0.0 sum_of_squared_diffs[cat_idx] = 0.0 - # first pass to compute the mean + # first pass to compute the weighted mean for sample_idx in range(n_samples): X_int_tmp = X_int[sample_idx, feat_idx] # -1 are unknown categories, which are not counted if X_int_tmp == -1: continue - counts[X_int_tmp] += 1 - means[X_int_tmp] += y[sample_idx] + weighted_counts[X_int_tmp] += sample_weight[sample_idx] + means[X_int_tmp] += y[sample_idx] * sample_weight[sample_idx] for cat_idx in range(n_cats): - means[cat_idx] /= counts[cat_idx] + means[cat_idx] /= weighted_counts[cat_idx] # second pass to compute the sum of squared differences for sample_idx in range(n_samples): @@ -146,13 +153,13 @@ def _fit_encoding_fast_auto_smooth( if X_int_tmp == -1: continue diff = y[sample_idx] - means[X_int_tmp] - sum_of_squared_diffs[X_int_tmp] += diff * diff + sum_of_squared_diffs[X_int_tmp] += diff * diff * sample_weight[sample_idx] for cat_idx in range(n_cats): lambda_ = ( - y_variance * counts[cat_idx] / - (y_variance * counts[cat_idx] + sum_of_squared_diffs[cat_idx] / - counts[cat_idx]) + y_variance * weighted_counts[cat_idx] / + (y_variance * weighted_counts[cat_idx] + sum_of_squared_diffs[cat_idx] / + weighted_counts[cat_idx]) ) if isnan(lambda_): # A nan can happen when: diff --git a/sklearn/preprocessing/tests/test_target_encoder.py b/sklearn/preprocessing/tests/test_target_encoder.py index c1e707b9bff98..af51ef715048e 100644 --- a/sklearn/preprocessing/tests/test_target_encoder.py +++ b/sklearn/preprocessing/tests/test_target_encoder.py @@ -22,32 +22,64 @@ ) -def _encode_target(X_ordinal, y_numeric, n_categories, smooth): +def _encode_target(X_ordinal, y_numeric, n_categories, smooth, sample_weight=None): """Simple Python implementation of target encoding.""" cur_encodings = np.zeros(n_categories, dtype=np.float64) - y_mean = np.mean(y_numeric) + if sample_weight is not None: + y_mean = np.average(y_numeric, weights=sample_weight) + else: + y_mean = np.mean(y_numeric) if smooth == "auto": - y_variance = np.var(y_numeric) + if sample_weight is not None: + y_variance = np.sum(sample_weight * (y_numeric - y_mean) ** 2) / np.sum( + sample_weight + ) + else: + y_variance = np.var(y_numeric) + for c in range(n_categories): - y_subset = y_numeric[X_ordinal == c] - n_i = y_subset.shape[0] + mask = X_ordinal == c + y_subset = y_numeric[mask] + + if sample_weight is not None: + weights_subset = sample_weight[mask] + + n_i = np.sum(weights_subset) + + y_subset_mean = np.average(y_subset, weights=weights_subset, axis=0) + + y_subset_variance = np.sum( + weights_subset * (y_subset - y_subset_mean) ** 2 + ) / (np.sum(weights_subset)) + + else: + n_i = y_subset.shape[0] + + if n_i == 0: + cur_encodings[c] = y_mean + continue - if n_i == 0: - cur_encodings[c] = y_mean - continue + y_subset_mean = np.mean(y_subset) + y_subset_variance = np.var(y_subset) - y_subset_variance = np.var(y_subset) m = y_subset_variance / y_variance lambda_ = n_i / (n_i + m) - cur_encodings[c] = lambda_ * np.mean(y_subset) + (1 - lambda_) * y_mean + cur_encodings[c] = lambda_ * y_subset_mean + (1 - lambda_) * y_mean return cur_encodings + else: # float for c in range(n_categories): - y_subset = y_numeric[X_ordinal == c] - current_sum = np.sum(y_subset) + y_mean * smooth - current_cnt = y_subset.shape[0] + smooth + mask = X_ordinal == c + y_subset = y_numeric[mask] + if sample_weight is not None: + weights_subset = sample_weight[mask] + current_sum = np.sum(y_subset * weights_subset) + y_mean * smooth + current_cnt = np.sum(weights_subset) + smooth + else: + current_sum = np.sum(y_subset) + y_mean * smooth + current_cnt = y_subset.shape[0] + smooth cur_encodings[c] = current_sum / current_cnt return cur_encodings @@ -61,9 +93,12 @@ def _encode_target(X_ordinal, y_numeric, n_categories, smooth): ("auto", 3), ], ) +@pytest.mark.parametrize("sample_weight", [None, np.random.RandomState(42).rand(90)]) @pytest.mark.parametrize("smooth", [5.0, "auto"]) @pytest.mark.parametrize("target_type", ["binary", "continuous"]) -def test_encoding(categories, unknown_value, global_random_seed, smooth, target_type): +def test_encoding( + categories, unknown_value, global_random_seed, smooth, target_type, sample_weight +): """Check encoding for binary and continuous targets. Compare the values returned by `TargetEncoder.fit_transform` against the @@ -117,7 +152,14 @@ def test_encoding(categories, unknown_value, global_random_seed, smooth, target_ for train_idx, test_idx in cv.split(X_train_int_array, y_train): X_, y_ = X_train_int_array[train_idx, 0], y_numeric[train_idx] - cur_encodings = _encode_target(X_, y_, n_categories, smooth) + if sample_weight is not None: + sample_weight_ = sample_weight[train_idx] + cur_encodings = _encode_target( + X_, y_, n_categories, smooth, sample_weight=sample_weight_ + ) + else: + cur_encodings = _encode_target(X_, y_, n_categories, smooth) + expected_X_fit_transform[test_idx, 0] = cur_encodings[ X_train_int_array[test_idx, 0] ] @@ -131,10 +173,12 @@ def test_encoding(categories, unknown_value, global_random_seed, smooth, target_ random_state=global_random_seed, ) - X_fit_transform = target_encoder.fit_transform(X_train, y_train) + X_fit_transform = target_encoder.fit_transform(X_train, y_train, sample_weight) assert target_encoder.target_type_ == target_type + assert_allclose(X_fit_transform, expected_X_fit_transform) + assert len(target_encoder.encodings_) == 1 if target_type == "binary": assert_array_equal(target_encoder.classes_, target_names) @@ -142,10 +186,19 @@ def test_encoding(categories, unknown_value, global_random_seed, smooth, target_ assert target_encoder.classes_ is None # compute encodings for all data to validate `transform` - y_mean = np.mean(y_numeric) + y_mean = ( + np.average(y_numeric, weights=sample_weight) + if sample_weight is not None + else np.mean(y_numeric) + ) expected_encodings = _encode_target( - X_train_int_array[:, 0], y_numeric, n_categories, smooth + X_train_int_array[:, 0], + y_numeric, + n_categories, + smooth, + sample_weight=sample_weight, ) + assert_allclose(target_encoder.encodings_[0], expected_encodings) assert target_encoder.target_mean_ == pytest.approx(y_mean)