diff --git a/benchmarks/bench_20newsgroups.py b/benchmarks/bench_20newsgroups.py index 44a117f1ad42d..a559bc59b5f8a 100644 --- a/benchmarks/bench_20newsgroups.py +++ b/benchmarks/bench_20newsgroups.py @@ -21,7 +21,7 @@ "extra_trees": ExtraTreesClassifier(max_features="sqrt", min_samples_split=10), "logistic_regression": LogisticRegression(), "naive_bayes": MultinomialNB(), - "adaboost": AdaBoostClassifier(n_estimators=10, algorithm="SAMME"), + "adaboost": AdaBoostClassifier(n_estimators=10), } diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 3a2c85d138bfc..8a466b24b9732 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -1709,7 +1709,7 @@ learners:: >>> from sklearn.ensemble import AdaBoostClassifier >>> X, y = load_iris(return_X_y=True) - >>> clf = AdaBoostClassifier(n_estimators=100, algorithm="SAMME",) + >>> clf = AdaBoostClassifier(n_estimators=100) >>> scores = cross_val_score(clf, X, y, cv=5) >>> scores.mean() 0.9... diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 1a3a89c1156e2..d92e7d726d15a 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -246,6 +246,10 @@ Changelog right child node as the tree is traversed. :pr:`28268` by :user:`Adam Li `. +- |API| The parameter `algorithm` of :class:`ensemble.AdaBoostClassifier` is deprecated + and will be removed in 1.8. + :pr:`29997` by :user:`Jérémie du Boisberranger `. + :mod:`sklearn.impute` ..................... diff --git a/examples/classification/plot_classifier_comparison.py b/examples/classification/plot_classifier_comparison.py index 5747d00ba7950..7028eaa70e029 100644 --- a/examples/classification/plot_classifier_comparison.py +++ b/examples/classification/plot_classifier_comparison.py @@ -64,7 +64,7 @@ max_depth=5, n_estimators=10, max_features=1, random_state=42 ), MLPClassifier(alpha=1, max_iter=1000, random_state=42), - AdaBoostClassifier(algorithm="SAMME", random_state=42), + AdaBoostClassifier(random_state=42), GaussianNB(), QuadraticDiscriminantAnalysis(), ] diff --git a/examples/ensemble/plot_adaboost_multiclass.py b/examples/ensemble/plot_adaboost_multiclass.py index a18ff4e09c7bb..e0c30ae1586b6 100644 --- a/examples/ensemble/plot_adaboost_multiclass.py +++ b/examples/ensemble/plot_adaboost_multiclass.py @@ -80,7 +80,6 @@ adaboost_clf = AdaBoostClassifier( estimator=weak_learner, n_estimators=n_estimators, - algorithm="SAMME", random_state=42, ).fit(X_train, y_train) diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index 5d1554eb754d4..c499a9f6dc44b 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -39,10 +39,7 @@ y = np.concatenate((y1, -y2 + 1)) # Create and fit an AdaBoosted decision tree -bdt = AdaBoostClassifier( - DecisionTreeClassifier(max_depth=1), algorithm="SAMME", n_estimators=200 -) - +bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=1), n_estimators=200) bdt.fit(X, y) plot_colors = "br" diff --git a/examples/ensemble/plot_forest_iris.py b/examples/ensemble/plot_forest_iris.py index 78a28e521ff90..1342872bb4d37 100644 --- a/examples/ensemble/plot_forest_iris.py +++ b/examples/ensemble/plot_forest_iris.py @@ -74,11 +74,7 @@ DecisionTreeClassifier(max_depth=None), RandomForestClassifier(n_estimators=n_estimators), ExtraTreesClassifier(n_estimators=n_estimators), - AdaBoostClassifier( - DecisionTreeClassifier(max_depth=3), - n_estimators=n_estimators, - algorithm="SAMME", - ), + AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=n_estimators), ] for pair in ([0, 1], [0, 2], [2, 3]): diff --git a/sklearn/ensemble/_weight_boosting.py b/sklearn/ensemble/_weight_boosting.py index 290360622100a..3569a85b5fc3c 100644 --- a/sklearn/ensemble/_weight_boosting.py +++ b/sklearn/ensemble/_weight_boosting.py @@ -24,7 +24,6 @@ from numbers import Integral, Real import numpy as np -from scipy.special import xlogy from ..base import ( ClassifierMixin, @@ -36,7 +35,7 @@ from ..metrics import accuracy_score, r2_score from ..tree import DecisionTreeClassifier, DecisionTreeRegressor from ..utils import _safe_indexing, check_random_state -from ..utils._param_validation import HasMethods, Interval, StrOptions +from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions from ..utils.extmath import softmax, stable_cumsum from ..utils.metadata_routing import ( _raise_for_unsupported_routing, @@ -375,16 +374,12 @@ class AdaBoostClassifier( a trade-off between the `learning_rate` and `n_estimators` parameters. Values must be in the range `(0.0, inf)`. - algorithm : {'SAMME', 'SAMME.R'}, default='SAMME.R' - If 'SAMME.R' then use the SAMME.R real boosting algorithm. - ``estimator`` must support calculation of class probabilities. - If 'SAMME' then use the SAMME discrete boosting algorithm. - The SAMME.R algorithm typically converges faster than SAMME, - achieving a lower test error with fewer boosting iterations. + algorithm : {'SAMME'}, default='SAMME' + Use the SAMME discrete boosting algorithm. - .. deprecated:: 1.4 - `"SAMME.R"` is deprecated and will be removed in version 1.6. - '"SAMME"' will become the default. + .. deprecated:: 1.6 + `algorithm` is deprecated and will be removed in version 1.8. This + estimator only implements the 'SAMME' algorithm. random_state : int, RandomState instance or None, default=None Controls the random seed given at each `estimator` at each @@ -470,9 +465,9 @@ class AdaBoostClassifier( >>> X, y = make_classification(n_samples=1000, n_features=4, ... n_informative=2, n_redundant=0, ... random_state=0, shuffle=False) - >>> clf = AdaBoostClassifier(n_estimators=100, algorithm="SAMME", random_state=0) + >>> clf = AdaBoostClassifier(n_estimators=100, random_state=0) >>> clf.fit(X, y) - AdaBoostClassifier(algorithm='SAMME', n_estimators=100, random_state=0) + AdaBoostClassifier(n_estimators=100, random_state=0) >>> clf.predict([[0, 0, 0, 0]]) array([1]) >>> clf.score(X, y) @@ -487,23 +482,19 @@ class AdaBoostClassifier( refer to :ref:`sphx_glr_auto_examples_ensemble_plot_adaboost_twoclass.py`. """ - # TODO(1.6): Modify _parameter_constraints for "algorithm" to only check - # for "SAMME" + # TODO(1.8): remove "algorithm" entry _parameter_constraints: dict = { **BaseWeightBoosting._parameter_constraints, - "algorithm": [ - StrOptions({"SAMME", "SAMME.R"}), - ], + "algorithm": [StrOptions({"SAMME"}), Hidden(StrOptions({"deprecated"}))], } - # TODO(1.6): Change default "algorithm" value to "SAMME" def __init__( self, estimator=None, *, n_estimators=50, learning_rate=1.0, - algorithm="SAMME.R", + algorithm="deprecated", random_state=None, ): super().__init__( @@ -519,43 +510,23 @@ def _validate_estimator(self): """Check the estimator and set the estimator_ attribute.""" super()._validate_estimator(default=DecisionTreeClassifier(max_depth=1)) - # TODO(1.6): Remove, as "SAMME.R" value for "algorithm" param will be - # removed in 1.6 - # SAMME-R requires predict_proba-enabled base estimators - if self.algorithm != "SAMME": + if self.algorithm != "deprecated": warnings.warn( - ( - "The SAMME.R algorithm (the default) is deprecated and will be" - " removed in 1.6. Use the SAMME algorithm to circumvent this" - " warning." - ), + "The parameter 'algorithm' is deprecated in 1.6 and has no effect. " + "It will be removed in version 1.8.", FutureWarning, ) - if not hasattr(self.estimator_, "predict_proba"): - raise TypeError( - "AdaBoostClassifier with algorithm='SAMME.R' requires " - "that the weak learner supports the calculation of class " - "probabilities with a predict_proba method.\n" - "Please change the base estimator or set " - "algorithm='SAMME' instead." - ) if not has_fit_parameter(self.estimator_, "sample_weight"): raise ValueError( f"{self.estimator.__class__.__name__} doesn't support sample_weight." ) - # TODO(1.6): Redefine the scope of the `_boost` and `_boost_discrete` - # functions to be the same since SAMME will be the default value for the - # "algorithm" parameter in version 1.6. Thus, a distinguishing function is - # no longer needed. (Or adjust code here, if another algorithm, shall be - # used instead of SAMME.R.) def _boost(self, iboost, X, y, sample_weight, random_state): """Implement a single boost. - Perform a single boost according to the real multi-class SAMME.R - algorithm or to the discrete SAMME algorithm and return the updated - sample weights. + Perform a single boost according to the discrete SAMME algorithm and return the + updated sample weights. Parameters ---------- @@ -589,75 +560,6 @@ def _boost(self, iboost, X, y, sample_weight, random_state): The classification error for the current boost. If None then boosting has terminated early. """ - if self.algorithm == "SAMME.R": - return self._boost_real(iboost, X, y, sample_weight, random_state) - - else: # elif self.algorithm == "SAMME": - return self._boost_discrete(iboost, X, y, sample_weight, random_state) - - # TODO(1.6): Remove function. The `_boost_real` function won't be used any - # longer, because the SAMME.R algorithm will be deprecated in 1.6. - def _boost_real(self, iboost, X, y, sample_weight, random_state): - """Implement a single boost using the SAMME.R real algorithm.""" - estimator = self._make_estimator(random_state=random_state) - - estimator.fit(X, y, sample_weight=sample_weight) - - y_predict_proba = estimator.predict_proba(X) - - if iboost == 0: - self.classes_ = getattr(estimator, "classes_", None) - self.n_classes_ = len(self.classes_) - - y_predict = self.classes_.take(np.argmax(y_predict_proba, axis=1), axis=0) - - # Instances incorrectly classified - incorrect = y_predict != y - - # Error fraction - estimator_error = np.mean(np.average(incorrect, weights=sample_weight, axis=0)) - - # Stop if classification is perfect - if estimator_error <= 0: - return sample_weight, 1.0, 0.0 - - # Construct y coding as described in Zhu et al [2]: - # - # y_k = 1 if c == k else -1 / (K - 1) - # - # where K == n_classes_ and c, k in [0, K) are indices along the second - # axis of the y coding with c being the index corresponding to the true - # class label. - n_classes = self.n_classes_ - classes = self.classes_ - y_codes = np.array([-1.0 / (n_classes - 1), 1.0]) - y_coding = y_codes.take(classes == y[:, np.newaxis]) - - # Displace zero probabilities so the log is defined. - # Also fix negative elements which may occur with - # negative sample weights. - proba = y_predict_proba # alias for readability - np.clip(proba, np.finfo(proba.dtype).eps, None, out=proba) - - # Boost weight using multi-class AdaBoost SAMME.R alg - estimator_weight = ( - -1.0 - * self.learning_rate - * ((n_classes - 1.0) / n_classes) - * xlogy(y_coding, y_predict_proba).sum(axis=1) - ) - - # Only boost the weights if it will fit again - if not iboost == self.n_estimators - 1: - # Only boost positive weights - sample_weight *= np.exp( - estimator_weight * ((sample_weight > 0) | (estimator_weight < 0)) - ) - - return sample_weight, 1.0, estimator_error - - def _boost_discrete(self, iboost, X, y, sample_weight, random_state): - """Implement a single boost using the SAMME discrete algorithm.""" estimator = self._make_estimator(random_state=random_state) estimator.fit(X, y, sample_weight=sample_weight) @@ -789,21 +691,17 @@ class in ``classes_``, respectively. n_classes = self.n_classes_ classes = self.classes_[:, np.newaxis] - # TODO(1.6): Remove, because "algorithm" param will be deprecated in 1.6 - if self.algorithm == "SAMME.R": - # The weights are all 1. for SAMME.R - pred = sum( - _samme_proba(estimator, n_classes, X) for estimator in self.estimators_ - ) - else: # self.algorithm == "SAMME" - pred = sum( - np.where( - (estimator.predict(X) == classes).T, - w, - -1 / (n_classes - 1) * w, - ) - for estimator, w in zip(self.estimators_, self.estimator_weights_) + if n_classes == 1: + return np.zeros_like(X, shape=(X.shape[0], 1)) + + pred = sum( + np.where( + (estimator.predict(X) == classes).T, + w, + -1 / (n_classes - 1) * w, ) + for estimator, w in zip(self.estimators_, self.estimator_weights_) + ) pred /= self.estimator_weights_.sum() if n_classes == 2: @@ -844,17 +742,11 @@ class in ``classes_``, respectively. for weight, estimator in zip(self.estimator_weights_, self.estimators_): norm += weight - # TODO(1.6): Remove, because "algorithm" param will be deprecated in - # 1.6 - if self.algorithm == "SAMME.R": - # The weights are all 1. for SAMME.R - current_pred = _samme_proba(estimator, n_classes, X) - else: # elif self.algorithm == "SAMME": - current_pred = np.where( - (estimator.predict(X) == classes).T, - weight, - -1 / (n_classes - 1) * weight, - ) + current_pred = np.where( + (estimator.predict(X) == classes).T, + weight, + -1 / (n_classes - 1) * weight, + ) if pred is None: pred = current_pred diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 44f28792a717e..f5386804d77d7 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -965,7 +965,7 @@ def test_bagging_with_metadata_routing(model): "model", [ BaggingClassifier( - estimator=AdaBoostClassifier(n_estimators=1, algorithm="SAMME"), + estimator=AdaBoostClassifier(n_estimators=1), n_estimators=1, ), BaggingRegressor(estimator=AdaBoostRegressor(n_estimators=1), n_estimators=1), diff --git a/sklearn/ensemble/tests/test_weight_boosting.py b/sklearn/ensemble/tests/test_weight_boosting.py index 251139de62940..55825c438d76b 100755 --- a/sklearn/ensemble/tests/test_weight_boosting.py +++ b/sklearn/ensemble/tests/test_weight_boosting.py @@ -20,7 +20,6 @@ assert_allclose, assert_array_almost_equal, assert_array_equal, - assert_array_less, ) from sklearn.utils.fixes import ( COO_CONTAINERS, @@ -87,18 +86,13 @@ def test_oneclass_adaboost_proba(): # In response to issue #7501 # https://github.com/scikit-learn/scikit-learn/issues/7501 y_t = np.ones(len(X)) - clf = AdaBoostClassifier(algorithm="SAMME").fit(X, y_t) + clf = AdaBoostClassifier().fit(X, y_t) assert_array_almost_equal(clf.predict_proba(X), np.ones((len(X), 1))) -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") -@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"]) -def test_classification_toy(algorithm): +def test_classification_toy(): # Check classification on a toy dataset. - clf = AdaBoostClassifier(algorithm=algorithm, random_state=0) + clf = AdaBoostClassifier(random_state=0) clf.fit(X, y_class) assert_array_equal(clf.predict(T), y_t_class) assert_array_equal(np.unique(np.asarray(y_t_class)), clf.classes_) @@ -113,42 +107,26 @@ def test_regression_toy(): assert_array_equal(clf.predict(T), y_t_regr) -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") def test_iris(): # Check consistency on dataset iris. classes = np.unique(iris.target) - clf_samme = prob_samme = None - for alg in ["SAMME", "SAMME.R"]: - clf = AdaBoostClassifier(algorithm=alg) - clf.fit(iris.data, iris.target) + clf = AdaBoostClassifier() + clf.fit(iris.data, iris.target) - assert_array_equal(classes, clf.classes_) - proba = clf.predict_proba(iris.data) - if alg == "SAMME": - clf_samme = clf - prob_samme = proba - assert proba.shape[1] == len(classes) - assert clf.decision_function(iris.data).shape[1] == len(classes) - - score = clf.score(iris.data, iris.target) - assert score > 0.9, "Failed with algorithm %s and score = %f" % (alg, score) - - # Check we used multiple estimators - assert len(clf.estimators_) > 1 - # Check for distinct random states (see issue #7408) - assert len(set(est.random_state for est in clf.estimators_)) == len( - clf.estimators_ - ) + assert_array_equal(classes, clf.classes_) + proba = clf.predict_proba(iris.data) + + assert proba.shape[1] == len(classes) + assert clf.decision_function(iris.data).shape[1] == len(classes) - # Somewhat hacky regression test: prior to - # ae7adc880d624615a34bafdb1d75ef67051b8200, - # predict_proba returned SAMME.R values for SAMME. - clf_samme.algorithm = "SAMME.R" - assert_array_less(0, np.abs(clf_samme.predict_proba(iris.data) - prob_samme)) + score = clf.score(iris.data, iris.target) + assert score > 0.9, f"Failed with {score = }" + + # Check we used multiple estimators + assert len(clf.estimators_) > 1 + # Check for distinct random states (see issue #7408) + assert len(set(est.random_state for est in clf.estimators_)) == len(clf.estimators_) @pytest.mark.parametrize("loss", ["linear", "square", "exponential"]) @@ -165,18 +143,13 @@ def test_diabetes(loss): assert len(set(est.random_state for est in reg.estimators_)) == len(reg.estimators_) -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") -@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"]) -def test_staged_predict(algorithm): +def test_staged_predict(): # Check staged predictions. rng = np.random.RandomState(0) iris_weights = rng.randint(10, size=iris.target.shape) diabetes_weights = rng.randint(10, size=diabetes.target.shape) - clf = AdaBoostClassifier(algorithm=algorithm, n_estimators=10) + clf = AdaBoostClassifier(n_estimators=10) clf.fit(iris.data, iris.target, sample_weight=iris_weights) predictions = clf.predict(iris.data) @@ -222,7 +195,6 @@ def test_gridsearch(): parameters = { "n_estimators": (1, 2), "estimator__max_depth": (1, 2), - "algorithm": ("SAMME", "SAMME.R"), } clf = GridSearchCV(boost, parameters) clf.fit(iris.data, iris.target) @@ -234,25 +206,20 @@ def test_gridsearch(): clf.fit(diabetes.data, diabetes.target) -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") def test_pickle(): # Check pickability. import pickle # Adaboost classifier - for alg in ["SAMME", "SAMME.R"]: - obj = AdaBoostClassifier(algorithm=alg) - obj.fit(iris.data, iris.target) - score = obj.score(iris.data, iris.target) - s = pickle.dumps(obj) + obj = AdaBoostClassifier() + obj.fit(iris.data, iris.target) + score = obj.score(iris.data, iris.target) + s = pickle.dumps(obj) - obj2 = pickle.loads(s) - assert type(obj2) == obj.__class__ - score2 = obj2.score(iris.data, iris.target) - assert score == score2 + obj2 = pickle.loads(s) + assert type(obj2) == obj.__class__ + score2 = obj2.score(iris.data, iris.target) + assert score == score2 # Adaboost regressor obj = AdaBoostRegressor(random_state=0) @@ -266,10 +233,6 @@ def test_pickle(): assert score == score2 -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") def test_importances(): # Check variable importances. X, y = datasets.make_classification( @@ -282,14 +245,13 @@ def test_importances(): random_state=1, ) - for alg in ["SAMME", "SAMME.R"]: - clf = AdaBoostClassifier(algorithm=alg) + clf = AdaBoostClassifier() - clf.fit(X, y) - importances = clf.feature_importances_ + clf.fit(X, y) + importances = clf.feature_importances_ - assert importances.shape[0] == 10 - assert (importances[:3, np.newaxis] >= importances[3:]).all() + assert importances.shape[0] == 10 + assert (importances[:3, np.newaxis] >= importances[3:]).all() def test_adaboost_classifier_sample_weight_error(): @@ -306,10 +268,10 @@ def test_estimator(): # XXX doesn't work with y_class because RF doesn't support classes_ # Shouldn't AdaBoost run a LabelBinarizer? - clf = AdaBoostClassifier(RandomForestClassifier(), algorithm="SAMME") + clf = AdaBoostClassifier(RandomForestClassifier()) clf.fit(X, y_regr) - clf = AdaBoostClassifier(SVC(), algorithm="SAMME") + clf = AdaBoostClassifier(SVC()) clf.fit(X, y_class) from sklearn.ensemble import RandomForestRegressor @@ -323,14 +285,14 @@ def test_estimator(): # Check that an empty discrete ensemble fails in fit, not predict. X_fail = [[1, 1], [1, 1], [1, 1], [1, 1]] y_fail = ["foo", "bar", 1, 2] - clf = AdaBoostClassifier(SVC(), algorithm="SAMME") + clf = AdaBoostClassifier(SVC()) with pytest.raises(ValueError, match="worse than random"): clf.fit(X_fail, y_fail) def test_sample_weights_infinite(): msg = "Sample weights have reached infinite values" - clf = AdaBoostClassifier(n_estimators=30, learning_rate=23.0, algorithm="SAMME") + clf = AdaBoostClassifier(n_estimators=30, learning_rate=23.0) with pytest.warns(UserWarning, match=msg): clf.fit(iris.data, iris.target) @@ -375,14 +337,12 @@ def fit(self, X, y, sample_weight=None): sparse_classifier = AdaBoostClassifier( estimator=CustomSVC(probability=True), random_state=1, - algorithm="SAMME", ).fit(X_train_sparse, y_train) # Trained on dense format dense_classifier = AdaBoostClassifier( estimator=CustomSVC(probability=True), random_state=1, - algorithm="SAMME", ).fit(X_train, y_train) # predict @@ -530,9 +490,7 @@ def test_multidimensional_X(): yc = rng.choice([0, 1], 51) yr = rng.randn(51) - boost = AdaBoostClassifier( - DummyClassifier(strategy="most_frequent"), algorithm="SAMME" - ) + boost = AdaBoostClassifier(DummyClassifier(strategy="most_frequent")) boost.fit(X, yc) boost.predict(X) boost.predict_proba(X) @@ -542,15 +500,10 @@ def test_multidimensional_X(): boost.predict(X) -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") -@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"]) -def test_adaboostclassifier_without_sample_weight(algorithm): +def test_adaboostclassifier_without_sample_weight(): X, y = iris.data, iris.target estimator = NoSampleWeightWrapper(DummyClassifier()) - clf = AdaBoostClassifier(estimator=estimator, algorithm=algorithm) + clf = AdaBoostClassifier(estimator=estimator) err_msg = "{} doesn't support sample_weight".format(estimator.__class__.__name__) with pytest.raises(ValueError, match=err_msg): clf.fit(X, y) @@ -594,19 +547,14 @@ def test_adaboostregressor_sample_weight(): assert score_no_outlier == pytest.approx(score_with_weight) -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") -@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"]) -def test_adaboost_consistent_predict(algorithm): +def test_adaboost_consistent_predict(): # check that predict_proba and predict give consistent results # regression test for: # https://github.com/scikit-learn/scikit-learn/issues/14084 X_train, X_test, y_train, y_test = train_test_split( *datasets.load_digits(return_X_y=True), random_state=42 ) - model = AdaBoostClassifier(algorithm=algorithm, random_state=42) + model = AdaBoostClassifier(random_state=42) model.fit(X_train, y_train) assert_array_equal( @@ -642,19 +590,12 @@ def test_adaboost_numerically_stable_feature_importance_with_small_weights(): y = rng.choice([0, 1], size=1000) sample_weight = np.ones_like(y) * 1e-263 tree = DecisionTreeClassifier(max_depth=10, random_state=12) - ada_model = AdaBoostClassifier( - estimator=tree, n_estimators=20, algorithm="SAMME", random_state=12 - ) + ada_model = AdaBoostClassifier(estimator=tree, n_estimators=20, random_state=12) ada_model.fit(X, y, sample_weight=sample_weight) assert np.isnan(ada_model.feature_importances_).sum() == 0 -# TODO(1.6): remove "@pytest.mark.filterwarnings" as SAMME.R will be removed -# and substituted with the SAMME algorithm as a default; also re-write test to -# only consider "SAMME" -@pytest.mark.filterwarnings("ignore:The SAMME.R algorithm") -@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"]) -def test_adaboost_decision_function(algorithm, global_random_seed): +def test_adaboost_decision_function(global_random_seed): """Check that the decision function respects the symmetric constraint for weak learners. @@ -665,26 +606,22 @@ def test_adaboost_decision_function(algorithm, global_random_seed): X, y = datasets.make_classification( n_classes=n_classes, n_clusters_per_class=1, random_state=global_random_seed ) - clf = AdaBoostClassifier( - n_estimators=1, random_state=global_random_seed, algorithm=algorithm - ).fit(X, y) + clf = AdaBoostClassifier(n_estimators=1, random_state=global_random_seed).fit(X, y) y_score = clf.decision_function(X) assert_allclose(y_score.sum(axis=1), 0, atol=1e-8) - if algorithm == "SAMME": - # With a single learner, we expect to have a decision function in - # {1, - 1 / (n_classes - 1)}. - assert set(np.unique(y_score)) == {1, -1 / (n_classes - 1)} + # With a single learner, we expect to have a decision function in + # {1, - 1 / (n_classes - 1)}. + assert set(np.unique(y_score)) == {1, -1 / (n_classes - 1)} # We can assert the same for staged_decision_function since we have a single learner for y_score in clf.staged_decision_function(X): assert_allclose(y_score.sum(axis=1), 0, atol=1e-8) - if algorithm == "SAMME": - # With a single learner, we expect to have a decision function in - # {1, - 1 / (n_classes - 1)}. - assert set(np.unique(y_score)) == {1, -1 / (n_classes - 1)} + # With a single learner, we expect to have a decision function in + # {1, - 1 / (n_classes - 1)}. + assert set(np.unique(y_score)) == {1, -1 / (n_classes - 1)} clf.set_params(n_estimators=5).fit(X, y) @@ -695,11 +632,8 @@ def test_adaboost_decision_function(algorithm, global_random_seed): assert_allclose(y_score.sum(axis=1), 0, atol=1e-8) -# TODO(1.6): remove -def test_deprecated_samme_r_algorithm(): - adaboost_clf = AdaBoostClassifier(n_estimators=1) - with pytest.warns( - FutureWarning, - match=re.escape("The SAMME.R algorithm (the default) is deprecated"), - ): +# TODO(1.8): remove +def test_deprecated_algorithm(): + adaboost_clf = AdaBoostClassifier(n_estimators=1, algorithm="SAMME") + with pytest.warns(FutureWarning, match="The parameter 'algorithm' is deprecated"): adaboost_clf.fit(X, y_class)