diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 6e331f6d0f1f8..f6c644a534b23 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -255,6 +255,10 @@ Changelog now raise consistent error messages. :pr:`15084` by :user:`Guillaume Lemaitre `. +- |Fix| :class:`ensemble.AdaBoostRegressor` where the loss should be normalized + by the max of the samples with non-null weights only. + :pr:`14294` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.feature_extraction` ................................. diff --git a/sklearn/ensemble/_weight_boosting.py b/sklearn/ensemble/_weight_boosting.py index 6a2a0ef312881..da6b4aa427fe3 100644 --- a/sklearn/ensemble/_weight_boosting.py +++ b/sklearn/ensemble/_weight_boosting.py @@ -38,6 +38,7 @@ from ..utils.extmath import stable_cumsum from ..metrics import accuracy_score, r2_score from ..utils.validation import check_is_fitted +from ..utils.validation import _check_sample_weight from ..utils.validation import has_fit_parameter from ..utils.validation import _num_samples @@ -117,20 +118,10 @@ def fit(self, X, y, sample_weight=None): X, y = self._validate_data(X, y) - if sample_weight is None: - # Initialize weights to 1 / n_samples - sample_weight = np.empty(_num_samples(X), dtype=np.float64) - sample_weight[:] = 1. / _num_samples(X) - else: - sample_weight = check_array(sample_weight, ensure_2d=False) - # Normalize existing weights - sample_weight = sample_weight / sample_weight.sum(dtype=np.float64) - - # Check that the sample weights sum is positive - if sample_weight.sum() <= 0: - raise ValueError( - "Attempting to fit with a non-positive " - "weighted number of samples.") + sample_weight = _check_sample_weight(sample_weight, X, np.float64) + sample_weight /= sample_weight.sum() + if np.any(sample_weight < 0): + raise ValueError("sample_weight cannot contain negative weights") # Check parameters self._validate_estimator() @@ -1029,13 +1020,10 @@ def _boost(self, iboost, X, y, sample_weight, random_state): estimator = self._make_estimator(random_state=random_state) # Weighted sampling of the training set with replacement - # For NumPy >= 1.7.0 use np.random.choice - cdf = stable_cumsum(sample_weight) - cdf /= cdf[-1] - uniform_samples = random_state.random_sample(_num_samples(X)) - bootstrap_idx = cdf.searchsorted(uniform_samples, side='right') - # searchsorted returns a scalar - bootstrap_idx = np.array(bootstrap_idx, copy=False) + bootstrap_idx = random_state.choice( + np.arange(_num_samples(X)), size=_num_samples(X), replace=True, + p=sample_weight + ) # Fit on the bootstrapped sample and obtain a prediction # for all samples in the training set @@ -1045,18 +1033,21 @@ def _boost(self, iboost, X, y, sample_weight, random_state): y_predict = estimator.predict(X) error_vect = np.abs(y_predict - y) - error_max = error_vect.max() + sample_mask = sample_weight > 0 + masked_sample_weight = sample_weight[sample_mask] + masked_error_vector = error_vect[sample_mask] - if error_max != 0.: - error_vect /= error_max + error_max = masked_error_vector.max() + if error_max != 0: + masked_error_vector /= error_max if self.loss == 'square': - error_vect **= 2 + masked_error_vector **= 2 elif self.loss == 'exponential': - error_vect = 1. - np.exp(- error_vect) + masked_error_vector = 1. - np.exp(-masked_error_vector) # Calculate the average loss - estimator_error = (sample_weight * error_vect).sum() + estimator_error = (masked_sample_weight * masked_error_vector).sum() if estimator_error <= 0: # Stop if fit is perfect @@ -1074,9 +1065,9 @@ def _boost(self, iboost, X, y, sample_weight, random_state): estimator_weight = self.learning_rate * np.log(1. / beta) if not iboost == self.n_estimators - 1: - sample_weight *= np.power( - beta, - (1. - error_vect) * self.learning_rate) + sample_weight[sample_mask] *= np.power( + beta, (1. - masked_error_vector) * self.learning_rate + ) return sample_weight, estimator_weight, estimator_error diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index a2c1373df95e5..f19c2cc09ce5e 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -24,6 +24,7 @@ from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split from sklearn.utils import check_random_state, tosequence +from sklearn.utils._mocking import NoSampleWeightWrapper from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal @@ -1292,20 +1293,6 @@ def test_early_stopping_stratified(): gbc.fit(X, y) -class _NoSampleWeightWrapper(BaseEstimator): - def __init__(self, est): - self.est = est - - def fit(self, X, y): - self.est.fit(X, y) - - def predict(self, X): - return self.est.predict(X) - - def predict_proba(self, X): - return self.est.predict_proba(X) - - def _make_multiclass(): return make_classification(n_classes=3, n_clusters_per_class=1) @@ -1330,7 +1317,7 @@ def test_gradient_boosting_with_init(gb, dataset_maker, init_estimator): gb(init=init_est).fit(X, y, sample_weight=sample_weight) # init does not support sample weights - init_est = _NoSampleWeightWrapper(init_estimator()) + init_est = NoSampleWeightWrapper(init_estimator()) gb(init=init_est).fit(X, y) # ok no sample weights with pytest.raises(ValueError, match="estimator.*does not support sample weights"): diff --git a/sklearn/ensemble/tests/test_weight_boosting.py b/sklearn/ensemble/tests/test_weight_boosting.py index 03261424e2095..bf5c8deab79da 100755 --- a/sklearn/ensemble/tests/test_weight_boosting.py +++ b/sklearn/ensemble/tests/test_weight_boosting.py @@ -3,24 +3,29 @@ import numpy as np import pytest +from scipy.sparse import csc_matrix +from scipy.sparse import csr_matrix +from scipy.sparse import coo_matrix +from scipy.sparse import dok_matrix +from scipy.sparse import lil_matrix + from sklearn.utils.testing import assert_array_equal, assert_array_less from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_raises, assert_raises_regexp from sklearn.base import BaseEstimator +from sklearn.base import clone +from sklearn.dummy import DummyClassifier, DummyRegressor +from sklearn.linear_model import LinearRegression from sklearn.model_selection import train_test_split from sklearn.model_selection import GridSearchCV from sklearn.ensemble import AdaBoostClassifier from sklearn.ensemble import AdaBoostRegressor from sklearn.ensemble._weight_boosting import _samme_proba -from scipy.sparse import csc_matrix -from scipy.sparse import csr_matrix -from scipy.sparse import coo_matrix -from scipy.sparse import dok_matrix -from scipy.sparse import lil_matrix from sklearn.svm import SVC, SVR from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils import shuffle +from sklearn.utils._mocking import NoSampleWeightWrapper from sklearn import datasets @@ -137,9 +142,10 @@ def test_iris(): np.abs(clf_samme.predict_proba(iris.data) - prob_samme)) -def test_boston(): +@pytest.mark.parametrize('loss', ['linear', 'square', 'exponential']) +def test_boston(loss): # Check consistency on dataset boston house prices. - reg = AdaBoostRegressor(random_state=0) + reg = AdaBoostRegressor(loss=loss, random_state=0) reg.fit(boston.data, boston.target) score = reg.score(boston.data, boston.target) assert score > 0.85 @@ -304,16 +310,6 @@ def test_base_estimator(): clf.fit, X_fail, y_fail) -def test_sample_weight_missing(): - from sklearn.cluster import KMeans - - clf = AdaBoostClassifier(KMeans(), algorithm="SAMME") - assert_raises(ValueError, clf.fit, X, y_regr) - - clf = AdaBoostRegressor(KMeans()) - assert_raises(ValueError, clf.fit, X, y_regr) - - def test_sparse_classification(): # Check classification with sparse input. @@ -486,9 +482,6 @@ def test_multidimensional_X(): Check that the AdaBoost estimators can work with n-dimensional data matrix """ - - from sklearn.dummy import DummyClassifier, DummyRegressor - rng = np.random.RandomState(0) X = rng.randn(50, 3, 3) @@ -505,6 +498,56 @@ def test_multidimensional_X(): boost.predict(X) +@pytest.mark.parametrize("algorithm", ['SAMME', 'SAMME.R']) +def test_adaboostclassifier_without_sample_weight(algorithm): + X, y = iris.data, iris.target + base_estimator = NoSampleWeightWrapper(DummyClassifier()) + clf = AdaBoostClassifier( + base_estimator=base_estimator, algorithm=algorithm + ) + err_msg = ("{} doesn't support sample_weight" + .format(base_estimator.__class__.__name__)) + with pytest.raises(ValueError, match=err_msg): + clf.fit(X, y) + + +def test_adaboostregressor_sample_weight(): + # check that giving weight will have an influence on the error computed + # for a weak learner + rng = np.random.RandomState(42) + X = np.linspace(0, 100, num=1000) + y = (.8 * X + 0.2) + (rng.rand(X.shape[0]) * 0.0001) + X = X.reshape(-1, 1) + + # add an arbitrary outlier + X[-1] *= 10 + y[-1] = 10000 + + # random_state=0 ensure that the underlying boostrap will use the outlier + regr_no_outlier = AdaBoostRegressor( + base_estimator=LinearRegression(), n_estimators=1, random_state=0 + ) + regr_with_weight = clone(regr_no_outlier) + regr_with_outlier = clone(regr_no_outlier) + + # fit 3 models: + # - a model containing the outlier + # - a model without the outlier + # - a model containing the outlier but with a null sample-weight + regr_with_outlier.fit(X, y) + regr_no_outlier.fit(X[:-1], y[:-1]) + sample_weight = np.ones_like(y) + sample_weight[-1] = 0 + regr_with_weight.fit(X, y, sample_weight=sample_weight) + + score_with_outlier = regr_with_outlier.score(X[:-1], y[:-1]) + score_no_outlier = regr_no_outlier.score(X[:-1], y[:-1]) + score_with_weight = regr_with_weight.score(X[:-1], y[:-1]) + + assert score_with_outlier < score_no_outlier + assert score_with_outlier < score_with_weight + assert score_no_outlier == pytest.approx(score_with_weight) + @pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"]) def test_adaboost_consistent_predict(algorithm): # check that predict_proba and predict give consistent results @@ -520,3 +563,17 @@ def test_adaboost_consistent_predict(algorithm): np.argmax(model.predict_proba(X_test), axis=1), model.predict(X_test) ) + + +@pytest.mark.parametrize( + 'model, X, y', + [(AdaBoostClassifier(), iris.data, iris.target), + (AdaBoostRegressor(), boston.data, boston.target)] +) +def test_adaboost_negative_weight_error(model, X, y): + sample_weight = np.ones_like(y) + sample_weight[-1] = -10 + + err_msg = "sample_weight cannot contain negative weight" + with pytest.raises(ValueError, match=err_msg): + model.fit(X, y, sample_weight=sample_weight) diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index 08198c455044e..3edcf8da53a95 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -135,3 +135,27 @@ def score(self, X=None, Y=None): def _more_tags(self): return {'_skip_test': True, 'X_types': ['1dlabel']} + + +class NoSampleWeightWrapper(BaseEstimator): + """Wrap estimator which will not expose `sample_weight`. + + Parameters + ---------- + est : estimator, default=None + The estimator to wrap. + """ + def __init__(self, est=None): + self.est = est + + def fit(self, X, y): + return self.est.fit(X, y) + + def predict(self, X): + return self.est.predict(X) + + def predict_proba(self, X): + return self.est.predict_proba(X) + + def _more_tags(self): + return {'_skip_test': True} # pragma: no cover diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 77f0e0b0d1612..7f52b6c3f4e64 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1061,8 +1061,8 @@ def _check_sample_weight(sample_weight, X, dtype=None): if dtype is None: dtype = [np.float64, np.float32] sample_weight = check_array( - sample_weight, accept_sparse=False, - ensure_2d=False, dtype=dtype, order="C" + sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype, + order="C" ) if sample_weight.ndim != 1: raise ValueError("Sample weights must be 1D array or scalar")