From c62106c72d51754d15c26d7ec5773655067dd3f9 Mon Sep 17 00:00:00 2001 From: Alihan Zihna Date: Wed, 10 Feb 2021 13:56:55 +0000 Subject: [PATCH 1/6] Add permutation imp and change dataset to digits --- .../plot_forest_importances_digits.py | 111 ++++++++++++++++++ .../ensemble/plot_forest_importances_faces.py | 48 -------- 2 files changed, 111 insertions(+), 48 deletions(-) create mode 100644 examples/ensemble/plot_forest_importances_digits.py delete mode 100644 examples/ensemble/plot_forest_importances_faces.py diff --git a/examples/ensemble/plot_forest_importances_digits.py b/examples/ensemble/plot_forest_importances_digits.py new file mode 100644 index 0000000000000..a04327aa93f70 --- /dev/null +++ b/examples/ensemble/plot_forest_importances_digits.py @@ -0,0 +1,111 @@ +""" +================================================= +Pixel importances with a parallel forest of trees +================================================= + +This example show the use of a forest of trees to evaluate +the importance of the pixels in an image classification task (digits) +based on impurity and permutation importance. +The hotter the pixel, the more important. + +The code below also illustrates how the construction and the computation +of the predictions can be parallelized within multiple jobs. +""" +print(__doc__) + +import matplotlib.pyplot as plt + +# %% +# Loading the data and model fitting +# ---------------------------------- +# We use the faces data from datasets submodules and split the dataset +# into training and testing subsets. Also, we'll set the number of cores +# to use for the tasks. +from sklearn.datasets import load_digits +from sklearn.model_selection import train_test_split + +# %% +# We select the number of cores to use to perform parallel fitting of +# the forest model. `-1` means use all available cores. +n_jobs = -1 + +# %% +# Load the faces dataset +data = load_digits() +X, y = data.data, data.target +X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=42) + +# %% +# A random forest classifier will be fitted to compute the feature importances. +from sklearn.ensemble import RandomForestClassifier + +forest = RandomForestClassifier( + n_estimators=750, n_jobs=n_jobs, random_state=42) + +forest.fit(X_train, y_train) + +# %% +# Feature importance based on mean decrease in impurity +# ----------------------------------------------------- +# Feature importances are provided by the fitted attribute +# `feature_importances_` and they are computed as the mean and standard +# deviation of accumulation of the impurity decrease within each tree. +# +# .. warning:: +# Impurity-based feature importances can be misleading for high cardinality +# features (many unique values). See :ref:`permutation_importance` as +# an alternative below. +import time + +start_time = time.time() +img_shape = data.images[0].shape +importances = forest.feature_importances_ +elapsed_time = time.time() - start_time + +print(f"Elapsed time to compute the importances: " + f"{elapsed_time:.3f} seconds") + +# %% +# Let's plot the impurity-based importance. +imp_reshaped = importances.reshape(img_shape) +plt.matshow(imp_reshaped, cmap=plt.cm.hot) +plt.title("Pixel importances using impurity values") +plt.colorbar() +plt.tight_layout() +plt.show() + +# %% +# Feature importance based on feature permutation +# ----------------------------------------------- +# Permutation feature importance overcomes limitations of the impurity-based +# feature importance: they do not have a bias toward high-cardinality features +# and can be computed on a left-out test set. +from sklearn.inspection import permutation_importance + +start_time = time.time() +result = permutation_importance( + forest, X_test, y_test, n_repeats=10, + random_state=42, n_jobs=n_jobs) +elapsed_time = time.time() - start_time +print(f"Elapsed time to compute the importances: " + f"{elapsed_time:.3f} seconds") + +# %% +# The computation for full permutation importance is more costly. Features are +# shuffled n times and the model refitted to estimate the importance of it. +# Please see :ref:`permutation_importance` for more details. We can now plot +# the importance ranking. + +plt.matshow(result.importances_mean.reshape(img_shape), cmap=plt.cm.hot) +plt.title("Pixel importances using permutation importance") +plt.colorbar() +plt.tight_layout() +plt.show() + +# %% +# We can see similar areas are detected using both methods. Although +# the importances vary. We can see that permutation importance gives lower importance +# values on any single pixel, which matches the intuition: The class of a +# digit seen on an image depends on values of many pixels together rather than +# a few pixels. diff --git a/examples/ensemble/plot_forest_importances_faces.py b/examples/ensemble/plot_forest_importances_faces.py deleted file mode 100644 index 6cea84ca4744c..0000000000000 --- a/examples/ensemble/plot_forest_importances_faces.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -================================================= -Pixel importances with a parallel forest of trees -================================================= - -This example shows the use of forests of trees to evaluate the impurity-based -importance of the pixels in an image classification task (faces). -The hotter the pixel, the more important. - -The code below also illustrates how the construction and the computation -of the predictions can be parallelized within multiple jobs. -""" -print(__doc__) - -from time import time -import matplotlib.pyplot as plt - -from sklearn.datasets import fetch_olivetti_faces -from sklearn.ensemble import ExtraTreesClassifier - -# Number of cores to use to perform parallel fitting of the forest model -n_jobs = 1 - -# Load the faces dataset -data = fetch_olivetti_faces() -X, y = data.data, data.target - -mask = y < 5 # Limit to 5 classes -X = X[mask] -y = y[mask] - -# Build a forest and compute the pixel importances -print("Fitting ExtraTreesClassifier on faces data with %d cores..." % n_jobs) -t0 = time() -forest = ExtraTreesClassifier(n_estimators=1000, - max_features=128, - n_jobs=n_jobs, - random_state=0) - -forest.fit(X, y) -print("done in %0.3fs" % (time() - t0)) -importances = forest.feature_importances_ -importances = importances.reshape(data.images[0].shape) - -# Plot pixel importances -plt.matshow(importances, cmap=plt.cm.hot) -plt.title("Pixel importances with forests of trees") -plt.show() From 0dca5ac30f5fa0260f0f21ac184b8668e2ba42f2 Mon Sep 17 00:00:00 2001 From: Alihan Zihna Date: Wed, 10 Feb 2021 14:08:07 +0000 Subject: [PATCH 2/6] Fix flake8 errors --- .../ensemble/plot_forest_importances_digits.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/ensemble/plot_forest_importances_digits.py b/examples/ensemble/plot_forest_importances_digits.py index a04327aa93f70..30bd15212eb5c 100644 --- a/examples/ensemble/plot_forest_importances_digits.py +++ b/examples/ensemble/plot_forest_importances_digits.py @@ -20,12 +20,12 @@ # ---------------------------------- # We use the faces data from datasets submodules and split the dataset # into training and testing subsets. Also, we'll set the number of cores -# to use for the tasks. +# to use for the tasks. from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split # %% -# We select the number of cores to use to perform parallel fitting of +# We select the number of cores to use to perform parallel fitting of # the forest model. `-1` means use all available cores. n_jobs = -1 @@ -85,7 +85,7 @@ start_time = time.time() result = permutation_importance( - forest, X_test, y_test, n_repeats=10, + forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=n_jobs) elapsed_time = time.time() - start_time print(f"Elapsed time to compute the importances: " @@ -105,7 +105,7 @@ # %% # We can see similar areas are detected using both methods. Although -# the importances vary. We can see that permutation importance gives lower importance -# values on any single pixel, which matches the intuition: The class of a -# digit seen on an image depends on values of many pixels together rather than -# a few pixels. +# the importances vary. We can see that permutation importance gives lower +# importance values on any single pixel, which matches the intuition: +# The class of a digit seen on an image depends on values of many pixels +# together rather than a few pixels. From 730535947a278f55062d23239a800118ed9e49f3 Mon Sep 17 00:00:00 2001 From: Alihan Zihna Date: Thu, 11 Feb 2021 11:33:51 +0000 Subject: [PATCH 3/6] Remove tight layout from plots --- examples/ensemble/plot_forest_importances_digits.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/ensemble/plot_forest_importances_digits.py b/examples/ensemble/plot_forest_importances_digits.py index 30bd15212eb5c..564f1166658ef 100644 --- a/examples/ensemble/plot_forest_importances_digits.py +++ b/examples/ensemble/plot_forest_importances_digits.py @@ -72,7 +72,6 @@ plt.matshow(imp_reshaped, cmap=plt.cm.hot) plt.title("Pixel importances using impurity values") plt.colorbar() -plt.tight_layout() plt.show() # %% @@ -100,7 +99,6 @@ plt.matshow(result.importances_mean.reshape(img_shape), cmap=plt.cm.hot) plt.title("Pixel importances using permutation importance") plt.colorbar() -plt.tight_layout() plt.show() # %% From 8e156a4f47870bfa621f6ceba31a55f72497811e Mon Sep 17 00:00:00 2001 From: Alihan Zihna Date: Tue, 23 Feb 2021 20:48:20 +0000 Subject: [PATCH 4/6] Reintroduce faces dataset and add the MDI usage warning --- .../plot_forest_importances_digits.py | 109 ------------------ .../ensemble/plot_forest_importances_faces.py | 90 +++++++++++++++ 2 files changed, 90 insertions(+), 109 deletions(-) delete mode 100644 examples/ensemble/plot_forest_importances_digits.py create mode 100644 examples/ensemble/plot_forest_importances_faces.py diff --git a/examples/ensemble/plot_forest_importances_digits.py b/examples/ensemble/plot_forest_importances_digits.py deleted file mode 100644 index 564f1166658ef..0000000000000 --- a/examples/ensemble/plot_forest_importances_digits.py +++ /dev/null @@ -1,109 +0,0 @@ -""" -================================================= -Pixel importances with a parallel forest of trees -================================================= - -This example show the use of a forest of trees to evaluate -the importance of the pixels in an image classification task (digits) -based on impurity and permutation importance. -The hotter the pixel, the more important. - -The code below also illustrates how the construction and the computation -of the predictions can be parallelized within multiple jobs. -""" -print(__doc__) - -import matplotlib.pyplot as plt - -# %% -# Loading the data and model fitting -# ---------------------------------- -# We use the faces data from datasets submodules and split the dataset -# into training and testing subsets. Also, we'll set the number of cores -# to use for the tasks. -from sklearn.datasets import load_digits -from sklearn.model_selection import train_test_split - -# %% -# We select the number of cores to use to perform parallel fitting of -# the forest model. `-1` means use all available cores. -n_jobs = -1 - -# %% -# Load the faces dataset -data = load_digits() -X, y = data.data, data.target -X_train, X_test, y_train, y_test = train_test_split( - X, y, stratify=y, random_state=42) - -# %% -# A random forest classifier will be fitted to compute the feature importances. -from sklearn.ensemble import RandomForestClassifier - -forest = RandomForestClassifier( - n_estimators=750, n_jobs=n_jobs, random_state=42) - -forest.fit(X_train, y_train) - -# %% -# Feature importance based on mean decrease in impurity -# ----------------------------------------------------- -# Feature importances are provided by the fitted attribute -# `feature_importances_` and they are computed as the mean and standard -# deviation of accumulation of the impurity decrease within each tree. -# -# .. warning:: -# Impurity-based feature importances can be misleading for high cardinality -# features (many unique values). See :ref:`permutation_importance` as -# an alternative below. -import time - -start_time = time.time() -img_shape = data.images[0].shape -importances = forest.feature_importances_ -elapsed_time = time.time() - start_time - -print(f"Elapsed time to compute the importances: " - f"{elapsed_time:.3f} seconds") - -# %% -# Let's plot the impurity-based importance. -imp_reshaped = importances.reshape(img_shape) -plt.matshow(imp_reshaped, cmap=plt.cm.hot) -plt.title("Pixel importances using impurity values") -plt.colorbar() -plt.show() - -# %% -# Feature importance based on feature permutation -# ----------------------------------------------- -# Permutation feature importance overcomes limitations of the impurity-based -# feature importance: they do not have a bias toward high-cardinality features -# and can be computed on a left-out test set. -from sklearn.inspection import permutation_importance - -start_time = time.time() -result = permutation_importance( - forest, X_test, y_test, n_repeats=10, - random_state=42, n_jobs=n_jobs) -elapsed_time = time.time() - start_time -print(f"Elapsed time to compute the importances: " - f"{elapsed_time:.3f} seconds") - -# %% -# The computation for full permutation importance is more costly. Features are -# shuffled n times and the model refitted to estimate the importance of it. -# Please see :ref:`permutation_importance` for more details. We can now plot -# the importance ranking. - -plt.matshow(result.importances_mean.reshape(img_shape), cmap=plt.cm.hot) -plt.title("Pixel importances using permutation importance") -plt.colorbar() -plt.show() - -# %% -# We can see similar areas are detected using both methods. Although -# the importances vary. We can see that permutation importance gives lower -# importance values on any single pixel, which matches the intuition: -# The class of a digit seen on an image depends on values of many pixels -# together rather than a few pixels. diff --git a/examples/ensemble/plot_forest_importances_faces.py b/examples/ensemble/plot_forest_importances_faces.py new file mode 100644 index 0000000000000..67a2297a29834 --- /dev/null +++ b/examples/ensemble/plot_forest_importances_faces.py @@ -0,0 +1,90 @@ +""" +================================================= +Pixel importances with a parallel forest of trees +================================================= + +This example show the use of a forest of trees to evaluate +the impurity based importance of the pixels in an image +classification task on the faces dataset. +The hotter the pixel, the more important it is. + +The code below also illustrates how the construction and the computation +of the predictions can be parallelized within multiple jobs. +""" +# %% +print(__doc__) + +import matplotlib.pyplot as plt + +# %% +# Loading the data and model fitting +# ---------------------------------- +# We use the faces data from datasets submodules when using impurity-based +# feature importance. It is not possible to evaluate the importance +# on a separate test set but for this example, we are interested +# in representing the information learned from the full dataset. +# Also, we'll set the number of cores to use for the tasks. +from sklearn.datasets import fetch_olivetti_faces + +# %% +# We select the number of cores to use to perform parallel fitting of +# the forest model. `-1` means use all available cores. +n_jobs = -1 + +# %% +# Load the faces dataset +data = fetch_olivetti_faces() +X, y = data.data, data.target + +# %% +# Limit the dataset to 5 classes. +mask = y < 5 +X = X[mask] +y = y[mask] + +# %% +# A random forest classifier will be fitted to compute the feature importances. +from sklearn.ensemble import RandomForestClassifier + +forest = RandomForestClassifier( + n_estimators=750, n_jobs=n_jobs, random_state=42) + +forest.fit(X, y) + +# %% +# Feature importance based on mean decrease in impurity (MDI) +# ----------------------------------------------------------- +# Feature importances are provided by the fitted attribute +# `feature_importances_` and they are computed as the mean and standard +# deviation of accumulation of the impurity decrease within each tree. +# +# .. warning:: +# Impurity-based feature importances can be misleading for high cardinality +# features (many unique values). See :ref:`permutation_importance` as +# an alternative. +import time + +start_time = time.time() +img_shape = data.images[0].shape +importances = forest.feature_importances_ +elapsed_time = time.time() - start_time + +print(f"Elapsed time to compute the importances: " + f"{elapsed_time:.3f} seconds") + +# %% +# Let's plot the impurity-based importance. +imp_reshaped = importances.reshape(img_shape) +plt.matshow(imp_reshaped, cmap=plt.cm.hot) +plt.title("Pixel importances using impurity values") +plt.colorbar() +plt.show() + +# %% +# The limitations of MDI is not a problem for this dataset because: +# * All features are homogeneous and will not suffer the cardinality bias +# * We are only interested to represent knowledge of the forest acquired +# on the training set. +# +# If these two conditions are not met, it is recommended to instead use +# the :func:`inspection.permutation_importance`. From 71e544c6e8b9332c55eb407fe37943b8d9304334 Mon Sep 17 00:00:00 2001 From: Alihan Zihna Date: Wed, 24 Feb 2021 09:09:33 +0000 Subject: [PATCH 5/6] Fix bullet points and the reference --- examples/ensemble/plot_forest_importances_faces.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/ensemble/plot_forest_importances_faces.py b/examples/ensemble/plot_forest_importances_faces.py index 67a2297a29834..046be0530b72a 100644 --- a/examples/ensemble/plot_forest_importances_faces.py +++ b/examples/ensemble/plot_forest_importances_faces.py @@ -82,9 +82,10 @@ # %% # The limitations of MDI is not a problem for this dataset because: -# * All features are homogeneous and will not suffer the cardinality bias -# * We are only interested to represent knowledge of the forest acquired -# on the training set. +# +# 1. All features are homogeneous and will not suffer the cardinality bias +# 2. We are only interested to represent knowledge of the forest acquired +# on the training set. # # If these two conditions are not met, it is recommended to instead use -# the :func:`inspection.permutation_importance`. +# the :func:`~sklearn.inspection.permutation_importance`. From a6d6f5cb5ad01fe3257f8526104206a0d772469b Mon Sep 17 00:00:00 2001 From: Alihan Zihna Date: Wed, 26 May 2021 17:47:05 +0000 Subject: [PATCH 6/6] assert_raises to raises --- sklearn/utils/tests/test_estimator_checks.py | 154 +++++++++---------- 1 file changed, 77 insertions(+), 77 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 4792f50f2baef..301ba2ffd6776 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -12,8 +12,7 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils import deprecated from sklearn.utils._testing import ( - assert_raises, - assert_raises_regex, + raises, assert_warns, ignore_warnings, MinimalClassifier, @@ -413,7 +412,8 @@ def test_not_an_array_array_function(): raise SkipTest("array_function protocol not supported in numpy <1.17") not_array = _NotAnArray(np.ones(10)) msg = "Don't want to call array_function sum!" - assert_raises_regex(TypeError, msg, np.sum, not_array) + with raises(TypeError, match=msg): + np.sum(not_array) # always returns True assert np.may_share_memory(not_array, None) @@ -437,92 +437,93 @@ def test_check_estimator(): # check that we have a set_params and can clone msg = "Passing a class was deprecated" - assert_raises_regex(TypeError, msg, check_estimator, object) + with raises(TypeError, match=msg): + check_estimator(object) msg = ( "Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed" ) # check that the "default_constructible" test checks for mutable parameters check_estimator(HasImmutableParameters()) # should pass - assert_raises_regex( - AssertionError, msg, check_estimator, HasMutableParameters() - ) + with raises(AssertionError, match=msg): + check_estimator(HasMutableParameters()) # check that values returned by get_params match set_params msg = "get_params result does not match what was passed to set_params" - assert_raises_regex(AssertionError, msg, check_estimator, - ModifiesValueInsteadOfRaisingError()) + with raises(AssertionError, match=msg): + check_estimator(ModifiesValueInsteadOfRaisingError()) assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams()) - assert_raises_regex(AssertionError, msg, check_estimator, - ModifiesAnotherValue()) + with raises(AssertionError, match=msg): + check_estimator(ModifiesAnotherValue()) # check that we have a fit method msg = "object has no attribute 'fit'" - assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator()) + with raises(AttributeError, match=msg): + check_estimator(BaseEstimator()) # check that fit does input validation msg = "Did not raise" - assert_raises_regex(AssertionError, msg, check_estimator, - BaseBadClassifier()) + with raises(AssertionError, match=msg): + check_estimator(BaseBadClassifier()) # check that sample_weights in fit accepts pandas.Series type try: from pandas import Series # noqa msg = ("Estimator NoSampleWeightPandasSeriesType raises error if " "'sample_weight' parameter is of type pandas.Series") - assert_raises_regex( - ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType()) + with raises(ValueError, match=msg): + check_estimator(NoSampleWeightPandasSeriesType()) except ImportError: pass # check that predict does input validation (doesn't accept dicts in input) msg = "Estimator doesn't check for NaN and inf in predict" - assert_raises_regex(AssertionError, msg, check_estimator, - NoCheckinPredict()) + with raises(AssertionError, match=msg): + check_estimator(NoCheckinPredict()) # check that estimator state does not change # at transform/predict/predict_proba time msg = 'Estimator changes __dict__ during predict' - assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict()) + with raises(AssertionError, match=msg): + check_estimator(ChangesDict()) # check that `fit` only changes attribures that # are private (start with an _ or end with a _). msg = ('Estimator ChangesWrongAttribute should not change or mutate ' 'the parameter wrong_attribute from 0 to 1 during fit.') - assert_raises_regex(AssertionError, msg, - check_estimator, ChangesWrongAttribute()) + with raises(AssertionError, match=msg): + check_estimator(ChangesWrongAttribute()) check_estimator(ChangesUnderscoreAttribute()) # check that `fit` doesn't add any public attribute msg = (r'Estimator adds public attribute\(s\) during the fit method.' ' Estimators are only allowed to add private attributes' ' either started with _ or ended' ' with _ but wrong_attribute added') - assert_raises_regex(AssertionError, msg, - check_estimator, SetsWrongAttribute()) + with raises(AssertionError, match=msg): + check_estimator(SetsWrongAttribute()) # check for sample order invariance name = NotInvariantSampleOrder.__name__ method = 'predict' msg = ("{method} of {name} is not invariant when applied to a dataset" "with different sample order.").format(method=method, name=name) - assert_raises_regex(AssertionError, msg, - check_estimator, NotInvariantSampleOrder()) + with raises(AssertionError, match=msg): + check_estimator(NotInvariantSampleOrder()) # check for invariant method name = NotInvariantPredict.__name__ method = 'predict' msg = ("{method} of {name} is not invariant when applied " "to a subset.").format(method=method, name=name) - assert_raises_regex(AssertionError, msg, - check_estimator, NotInvariantPredict()) + with raises(AssertionError, match=msg): + check_estimator(NotInvariantPredict()) # check for sparse matrix input handling name = NoSparseClassifier.__name__ msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name - assert_raises_regex( - AssertionError, msg, check_estimator, NoSparseClassifier() - ) + with raises(AssertionError, match=msg): + check_estimator(NoSparseClassifier()) # Large indices test on bad estimator msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to ' r'support \S{3}_64 matrix, and is not failing gracefully.*') - assert_raises_regex(AssertionError, msg, check_estimator, - LargeSparseNotSupportedClassifier()) + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier()) # does error on binary_only untagged estimator msg = 'Only 2 classes are supported' - assert_raises_regex(ValueError, msg, check_estimator, - UntaggedBinaryClassifier()) + with raises(ValueError, match=msg): + check_estimator(UntaggedBinaryClassifier()) # non-regression test for estimators transforming to sparse data check_estimator(SparseTransformer()) @@ -537,8 +538,8 @@ def test_check_estimator(): # Check regressor with requires_positive_y estimator tag msg = 'negative y values not supported!' - assert_raises_regex(ValueError, msg, check_estimator, - RequiresPositiveYRegressor()) + with raises(ValueError, match=msg): + check_estimator(RequiresPositiveYRegressor()) # Does not raise error on classifier with poor_score tag check_estimator(PoorScoreLogisticRegression()) @@ -547,7 +548,8 @@ def test_check_estimator(): def test_check_outlier_corruption(): # should raise AssertionError decision = np.array([0., 1., 1.5, 2.]) - assert_raises(AssertionError, check_outlier_corruption, 1, 2, decision) + with raises(AssertionError): + check_outlier_corruption(1, 2, decision) # should pass decision = np.array([0., 1., 1., 2.]) check_outlier_corruption(1, 2, decision) @@ -555,8 +557,8 @@ def test_check_outlier_corruption(): def test_check_estimator_transformer_no_mixin(): # check that TransformerMixin is not required for transformer tests to run - assert_raises_regex(AttributeError, '.*fit_transform.*', - check_estimator, BadTransformerWithoutMixin()) + with raises(AttributeError, '.*fit_transform.*'): + check_estimator(BadTransformerWithoutMixin()) def test_check_estimator_clones(): @@ -593,8 +595,8 @@ def test_check_estimators_unfitted(): # check that a ValueError/AttributeError is raised when calling predict # on an unfitted estimator msg = "Did not raise" - assert_raises_regex(AssertionError, msg, check_estimators_unfitted, - "estimator", NoSparseClassifier()) + with raises(AssertionError, match=msg): + check_estimators_unfitted("estimator", NoSparseClassifier()) # check that CorrectNotFittedError inherit from either ValueError # or AttributeError @@ -610,19 +612,22 @@ class NonConformantEstimatorNoParamSet(BaseEstimator): def __init__(self, you_should_set_this_=None): pass - assert_raises_regex(AssertionError, - "Estimator estimator_name should not set any" - " attribute apart from parameters during init." - r" Found attributes \['you_should_not_set_this_'\].", - check_no_attributes_set_in_init, - 'estimator_name', - NonConformantEstimatorPrivateSet()) - assert_raises_regex(AttributeError, - "Estimator estimator_name should store all " - "parameters as an attribute during init.", - check_no_attributes_set_in_init, - 'estimator_name', - NonConformantEstimatorNoParamSet()) + msg = ( + "Estimator estimator_name should not set any" + " attribute apart from parameters during init." + r" Found attributes \['you_should_not_set_this_'\]." + ) + with raises(AssertionError, match=msg): + check_no_attributes_set_in_init('estimator_name', + NonConformantEstimatorPrivateSet()) + + msg = ( + "Estimator estimator_name should store all parameters as an attribute" + " during init" + ) + with raises(AttributeError, match=msg): + check_no_attributes_set_in_init('estimator_name', + NonConformantEstimatorNoParamSet()) def test_check_estimator_pairwise(): @@ -639,32 +644,24 @@ def test_check_estimator_pairwise(): def test_check_classifier_data_not_an_array(): - assert_raises_regex(AssertionError, - 'Not equal to tolerance', - check_classifier_data_not_an_array, - 'estimator_name', - EstimatorInconsistentForPandas()) + with raises(AssertionError, match='Not equal to tolerance'): + check_classifier_data_not_an_array('estimator_name', + EstimatorInconsistentForPandas()) def test_check_regressor_data_not_an_array(): - assert_raises_regex(AssertionError, - 'Not equal to tolerance', - check_regressor_data_not_an_array, - 'estimator_name', - EstimatorInconsistentForPandas()) + with raises(AssertionError, match='Not equal to tolerance'): + check_regressor_data_not_an_array('estimator_name', + EstimatorInconsistentForPandas()) def test_check_estimator_get_tags_default_keys(): estimator = EstimatorMissingDefaultTags() err_msg = (r"EstimatorMissingDefaultTags._get_tags\(\) is missing entries" r" for the following default tags: {'allow_nan'}") - assert_raises_regex( - AssertionError, - err_msg, - check_estimator_get_tags_default_keys, - estimator.__class__.__name__, - estimator, - ) + with raises(AssertionError, match=err_msg): + check_estimator_get_tags_default_keys(estimator.__class__.__name__, + estimator) # noop check when _get_tags is not available estimator = MinimalTransformer() @@ -688,12 +685,15 @@ def run_tests_without_pytest(): def test_check_class_weight_balanced_linear_classifier(): # check that ill-computed balanced weights raises an exception - assert_raises_regex(AssertionError, - "Classifier estimator_name is not computing" - " class_weight=balanced properly.", - check_class_weight_balanced_linear_classifier, - 'estimator_name', - BadBalancedWeightsClassifier) + msg = ( + "Classifier estimator_name is not computing class_weight=balanced " + "properly" + ) + with raises(AssertionError, match=msg): + check_class_weight_balanced_linear_classifier( + 'estimator_name', + BadBalancedWeightsClassifier + ) def test_all_estimators_all_public():