From a68194be052e781e01151a5c009996774ac22931 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 9 Nov 2020 16:59:51 +0100 Subject: [PATCH 01/50] TST reintroduce _safe_tags for estimator not inheriting from BaseEstimator --- sklearn/utils/estimator_checks.py | 97 ++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 33 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 3cd19967ba9c1..98fd46a0b776d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -32,6 +32,7 @@ from ..base import ( clone, ClusterMixin, + _DEFAULT_TAGS, is_classifier, is_regressor, is_outlier_detector, @@ -66,9 +67,39 @@ CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'] +def _safe_tags(estimator, key=None): + """Safely get estimator tags for common checks. + + :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. + However, if a compatible estimator does not inherit from this base class, + we should default to the default tag. + + Parameters + ---------- + estimator : estimator object + The estimator from which to get the tag. + key : str, default=None + Tag name to get. By default (`None`), all tags are returned. + + Returns + ------- + tags : dict + The estimator tags. + """ + if hasattr(estimator, "_get_tags"): + if key is not None: + return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + tags = estimator._get_tags() + return {key: tags.get(key, _DEFAULT_TAGS[key]) + for key in _DEFAULT_TAGS.keys()} + if key is not None: + return _DEFAULT_TAGS[key] + return _DEFAULT_TAGS + + def _yield_checks(estimator): name = estimator.__class__.__name__ - tags = estimator._get_tags() + tags = _safe_tags(estimator) pairwise = _is_pairwise(estimator) yield check_no_attributes_set_in_init @@ -116,7 +147,7 @@ def _yield_checks(estimator): def _yield_classifier_checks(classifier): - tags = classifier._get_tags() + tags = _safe_tags(classifier) # test classifiers can handle non-array data and pandas objects yield check_classifier_data_not_an_array @@ -170,7 +201,7 @@ def check_supervised_y_no_nan(name, estimator_orig, strict_mode=True): def _yield_regressor_checks(regressor): - tags = regressor._get_tags() + tags = _safe_tags(regressor) # TODO: test with intercept # TODO: test with multiple responses # basic testing @@ -196,7 +227,7 @@ def _yield_regressor_checks(regressor): def _yield_transformer_checks(transformer): - tags = transformer._get_tags() + tags = _safe_tags(transformer) # All transformers should either deal with sparse data or raise an # exception with type TypeError and an intelligible error message if not tags["no_validation"]: @@ -206,7 +237,7 @@ def _yield_transformer_checks(transformer): if tags["preserves_dtype"]: yield check_transformer_preserve_dtypes yield partial(check_transformer_general, readonly_memmap=True) - if not transformer._get_tags()["stateless"]: + if not _safe_tags(transformer, key="stateless"): yield check_transformers_unfitted # Dependent on external solvers and hence accessing the iter # param is non-trivial. @@ -243,13 +274,13 @@ def _yield_outliers_checks(estimator): # test outlier detectors can handle non-array data yield check_classifier_data_not_an_array # test if NotFittedError is raised - if estimator._get_tags()["requires_fit"]: + if _safe_tags(estimator, key="requires_fit"): yield check_estimators_unfitted def _yield_all_checks(estimator): name = estimator.__class__.__name__ - tags = estimator._get_tags() + tags = _safe_tags(estimator) if "2darray" not in tags["X_types"]: warnings.warn("Can't test estimator {} which requires input " " of type {}".format(name, tags["X_types"]), @@ -421,7 +452,7 @@ def _should_be_skipped_or_marked(estimator, check, strict_mode): check_name = (check.func.__name__ if isinstance(check, partial) else check.__name__) - xfail_checks = estimator._get_tags()['_xfail_checks'] or {} + xfail_checks = _safe_tags(estimator, key='_xfail_checks') or {} if check_name in xfail_checks: return True, xfail_checks[check_name] @@ -772,7 +803,7 @@ def check_estimator_sparse_data(name, estimator_orig, strict_mode=True): with ignore_warnings(category=FutureWarning): estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) for matrix_format, X in _generate_sparse_matrix(X_csr): # catch deprecation warnings with ignore_warnings(category=FutureWarning): @@ -829,7 +860,7 @@ def check_sample_weights_pandas_series(name, estimator_orig, strict_mode=True): X = pd.DataFrame(_pairwise_estimator_convert_X(X, estimator_orig)) y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2]) weights = pd.Series([1] * 12) - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): y = pd.DataFrame(y) try: estimator.fit(X, y, sample_weight=weights) @@ -854,7 +885,7 @@ def check_sample_weights_not_an_array(name, estimator_orig, strict_mode=True): X = _NotAnArray(_pairwise_estimator_convert_X(X, estimator_orig)) y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2]) weights = _NotAnArray([1] * 12) - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): y = _NotAnArray(y.data.reshape(-1, 1)) estimator.fit(X, y, sample_weight=weights) @@ -959,7 +990,7 @@ def check_dtype_object(name, estimator_orig, strict_mode=True): rng = np.random.RandomState(0) X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig) X = X.astype(object) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) y = (X[:, 0] * 4).astype(int) estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) @@ -1179,7 +1210,7 @@ def check_methods_sample_order_invariance( X = 3 * rnd.uniform(size=(20, 3)) X = _pairwise_estimator_convert_X(X, estimator_orig) y = X[:, 0].astype(np.int) - if estimator_orig._get_tags()['binary_only']: + if _safe_tags(estimator_orig, key='binary_only'): y[y == 2] = 1 estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) @@ -1368,7 +1399,7 @@ def _check_transformer(name, transformer_orig, X, y, strict_mode=True): X_pred2 = transformer.transform(X) X_pred3 = transformer.fit_transform(X, y=y_) - if transformer_orig._get_tags()['non_deterministic']: + if _safe_tags(transformer_orig, key='non_deterministic'): msg = name + ' is non deterministic' raise SkipTest(msg) if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple): @@ -1399,7 +1430,7 @@ def _check_transformer(name, transformer_orig, X, y, strict_mode=True): # raises error on malformed input for transform if hasattr(X, 'shape') and \ - not transformer._get_tags()["stateless"] and \ + not _safe_tags(transformer, key="stateless") and \ X.ndim == 2 and X.shape[1] > 1: # If it's not an array, it does not have a 'T' property @@ -1414,7 +1445,7 @@ def _check_transformer(name, transformer_orig, X, y, strict_mode=True): @ignore_warnings def check_pipeline_consistency(name, estimator_orig, strict_mode=True): - if estimator_orig._get_tags()['non_deterministic']: + if _safe_tags(estimator_orig, key='non_deterministic'): msg = name + ' is non deterministic' raise SkipTest(msg) @@ -1508,7 +1539,7 @@ def check_transformer_preserve_dtypes( X -= X.min() X = _pairwise_estimator_convert_X(X, transformer_orig) - for dtype in transformer_orig._get_tags()["preserves_dtype"]: + for dtype in _safe_tags(transformer_orig, key="preserves_dtype"): X_cast = X.astype(dtype) transformer = clone(transformer_orig) set_random_state(transformer) @@ -1634,7 +1665,7 @@ def check_estimators_pickle(name, estimator_orig, strict_mode=True): X -= X.min() X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) # include NaN values when the estimator should deal with them if tags['allow_nan']: # set randomly 10 elements to np.nan @@ -1696,7 +1727,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig, @ignore_warnings(category=FutureWarning) def check_classifier_multioutput(name, estimator, strict_mode=True): n_samples, n_labels, n_classes = 42, 5, 3 - tags = estimator._get_tags() + tags = _safe_tags(estimator) estimator = clone(estimator) X, y = make_multilabel_classification(random_state=42, n_samples=n_samples, @@ -1804,7 +1835,7 @@ def check_clustering(name, clusterer_orig, readonly_memmap=False, pred = clusterer.labels_ assert pred.shape == (n_samples,) assert adjusted_rand_score(pred, y) > 0.4 - if clusterer._get_tags()['non_deterministic']: + if _safe_tags(clusterer, key='non_deterministic'): return set_random_state(clusterer) with warnings.catch_warnings(record=True): @@ -1896,7 +1927,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False, X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b]) problems = [(X_b, y_b)] - tags = classifier_orig._get_tags() + tags = _safe_tags(classifier_orig) if not tags['binary_only']: problems.append((X_m, y_m)) @@ -2187,7 +2218,7 @@ def check_estimators_unfitted(name, estimator_orig, strict_mode=True): @ignore_warnings(category=FutureWarning) def check_supervised_y_2d(name, estimator_orig, strict_mode=True): - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) rnd = np.random.RandomState(0) n_samples = 30 X = _pairwise_estimator_convert_X( @@ -2291,7 +2322,7 @@ def check_classifiers_classes(name, classifier_orig, strict_mode=True): y_names_binary = np.take(labels_binary, y_binary) problems = [(X_binary, y_binary, y_names_binary)] - if not classifier_orig._get_tags()['binary_only']: + if not _safe_tags(classifier_orig, key='binary_only'): problems.append((X_multiclass, y_multiclass, y_names_multiclass)) for X, y, y_names in problems: @@ -2377,7 +2408,7 @@ def check_regressors_train(name, regressor_orig, readonly_memmap=False, # TODO: find out why PLS and CCA fail. RANSAC is random # and furthermore assumes the presence of outliers, hence # skipped - if not regressor._get_tags()["poor_score"]: + if not _safe_tags(regressor, key="poor_score"): assert regressor.score(X, y_) > 0.5 @@ -2402,7 +2433,7 @@ def check_regressors_no_decision_function(name, regressor_orig, @ignore_warnings(category=FutureWarning) def check_class_weight_classifiers(name, classifier_orig, strict_mode=True): - if classifier_orig._get_tags()['binary_only']: + if _safe_tags(classifier_orig, key='binary_only'): problems = [2] else: problems = [2, 3] @@ -2441,7 +2472,7 @@ def check_class_weight_classifiers(name, classifier_orig, strict_mode=True): y_pred = classifier.predict(X_test) # XXX: Generally can use 0.89 here. On Windows, LinearSVC gets # 0.88 (Issue #9111) - if not classifier_orig._get_tags()['poor_score']: + if not _safe_tags(classifier_orig, key='poor_score'): assert np.mean(y_pred == 0) > 0.87 @@ -2761,16 +2792,16 @@ def param_filter(p): def _enforce_estimator_tags_y(estimator, y): # Estimators with a `requires_positive_y` tag only accept strictly positive # data - if estimator._get_tags()["requires_positive_y"]: + if _safe_tags(estimator, key="requires_positive_y"): # Create strictly positive y. The minimal increment above 0 is 1, as # y could be of integer dtype. y += 1 + abs(y.min()) # Estimators with a `binary_only` tag only accept up to two unique y values - if estimator._get_tags()["binary_only"] and y.size > 0: + if _safe_tags(estimator, key="binary_only") and y.size > 0: y = np.where(y == y.flat[0], y, y.flat[0] + 1) # Estimators in mono_output_task_error raise ValueError if y is of 1-D # Convert into a 2-D y for those estimators. - if estimator._get_tags()["multioutput_only"]: + if _safe_tags(estimator, key="multioutput_only"): return np.reshape(y, (-1, 1)) return y @@ -2782,11 +2813,11 @@ def _enforce_estimator_tags_x(estimator, X): X = X.dot(X.T) # Estimators with `1darray` in `X_types` tag only accept # X of shape (`n_samples`,) - if '1darray' in estimator._get_tags()['X_types']: + if '1darray' in _safe_tags(estimator, key='X_types'): X = X[:, 0] # Estimators with a `requires_positive_X` tag only accept # strictly positive data - if estimator._get_tags()['requires_positive_X']: + if _safe_tags(estimator, key='requires_positive_X'): X -= X.min() return X @@ -2928,7 +2959,7 @@ def check_classifiers_regression_target(name, estimator_orig, X = X + 1 + abs(X.min(axis=0)) # be sure that X is non-negative e = clone(estimator_orig) msg = "Unknown label type: " - if not e._get_tags()["no_validation"]: + if not _safe_tags(e, keyy="no_validation"): with raises(ValueError, match=msg): e.fit(X, y) @@ -3145,7 +3176,7 @@ def check_requires_y_none(name, estimator_orig, strict_mode=True): def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): # Make sure that n_features_in are checked after fitting - tags = estimator_orig._get_tags() + tags = _safe_tags(estimator_orig) if "2darray" not in tags["X_types"] or tags["no_validation"]: return From 36f1c5c68f3ff16a5b9fa936e100de2bc7a59ffa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 9 Nov 2020 17:06:29 +0100 Subject: [PATCH 02/50] typo --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 98fd46a0b776d..ac30f66d41866 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2959,7 +2959,7 @@ def check_classifiers_regression_target(name, estimator_orig, X = X + 1 + abs(X.min(axis=0)) # be sure that X is non-negative e = clone(estimator_orig) msg = "Unknown label type: " - if not _safe_tags(e, keyy="no_validation"): + if not _safe_tags(e, key="no_validation"): with raises(ValueError, match=msg): e.fit(X, y) From 9e540141319126dce275b76a996679d45495de2e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 11 Nov 2020 13:20:26 +0100 Subject: [PATCH 03/50] TST implement minimal classifier --- sklearn/utils/tests/test_estimator_checks.py | 41 +++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index ecbf7cb7be7f4..9a069224f88ba 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -23,6 +23,7 @@ from sklearn.utils.estimator_checks import check_regressor_data_not_an_array from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption +from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.fixes import np_version, parse_version from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import LinearRegression, SGDClassifier @@ -418,8 +419,8 @@ def test_check_estimator(): # check that we have a set_params and can clone msg = "Passing a class was deprecated" assert_raises_regex(TypeError, msg, check_estimator, object) - msg = "object has no attribute '_get_tags'" - assert_raises_regex(AttributeError, msg, check_estimator, object()) + # msg = "object has no attribute '_get_tags'" + # assert_raises_regex(AttributeError, msg, check_estimator, object()) msg = ( "Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed" @@ -620,6 +621,42 @@ def test_check_estimator_pairwise(): check_estimator(est) +class MinimalEstimator: + + # Our minimal required supposed that the following are implemented + _get_param_names = BaseEstimator._get_param_names # used by get_params + set_params = BaseEstimator.set_params + get_params = BaseEstimator.get_params + __setstate__ = BaseEstimator.__setstate__ + __getstate__ = BaseEstimator.__getstate__ + + def fit(self, X, y): + return self + + +class MinimalClassifier(MinimalEstimator): + + def fit(self, X, y): + self.classes_ = np.unique(y) + return super().fit(X, y) + + def predict_proba(self, X): + proba_shape = (len(X), self.classes_.size) + y_proba = np.zeros(shape=proba_shape, dtype=np.float64) + y_proba[:, 0] = 1.0 + return y_proba + + def predict(self, X): + y_proba = self.predict_proba(X) + y_pred = y_proba.argmax(axis=1) + return self.classes_[y_pred] + + +@parametrize_with_checks([MinimalClassifier()], strict_mode=False) +def test_check_estimator_minimal(estimator, check): + check(estimator) + + def test_check_classifier_data_not_an_array(): assert_raises_regex(AssertionError, 'Not equal to tolerance', From 8f571acc46f33c8bd42cef035d5f6a41d27adb31 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 11:23:44 +0100 Subject: [PATCH 04/50] Add future minimal tests --- sklearn/tests/test_common.py | 151 ++++++++++++++++++- sklearn/utils/estimator_checks.py | 5 +- sklearn/utils/tests/test_estimator_checks.py | 2 - 3 files changed, 154 insertions(+), 4 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 0a34f30765862..25624d934fa1d 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -45,7 +45,13 @@ parametrize_with_checks, check_n_features_in_after_fitting, ) -from sklearn.utils.validation import check_non_negative, check_array +from sklearn.multiclass import check_classification_targets +from sklearn.utils.validation import ( + check_array, + check_is_fitted, + check_non_negative, + check_X_y, +) def test_all_estimator_no_base_class(): @@ -370,3 +376,146 @@ def test_search_cv(estimator, check, request): def test_check_n_features_in_after_fitting(estimator): _set_checking_parameters(estimator) check_n_features_in_after_fitting(estimator.__class__.__name__, estimator) + + +class MinimalClassifier: + """Minimal classifier implementation with inheriting from BaseEstimator.""" + _estimator_type = "classifier" + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, **params): + return {} + + def set_params(self, deep=True): + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y): + X, y = check_X_y(X, y) + check_classification_targets(y) + self.n_features_in_ = X.shape[1] + self.classes_, counts = np.unique(y, return_counts=True) + self._most_frequent_class_idx = counts.argmax() + return self + + def predict_proba(self, X): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + proba_shape = (X.shape[0], self.classes_.size) + y_proba = np.zeros(shape=proba_shape, dtype=np.float64) + y_proba[:, self._most_frequent_class_idx] = 1.0 + return y_proba + + def predict(self, X): + y_proba = self.predict_proba(X) + y_pred = y_proba.argmax(axis=1) + return self.classes_[y_pred] + + def score(self, X, y): + from sklearn.metrics import accuracy_score + return accuracy_score(y, self.predict(X)) + + +class MinimalRegressor: + """Minimal regressor implementation with inheriting from BaseEstimator.""" + _estimator_type = "regressor" + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, **params): + return {} + + def set_params(self, deep=True): + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y): + X, y = check_X_y(X, y) + self.n_features_in_ = X.shape[1] + self._mean = np.mean(y) + return self + + def predict(self, X): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + return np.ones(shape=(X.shape[0],)) * self._mean + + def score(self, X, y): + from sklearn.metrics import r2_score + return r2_score(y, self.predict(X)) + + +class MinimalTransformer: + """Minimal transformer implementation with inheriting from + BaseEstimator.""" + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, **params): + return {} + + def set_params(self, deep=True): + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y=None): + X = check_array(X) + self.n_features_in_ = X.shape[1] + return self + + def transform(self, X, y=None): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + return X + + def inverse_transform(self, X, y=None): + return self.transform(X) + + def fit_transform(self, X, y=None): + return self.fit(X, y).transform(X, y) + + +# FIXME: hopefully in 0.25 +@pytest.mark.skip( + reason=("This test is currently failing because checks are granular " + "enough. Once checks are split with some kind of only API tests, " + "this test should enabled.") +) +@parametrize_with_checks( + [MinimalClassifier(), MinimalRegressor(), MinimalTransformer()], +) +def test_check_estimator_minimal(estimator, check): + # Check that third-party library can run tests without inheriting from + # BaseEstimator. + check(estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5bf109000dc8c..c34008a76c3ef 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1682,7 +1682,10 @@ def check_estimators_pickle(name, estimator_orig, strict_mode=True): # pickle and unpickle! pickled_estimator = pickle.dumps(estimator) - if estimator.__module__.startswith('sklearn.'): + module_name = estimator.__module__ + if module_name.startswith('sklearn.') and "test_" not in module_name: + # strict check for sklearn estimators that are not implemented in test + # modules. assert b"version" in pickled_estimator unpickled_estimator = pickle.loads(pickled_estimator) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 9a069224f88ba..d0ceb7a2791a3 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -419,8 +419,6 @@ def test_check_estimator(): # check that we have a set_params and can clone msg = "Passing a class was deprecated" assert_raises_regex(TypeError, msg, check_estimator, object) - # msg = "object has no attribute '_get_tags'" - # assert_raises_regex(AttributeError, msg, check_estimator, object()) msg = ( "Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed" From 88b01f10db155bbf6172cbb84e8cc873ce7a156d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 12:08:49 +0100 Subject: [PATCH 05/50] refactor --- sklearn/base.py | 26 ++-------- sklearn/tests/test_common.py | 82 ++++++++++++++++++++++++------- sklearn/utils/__init__.py | 54 +++++++++++++++++++- sklearn/utils/estimator_checks.py | 36 ++------------ 4 files changed, 125 insertions(+), 73 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 96abc511b6125..1480f47c5c3d1 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -14,33 +14,15 @@ from . import __version__ from ._config import get_config -from .utils import _IS_32BIT +from .utils import ( + _DEFAULT_TAGS, + _IS_32BIT, +) from .utils.validation import check_X_y from .utils.validation import check_array from .utils._estimator_html_repr import estimator_html_repr from .utils.validation import _deprecate_positional_args -_DEFAULT_TAGS = { - 'non_deterministic': False, - 'requires_positive_X': False, - 'requires_positive_y': False, - 'X_types': ['2darray'], - 'poor_score': False, - 'no_validation': False, - 'multioutput': False, - "allow_nan": False, - 'stateless': False, - 'multilabel': False, - '_skip_test': False, - '_xfail_checks': False, - 'multioutput_only': False, - 'binary_only': False, - 'requires_fit': True, - 'preserves_dtype': [np.float64], - 'requires_y': False, - 'pairwise': False, - } - @_deprecate_positional_args def clone(estimator, *, safe=True): diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 25624d934fa1d..757fbc4cc91f5 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -382,15 +382,23 @@ class MinimalClassifier: """Minimal classifier implementation with inheriting from BaseEstimator.""" _estimator_type = "classifier" + def __init__(self, param=None): + self.param = param + def __repr__(self): # Only required when using pytest-xdist to get an id not associated # with the memory location return self.__class__.__name__ - def get_params(self, **params): - return {} + def get_params(self, deep=True): + return {"param": self.param} - def set_params(self, deep=True): + def set_params(self, **params): + valid_params = self.get_params() + for key, value in params.items(): + if key not in valid_params: + raise ValueError("Wrong params") + setattr(self, key, value) return self def __getstate__(self): @@ -431,15 +439,23 @@ class MinimalRegressor: """Minimal regressor implementation with inheriting from BaseEstimator.""" _estimator_type = "regressor" + def __init__(self, param=None): + self.param = param + def __repr__(self): # Only required when using pytest-xdist to get an id not associated # with the memory location return self.__class__.__name__ - def get_params(self, **params): - return {} + def get_params(self, deep=True): + return {"param": self.param} - def set_params(self, deep=True): + def set_params(self, **params): + valid_params = self.get_params() + for key, value in params.items(): + if key not in valid_params: + raise ValueError("Wrong params") + setattr(self, key, value) return self def __getstate__(self): @@ -470,15 +486,23 @@ class MinimalTransformer: """Minimal transformer implementation with inheriting from BaseEstimator.""" + def __init__(self, param=None): + self.param = param + def __repr__(self): # Only required when using pytest-xdist to get an id not associated # with the memory location return self.__class__.__name__ - def get_params(self, **params): - return {} + def get_params(self, deep=True): + return {"param": self.param} - def set_params(self, deep=True): + def set_params(self, **params): + valid_params = self.get_params() + for key, value in params.items(): + if key not in valid_params: + raise ValueError("Wrong params") + setattr(self, key, value) return self def __getstate__(self): @@ -506,16 +530,38 @@ def fit_transform(self, X, y=None): return self.fit(X, y).transform(X, y) +def _generate_minimal_compatible_instances(): + """Generate instance containing estimators from minimal class compatible + implementation.""" + for SearchCV, (Estimator, param_grid) in zip( + [GridSearchCV, RandomizedSearchCV], + [ + (MinimalRegressor, {"param": [1, 10]}), + (MinimalClassifier, {"param": [1, 10]}), + ], + ): + yield SearchCV(Estimator(), param_grid) + + for SearchCV, (Estimator, param_grid) in zip( + [GridSearchCV, RandomizedSearchCV], + [ + (MinimalRegressor, {"param": [1, 10]}), + (MinimalClassifier, {"param": [1, 10]}), + ], + ): + yield SearchCV( + make_pipeline(MinimalTransformer(), Estimator()), param_grid + ).set_params(error_score="raise") + + # FIXME: hopefully in 0.25 -@pytest.mark.skip( - reason=("This test is currently failing because checks are granular " - "enough. Once checks are split with some kind of only API tests, " - "this test should enabled.") -) -@parametrize_with_checks( - [MinimalClassifier(), MinimalRegressor(), MinimalTransformer()], -) -def test_check_estimator_minimal(estimator, check): +# @pytest.mark.skip( +# reason=("This test is currently failing because checks are granular " +# "enough. Once checks are split with some kind of only API tests, " +# "this test should enabled.") +# ) +@parametrize_with_checks(list(_generate_minimal_compatible_instances())) +def test_minimal_class_implementation_checks(estimator, check): # Check that third-party library can run tests without inheriting from # BaseEstimator. check(estimator) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 9d542102b1dda..6d1c52406f112 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -43,6 +43,26 @@ parallel_backend = _joblib.parallel_backend register_parallel_backend = _joblib.register_parallel_backend +_DEFAULT_TAGS = { + 'non_deterministic': False, + 'requires_positive_X': False, + 'requires_positive_y': False, + 'X_types': ['2darray'], + 'poor_score': False, + 'no_validation': False, + 'multioutput': False, + "allow_nan": False, + 'stateless': False, + 'multilabel': False, + '_skip_test': False, + '_xfail_checks': False, + 'multioutput_only': False, + 'binary_only': False, + 'requires_fit': True, + 'preserves_dtype': [np.float64], + 'requires_y': False, + 'pairwise': False, +} __all__ = ["murmurhash3_32", "as_float_array", "assert_all_finite", "check_array", @@ -53,7 +73,9 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning", "estimator_html_repr" + "DataConversionWarning", "estimator_html_repr", + "_DEFAULT_TAGS", + "_safe_tags", ] IS_PYPY = platform.python_implementation() == 'PyPy' @@ -1182,3 +1204,33 @@ def is_abstract(c): # itemgetter is used to ensure the sort does not extend to the 2nd item of # the tuple return sorted(set(estimators), key=itemgetter(0)) + + +def _safe_tags(estimator, key=None): + """Safely get estimator tags for common checks. + + :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. + However, if a compatible estimator does not inherit from this base class, + we should default to the default tag. + + Parameters + ---------- + estimator : estimator object + The estimator from which to get the tag. + key : str, default=None + Tag name to get. By default (`None`), all tags are returned. + + Returns + ------- + tags : dict + The estimator tags. + """ + if hasattr(estimator, "_get_tags"): + if key is not None: + return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + tags = estimator._get_tags() + return {key: tags.get(key, _DEFAULT_TAGS[key]) + for key in _DEFAULT_TAGS.keys()} + if key is not None: + return _DEFAULT_TAGS[key] + return _DEFAULT_TAGS diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c34008a76c3ef..fdc388f21f95b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -32,7 +32,6 @@ from ..base import ( clone, ClusterMixin, - _DEFAULT_TAGS, is_classifier, is_regressor, is_outlier_detector, @@ -52,7 +51,10 @@ from ..model_selection._validation import _safe_split from ..metrics.pairwise import (rbf_kernel, linear_kernel, pairwise_distances) -from .import shuffle +from .import ( + shuffle, + _safe_tags, +) from .validation import has_fit_parameter, _num_samples from ..preprocessing import StandardScaler from ..preprocessing import scale @@ -67,36 +69,6 @@ CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'] -def _safe_tags(estimator, key=None): - """Safely get estimator tags for common checks. - - :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. - However, if a compatible estimator does not inherit from this base class, - we should default to the default tag. - - Parameters - ---------- - estimator : estimator object - The estimator from which to get the tag. - key : str, default=None - Tag name to get. By default (`None`), all tags are returned. - - Returns - ------- - tags : dict - The estimator tags. - """ - if hasattr(estimator, "_get_tags"): - if key is not None: - return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) - tags = estimator._get_tags() - return {key: tags.get(key, _DEFAULT_TAGS[key]) - for key in _DEFAULT_TAGS.keys()} - if key is not None: - return _DEFAULT_TAGS[key] - return _DEFAULT_TAGS - - def _yield_checks(estimator): name = estimator.__class__.__name__ tags = _safe_tags(estimator) From 05d42255d008f06f0e5f695122d8b194b77721b7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 12:14:34 +0100 Subject: [PATCH 06/50] fix change _get_tags in search and pipeline --- sklearn/model_selection/_search.py | 8 +++++--- sklearn/pipeline.py | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index eb282535dd4d5..c028872b2fbe0 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -33,7 +33,10 @@ from ._validation import _normalize_score_results from ..exceptions import NotFittedError from joblib import Parallel -from ..utils import check_random_state +from ..utils import ( + check_random_state, + _safe_tags, +) from ..utils.random import sample_without_replacement from ..utils.validation import indexable, check_is_fitted, _check_fit_params from ..utils.validation import _deprecate_positional_args @@ -433,9 +436,8 @@ def _estimator_type(self): def _more_tags(self): # allows cross-validation to see 'precomputed' metrics - estimator_tags = self.estimator._get_tags() return { - 'pairwise': estimator_tags.get('pairwise', False), + 'pairwise': _safe_tags(self.estimator, "pairwise"), "_xfail_checks": {"check_supervised_y_2d": "DataConversionWarning not caught"}, } diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 8d738d4b90fff..597f3e2e88dba 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -19,7 +19,11 @@ from .base import clone, TransformerMixin from .utils._estimator_html_repr import _VisualBlock from .utils.metaestimators import if_delegate_has_method -from .utils import Bunch, _print_elapsed_time +from .utils import ( + Bunch, + _print_elapsed_time, + _safe_tags, +) from .utils.deprecation import deprecated from .utils.validation import check_memory from .utils.validation import _deprecate_positional_args @@ -623,8 +627,7 @@ def classes_(self): def _more_tags(self): # check if first estimator expects pairwise input - estimator_tags = self.steps[0][1]._get_tags() - return {'pairwise': estimator_tags.get('pairwise', False)} + return {'pairwise': _safe_tags(self.steps[0][1], "pairwise")} # TODO: Remove in 0.26 # mypy error: Decorated property not supported From 5630266f8d0eced74dbf27da414456e8101916bf Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 12:22:50 +0100 Subject: [PATCH 07/50] fix nested param name --- sklearn/tests/test_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 757fbc4cc91f5..80f2a1e33a92c 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -545,8 +545,8 @@ def _generate_minimal_compatible_instances(): for SearchCV, (Estimator, param_grid) in zip( [GridSearchCV, RandomizedSearchCV], [ - (MinimalRegressor, {"param": [1, 10]}), - (MinimalClassifier, {"param": [1, 10]}), + (MinimalRegressor, {"minimalregressor__param": [1, 10]}), + (MinimalClassifier, {"minimalclassifier__param": [1, 10]}), ], ): yield SearchCV( From 508d5e0f6c43d37bdd371aafdb6e2da3406b11cf Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 12:23:49 +0100 Subject: [PATCH 08/50] skip test --- sklearn/tests/test_common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 80f2a1e33a92c..4e61146082c9b 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -555,11 +555,11 @@ def _generate_minimal_compatible_instances(): # FIXME: hopefully in 0.25 -# @pytest.mark.skip( -# reason=("This test is currently failing because checks are granular " -# "enough. Once checks are split with some kind of only API tests, " -# "this test should enabled.") -# ) +@pytest.mark.skip( + reason=("This test is currently failing because checks are granular " + "enough. Once checks are split with some kind of only API tests, " + "this test should enabled.") +) @parametrize_with_checks(list(_generate_minimal_compatible_instances())) def test_minimal_class_implementation_checks(estimator, check): # Check that third-party library can run tests without inheriting from From 610d64525032c4e4a557280417c9d4d68b421d57 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 12:28:39 +0100 Subject: [PATCH 09/50] upadte multiclass --- sklearn/multiclass.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 50bf83d4eaa43..f083186c5860c 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -45,7 +45,10 @@ from .base import _is_pairwise from .preprocessing import LabelBinarizer from .metrics.pairwise import euclidean_distances -from .utils import check_random_state +from .utils import ( + check_random_state, + _safe_tags, +) from .utils.deprecation import deprecated from .utils.validation import _num_samples from .utils.validation import check_is_fitted @@ -499,8 +502,7 @@ def _pairwise(self): def _more_tags(self): """Indicate if wrapped estimator is using a precomputed Gram matrix""" - estimator_tags = self.estimator._get_tags() - return {'pairwise': estimator_tags.get('pairwise', False)} + return {'pairwise': _safe_tags(self.estimator, key="pairwise")} @property def _first_estimator(self): @@ -780,8 +782,12 @@ def _pairwise(self): def _more_tags(self): """Indicate if wrapped estimator is using a precomputed Gram matrix""" - estimator_tags = self.estimator._get_tags() - return {'pairwise': estimator_tags.get('pairwise', True)} + if (hasattr(self.estimator, '_get_tags') and + callable(self.estimator._get_tags)): + pairwise_tag = self.estimator._get_tags().get('pairwise', True) + else: + pairwise_tag = True + return {'pairwise': pairwise_tag} class OutputCodeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): From 7d0a4f6d9fd876178c3f391d32ed31fdd5f314fc Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 12:32:09 +0100 Subject: [PATCH 10/50] fix feature selection --- sklearn/feature_selection/_from_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index 5fb519a2bd798..f604aa71a04e9 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -283,5 +283,9 @@ def n_features_in_(self): return self.estimator_.n_features_in_ def _more_tags(self): - estimator_tags = self.estimator._get_tags() - return {'allow_nan': estimator_tags.get('allow_nan', True)} + if (hasattr(self.estimator, '_get_tags') and + callable(self.estimator._get_tags)): + allow_nan_tag = self.estimator._get_tags().get('pairwise', True) + else: + allow_nan_tag = True + return {'allow_nan': allow_nan_tag} From eeaf7b0779dfb76abf00a112e8c25d2eb2dd8438 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 12:43:57 +0100 Subject: [PATCH 11/50] add default overwrite in safe_tag --- sklearn/base.py | 7 ++----- sklearn/feature_selection/_base.py | 19 +++++++++++++------ sklearn/feature_selection/_from_model.py | 11 +++++------ sklearn/multiclass.py | 10 ++++------ sklearn/utils/__init__.py | 10 ++++++++-- 5 files changed, 32 insertions(+), 25 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 1480f47c5c3d1..b850e7668cea2 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -17,6 +17,7 @@ from .utils import ( _DEFAULT_TAGS, _IS_32BIT, + _safe_tags, ) from .utils.validation import check_X_y from .utils.validation import check_array @@ -840,11 +841,7 @@ def _is_pairwise(estimator): warnings.filterwarnings('ignore', category=FutureWarning) has_pairwise_attribute = hasattr(estimator, '_pairwise') pairwise_attribute = getattr(estimator, '_pairwise', False) - - if hasattr(estimator, '_get_tags') and callable(estimator._get_tags): - pairwise_tag = estimator._get_tags().get('pairwise', False) - else: - pairwise_tag = False + pairwise_tag = _safe_tags(estimator, key="pairwise") if has_pairwise_attribute: if pairwise_attribute != pairwise_tag: diff --git a/sklearn/feature_selection/_base.py b/sklearn/feature_selection/_base.py index a5d752cb3f4b6..ae0047cf8f0c9 100644 --- a/sklearn/feature_selection/_base.py +++ b/sklearn/feature_selection/_base.py @@ -12,9 +12,12 @@ from scipy.sparse import issparse, csc_matrix from ..base import TransformerMixin -from ..utils import check_array -from ..utils import safe_mask -from ..utils import safe_sqr +from ..utils import ( + check_array, + safe_mask, + safe_sqr, + _safe_tags, +) class SelectorMixin(TransformerMixin, metaclass=ABCMeta): @@ -74,9 +77,13 @@ def transform(self, X): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ - tags = self._get_tags() - X = check_array(X, dtype=None, accept_sparse='csr', - force_all_finite=not tags.get('allow_nan', True)) + force_all_finite = not _safe_tags(self, key="allow_nan", default=True) + X = check_array( + X, + dtype=None, + accept_sparse="csr", + force_all_finite=force_all_finite, + ) mask = self.get_support() if not mask.any(): warn("No features were selected: either the data is" diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index f604aa71a04e9..a19fc78cee7ac 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -7,6 +7,7 @@ from ._base import SelectorMixin from ._base import _get_feature_importances from ..base import BaseEstimator, clone, MetaEstimatorMixin +from ..utils import _safe_tags from ..utils.validation import check_is_fitted from ..exceptions import NotFittedError @@ -283,9 +284,7 @@ def n_features_in_(self): return self.estimator_.n_features_in_ def _more_tags(self): - if (hasattr(self.estimator, '_get_tags') and - callable(self.estimator._get_tags)): - allow_nan_tag = self.estimator._get_tags().get('pairwise', True) - else: - allow_nan_tag = True - return {'allow_nan': allow_nan_tag} + return { + 'allow_nan': + _safe_tags(self.estimator, key="pairwise", default=True) + } diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index f083186c5860c..f24199cc4b054 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -782,12 +782,10 @@ def _pairwise(self): def _more_tags(self): """Indicate if wrapped estimator is using a precomputed Gram matrix""" - if (hasattr(self.estimator, '_get_tags') and - callable(self.estimator._get_tags)): - pairwise_tag = self.estimator._get_tags().get('pairwise', True) - else: - pairwise_tag = True - return {'pairwise': pairwise_tag} + return { + 'pairwise': + _safe_tags(self.estimator, key="pairwise", default=True) + } class OutputCodeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 6d1c52406f112..ef24e8478d912 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1206,7 +1206,7 @@ def is_abstract(c): return sorted(set(estimators), key=itemgetter(0)) -def _safe_tags(estimator, key=None): +def _safe_tags(estimator, key=None, default=None): """Safely get estimator tags for common checks. :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. @@ -1219,6 +1219,10 @@ def _safe_tags(estimator, key=None): The estimator from which to get the tag. key : str, default=None Tag name to get. By default (`None`), all tags are returned. + default : list of {str, dtype} or bool, default=None + When `key is not None`, if the tag was not set in the estimator, the + default value set in `sklearn.utils._DEFAULT_TAGS` will be returned. + `default` allows to overwrite the default value. Returns ------- @@ -1232,5 +1236,7 @@ def _safe_tags(estimator, key=None): return {key: tags.get(key, _DEFAULT_TAGS[key]) for key in _DEFAULT_TAGS.keys()} if key is not None: - return _DEFAULT_TAGS[key] + if default is None: + return _DEFAULT_TAGS[key] + return default return _DEFAULT_TAGS From fdf1011316945bd3560e6ecfc25091bfe4e8cade Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 16:33:48 +0100 Subject: [PATCH 12/50] iter --- sklearn/feature_selection/_rfe.py | 20 +++++++++++--------- sklearn/feature_selection/_sequential.py | 9 +++++---- sklearn/feature_selection/tests/test_rfe.py | 5 +---- sklearn/utils/__init__.py | 2 ++ 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index 9e6912792e837..a7f0a0c47a3c2 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -10,6 +10,7 @@ import numbers from joblib import Parallel, effective_n_jobs +from ..utils import _safe_tags from ..utils.metaestimators import if_delegate_has_method from ..utils.metaestimators import _safe_split from ..utils.validation import check_is_fitted @@ -187,11 +188,11 @@ def _fit(self, X, y, step_score=None): # and is used when implementing RFECV # self.scores_ will not be calculated when calling _fit through fit - tags = self._get_tags() + force_all_finite = not _safe_tags(self, key="allow_nan", default=True) X, y = self._validate_data( X, y, accept_sparse="csc", ensure_min_features=2, - force_all_finite=not tags.get('allow_nan', True), + force_all_finite=force_all_finite, multi_output=True ) error_msg = ("n_features_to_select must be either None, a " @@ -371,11 +372,12 @@ def predict_log_proba(self, X): return self.estimator_.predict_log_proba(self.transform(X)) def _more_tags(self): - estimator_tags = self.estimator._get_tags() - return {'poor_score': True, - 'allow_nan': estimator_tags.get('allow_nan', True), - 'requires_y': True, - } + return { + 'poor_score': True, + 'allow_nan': + _safe_tags(self.estimator, key='allow_nan', default=True), + 'requires_y': True, + } class RFECV(RFE): @@ -556,10 +558,10 @@ def fit(self, X, y, groups=None): .. versionadded:: 0.20 """ - tags = self._get_tags() + force_all_finite = not _safe_tags(self, key="allow_nan", default=True) X, y = self._validate_data( X, y, accept_sparse="csr", ensure_min_features=2, - force_all_finite=not tags.get('allow_nan', True), + force_all_finite=force_all_finite, multi_output=True ) diff --git a/sklearn/feature_selection/_sequential.py b/sklearn/feature_selection/_sequential.py index 545cde6a5cfef..9c66befc9903d 100644 --- a/sklearn/feature_selection/_sequential.py +++ b/sklearn/feature_selection/_sequential.py @@ -7,6 +7,7 @@ from ._base import SelectorMixin from ..base import BaseEstimator, MetaEstimatorMixin, clone +from ..utils import _safe_tags from ..utils.validation import check_is_fitted from ..model_selection import cross_val_score @@ -129,11 +130,11 @@ def fit(self, X, y): self : object """ - tags = self._get_tags() + force_all_finite = not _safe_tags(self, key="allow_nan", default=True) X, y = self._validate_data( X, y, accept_sparse="csc", ensure_min_features=2, - force_all_finite=not tags.get('allow_nan', True), + force_all_finite=force_all_finite, multi_output=True ) n_features = X.shape[1] @@ -207,8 +208,8 @@ def _get_support_mask(self): return self.support_ def _more_tags(self): - estimator_tags = self.estimator._get_tags() return { - 'allow_nan': estimator_tags.get('allow_nan', True), + 'allow_nan': + _safe_tags(self.estimator, key="allow_nan", default=True), 'requires_y': True, } diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 6cbaeffb12997..d0511b1b5bfb6 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -448,10 +448,7 @@ def test_rfe_importance_getter_validation(importance_getter, err_type, model.fit(X, y) -@pytest.mark.parametrize("cv", [ - None, - 5 -]) +@pytest.mark.parametrize("cv", [None, 5]) def test_rfe_allow_nan_inf_in_x(cv): iris = load_iris() X = iris.data diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index ef24e8478d912..10ae1d121fe9f 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1231,6 +1231,8 @@ def _safe_tags(estimator, key=None, default=None): """ if hasattr(estimator, "_get_tags"): if key is not None: + if default is not None: + return default return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) tags = estimator._get_tags() return {key: tags.get(key, _DEFAULT_TAGS[key]) From b9b2331cae9eca5cbbd56a9d11777ea4597869bf Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 17:09:55 +0100 Subject: [PATCH 13/50] iter --- sklearn/feature_selection/_from_model.py | 2 +- sklearn/feature_selection/tests/test_from_model.py | 5 +++-- sklearn/impute/tests/test_knn.py | 3 ++- sklearn/linear_model/_glm/tests/test_glm.py | 3 ++- sklearn/linear_model/tests/test_coordinate_descent.py | 3 ++- sklearn/model_selection/tests/test_search.py | 3 ++- sklearn/neighbors/_base.py | 10 +++++++--- sklearn/preprocessing/tests/test_data.py | 7 +++++-- sklearn/preprocessing/tests/test_encoders.py | 3 ++- sklearn/tests/test_base.py | 11 ++++++----- sklearn/tests/test_docstring_parameters.py | 9 ++++++--- sklearn/tests/test_multiclass.py | 11 +++++++---- sklearn/utils/__init__.py | 8 +++----- 13 files changed, 48 insertions(+), 30 deletions(-) diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index a19fc78cee7ac..a2d1c41e8825c 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -286,5 +286,5 @@ def n_features_in_(self): def _more_tags(self): return { 'allow_nan': - _safe_tags(self.estimator, key="pairwise", default=True) + _safe_tags(self.estimator, key="allow_nan", default=True) } diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 37b5c105e1daa..167dced716872 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -1,6 +1,7 @@ import pytest import numpy as np +from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_allclose @@ -364,11 +365,11 @@ def test_transform_accepts_nan_inf(): def test_allow_nan_tag_comes_from_estimator(): allow_nan_est = NaNTag() model = SelectFromModel(estimator=allow_nan_est) - assert model._get_tags()['allow_nan'] is True + assert _safe_tags(model, key='allow_nan') is True no_nan_est = NoNaNTag() model = SelectFromModel(estimator=no_nan_est) - assert model._get_tags()['allow_nan'] is False + assert _safe_tags(model, key='allow_nan') is False def _pca_importances(pca_estimator): diff --git a/sklearn/impute/tests/test_knn.py b/sklearn/impute/tests/test_knn.py index 68c4d9f3cc54a..b418c704c15e7 100644 --- a/sklearn/impute/tests/test_knn.py +++ b/sklearn/impute/tests/test_knn.py @@ -6,6 +6,7 @@ from sklearn.metrics.pairwise import nan_euclidean_distances from sklearn.metrics.pairwise import pairwise_distances from sklearn.neighbors import KNeighborsRegressor +from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_allclose @@ -638,4 +639,4 @@ def test_knn_imputer_distance_weighted_not_enough_neighbors(na, @pytest.mark.parametrize("na, allow_nan", [(-1, False), (np.nan, True)]) def test_knn_tags(na, allow_nan): knn = KNNImputer(missing_values=na) - assert knn._get_tags()["allow_nan"] == allow_nan + assert _safe_tags(knn, key="allow_nan") == allow_nan diff --git a/sklearn/linear_model/_glm/tests/test_glm.py b/sklearn/linear_model/_glm/tests/test_glm.py index ece8f09c76acd..287e6169eacef 100644 --- a/sklearn/linear_model/_glm/tests/test_glm.py +++ b/sklearn/linear_model/_glm/tests/test_glm.py @@ -26,6 +26,7 @@ from sklearn.linear_model import Ridge from sklearn.exceptions import ConvergenceWarning from sklearn.model_selection import train_test_split +from sklearn.utils import _safe_tags @pytest.fixture(scope="module") @@ -428,4 +429,4 @@ def test_tweedie_regression_family(regression_data): ], ) def test_tags(estimator, value): - assert estimator._get_tags()['requires_positive_y'] is value + assert _safe_tags(estimator, key='requires_positive_y') is value diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 0c504dfbd8c6c..7b5eaaad09bac 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -16,6 +16,7 @@ from sklearn.preprocessing import StandardScaler from sklearn.exceptions import ConvergenceWarning +from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_almost_equal @@ -315,7 +316,7 @@ def test_model_pipeline_same_as_normalize_true(LinearModel, params): LinearModel(normalize=False, fit_intercept=True, **params) ) - is_multitask = model_normalize._get_tags().get("multioutput_only", False) + is_multitask = _safe_tags(model_normalize, key="multioutput_only") # prepare the data n_samples, n_features = 100, 2 diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 308d927911eaf..aea01c080e8fe 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -13,6 +13,7 @@ import scipy.sparse as sp import pytest +from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_raises from sklearn.utils._testing import assert_warns from sklearn.utils._testing import assert_warns_message @@ -1954,7 +1955,7 @@ def _more_tags(self): est = TestEstimator() attr_message = "BaseSearchCV pairwise tag must match estimator" cv = GridSearchCV(est, {'n_neighbors': [10]}) - assert pairwise == cv._get_tags()['pairwise'], attr_message + assert pairwise == _safe_tags(cv, key='pairwise'), attr_message # TODO: Remove in 0.26 diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 187e41c242405..5fe1a7a2ce07e 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -23,8 +23,12 @@ from ..base import is_classifier from ..metrics import pairwise_distances_chunked from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS -from ..utils import check_array, gen_even_slices -from ..utils import _to_object_array +from ..utils import ( + check_array, + gen_even_slices, + _safe_tags, + _to_object_array, +) from ..utils.deprecation import deprecated from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted @@ -355,7 +359,7 @@ def _check_algorithm_metric(self): raise ValueError("p must be greater than one for minkowski metric") def _fit(self, X, y=None): - if self._get_tags()["requires_y"]: + if _safe_tags(self, key="requires_y"): if not isinstance(X, (KDTree, BallTree, NeighborsBase)): X, y = self._validate_data(X, y, accept_sparse="csr", multi_output=True) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 4fef462b9d849..0ce7762129acf 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -54,7 +54,10 @@ from sklearn.pipeline import Pipeline from sklearn.model_selection import cross_val_predict from sklearn.svm import SVR -from sklearn.utils import shuffle +from sklearn.utils import ( + shuffle, + _safe_tags, +) from sklearn import datasets @@ -2244,7 +2247,7 @@ def test_cv_pipeline_precomputed(): pipeline = Pipeline([("kernel_centerer", kcent), ("svr", SVR())]) # did the pipeline set the pairwise attribute? - assert pipeline._get_tags()['pairwise'] + assert _safe_tags(pipeline, key='pairwise') # TODO: Remove in 0.26 msg = r"Attribute _pairwise was deprecated in version 0\.24" diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index 213aa85047574..ad8c4f53f2745 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -7,6 +7,7 @@ import pytest from sklearn.exceptions import NotFittedError +from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import _convert_container @@ -827,7 +828,7 @@ def test_categories(density, drop): @pytest.mark.parametrize('Encoder', [OneHotEncoder, OrdinalEncoder]) def test_encoders_has_categorical_tags(Encoder): - assert 'categorical' in Encoder()._get_tags()['X_types'] + assert 'categorical' in _safe_tags(Encoder(), key='X_types') @pytest.mark.parametrize('input_dtype', ['O', 'U']) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index b8d78a96d8e85..ceafdb279897d 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -23,6 +23,7 @@ from sklearn import datasets from sklearn.base import TransformerMixin +from sklearn.utils import _safe_tags from sklearn.utils._mocking import MockDataFrame from sklearn import config_context import pickle @@ -487,17 +488,17 @@ def test_tag_inheritance(): nan_tag_est = NaNTag() no_nan_tag_est = NoNaNTag() - assert nan_tag_est._get_tags()['allow_nan'] - assert not no_nan_tag_est._get_tags()['allow_nan'] + assert _safe_tags(nan_tag_est, key='allow_nan') + assert not _safe_tags(no_nan_tag_est, key='allow_nan') redefine_tags_est = OverrideTag() - assert not redefine_tags_est._get_tags()['allow_nan'] + assert not _safe_tags(redefine_tags_est, key='allow_nan') diamond_tag_est = DiamondOverwriteTag() - assert diamond_tag_est._get_tags()['allow_nan'] + assert _safe_tags(diamond_tag_est, key='allow_nan') inherit_diamond_tag_est = InheritDiamondOverwriteTag() - assert inherit_diamond_tag_est._get_tags()['allow_nan'] + assert _safe_tags(inherit_diamond_tag_est, key='allow_nan') def test_raises_on_get_params_non_attribute(): diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index 8d8399f0cf4da..9c8940bb88753 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -12,7 +12,10 @@ import numpy as np import sklearn -from sklearn.utils import IS_PYPY +from sklearn.utils import ( + IS_PYPY, + _safe_tags, +) from sklearn.utils._testing import check_docstring_parameters from sklearn.utils._testing import _get_func_name from sklearn.utils._testing import ignore_warnings @@ -227,9 +230,9 @@ def test_fit_docstring_attributes(name, Estimator): y = _enforce_estimator_tags_y(est, y) X = _enforce_estimator_tags_x(est, X) - if '1dlabels' in est._get_tags()['X_types']: + if '1dlabels' in _safe_tags(est, key='X_types'): est.fit(y) - elif '2dlabels' in est._get_tags()['X_types']: + elif '2dlabels' in _safe_tags(est, key='X_types'): est.fit(np.c_[y, y]) else: est.fit(X, y) diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index dbf574db6fe4d..766acf53081d6 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -17,8 +17,11 @@ from sklearn.multiclass import OutputCodeClassifier from sklearn.utils.multiclass import (check_classification_targets, type_of_target) -from sklearn.utils import check_array -from sklearn.utils import shuffle +from sklearn.utils import ( + check_array, + shuffle, + _safe_tags, +) from sklearn.metrics import precision_score from sklearn.metrics import recall_score @@ -794,10 +797,10 @@ def test_pairwise_tag(MultiClassClassifier): clf_notprecomputed = svm.SVC() ovr_false = MultiClassClassifier(clf_notprecomputed) - assert not ovr_false._get_tags()['pairwise'] + assert not _safe_tags(ovr_false, key='pairwise') ovr_true = MultiClassClassifier(clf_precomputed) - assert ovr_true._get_tags()['pairwise'] + assert _safe_tags(ovr_true, key='pairwise') # TODO: Remove in 0.26 diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 10ae1d121fe9f..94001bf95702d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1231,14 +1231,12 @@ def _safe_tags(estimator, key=None, default=None): """ if hasattr(estimator, "_get_tags"): if key is not None: - if default is not None: - return default - return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + default = _DEFAULT_TAGS[key] if default is None else default + return estimator._get_tags().get(key, default) tags = estimator._get_tags() return {key: tags.get(key, _DEFAULT_TAGS[key]) for key in _DEFAULT_TAGS.keys()} if key is not None: - if default is None: - return _DEFAULT_TAGS[key] + default = _DEFAULT_TAGS[key] if default is None else default return default return _DEFAULT_TAGS From 4e1f93b10a582c5d97301e8ffff1305f3a33c04d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 17:32:12 +0100 Subject: [PATCH 14/50] TST safe_tags --- sklearn/utils/tests/test_utils.py | 37 +++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 44e448841cef0..9b31e5f21f2d6 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -8,12 +8,14 @@ import numpy as np import scipy.sparse as sp +from sklearn.base import BaseEstimator from sklearn.utils._testing import (assert_array_equal, assert_allclose_dense_sparse, assert_warns_message, assert_no_warnings, _convert_container) from sklearn.utils import check_random_state +from sklearn.utils import _DEFAULT_TAGS from sklearn.utils import _determine_key_type from sklearn.utils import deprecated from sklearn.utils import gen_batches @@ -22,6 +24,7 @@ from sklearn.utils import safe_mask from sklearn.utils import column_or_1d from sklearn.utils import _safe_indexing +from sklearn.utils import _safe_tags from sklearn.utils import shuffle from sklearn.utils import gen_even_slices from sklearn.utils import _message_with_time, _print_elapsed_time @@ -693,3 +696,37 @@ def test_to_object_array(sequence): assert isinstance(out, np.ndarray) assert out.dtype.kind == 'O' assert out.ndim == 1 + + +class NoTags: + pass + + +class BaseEstimatorNotATag(BaseEstimator): + def _get_tags(self): + tags = super()._get_tags().copy() + del tags["allow_nan"] + return tags + + +@pytest.mark.parametrize( + "estimator, key, default, expected_tags", + [ + (NoTags(), None, None, _DEFAULT_TAGS), + (NoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + (NoTags(), "allow_nan", True, True), + (BaseEstimator(), None, None, _DEFAULT_TAGS), + (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + (BaseEstimator(), "allow_nan", True, False), + (BaseEstimatorNotATag(), None, None, _DEFAULT_TAGS), + ( + BaseEstimatorNotATag(), + "allow_nan", + None, + _DEFAULT_TAGS["allow_nan"], + ), + (BaseEstimatorNotATag(), "allow_nan", True, True), + ], +) +def test_safe_tags(estimator, key, default, expected_tags): + assert _safe_tags(estimator, key=key, default=default) == expected_tags From 45ada6b75f4c5186cec9d6861d708f80593a3458 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 17:39:43 +0100 Subject: [PATCH 15/50] whoops --- sklearn/tests/test_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index fc4304c33a420..6ca7e09c05931 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -14,6 +14,7 @@ from inspect import isgenerator from functools import partial +import numpy as np import pytest from sklearn.utils import all_estimators From a51d75bc3054c5efc614f6dad8b57bccaa71ee48 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 17:44:20 +0100 Subject: [PATCH 16/50] TST add test for passthtough in pipeline --- sklearn/tests/test_pipeline.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index bd88f4acd03c3..d9df891d90128 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1264,3 +1264,11 @@ def test_feature_union_warns_unknown_transformer_weight(): union = FeatureUnion(transformer_list, transformer_weights=weights) with pytest.raises(ValueError, match=expected_msg): union.fit(X, y) + + +@pytest.mark.parametrize('passthrough', [None, 'passthrough']) +def test_pipeline_get_tags_none(passthrough): + """Checks that tags are set correctly when passthrough is None or + 'passthrough'""" + pipe = make_pipeline(passthrough, SVC()) + assert not pipe._get_tags()['pairwise'] From 9874b7c9d9df6b06b4f2cdfe9407a10d5cf51805 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Nov 2020 17:52:26 +0100 Subject: [PATCH 17/50] remove outdated test --- sklearn/utils/tests/test_estimator_checks.py | 37 -------------------- 1 file changed, 37 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index d0ceb7a2791a3..0ca547ee8299d 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -23,7 +23,6 @@ from sklearn.utils.estimator_checks import check_regressor_data_not_an_array from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption -from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.fixes import np_version, parse_version from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import LinearRegression, SGDClassifier @@ -619,42 +618,6 @@ def test_check_estimator_pairwise(): check_estimator(est) -class MinimalEstimator: - - # Our minimal required supposed that the following are implemented - _get_param_names = BaseEstimator._get_param_names # used by get_params - set_params = BaseEstimator.set_params - get_params = BaseEstimator.get_params - __setstate__ = BaseEstimator.__setstate__ - __getstate__ = BaseEstimator.__getstate__ - - def fit(self, X, y): - return self - - -class MinimalClassifier(MinimalEstimator): - - def fit(self, X, y): - self.classes_ = np.unique(y) - return super().fit(X, y) - - def predict_proba(self, X): - proba_shape = (len(X), self.classes_.size) - y_proba = np.zeros(shape=proba_shape, dtype=np.float64) - y_proba[:, 0] = 1.0 - return y_proba - - def predict(self, X): - y_proba = self.predict_proba(X) - y_pred = y_proba.argmax(axis=1) - return self.classes_[y_pred] - - -@parametrize_with_checks([MinimalClassifier()], strict_mode=False) -def test_check_estimator_minimal(estimator, check): - check(estimator) - - def test_check_classifier_data_not_an_array(): assert_raises_regex(AssertionError, 'Not equal to tolerance', From 007fa0953b62cfcde06c523765d9cce6f44625fb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 11:45:54 +0100 Subject: [PATCH 18/50] revert _safe_tags when inheriting from BaseEstimator --- sklearn/feature_selection/_rfe.py | 8 ++++---- sklearn/feature_selection/_sequential.py | 5 ++--- .../feature_selection/tests/test_from_model.py | 5 ++--- sklearn/impute/tests/test_knn.py | 3 +-- sklearn/linear_model/_glm/tests/test_glm.py | 17 ++++++++--------- .../tests/test_coordinate_descent.py | 5 ++--- sklearn/model_selection/tests/test_search.py | 3 +-- sklearn/neighbors/_base.py | 3 +-- sklearn/preprocessing/tests/test_data.py | 7 ++----- sklearn/preprocessing/tests/test_encoders.py | 3 +-- sklearn/tests/test_base.py | 11 +++++------ sklearn/tests/test_docstring_parameters.py | 9 +++------ sklearn/tests/test_multiclass.py | 5 ++--- 13 files changed, 34 insertions(+), 50 deletions(-) diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index a7f0a0c47a3c2..034d591a7299d 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -188,11 +188,11 @@ def _fit(self, X, y, step_score=None): # and is used when implementing RFECV # self.scores_ will not be calculated when calling _fit through fit - force_all_finite = not _safe_tags(self, key="allow_nan", default=True) + tags = self._get_tags() X, y = self._validate_data( X, y, accept_sparse="csc", ensure_min_features=2, - force_all_finite=force_all_finite, + force_all_finite=not tags.get("allow_nan", True), multi_output=True ) error_msg = ("n_features_to_select must be either None, a " @@ -558,10 +558,10 @@ def fit(self, X, y, groups=None): .. versionadded:: 0.20 """ - force_all_finite = not _safe_tags(self, key="allow_nan", default=True) + tags = self._get_tags() X, y = self._validate_data( X, y, accept_sparse="csr", ensure_min_features=2, - force_all_finite=force_all_finite, + force_all_finite=not tags.get('allow_nan', True), multi_output=True ) diff --git a/sklearn/feature_selection/_sequential.py b/sklearn/feature_selection/_sequential.py index 9c66befc9903d..17a9fc2a5693a 100644 --- a/sklearn/feature_selection/_sequential.py +++ b/sklearn/feature_selection/_sequential.py @@ -129,12 +129,11 @@ def fit(self, X, y): ------- self : object """ - - force_all_finite = not _safe_tags(self, key="allow_nan", default=True) + tags = self._get_tags() X, y = self._validate_data( X, y, accept_sparse="csc", ensure_min_features=2, - force_all_finite=force_all_finite, + force_all_finite=not tags.get("allow_nan", True), multi_output=True ) n_features = X.shape[1] diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 167dced716872..37b5c105e1daa 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -1,7 +1,6 @@ import pytest import numpy as np -from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_allclose @@ -365,11 +364,11 @@ def test_transform_accepts_nan_inf(): def test_allow_nan_tag_comes_from_estimator(): allow_nan_est = NaNTag() model = SelectFromModel(estimator=allow_nan_est) - assert _safe_tags(model, key='allow_nan') is True + assert model._get_tags()['allow_nan'] is True no_nan_est = NoNaNTag() model = SelectFromModel(estimator=no_nan_est) - assert _safe_tags(model, key='allow_nan') is False + assert model._get_tags()['allow_nan'] is False def _pca_importances(pca_estimator): diff --git a/sklearn/impute/tests/test_knn.py b/sklearn/impute/tests/test_knn.py index b418c704c15e7..68c4d9f3cc54a 100644 --- a/sklearn/impute/tests/test_knn.py +++ b/sklearn/impute/tests/test_knn.py @@ -6,7 +6,6 @@ from sklearn.metrics.pairwise import nan_euclidean_distances from sklearn.metrics.pairwise import pairwise_distances from sklearn.neighbors import KNeighborsRegressor -from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_allclose @@ -639,4 +638,4 @@ def test_knn_imputer_distance_weighted_not_enough_neighbors(na, @pytest.mark.parametrize("na, allow_nan", [(-1, False), (np.nan, True)]) def test_knn_tags(na, allow_nan): knn = KNNImputer(missing_values=na) - assert _safe_tags(knn, key="allow_nan") == allow_nan + assert knn._get_tags()["allow_nan"] == allow_nan diff --git a/sklearn/linear_model/_glm/tests/test_glm.py b/sklearn/linear_model/_glm/tests/test_glm.py index 287e6169eacef..0ddbcf36c7465 100644 --- a/sklearn/linear_model/_glm/tests/test_glm.py +++ b/sklearn/linear_model/_glm/tests/test_glm.py @@ -26,7 +26,6 @@ from sklearn.linear_model import Ridge from sklearn.exceptions import ConvergenceWarning from sklearn.model_selection import train_test_split -from sklearn.utils import _safe_tags @pytest.fixture(scope="module") @@ -420,13 +419,13 @@ def test_tweedie_regression_family(regression_data): @pytest.mark.parametrize( - 'estimator, value', - [ - (PoissonRegressor(), True), - (GammaRegressor(), True), - (TweedieRegressor(power=1.5), True), - (TweedieRegressor(power=0), False) - ], + 'estimator, value', + [ + (PoissonRegressor(), True), + (GammaRegressor(), True), + (TweedieRegressor(power=1.5), True), + (TweedieRegressor(power=0), False), + ], ) def test_tags(estimator, value): - assert _safe_tags(estimator, key='requires_positive_y') is value + assert estimator.get_tags()['requires_positive_y'] is value diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 7b5eaaad09bac..f9a7efc987699 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -16,7 +16,6 @@ from sklearn.preprocessing import StandardScaler from sklearn.exceptions import ConvergenceWarning -from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_almost_equal @@ -301,7 +300,7 @@ def test_lasso_cv_positive_constraint(): (Lars, {}), (LinearRegression, {}), (LassoLarsIC, {})] - ) +) def test_model_pipeline_same_as_normalize_true(LinearModel, params): # Test that linear models (LinearModel) set with normalize set to True are # doing the same as the same linear model preceeded by StandardScaler @@ -316,7 +315,7 @@ def test_model_pipeline_same_as_normalize_true(LinearModel, params): LinearModel(normalize=False, fit_intercept=True, **params) ) - is_multitask = _safe_tags(model_normalize, key="multioutput_only") + is_multitask = model_normalize._get_tags()["multioutput_only"] # prepare the data n_samples, n_features = 100, 2 diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index aea01c080e8fe..308d927911eaf 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -13,7 +13,6 @@ import scipy.sparse as sp import pytest -from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_raises from sklearn.utils._testing import assert_warns from sklearn.utils._testing import assert_warns_message @@ -1955,7 +1954,7 @@ def _more_tags(self): est = TestEstimator() attr_message = "BaseSearchCV pairwise tag must match estimator" cv = GridSearchCV(est, {'n_neighbors': [10]}) - assert pairwise == _safe_tags(cv, key='pairwise'), attr_message + assert pairwise == cv._get_tags()['pairwise'], attr_message # TODO: Remove in 0.26 diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 5fe1a7a2ce07e..1e666043347cf 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -26,7 +26,6 @@ from ..utils import ( check_array, gen_even_slices, - _safe_tags, _to_object_array, ) from ..utils.deprecation import deprecated @@ -359,7 +358,7 @@ def _check_algorithm_metric(self): raise ValueError("p must be greater than one for minkowski metric") def _fit(self, X, y=None): - if _safe_tags(self, key="requires_y"): + if self._get_tags()["requires_y"]: if not isinstance(X, (KDTree, BallTree, NeighborsBase)): X, y = self._validate_data(X, y, accept_sparse="csr", multi_output=True) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 0ce7762129acf..4fef462b9d849 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -54,10 +54,7 @@ from sklearn.pipeline import Pipeline from sklearn.model_selection import cross_val_predict from sklearn.svm import SVR -from sklearn.utils import ( - shuffle, - _safe_tags, -) +from sklearn.utils import shuffle from sklearn import datasets @@ -2247,7 +2244,7 @@ def test_cv_pipeline_precomputed(): pipeline = Pipeline([("kernel_centerer", kcent), ("svr", SVR())]) # did the pipeline set the pairwise attribute? - assert _safe_tags(pipeline, key='pairwise') + assert pipeline._get_tags()['pairwise'] # TODO: Remove in 0.26 msg = r"Attribute _pairwise was deprecated in version 0\.24" diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index ad8c4f53f2745..213aa85047574 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -7,7 +7,6 @@ import pytest from sklearn.exceptions import NotFittedError -from sklearn.utils import _safe_tags from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import _convert_container @@ -828,7 +827,7 @@ def test_categories(density, drop): @pytest.mark.parametrize('Encoder', [OneHotEncoder, OrdinalEncoder]) def test_encoders_has_categorical_tags(Encoder): - assert 'categorical' in _safe_tags(Encoder(), key='X_types') + assert 'categorical' in Encoder()._get_tags()['X_types'] @pytest.mark.parametrize('input_dtype', ['O', 'U']) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index ceafdb279897d..b8d78a96d8e85 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -23,7 +23,6 @@ from sklearn import datasets from sklearn.base import TransformerMixin -from sklearn.utils import _safe_tags from sklearn.utils._mocking import MockDataFrame from sklearn import config_context import pickle @@ -488,17 +487,17 @@ def test_tag_inheritance(): nan_tag_est = NaNTag() no_nan_tag_est = NoNaNTag() - assert _safe_tags(nan_tag_est, key='allow_nan') - assert not _safe_tags(no_nan_tag_est, key='allow_nan') + assert nan_tag_est._get_tags()['allow_nan'] + assert not no_nan_tag_est._get_tags()['allow_nan'] redefine_tags_est = OverrideTag() - assert not _safe_tags(redefine_tags_est, key='allow_nan') + assert not redefine_tags_est._get_tags()['allow_nan'] diamond_tag_est = DiamondOverwriteTag() - assert _safe_tags(diamond_tag_est, key='allow_nan') + assert diamond_tag_est._get_tags()['allow_nan'] inherit_diamond_tag_est = InheritDiamondOverwriteTag() - assert _safe_tags(inherit_diamond_tag_est, key='allow_nan') + assert inherit_diamond_tag_est._get_tags()['allow_nan'] def test_raises_on_get_params_non_attribute(): diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index 9c8940bb88753..8f09f2baf902f 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -12,10 +12,7 @@ import numpy as np import sklearn -from sklearn.utils import ( - IS_PYPY, - _safe_tags, -) +from sklearn.utils import IS_PYPY from sklearn.utils._testing import check_docstring_parameters from sklearn.utils._testing import _get_func_name from sklearn.utils._testing import ignore_warnings @@ -230,9 +227,9 @@ def test_fit_docstring_attributes(name, Estimator): y = _enforce_estimator_tags_y(est, y) X = _enforce_estimator_tags_x(est, X) - if '1dlabels' in _safe_tags(est, key='X_types'): + if "1dlabels" in est._get_tags()["X_types"]: est.fit(y) - elif '2dlabels' in _safe_tags(est, key='X_types'): + elif "2dlabels" in est._get_tags()["X_types"]: est.fit(np.c_[y, y]) else: est.fit(X, y) diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 766acf53081d6..b60a85c7bde00 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -20,7 +20,6 @@ from sklearn.utils import ( check_array, shuffle, - _safe_tags, ) from sklearn.metrics import precision_score @@ -797,10 +796,10 @@ def test_pairwise_tag(MultiClassClassifier): clf_notprecomputed = svm.SVC() ovr_false = MultiClassClassifier(clf_notprecomputed) - assert not _safe_tags(ovr_false, key='pairwise') + assert not ovr_false._get_tags()["pairwise"] ovr_true = MultiClassClassifier(clf_precomputed) - assert _safe_tags(ovr_true, key='pairwise') + assert ovr_true._get_tags()["pairwise"] # TODO: Remove in 0.26 From 6839f581e141ccbc2bad01d4de93974126c4a7de Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 12:36:22 +0100 Subject: [PATCH 19/50] add test _safe_tags --- sklearn/utils/__init__.py | 48 +++++++++++++++++++------- sklearn/utils/tests/test_utils.py | 56 ++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 94001bf95702d..366523907c810 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1217,12 +1217,16 @@ def _safe_tags(estimator, key=None, default=None): ---------- estimator : estimator object The estimator from which to get the tag. + key : str, default=None Tag name to get. By default (`None`), all tags are returned. + default : list of {str, dtype} or bool, default=None - When `key is not None`, if the tag was not set in the estimator, the - default value set in `sklearn.utils._DEFAULT_TAGS` will be returned. - `default` allows to overwrite the default value. + For estimator not implementing `_get_tags`, `default` allows to define + the `default` value of a tag if it is not present in `_DEFAULT_TAGS` or + to overwrite the value in `_DEFAULT_TAGS` if it the tag is defined. + When `default is None`, no default values nor overwriting will take + place. Returns ------- @@ -1231,12 +1235,32 @@ def _safe_tags(estimator, key=None, default=None): """ if hasattr(estimator, "_get_tags"): if key is not None: - default = _DEFAULT_TAGS[key] if default is None else default - return estimator._get_tags().get(key, default) - tags = estimator._get_tags() - return {key: tags.get(key, _DEFAULT_TAGS[key]) - for key in _DEFAULT_TAGS.keys()} - if key is not None: - default = _DEFAULT_TAGS[key] if default is None else default - return default - return _DEFAULT_TAGS + try: + return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + except KeyError as exc: + raise ValueError( + f"The key {key} is neither defined in _more_tags() in the " + f"class {repr(estimator)} nor a default estimator key in " + f"_DEFAULT_TAGS: {repr(_DEFAULT_TAGS.keys())}" + ) from exc + else: + tags = estimator._get_tags() + return { + key: tags.get(key, _DEFAULT_TAGS[key]) + for key in _DEFAULT_TAGS.keys() + } + else: + if key is not None: + if default is None: + try: + default = _DEFAULT_TAGS[key] + except KeyError as exc: + raise ValueError( + f"The key {key} is not a default tags defined in " + f"_DEFAULT_TAGS and thus no default values are " + f"available. Use the parameter default if you want to " + f"define a default value." + ) from exc + return default + else: + return _DEFAULT_TAGS diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 9b31e5f21f2d6..ae8c8c386b2ff 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -698,35 +698,51 @@ def test_to_object_array(sequence): assert out.ndim == 1 -class NoTags: +class EstimatorNoTags: pass -class BaseEstimatorNotATag(BaseEstimator): - def _get_tags(self): - tags = super()._get_tags().copy() - del tags["allow_nan"] - return tags +@pytest.mark.parametrize( + "estimator, err_msg", + [ + (BaseEstimator(), "The key xxx is neither defined in _more_tags"), + (EstimatorNoTags(), "The key xxx is not a default tags defined"), + ], +) +def test_safe_tags_error(estimator, err_msg): + # Check that safe_tags raises error in ambiguous case. + with pytest.raises(ValueError, match=err_msg): + _safe_tags(estimator, key="xxx") + + +@pytest.mark.parametrize( + "estimator, key, expected_results", + [ + (BaseEstimator(), None, _DEFAULT_TAGS), + (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), + ], +) +def test_safe_tags_implement_get_tags(estimator, key, expected_results): + assert _safe_tags(estimator, key=key) == expected_results @pytest.mark.parametrize( - "estimator, key, default, expected_tags", + "estimator, key, default, expected_results", [ - (NoTags(), None, None, _DEFAULT_TAGS), - (NoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - (NoTags(), "allow_nan", True, True), - (BaseEstimator(), None, None, _DEFAULT_TAGS), - (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - (BaseEstimator(), "allow_nan", True, False), - (BaseEstimatorNotATag(), None, None, _DEFAULT_TAGS), + (EstimatorNoTags(), None, None, _DEFAULT_TAGS), + (EstimatorNoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + # Overwrite the default tags value ( - BaseEstimatorNotATag(), + EstimatorNoTags(), "allow_nan", - None, - _DEFAULT_TAGS["allow_nan"], + not _DEFAULT_TAGS["allow_nan"], + not _DEFAULT_TAGS["allow_nan"], ), - (BaseEstimatorNotATag(), "allow_nan", True, True), + # Define a default value for unknown tags + (EstimatorNoTags(), "xxx", True, True), ], ) -def test_safe_tags(estimator, key, default, expected_tags): - assert _safe_tags(estimator, key=key, default=default) == expected_tags +def test_safe_tags_no_get_tags(estimator, key, default, expected_results): + # check the behaviour of _safe_tags when an estimator does not implement + # _get_tags + assert _safe_tags(estimator, key=key, default=default) == expected_results From 5d737305f2db8486485e4d923c9e04172c2fe2e0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 12:50:13 +0100 Subject: [PATCH 20/50] iter --- sklearn/utils/__init__.py | 29 ++++++++++++++++------------- sklearn/utils/tests/test_utils.py | 27 ++++++++++++++++++--------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 366523907c810..1e205ce985200 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1222,11 +1222,10 @@ def _safe_tags(estimator, key=None, default=None): Tag name to get. By default (`None`), all tags are returned. default : list of {str, dtype} or bool, default=None - For estimator not implementing `_get_tags`, `default` allows to define - the `default` value of a tag if it is not present in `_DEFAULT_TAGS` or - to overwrite the value in `_DEFAULT_TAGS` if it the tag is defined. - When `default is None`, no default values nor overwriting will take - place. + `default` allows to define the `default` value of a tag if it is not + present in `_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` + if it the tag is defined. When `default is None`, no default values nor + overwriting will take place. Returns ------- @@ -1235,14 +1234,18 @@ def _safe_tags(estimator, key=None, default=None): """ if hasattr(estimator, "_get_tags"): if key is not None: - try: - return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) - except KeyError as exc: - raise ValueError( - f"The key {key} is neither defined in _more_tags() in the " - f"class {repr(estimator)} nor a default estimator key in " - f"_DEFAULT_TAGS: {repr(_DEFAULT_TAGS.keys())}" - ) from exc + if default is None: + try: + return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + except KeyError as exc: + raise ValueError( + f"The key {key} is neither defined in _more_tags() in " + f"the class {repr(estimator)} nor a default estimator " + f"key in _DEFAULT_TAGS. Use the parameter default if " + f"you want to define a default value." + ) from exc + else: + return default else: tags = estimator._get_tags() return { diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index ae8c8c386b2ff..1c00f1191e949 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -715,22 +715,24 @@ def test_safe_tags_error(estimator, err_msg): _safe_tags(estimator, key="xxx") -@pytest.mark.parametrize( - "estimator, key, expected_results", - [ - (BaseEstimator(), None, _DEFAULT_TAGS), - (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), - ], -) -def test_safe_tags_implement_get_tags(estimator, key, expected_results): - assert _safe_tags(estimator, key=key) == expected_results +# @pytest.mark.parametrize( +# "estimator, key, default, expected_results", +# [ +# (BaseEstimator(), None, _DEFAULT_TAGS), +# (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), +# ], +# ) +# def test_safe_tags_implement_get_tags(estimator, key, expected_results): +# assert _safe_tags(estimator, key=key) == expected_results @pytest.mark.parametrize( "estimator, key, default, expected_results", [ (EstimatorNoTags(), None, None, _DEFAULT_TAGS), + (BaseEstimator(), None, None, _DEFAULT_TAGS), (EstimatorNoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), # Overwrite the default tags value ( EstimatorNoTags(), @@ -738,8 +740,15 @@ def test_safe_tags_implement_get_tags(estimator, key, expected_results): not _DEFAULT_TAGS["allow_nan"], not _DEFAULT_TAGS["allow_nan"], ), + ( + BaseEstimator(), + "allow_nan", + not _DEFAULT_TAGS["allow_nan"], + not _DEFAULT_TAGS["allow_nan"], + ), # Define a default value for unknown tags (EstimatorNoTags(), "xxx", True, True), + (BaseEstimator(), "xxx", True, True), ], ) def test_safe_tags_no_get_tags(estimator, key, default, expected_results): From 5021fbc2389154460c6b9eedd483cbfe88f63945 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 12:54:40 +0100 Subject: [PATCH 21/50] iter --- sklearn/utils/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 1e205ce985200..8c625ba0fddd8 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1245,6 +1245,9 @@ def _safe_tags(estimator, key=None, default=None): f"you want to define a default value." ) from exc else: + tags = estimator._get_tags() + if key in tags: + raise ValueError("xxxx") return default else: tags = estimator._get_tags() From ddf6c796a3f29b2dc15270e49d3bf1df21bdce3b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 13:06:17 +0100 Subject: [PATCH 22/50] iter --- sklearn/utils/__init__.py | 8 +++---- sklearn/utils/tests/test_utils.py | 37 ++++++++++++++++--------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 8c625ba0fddd8..e7e2c7c8a06f4 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1225,7 +1225,8 @@ def _safe_tags(estimator, key=None, default=None): `default` allows to define the `default` value of a tag if it is not present in `_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` if it the tag is defined. When `default is None`, no default values nor - overwriting will take place. + overwriting will take place. If `default is not None` but that the + tag is defined, `default` will be discarded. Returns ------- @@ -1245,10 +1246,7 @@ def _safe_tags(estimator, key=None, default=None): f"you want to define a default value." ) from exc else: - tags = estimator._get_tags() - if key in tags: - raise ValueError("xxxx") - return default + return estimator._get_tags().get(key, default) else: tags = estimator._get_tags() return { diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 1c00f1191e949..21aa5bb9688ba 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -702,10 +702,15 @@ class EstimatorNoTags: pass +class EstimatorOwnTags: + def _get_tags(self): + return {} + + @pytest.mark.parametrize( "estimator, err_msg", [ - (BaseEstimator(), "The key xxx is neither defined in _more_tags"), + (BaseEstimator(), "The key xxx is neither defined"), (EstimatorNoTags(), "The key xxx is not a default tags defined"), ], ) @@ -715,40 +720,36 @@ def test_safe_tags_error(estimator, err_msg): _safe_tags(estimator, key="xxx") -# @pytest.mark.parametrize( -# "estimator, key, default, expected_results", -# [ -# (BaseEstimator(), None, _DEFAULT_TAGS), -# (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), -# ], -# ) -# def test_safe_tags_implement_get_tags(estimator, key, expected_results): -# assert _safe_tags(estimator, key=key) == expected_results - - @pytest.mark.parametrize( "estimator, key, default, expected_results", [ (EstimatorNoTags(), None, None, _DEFAULT_TAGS), - (BaseEstimator(), None, None, _DEFAULT_TAGS), (EstimatorNoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - # Overwrite the default tags value ( EstimatorNoTags(), "allow_nan", not _DEFAULT_TAGS["allow_nan"], not _DEFAULT_TAGS["allow_nan"], ), + (EstimatorNoTags(), "xxx", True, True), + (BaseEstimator(), None, None, _DEFAULT_TAGS), + (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), ( BaseEstimator(), "allow_nan", not _DEFAULT_TAGS["allow_nan"], - not _DEFAULT_TAGS["allow_nan"], + _DEFAULT_TAGS["allow_nan"], ), - # Define a default value for unknown tags - (EstimatorNoTags(), "xxx", True, True), (BaseEstimator(), "xxx", True, True), + (EstimatorOwnTags(), None, None, _DEFAULT_TAGS), + (EstimatorOwnTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + ( + EstimatorOwnTags(), + "allow_nan", + not _DEFAULT_TAGS["allow_nan"], + not _DEFAULT_TAGS["allow_nan"], + ), + (EstimatorOwnTags(), "xxx", True, True), ], ) def test_safe_tags_no_get_tags(estimator, key, default, expected_results): From 369e7b72f7c9cc9cfab3db21853193a95501a87f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 13:22:04 +0100 Subject: [PATCH 23/50] fix --- sklearn/linear_model/_glm/tests/test_glm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/_glm/tests/test_glm.py b/sklearn/linear_model/_glm/tests/test_glm.py index 0ddbcf36c7465..d6fc4e14b12fa 100644 --- a/sklearn/linear_model/_glm/tests/test_glm.py +++ b/sklearn/linear_model/_glm/tests/test_glm.py @@ -428,4 +428,4 @@ def test_tweedie_regression_family(regression_data): ], ) def test_tags(estimator, value): - assert estimator.get_tags()['requires_positive_y'] is value + assert estimator._get_tags()['requires_positive_y'] is value From 61cdfcab7ec482337e7e6ba2cfb68312c309b3ca Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 14:59:39 +0100 Subject: [PATCH 24/50] mark as xfail --- sklearn/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 6ca7e09c05931..52f25b977cf70 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -483,7 +483,7 @@ def _generate_minimal_compatible_instances(): # FIXME: hopefully in 0.25 -@pytest.mark.skip( +@pytest.mark.xfail( reason=("This test is currently failing because checks are granular " "enough. Once checks are split with some kind of only API tests, " "this test should enabled.") From 5ba8c0c4792cac11b18696563c81c8a03240ce84 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 15:56:48 +0100 Subject: [PATCH 25/50] cover transformer set_params --- sklearn/tests/test_common.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 52f25b977cf70..cd48db1a4a929 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -473,8 +473,14 @@ def _generate_minimal_compatible_instances(): for SearchCV, (Estimator, param_grid) in zip( [GridSearchCV, RandomizedSearchCV], [ - (MinimalRegressor, {"minimalregressor__param": [1, 10]}), - (MinimalClassifier, {"minimalclassifier__param": [1, 10]}), + (MinimalRegressor, { + "minimaltransformer__param": [1, 10], + "minimalregressor__param": [1, 10], + }), + (MinimalClassifier, { + "minimaltransformer__param": [1, 10], + "minimalclassifier__param": [1, 10], + }), ], ): yield SearchCV( From 41bc206893205ec1460b483243a8606c7609d285 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 15:59:52 +0100 Subject: [PATCH 26/50] reduce check in set_params --- sklearn/tests/test_common.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index cd48db1a4a929..adefc106edce8 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -322,10 +322,7 @@ def get_params(self, deep=True): return {"param": self.param} def set_params(self, **params): - valid_params = self.get_params() for key, value in params.items(): - if key not in valid_params: - raise ValueError("Wrong params") setattr(self, key, value) return self @@ -379,10 +376,7 @@ def get_params(self, deep=True): return {"param": self.param} def set_params(self, **params): - valid_params = self.get_params() for key, value in params.items(): - if key not in valid_params: - raise ValueError("Wrong params") setattr(self, key, value) return self @@ -426,10 +420,7 @@ def get_params(self, deep=True): return {"param": self.param} def set_params(self, **params): - valid_params = self.get_params() for key, value in params.items(): - if key not in valid_params: - raise ValueError("Wrong params") setattr(self, key, value) return self @@ -451,9 +442,6 @@ def transform(self, X, y=None): raise ValueError return X - def inverse_transform(self, X, y=None): - return self.transform(X) - def fit_transform(self, X, y=None): return self.fit(X, y).transform(X, y) From 739d084b6bdaeb491c4bf4dcf10218393ad03ac8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 21:48:19 +0100 Subject: [PATCH 27/50] MNT move _sage_tags in _tags module --- sklearn/base.py | 4 +- sklearn/feature_selection/_base.py | 2 +- sklearn/feature_selection/_from_model.py | 2 +- sklearn/feature_selection/_rfe.py | 3 +- sklearn/feature_selection/_sequential.py | 2 +- sklearn/model_selection/_search.py | 6 +- sklearn/multiclass.py | 6 +- sklearn/pipeline.py | 2 +- sklearn/utils/__init__.py | 90 +----------------------- sklearn/utils/_tags.py | 86 ++++++++++++++++++++++ sklearn/utils/estimator_checks.py | 6 +- sklearn/utils/tests/test_tags.py | 67 ++++++++++++++++++ sklearn/utils/tests/test_utils.py | 63 ----------------- 13 files changed, 168 insertions(+), 171 deletions(-) create mode 100644 sklearn/utils/_tags.py create mode 100644 sklearn/utils/tests/test_tags.py diff --git a/sklearn/base.py b/sklearn/base.py index b850e7668cea2..3d49ec4fe96f6 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -14,9 +14,9 @@ from . import __version__ from ._config import get_config -from .utils import ( +from .utils import _IS_32BIT +from .utils._tags import ( _DEFAULT_TAGS, - _IS_32BIT, _safe_tags, ) from .utils.validation import check_X_y diff --git a/sklearn/feature_selection/_base.py b/sklearn/feature_selection/_base.py index ae0047cf8f0c9..705a7bcc15515 100644 --- a/sklearn/feature_selection/_base.py +++ b/sklearn/feature_selection/_base.py @@ -16,8 +16,8 @@ check_array, safe_mask, safe_sqr, - _safe_tags, ) +from ..utils._tags import _safe_tags class SelectorMixin(TransformerMixin, metaclass=ABCMeta): diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index a2d1c41e8825c..730f6fac6833e 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -7,7 +7,7 @@ from ._base import SelectorMixin from ._base import _get_feature_importances from ..base import BaseEstimator, clone, MetaEstimatorMixin -from ..utils import _safe_tags +from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..exceptions import NotFittedError diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index 034d591a7299d..a76623659631f 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -10,9 +10,10 @@ import numbers from joblib import Parallel, effective_n_jobs -from ..utils import _safe_tags + from ..utils.metaestimators import if_delegate_has_method from ..utils.metaestimators import _safe_split +from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..utils.validation import _deprecate_positional_args from ..utils.fixes import delayed diff --git a/sklearn/feature_selection/_sequential.py b/sklearn/feature_selection/_sequential.py index 17a9fc2a5693a..6edeee74238a5 100644 --- a/sklearn/feature_selection/_sequential.py +++ b/sklearn/feature_selection/_sequential.py @@ -7,7 +7,7 @@ from ._base import SelectorMixin from ..base import BaseEstimator, MetaEstimatorMixin, clone -from ..utils import _safe_tags +from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..model_selection import cross_val_score diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index c028872b2fbe0..51f43debf78ed 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -33,11 +33,9 @@ from ._validation import _normalize_score_results from ..exceptions import NotFittedError from joblib import Parallel -from ..utils import ( - check_random_state, - _safe_tags, -) +from ..utils import check_random_state from ..utils.random import sample_without_replacement +from ..utils._tags import _safe_tags from ..utils.validation import indexable, check_is_fitted, _check_fit_params from ..utils.validation import _deprecate_positional_args from ..utils.metaestimators import if_delegate_has_method diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index f24199cc4b054..3209b0fc274a3 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -45,11 +45,9 @@ from .base import _is_pairwise from .preprocessing import LabelBinarizer from .metrics.pairwise import euclidean_distances -from .utils import ( - check_random_state, - _safe_tags, -) +from .utils import check_random_state from .utils.deprecation import deprecated +from .utils._tags import _safe_tags from .utils.validation import _num_samples from .utils.validation import check_is_fitted from .utils.validation import check_X_y, check_array diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 597f3e2e88dba..6df8cddc476c4 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -22,9 +22,9 @@ from .utils import ( Bunch, _print_elapsed_time, - _safe_tags, ) from .utils.deprecation import deprecated +from .utils._tags import _safe_tags from .utils.validation import check_memory from .utils.validation import _deprecate_positional_args from .utils.fixes import delayed diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index e7e2c7c8a06f4..ca2be9d14fe29 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -43,27 +43,6 @@ parallel_backend = _joblib.parallel_backend register_parallel_backend = _joblib.register_parallel_backend -_DEFAULT_TAGS = { - 'non_deterministic': False, - 'requires_positive_X': False, - 'requires_positive_y': False, - 'X_types': ['2darray'], - 'poor_score': False, - 'no_validation': False, - 'multioutput': False, - "allow_nan": False, - 'stateless': False, - 'multilabel': False, - '_skip_test': False, - '_xfail_checks': False, - 'multioutput_only': False, - 'binary_only': False, - 'requires_fit': True, - 'preserves_dtype': [np.float64], - 'requires_y': False, - 'pairwise': False, -} - __all__ = ["murmurhash3_32", "as_float_array", "assert_all_finite", "check_array", "check_random_state", @@ -73,10 +52,7 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning", "estimator_html_repr", - "_DEFAULT_TAGS", - "_safe_tags", - ] + "DataConversionWarning", "estimator_html_repr"] IS_PYPY = platform.python_implementation() == 'PyPy' _IS_32BIT = 8 * struct.calcsize("P") == 32 @@ -1204,67 +1180,3 @@ def is_abstract(c): # itemgetter is used to ensure the sort does not extend to the 2nd item of # the tuple return sorted(set(estimators), key=itemgetter(0)) - - -def _safe_tags(estimator, key=None, default=None): - """Safely get estimator tags for common checks. - - :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. - However, if a compatible estimator does not inherit from this base class, - we should default to the default tag. - - Parameters - ---------- - estimator : estimator object - The estimator from which to get the tag. - - key : str, default=None - Tag name to get. By default (`None`), all tags are returned. - - default : list of {str, dtype} or bool, default=None - `default` allows to define the `default` value of a tag if it is not - present in `_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` - if it the tag is defined. When `default is None`, no default values nor - overwriting will take place. If `default is not None` but that the - tag is defined, `default` will be discarded. - - Returns - ------- - tags : dict - The estimator tags. - """ - if hasattr(estimator, "_get_tags"): - if key is not None: - if default is None: - try: - return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) - except KeyError as exc: - raise ValueError( - f"The key {key} is neither defined in _more_tags() in " - f"the class {repr(estimator)} nor a default estimator " - f"key in _DEFAULT_TAGS. Use the parameter default if " - f"you want to define a default value." - ) from exc - else: - return estimator._get_tags().get(key, default) - else: - tags = estimator._get_tags() - return { - key: tags.get(key, _DEFAULT_TAGS[key]) - for key in _DEFAULT_TAGS.keys() - } - else: - if key is not None: - if default is None: - try: - default = _DEFAULT_TAGS[key] - except KeyError as exc: - raise ValueError( - f"The key {key} is not a default tags defined in " - f"_DEFAULT_TAGS and thus no default values are " - f"available. Use the parameter default if you want to " - f"define a default value." - ) from exc - return default - else: - return _DEFAULT_TAGS diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py new file mode 100644 index 0000000000000..efbbb15ffae72 --- /dev/null +++ b/sklearn/utils/_tags.py @@ -0,0 +1,86 @@ +import numpy as np + +_DEFAULT_TAGS = { + 'non_deterministic': False, + 'requires_positive_X': False, + 'requires_positive_y': False, + 'X_types': ['2darray'], + 'poor_score': False, + 'no_validation': False, + 'multioutput': False, + "allow_nan": False, + 'stateless': False, + 'multilabel': False, + '_skip_test': False, + '_xfail_checks': False, + 'multioutput_only': False, + 'binary_only': False, + 'requires_fit': True, + 'preserves_dtype': [np.float64], + 'requires_y': False, + 'pairwise': False, +} + + +def _safe_tags(estimator, key=None, default=None): + """Safely get estimator tags for common checks. + + :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. + However, if a compatible estimator does not inherit from this base class, + we should default to the default tag. + + Parameters + ---------- + estimator : estimator object + The estimator from which to get the tag. + + key : str, default=None + Tag name to get. By default (`None`), all tags are returned. + + default : list of {str, dtype} or bool, default=None + `default` allows to define the `default` value of a tag if it is not + present in `_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` + if it the tag is defined. When `default is None`, no default values nor + overwriting will take place. If `default is not None` but that the + tag is defined, `default` will be discarded. + + Returns + ------- + tags : dict + The estimator tags. + """ + if hasattr(estimator, "_get_tags"): + if key is not None: + if default is None: + try: + return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + except KeyError as exc: + raise ValueError( + f"The key {key} is neither defined in _more_tags() in " + f"the class {repr(estimator)} nor a default estimator " + f"key in _DEFAULT_TAGS. Use the parameter default if " + f"you want to define a default value." + ) from exc + else: + return estimator._get_tags().get(key, default) + else: + tags = estimator._get_tags() + return { + key: tags.get(key, _DEFAULT_TAGS[key]) + for key in _DEFAULT_TAGS.keys() + } + else: + if key is not None: + if default is None: + try: + default = _DEFAULT_TAGS[key] + except KeyError as exc: + raise ValueError( + f"The key {key} is not a default tags defined in " + f"_DEFAULT_TAGS and thus no default values are " + f"available. Use the parameter default if you want to " + f"define a default value." + ) from exc + return default + else: + return _DEFAULT_TAGS diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index ed3fb67ccae47..eeab54054c57e 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -51,10 +51,8 @@ from ..model_selection._validation import _safe_split from ..metrics.pairwise import (rbf_kernel, linear_kernel, pairwise_distances) -from .import ( - shuffle, - _safe_tags, -) +from .import shuffle +from ._tags import _safe_tags from .validation import has_fit_parameter, _num_samples from ..preprocessing import StandardScaler from ..preprocessing import scale diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py new file mode 100644 index 0000000000000..638ede4e39cbe --- /dev/null +++ b/sklearn/utils/tests/test_tags.py @@ -0,0 +1,67 @@ +import pytest + +from sklearn.base import BaseEstimator +from sklearn.utils._tags import ( + _DEFAULT_TAGS, + _safe_tags, +) + + +class EstimatorNoTags: + pass + + +class EstimatorOwnTags: + def _get_tags(self): + return {} + + +@pytest.mark.parametrize( + "estimator, err_msg", + [ + (BaseEstimator(), "The key xxx is neither defined"), + (EstimatorNoTags(), "The key xxx is not a default tags defined"), + ], +) +def test_safe_tags_error(estimator, err_msg): + # Check that safe_tags raises error in ambiguous case. + with pytest.raises(ValueError, match=err_msg): + _safe_tags(estimator, key="xxx") + + +@pytest.mark.parametrize( + "estimator, key, default, expected_results", + [ + (EstimatorNoTags(), None, None, _DEFAULT_TAGS), + (EstimatorNoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + ( + EstimatorNoTags(), + "allow_nan", + not _DEFAULT_TAGS["allow_nan"], + not _DEFAULT_TAGS["allow_nan"], + ), + (EstimatorNoTags(), "xxx", True, True), + (BaseEstimator(), None, None, _DEFAULT_TAGS), + (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + ( + BaseEstimator(), + "allow_nan", + not _DEFAULT_TAGS["allow_nan"], + _DEFAULT_TAGS["allow_nan"], + ), + (BaseEstimator(), "xxx", True, True), + (EstimatorOwnTags(), None, None, _DEFAULT_TAGS), + (EstimatorOwnTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + ( + EstimatorOwnTags(), + "allow_nan", + not _DEFAULT_TAGS["allow_nan"], + not _DEFAULT_TAGS["allow_nan"], + ), + (EstimatorOwnTags(), "xxx", True, True), + ], +) +def test_safe_tags_no_get_tags(estimator, key, default, expected_results): + # check the behaviour of _safe_tags when an estimator does not implement + # _get_tags + assert _safe_tags(estimator, key=key, default=default) == expected_results diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 21aa5bb9688ba..44e448841cef0 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -8,14 +8,12 @@ import numpy as np import scipy.sparse as sp -from sklearn.base import BaseEstimator from sklearn.utils._testing import (assert_array_equal, assert_allclose_dense_sparse, assert_warns_message, assert_no_warnings, _convert_container) from sklearn.utils import check_random_state -from sklearn.utils import _DEFAULT_TAGS from sklearn.utils import _determine_key_type from sklearn.utils import deprecated from sklearn.utils import gen_batches @@ -24,7 +22,6 @@ from sklearn.utils import safe_mask from sklearn.utils import column_or_1d from sklearn.utils import _safe_indexing -from sklearn.utils import _safe_tags from sklearn.utils import shuffle from sklearn.utils import gen_even_slices from sklearn.utils import _message_with_time, _print_elapsed_time @@ -696,63 +693,3 @@ def test_to_object_array(sequence): assert isinstance(out, np.ndarray) assert out.dtype.kind == 'O' assert out.ndim == 1 - - -class EstimatorNoTags: - pass - - -class EstimatorOwnTags: - def _get_tags(self): - return {} - - -@pytest.mark.parametrize( - "estimator, err_msg", - [ - (BaseEstimator(), "The key xxx is neither defined"), - (EstimatorNoTags(), "The key xxx is not a default tags defined"), - ], -) -def test_safe_tags_error(estimator, err_msg): - # Check that safe_tags raises error in ambiguous case. - with pytest.raises(ValueError, match=err_msg): - _safe_tags(estimator, key="xxx") - - -@pytest.mark.parametrize( - "estimator, key, default, expected_results", - [ - (EstimatorNoTags(), None, None, _DEFAULT_TAGS), - (EstimatorNoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - ( - EstimatorNoTags(), - "allow_nan", - not _DEFAULT_TAGS["allow_nan"], - not _DEFAULT_TAGS["allow_nan"], - ), - (EstimatorNoTags(), "xxx", True, True), - (BaseEstimator(), None, None, _DEFAULT_TAGS), - (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - ( - BaseEstimator(), - "allow_nan", - not _DEFAULT_TAGS["allow_nan"], - _DEFAULT_TAGS["allow_nan"], - ), - (BaseEstimator(), "xxx", True, True), - (EstimatorOwnTags(), None, None, _DEFAULT_TAGS), - (EstimatorOwnTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - ( - EstimatorOwnTags(), - "allow_nan", - not _DEFAULT_TAGS["allow_nan"], - not _DEFAULT_TAGS["allow_nan"], - ), - (EstimatorOwnTags(), "xxx", True, True), - ], -) -def test_safe_tags_no_get_tags(estimator, key, default, expected_results): - # check the behaviour of _safe_tags when an estimator does not implement - # _get_tags - assert _safe_tags(estimator, key=key, default=default) == expected_results From 04b00187caf2b071f1d7021dd8c7f1af0f31fb1b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 26 Nov 2020 22:33:15 +0100 Subject: [PATCH 28/50] TST/DOC force estimator to have default tags when implementing _get_tgas --- doc/developers/develop.rst | 14 +++++--- sklearn/feature_selection/tests/test_rfe.py | 3 -- sklearn/utils/_tags.py | 36 ++++++++------------- sklearn/utils/estimator_checks.py | 18 ++++++++++- sklearn/utils/tests/test_tags.py | 10 ------ 5 files changed, 40 insertions(+), 41 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index b7b5d2ac0316f..20d5862b2d251 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -512,11 +512,15 @@ of estimators that allow programmatic inspection of their capabilities, such as sparse matrix support, supported output types and supported methods. The estimator tags are a dictionary returned by the method ``_get_tags()``. These tags are used by the common tests and the -:func:`sklearn.utils.estimator_checks.check_estimator` function to decide what -tests to run and what input data is appropriate. Tags can depend on estimator -parameters or even system architecture and can in general only be determined at -runtime. The default values for the estimator tags are defined in the -``BaseEstimator`` class. +:func:`~sklearn.utils.estimator_checks.check_estimator` function and +:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator to +decide what tests to run and what input data is appropriate. Tags can depend on +estimator parameters or even system architecture and can in general only be +determined at runtime. The default values for the estimator tags are defined in +the :class:`~sklearn.base.BaseEstimator` class. If rolling your own estimator +without inheriting from :class:`~sklearn.base.BaseEstimator`, you will need to +implement the `_get_tags()` function. Besides, this function needs to at least +return the tags defined below. The current set of estimator tags are: diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index d0511b1b5bfb6..553c709f8983a 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -56,9 +56,6 @@ def get_params(self, deep=True): def set_params(self, **params): return self - def _get_tags(self): - return {} - def test_rfe_features_importance(): generator = check_random_state(0) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index efbbb15ffae72..639ef5d455f48 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -38,11 +38,11 @@ def _safe_tags(estimator, key=None, default=None): Tag name to get. By default (`None`), all tags are returned. default : list of {str, dtype} or bool, default=None - `default` allows to define the `default` value of a tag if it is not - present in `_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` - if it the tag is defined. When `default is None`, no default values nor - overwriting will take place. If `default is not None` but that the - tag is defined, `default` will be discarded. + When `esimator.get_tags()` is not implemented, default` allows to + define the default value of a tag if it is not present in + `_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` if it the + tag is defined. When `default is None`, no default values nor + overwriting will take place. Returns ------- @@ -51,24 +51,16 @@ def _safe_tags(estimator, key=None, default=None): """ if hasattr(estimator, "_get_tags"): if key is not None: - if default is None: - try: - return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) - except KeyError as exc: - raise ValueError( - f"The key {key} is neither defined in _more_tags() in " - f"the class {repr(estimator)} nor a default estimator " - f"key in _DEFAULT_TAGS. Use the parameter default if " - f"you want to define a default value." - ) from exc - else: - return estimator._get_tags().get(key, default) + try: + return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) + except KeyError as exc: + raise ValueError( + f"The key {key} is neither defined in _more_tags() in " + f"the class {repr(estimator)} nor a default tag key in " + f"_DEFAULT_TAGS." + ) from exc else: - tags = estimator._get_tags() - return { - key: tags.get(key, _DEFAULT_TAGS[key]) - for key in _DEFAULT_TAGS.keys() - } + return {**_DEFAULT_TAGS, **estimator._get_tags()} else: if key is not None: if default is None: diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index eeab54054c57e..f6eb914c3873f 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -52,7 +52,10 @@ from ..metrics.pairwise import (rbf_kernel, linear_kernel, pairwise_distances) from .import shuffle -from ._tags import _safe_tags +from ._tags import ( + _DEFAULT_TAGS, + _safe_tags, +) from .validation import has_fit_parameter, _num_samples from ..preprocessing import StandardScaler from ..preprocessing import scale @@ -115,6 +118,7 @@ def _yield_checks(estimator): # give the same answer as before. yield check_estimators_pickle + yield check_estimator_get_tags_default_keys def _yield_classifier_checks(classifier): tags = _safe_tags(classifier) @@ -3151,3 +3155,15 @@ def check_n_features_in_after_fitting(name, estimator_orig): with raises(ValueError, match=msg): estimator.partial_fit(X_bad, y) + + +def check_estimator_get_tags_default_keys(name, estimator_orig): + # check that if _get_tags is implemented, it contains all keys from + # _DEFAULT_KEYS + estimator = clone(estimator_orig) + if not hasattr(estimator, "_get_tags"): + return + + tags_keys = set(estimator._get_tags().keys()) + default_tags_keys = set(_DEFAULT_TAGS.keys()) + assert tags_keys.intersection(default_tags_keys) == default_tags_keys diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index 638ede4e39cbe..fa802ca77f75a 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -49,16 +49,6 @@ def test_safe_tags_error(estimator, err_msg): not _DEFAULT_TAGS["allow_nan"], _DEFAULT_TAGS["allow_nan"], ), - (BaseEstimator(), "xxx", True, True), - (EstimatorOwnTags(), None, None, _DEFAULT_TAGS), - (EstimatorOwnTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - ( - EstimatorOwnTags(), - "allow_nan", - not _DEFAULT_TAGS["allow_nan"], - not _DEFAULT_TAGS["allow_nan"], - ), - (EstimatorOwnTags(), "xxx", True, True), ], ) def test_safe_tags_no_get_tags(estimator, key, default, expected_results): From 18368eac5a7c5959a06f7a82c71a715caa8bc32e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 27 Nov 2020 15:00:48 +0100 Subject: [PATCH 29/50] update documentation --- doc/developers/develop.rst | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 20d5862b2d251..2e51519c609f2 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -517,10 +517,19 @@ tags are used by the common tests and the decide what tests to run and what input data is appropriate. Tags can depend on estimator parameters or even system architecture and can in general only be determined at runtime. The default values for the estimator tags are defined in -the :class:`~sklearn.base.BaseEstimator` class. If rolling your own estimator -without inheriting from :class:`~sklearn.base.BaseEstimator`, you will need to -implement the `_get_tags()` function. Besides, this function needs to at least -return the tags defined below. +the :class:`~sklearn.base.BaseEstimator` class. + +When running integration checks, both +:func:`~sklearn.utils.estimator_checks.check_estimator` function and +:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator +are retrieving necessary tag values from the following manner: + +* if your estimator inherits from :class:`~sklearn.base.BaseEstimator`, + it will use `_get_tags`. The tags values will corresponds to either the + default values or the values specified in `_more_tags()` that overwrite the + defaults; +* if your estimator does not inherit from :class:`~sklearn.base.BaseEstimator`, + the tag values will be set to the defaults. The current set of estimator tags are: @@ -623,8 +632,9 @@ X_types (default=['2darray']) of the ``'sparse'`` tag. -To override the tags of a child class, one must define the `_more_tags()` -method and return a dict with the desired tags, e.g:: +When inheriting from :class:`~sklearn.base.BaseEstimator`, the tags of a child +class can be overridden, by defining the `_more_tags()` method and return a +dict with the desired tags, e.g:: class MyMultiOutputEstimator(BaseEstimator): @@ -632,6 +642,11 @@ method and return a dict with the desired tags, e.g:: return {'multioutput_only': True, 'non_deterministic': True} +If rolling your own estimator without inheriting from +:class:`~sklearn.base.BaseEstimator`, you will need to implement the +`_get_tags()` method. Besides, this function needs to at least return the +tags defined above. + In addition to the tags, estimators also need to declare any non-optional parameters to ``__init__`` in the ``_required_parameters`` class attribute, which is a list or tuple. If ``_required_parameters`` is only From 2e08d2a0884bc95218490067fcf92e49f4b142e2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 27 Nov 2020 19:45:51 +0000 Subject: [PATCH 30/50] slight rework of developers docs --- doc/developers/develop.rst | 41 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 2e51519c609f2..983d9fa9e67d1 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -511,26 +511,15 @@ Scikit-learn introduced estimator tags in version 0.21. These are annotations of estimators that allow programmatic inspection of their capabilities, such as sparse matrix support, supported output types and supported methods. The estimator tags are a dictionary returned by the method ``_get_tags()``. These -tags are used by the common tests and the -:func:`~sklearn.utils.estimator_checks.check_estimator` function and -:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator to -decide what tests to run and what input data is appropriate. Tags can depend on -estimator parameters or even system architecture and can in general only be +tags are used in the common checks run by the +:func:`~sklearn.utils.estimator_checks.check_estimator` function and the +:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator. +Tags determine which checks to run and what input data is appropriate. Tags +can depend on estimator parameters or even system architecture and can in +general only be determined at runtime. The default values for the estimator tags are defined in the :class:`~sklearn.base.BaseEstimator` class. -When running integration checks, both -:func:`~sklearn.utils.estimator_checks.check_estimator` function and -:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator -are retrieving necessary tag values from the following manner: - -* if your estimator inherits from :class:`~sklearn.base.BaseEstimator`, - it will use `_get_tags`. The tags values will corresponds to either the - default values or the values specified in `_more_tags()` that overwrite the - defaults; -* if your estimator does not inherit from :class:`~sklearn.base.BaseEstimator`, - the tag values will be set to the defaults. - The current set of estimator tags are: allow_nan (default=False) @@ -631,16 +620,28 @@ X_types (default=['2darray']) ``'categorical'`` data. For now, the test for sparse data do not make use of the ``'sparse'`` tag. +It is unlikely that the default values for each tag will suit the needs of +your specific estimator. There are two ways to override the defaults in your +own estimator: -When inheriting from :class:`~sklearn.base.BaseEstimator`, the tags of a child -class can be overridden, by defining the `_more_tags()` method and return a -dict with the desired tags, e.g:: +* If your estimator inherits from :class:`~sklearn.base.BaseEstimator`, which + is recommended, you can define a `_more_tags()` method which returns a dict + with the desired overridden tags. For example:: class MyMultiOutputEstimator(BaseEstimator): def _more_tags(self): return {'multioutput_only': True, 'non_deterministic': True} + + Any tag that is not in `_more_tags()` will just default to the value + documented above. + +* If your estimator does not inherit from :class:`~sklearn.base.BaseEstimator`, + you will need to implement a `_get_tags()` method which returns a dict, + similar to `_more_tags()`. Note however that **all tags must be present in + the dict**. If any of the keys documented above is not present in the + output of `_get_tags()`, an error might occur. If rolling your own estimator without inheriting from :class:`~sklearn.base.BaseEstimator`, you will need to implement the From 4aeae955a9290938931f23f4b9156fcbd5350038 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 27 Nov 2020 19:48:04 +0000 Subject: [PATCH 31/50] didn't remove some stuff --- doc/developers/develop.rst | 5 ----- 1 file changed, 5 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 983d9fa9e67d1..5adde4ba610c3 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -643,11 +643,6 @@ own estimator: the dict**. If any of the keys documented above is not present in the output of `_get_tags()`, an error might occur. -If rolling your own estimator without inheriting from -:class:`~sklearn.base.BaseEstimator`, you will need to implement the -`_get_tags()` method. Besides, this function needs to at least return the -tags defined above. - In addition to the tags, estimators also need to declare any non-optional parameters to ``__init__`` in the ``_required_parameters`` class attribute, which is a list or tuple. If ``_required_parameters`` is only From 15410b3a737dc52bb062e2961480c0f2b109ff7e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 10:51:01 +0100 Subject: [PATCH 32/50] move test around --- doc/developers/develop.rst | 6 +- sklearn/model_selection/tests/test_search.py | 79 +++++++- sklearn/tests/test_common.py | 190 ------------------- sklearn/tests/test_pipeline.py | 44 ++++- sklearn/utils/_testing.py | 167 ++++++++++++++++ sklearn/utils/tests/test_estimator_checks.py | 28 ++- 6 files changed, 299 insertions(+), 215 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 5adde4ba610c3..0e83d27eab3d7 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -516,9 +516,7 @@ tags are used in the common checks run by the :func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator. Tags determine which checks to run and what input data is appropriate. Tags can depend on estimator parameters or even system architecture and can in -general only be -determined at runtime. The default values for the estimator tags are defined in -the :class:`~sklearn.base.BaseEstimator` class. +general only be determined at runtime. The current set of estimator tags are: @@ -633,7 +631,7 @@ own estimator: def _more_tags(self): return {'multioutput_only': True, 'non_deterministic': True} - + Any tag that is not in `_more_tags()` will just default to the value documented above. diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 308d927911eaf..7ab57149775d2 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -13,15 +13,21 @@ import scipy.sparse as sp import pytest -from sklearn.utils._testing import assert_raises -from sklearn.utils._testing import assert_warns -from sklearn.utils._testing import assert_warns_message -from sklearn.utils._testing import assert_raise_message -from sklearn.utils._testing import assert_array_equal -from sklearn.utils._testing import assert_array_almost_equal -from sklearn.utils._testing import assert_allclose -from sklearn.utils._testing import assert_almost_equal -from sklearn.utils._testing import ignore_warnings +from sklearn.utils.estimator_checks import parametrize_with_checks +from sklearn.utils._testing import ( + assert_raises, + assert_warns, + assert_warns_message, + assert_raise_message, + assert_array_equal, + assert_array_almost_equal, + assert_allclose, + assert_almost_equal, + ignore_warnings, + MinimalClassifier, + MinimalRegressor, + MinimalTransformer, +) from sklearn.utils._mocking import CheckingClassifier, MockDataFrame from scipy.stats import bernoulli, expon, uniform @@ -65,7 +71,7 @@ from sklearn.metrics import confusion_matrix from sklearn.metrics.pairwise import euclidean_distances from sklearn.impute import SimpleImputer -from sklearn.pipeline import Pipeline +from sklearn.pipeline import Pipeline, make_pipeline from sklearn.linear_model import Ridge, SGDClassifier, LinearRegression from sklearn.experimental import enable_hist_gradient_boosting # noqa from sklearn.ensemble import HistGradientBoostingClassifier @@ -2079,3 +2085,56 @@ def _fit_param_callable(): 'scalar_param': 42, } model.fit(X_train, y_train, **fit_params) + + +def _generate_search_cv_using_minimal_compatible_instances(): + """Generate instance containing estimators from minimal class compatible + implementation that should be supported by `SearhCV`.""" + for SearchCV, (Estimator, param_grid) in zip( + [GridSearchCV, RandomizedSearchCV], + [ + (MinimalRegressor, {"param": [1, 10]}), + (MinimalClassifier, {"param": [1, 10]}), + ], + ): + yield SearchCV(Estimator(), param_grid) + + for SearchCV, (Estimator, param_grid) in zip( + [GridSearchCV, RandomizedSearchCV], + [ + ( + MinimalRegressor, + { + "minimaltransformer__param": [1, 10], + "minimalregressor__param": [1, 10], + }, + ), + ( + MinimalClassifier, + { + "minimaltransformer__param": [1, 10], + "minimalclassifier__param": [1, 10], + }, + ), + ], + ): + yield SearchCV( + make_pipeline(MinimalTransformer(), Estimator()), param_grid + ).set_params(error_score="raise") + + +# FIXME: hopefully in 0.25 +@pytest.mark.xfail( + reason=( + "This test is currently failing because checks are granular " + "enough. Once checks are split with some kind of only API tests, " + "this test should enabled." + ) +) +@parametrize_with_checks( + list(_generate_search_cv_using_minimal_compatible_instances()) +) +def test_search_cv_using_minimal_compatible_estimator(estimator, check): + # Check that third-party library can run tests without inheriting from + # BaseEstimator. + check(estimator) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index adefc106edce8..730f1135b833a 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -14,7 +14,6 @@ from inspect import isgenerator from functools import partial -import numpy as np import pytest from sklearn.utils import all_estimators @@ -43,12 +42,6 @@ parametrize_with_checks, check_n_features_in_after_fitting, ) -from sklearn.multiclass import check_classification_targets -from sklearn.utils.validation import ( - check_array, - check_is_fitted, - check_X_y, -) def test_all_estimator_no_base_class(): @@ -304,186 +297,3 @@ def test_search_cv(estimator, check, request): def test_check_n_features_in_after_fitting(estimator): _set_checking_parameters(estimator) check_n_features_in_after_fitting(estimator.__class__.__name__, estimator) - - -class MinimalClassifier: - """Minimal classifier implementation with inheriting from BaseEstimator.""" - _estimator_type = "classifier" - - def __init__(self, param=None): - self.param = param - - def __repr__(self): - # Only required when using pytest-xdist to get an id not associated - # with the memory location - return self.__class__.__name__ - - def get_params(self, deep=True): - return {"param": self.param} - - def set_params(self, **params): - for key, value in params.items(): - setattr(self, key, value) - return self - - def __getstate__(self): - return self.__dict__.copy() - - def __setstate__(self, state): - self.__dict__.update(state) - - def fit(self, X, y): - X, y = check_X_y(X, y) - check_classification_targets(y) - self.n_features_in_ = X.shape[1] - self.classes_, counts = np.unique(y, return_counts=True) - self._most_frequent_class_idx = counts.argmax() - return self - - def predict_proba(self, X): - check_is_fitted(self) - X = check_array(X) - if X.shape[1] != self.n_features_in_: - raise ValueError - proba_shape = (X.shape[0], self.classes_.size) - y_proba = np.zeros(shape=proba_shape, dtype=np.float64) - y_proba[:, self._most_frequent_class_idx] = 1.0 - return y_proba - - def predict(self, X): - y_proba = self.predict_proba(X) - y_pred = y_proba.argmax(axis=1) - return self.classes_[y_pred] - - def score(self, X, y): - from sklearn.metrics import accuracy_score - return accuracy_score(y, self.predict(X)) - - -class MinimalRegressor: - """Minimal regressor implementation with inheriting from BaseEstimator.""" - _estimator_type = "regressor" - - def __init__(self, param=None): - self.param = param - - def __repr__(self): - # Only required when using pytest-xdist to get an id not associated - # with the memory location - return self.__class__.__name__ - - def get_params(self, deep=True): - return {"param": self.param} - - def set_params(self, **params): - for key, value in params.items(): - setattr(self, key, value) - return self - - def __getstate__(self): - return self.__dict__.copy() - - def __setstate__(self, state): - self.__dict__.update(state) - - def fit(self, X, y): - X, y = check_X_y(X, y) - self.n_features_in_ = X.shape[1] - self._mean = np.mean(y) - return self - - def predict(self, X): - check_is_fitted(self) - X = check_array(X) - if X.shape[1] != self.n_features_in_: - raise ValueError - return np.ones(shape=(X.shape[0],)) * self._mean - - def score(self, X, y): - from sklearn.metrics import r2_score - return r2_score(y, self.predict(X)) - - -class MinimalTransformer: - """Minimal transformer implementation with inheriting from - BaseEstimator.""" - - def __init__(self, param=None): - self.param = param - - def __repr__(self): - # Only required when using pytest-xdist to get an id not associated - # with the memory location - return self.__class__.__name__ - - def get_params(self, deep=True): - return {"param": self.param} - - def set_params(self, **params): - for key, value in params.items(): - setattr(self, key, value) - return self - - def __getstate__(self): - return self.__dict__.copy() - - def __setstate__(self, state): - self.__dict__.update(state) - - def fit(self, X, y=None): - X = check_array(X) - self.n_features_in_ = X.shape[1] - return self - - def transform(self, X, y=None): - check_is_fitted(self) - X = check_array(X) - if X.shape[1] != self.n_features_in_: - raise ValueError - return X - - def fit_transform(self, X, y=None): - return self.fit(X, y).transform(X, y) - - -def _generate_minimal_compatible_instances(): - """Generate instance containing estimators from minimal class compatible - implementation.""" - for SearchCV, (Estimator, param_grid) in zip( - [GridSearchCV, RandomizedSearchCV], - [ - (MinimalRegressor, {"param": [1, 10]}), - (MinimalClassifier, {"param": [1, 10]}), - ], - ): - yield SearchCV(Estimator(), param_grid) - - for SearchCV, (Estimator, param_grid) in zip( - [GridSearchCV, RandomizedSearchCV], - [ - (MinimalRegressor, { - "minimaltransformer__param": [1, 10], - "minimalregressor__param": [1, 10], - }), - (MinimalClassifier, { - "minimaltransformer__param": [1, 10], - "minimalclassifier__param": [1, 10], - }), - ], - ): - yield SearchCV( - make_pipeline(MinimalTransformer(), Estimator()), param_grid - ).set_params(error_score="raise") - - -# FIXME: hopefully in 0.25 -@pytest.mark.xfail( - reason=("This test is currently failing because checks are granular " - "enough. Once checks are split with some kind of only API tests, " - "this test should enabled.") -) -@parametrize_with_checks(list(_generate_minimal_compatible_instances())) -def test_minimal_class_implementation_checks(estimator, check): - # Check that third-party library can run tests without inheriting from - # BaseEstimator. - check(estimator) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index d9df891d90128..82ea76894e432 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -12,14 +12,20 @@ from scipy import sparse import joblib -from sklearn.utils._testing import assert_raises -from sklearn.utils._testing import assert_raises_regex -from sklearn.utils._testing import assert_raise_message -from sklearn.utils._testing import assert_allclose -from sklearn.utils._testing import assert_array_equal -from sklearn.utils._testing import assert_array_almost_equal -from sklearn.utils._testing import assert_no_warnings +from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.fixes import parse_version +from sklearn.utils._testing import ( + assert_raises, + assert_raises_regex, + assert_raise_message, + assert_allclose, + assert_array_equal, + assert_array_almost_equal, + assert_no_warnings, + MinimalClassifier, + MinimalRegressor, + MinimalTransformer, +) from sklearn.base import clone, BaseEstimator, TransformerMixin from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union @@ -1272,3 +1278,27 @@ def test_pipeline_get_tags_none(passthrough): 'passthrough'""" pipe = make_pipeline(passthrough, SVC()) assert not pipe._get_tags()['pairwise'] + + +def _generate_pipeline_using_minimal_compatible_instances(): + """Generate pipeline containing estimators from minimal class compatible + implementation.""" + for Predictor in [MinimalRegressor, MinimalClassifier]: + yield make_pipeline(MinimalTransformer(), Predictor()) + + +# FIXME: hopefully in 0.25 +@pytest.mark.xfail( + reason=( + "This test is currently failing because checks are granular " + "enough. Once checks are split with some kind of only API tests, " + "this test should enabled." + ) +) +@parametrize_with_checks( + list(_generate_pipeline_using_minimal_compatible_instances()) +) +def test_pipeline_using_minimal_compatible_estimator(estimator, check): + # check that we can pass minimal estimator implementation within a + # pipeline. + check(estimator) diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 0f340967b0cec..ccace45277ce2 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -49,6 +49,12 @@ import sklearn from sklearn.utils import IS_PYPY, _IS_32BIT +from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.validation import ( + check_array, + check_is_fitted, + check_X_y, +) __all__ = ["assert_raises", @@ -866,3 +872,164 @@ def __exit__(self, exc_type, exc_value, _): self.raised_and_matched = True return True + + +class MinimalClassifier: + """Minimal classifier implementation with inheriting from BaseEstimator. + + This estimator should be tested with: + + * `check_estimator` in `test_estimator_checks.py`; + * within a `Pipeline` in `test_pipeline.py`; + * within a `SearchCV` in `test_search.py`. + """ + _estimator_type = "classifier" + + def __init__(self, param=None): + self.param = param + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, deep=True): + return {"param": self.param} + + def set_params(self, **params): + for key, value in params.items(): + setattr(self, key, value) + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y): + X, y = check_X_y(X, y) + check_classification_targets(y) + self.n_features_in_ = X.shape[1] + self.classes_, counts = np.unique(y, return_counts=True) + self._most_frequent_class_idx = counts.argmax() + return self + + def predict_proba(self, X): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + proba_shape = (X.shape[0], self.classes_.size) + y_proba = np.zeros(shape=proba_shape, dtype=np.float64) + y_proba[:, self._most_frequent_class_idx] = 1.0 + return y_proba + + def predict(self, X): + y_proba = self.predict_proba(X) + y_pred = y_proba.argmax(axis=1) + return self.classes_[y_pred] + + def score(self, X, y): + from sklearn.metrics import accuracy_score + return accuracy_score(y, self.predict(X)) + + +class MinimalRegressor: + """Minimal regressor implementation with inheriting from BaseEstimator. + + This estimator should be tested with: + + * `check_estimator` in `test_estimator_checks.py`; + * within a `Pipeline` in `test_pipeline.py`; + * within a `SearchCV` in `test_search.py`. + """ + _estimator_type = "regressor" + + def __init__(self, param=None): + self.param = param + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, deep=True): + return {"param": self.param} + + def set_params(self, **params): + for key, value in params.items(): + setattr(self, key, value) + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y): + X, y = check_X_y(X, y) + self.n_features_in_ = X.shape[1] + self._mean = np.mean(y) + return self + + def predict(self, X): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + return np.ones(shape=(X.shape[0],)) * self._mean + + def score(self, X, y): + from sklearn.metrics import r2_score + return r2_score(y, self.predict(X)) + + +class MinimalTransformer: + """Minimal transformer implementation with inheriting from + BaseEstimator. + + This estimator should be tested with: + + * `check_estimator` in `test_estimator_checks.py`; + * within a `Pipeline` in `test_pipeline.py`; + * within a `SearchCV` in `test_search.py`. + """ + + def __init__(self, param=None): + self.param = param + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location + return self.__class__.__name__ + + def get_params(self, deep=True): + return {"param": self.param} + + def set_params(self, **params): + for key, value in params.items(): + setattr(self, key, value) + return self + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + def fit(self, X, y=None): + X = check_array(X) + self.n_features_in_ = X.shape[1] + return self + + def transform(self, X, y=None): + check_is_fitted(self) + X = check_array(X) + if X.shape[1] != self.n_features_in_: + raise ValueError + return X + + def fit_transform(self, X, y=None): + return self.fit(X, y).transform(X, y) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 0ca547ee8299d..c67d60bc10b90 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -7,10 +7,16 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils import deprecated -from sklearn.utils._testing import (assert_raises_regex, - ignore_warnings, - assert_warns, assert_raises, - SkipTest) +from sklearn.utils._testing import ( + assert_raises, + assert_raises_regex, + assert_warns, + ignore_warnings, + MinimalClassifier, + MinimalRegressor, + MinimalTransformer, + SkipTest, +) from sklearn.utils.estimator_checks import check_estimator, _NotAnArray from sklearn.utils.estimator_checks \ import check_class_weight_balanced_linear_classifier @@ -675,3 +681,17 @@ def test_xfail_ignored_in_check_estimator(): # Make sure checks marked as xfail are just ignored and not run by # check_estimator(), but still raise a warning. assert_warns(SkipTestWarning, check_estimator, NuSVC()) + + +# FIXME: this test should be uncommented when the checks will be enough +# granular. In 0.24, these tests fail due to low estimator performance. +def test_minimal_class_implementation_checks(estimator, check): + # Check that third-party library can run tests without inheriting from + # BaseEstimator. + # FIXME + raise SkipTest + minimal_estimators = [ + MinimalTransformer(), MinimalRegressor(), MinimalClassifier() + ] + for estimator in minimal_estimators: + check_estimator(estimator) From 7eaab7163e2ca4007e872d3c3579e4d1087ce631 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 11:01:21 +0100 Subject: [PATCH 33/50] first pass on Nicolas comments --- sklearn/tests/test_pipeline.py | 6 ++++-- sklearn/utils/_tags.py | 17 +++++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 82ea76894e432..dc00ba04ec1f6 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1274,8 +1274,10 @@ def test_feature_union_warns_unknown_transformer_weight(): @pytest.mark.parametrize('passthrough', [None, 'passthrough']) def test_pipeline_get_tags_none(passthrough): - """Checks that tags are set correctly when passthrough is None or - 'passthrough'""" + # Checks that tags are set correctly when the first transformer is None or + # 'passthrough' + # Non-regression test for: + # https://github.com/scikit-learn/scikit-learn/issues/18815 pipe = make_pipeline(passthrough, SVC()) assert not pipe._get_tags()['pairwise'] diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 639ef5d455f48..8d2b15964bd0b 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -23,11 +23,16 @@ def _safe_tags(estimator, key=None, default=None): - """Safely get estimator tags for common checks. + """Safely get estimator tags. :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. - However, if a compatible estimator does not inherit from this base class, - we should default to the default tag. + However, if an estimator does not inherit from this base class, we should + default to the default tag. + + For scikit-learn built-in estimators, we should still rely on + `self._get_tags()`. `_safe_tags(est)` should be used when we are not sure + where `est` comes from: typically `_safe_tags(self.base_estimator)` where + `self` is a meta-estimator, or in the common checks. Parameters ---------- @@ -46,8 +51,8 @@ def _safe_tags(estimator, key=None, default=None): Returns ------- - tags : dict - The estimator tags. + tags : dict or tag value + The estimator tags. A single value is returned if `key` is not None. """ if hasattr(estimator, "_get_tags"): if key is not None: @@ -60,7 +65,7 @@ def _safe_tags(estimator, key=None, default=None): f"_DEFAULT_TAGS." ) from exc else: - return {**_DEFAULT_TAGS, **estimator._get_tags()} + return estimator._get_tags() else: if key is not None: if default is None: From 93dc09928d7b5818c0fab4af4a453502362e5294 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 11:28:25 +0100 Subject: [PATCH 34/50] add test for check --- sklearn/utils/_tags.py | 2 +- sklearn/utils/_testing.py | 18 ---------------- sklearn/utils/estimator_checks.py | 9 ++++++-- sklearn/utils/tests/test_estimator_checks.py | 22 ++++++++++++++++++++ 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 8d2b15964bd0b..37753123009c4 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -27,7 +27,7 @@ def _safe_tags(estimator, key=None, default=None): :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. However, if an estimator does not inherit from this base class, we should - default to the default tag. + fall-back to the default tags. For scikit-learn built-in estimators, we should still rely on `self._get_tags()`. `_safe_tags(est)` should be used when we are not sure diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index ccace45277ce2..4e29f063b22e0 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -901,12 +901,6 @@ def set_params(self, **params): setattr(self, key, value) return self - def __getstate__(self): - return self.__dict__.copy() - - def __setstate__(self, state): - self.__dict__.update(state) - def fit(self, X, y): X, y = check_X_y(X, y) check_classification_targets(y) @@ -962,12 +956,6 @@ def set_params(self, **params): setattr(self, key, value) return self - def __getstate__(self): - return self.__dict__.copy() - - def __setstate__(self, state): - self.__dict__.update(state) - def fit(self, X, y): X, y = check_X_y(X, y) self.n_features_in_ = X.shape[1] @@ -1013,12 +1001,6 @@ def set_params(self, **params): setattr(self, key, value) return self - def __getstate__(self): - return self.__dict__.copy() - - def __setstate__(self, state): - self.__dict__.update(state) - def fit(self, X, y=None): X = check_array(X) self.n_features_in_ = X.shape[1] diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index f6eb914c3873f..7925487bad522 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1614,7 +1614,9 @@ def check_estimators_pickle(name, estimator_orig): # pickle and unpickle! pickled_estimator = pickle.dumps(estimator) module_name = estimator.__module__ - if module_name.startswith('sklearn.') and "test_" not in module_name: + if module_name.startswith('sklearn.') and not ( + "test_" in module_name or module_name.endswith("_testing") + ): # strict check for sklearn estimators that are not implemented in test # modules. assert b"version" in pickled_estimator @@ -3166,4 +3168,7 @@ def check_estimator_get_tags_default_keys(name, estimator_orig): tags_keys = set(estimator._get_tags().keys()) default_tags_keys = set(_DEFAULT_TAGS.keys()) - assert tags_keys.intersection(default_tags_keys) == default_tags_keys + assert tags_keys.intersection(default_tags_keys) == default_tags_keys, ( + f"{name}._get_tags() is missing entries for the following default tags" + f": {default_tags_keys - tags_keys.intersection(default_tags_keys)}" + ) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index c67d60bc10b90..25f40d0b546bf 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -27,6 +27,8 @@ from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.utils.estimator_checks import check_classifier_data_not_an_array from sklearn.utils.estimator_checks import check_regressor_data_not_an_array +from sklearn.utils.estimator_checks import \ + check_estimator_get_tags_default_keys from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption from sklearn.utils.fixes import np_version, parse_version @@ -374,6 +376,13 @@ def _more_tags(self): return {'binary_only': True} +class EstimatorMissingDefaultTags(BaseEstimator): + def _get_tags(self): + tags = super()._get_tags().copy() + del tags["allow_nan"] + return tags + + class RequiresPositiveYRegressor(LinearRegression): def fit(self, X, y): @@ -640,6 +649,19 @@ def test_check_regressor_data_not_an_array(): EstimatorInconsistentForPandas()) +def test_check_estimator_get_tags_default_keys(): + estimator = EstimatorMissingDefaultTags() + err_msg = (r"EstimatorMissingDefaultTags._get_tags\(\) is missing entries" + r" for the following default tags: {'allow_nan'}") + assert_raises_regex( + AssertionError, + err_msg, + check_estimator_get_tags_default_keys, + estimator.__class__.__name__, + estimator, + ) + + def run_tests_without_pytest(): """Runs the tests in this file without using pytest. """ From d53adca4c674fbc92f87ba2b31bf4bf2cbebbf4d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 11:31:55 +0100 Subject: [PATCH 35/50] doc --- doc/developers/develop.rst | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 0e83d27eab3d7..b3a0d0a96df38 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -632,14 +632,16 @@ own estimator: return {'multioutput_only': True, 'non_deterministic': True} - Any tag that is not in `_more_tags()` will just default to the value - documented above. + Any tag that is not in `_more_tags()` will just fall-back to the default + values documented above. * If your estimator does not inherit from :class:`~sklearn.base.BaseEstimator`, - you will need to implement a `_get_tags()` method which returns a dict, - similar to `_more_tags()`. Note however that **all tags must be present in - the dict**. If any of the keys documented above is not present in the - output of `_get_tags()`, an error might occur. + you will need to implement a `_get_tags()` method which returns a dict that + should contains all the necessary tags for that estimator, including the + default tags typically defined in :class:`~sklearn.base.BaseEstimator` and + other scikit-learn mixin classes. Note however that **all tags must be + present in the dict**. If any of the keys documented above is not present in + the output of `_get_tags()`, an error might occur. In addition to the tags, estimators also need to declare any non-optional parameters to ``__init__`` in the ``_required_parameters`` class attribute, From 3b671d8d2097e95b30b500ab8daaec4caa2c4215 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 11:35:30 +0100 Subject: [PATCH 36/50] less diff --- sklearn/tests/test_docstring_parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index 8f09f2baf902f..8d8399f0cf4da 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -227,9 +227,9 @@ def test_fit_docstring_attributes(name, Estimator): y = _enforce_estimator_tags_y(est, y) X = _enforce_estimator_tags_x(est, X) - if "1dlabels" in est._get_tags()["X_types"]: + if '1dlabels' in est._get_tags()['X_types']: est.fit(y) - elif "2dlabels" in est._get_tags()["X_types"]: + elif '2dlabels' in est._get_tags()['X_types']: est.fit(np.c_[y, y]) else: est.fit(X, y) From 7130ff699db9cab551f434f219a8f9a85b0c4a35 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 11:54:55 +0100 Subject: [PATCH 37/50] remove useless parametre --- sklearn/utils/tests/test_estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 25f40d0b546bf..f20e375abcdb8 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -707,7 +707,7 @@ def test_xfail_ignored_in_check_estimator(): # FIXME: this test should be uncommented when the checks will be enough # granular. In 0.24, these tests fail due to low estimator performance. -def test_minimal_class_implementation_checks(estimator, check): +def test_minimal_class_implementation_checks(): # Check that third-party library can run tests without inheriting from # BaseEstimator. # FIXME From c9f6af415e0b1079ad77f30e5460af43bde06c60 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 13:44:23 +0100 Subject: [PATCH 38/50] add comment --- sklearn/feature_selection/_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/feature_selection/_base.py b/sklearn/feature_selection/_base.py index 705a7bcc15515..8993312bd2cec 100644 --- a/sklearn/feature_selection/_base.py +++ b/sklearn/feature_selection/_base.py @@ -77,6 +77,8 @@ def transform(self, X): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ + # note: we use _safe_tags instead of _get_tags because this is a + # public Mixin. force_all_finite = not _safe_tags(self, key="allow_nan", default=True) X = check_array( X, From 642c3053cba70ac8686bf5c79dd59432a921135d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 18:59:48 +0100 Subject: [PATCH 39/50] Update sklearn/utils/tests/test_estimator_checks.py Co-authored-by: Olivier Grisel --- sklearn/utils/tests/test_estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index f20e375abcdb8..34b35f4670915 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -705,8 +705,8 @@ def test_xfail_ignored_in_check_estimator(): assert_warns(SkipTestWarning, check_estimator, NuSVC()) -# FIXME: this test should be uncommented when the checks will be enough -# granular. In 0.24, these tests fail due to low estimator performance. +# FIXME: this test should be uncommented when the checks will be granular +# enough. In 0.24, these tests fail due to low estimator performance. def test_minimal_class_implementation_checks(): # Check that third-party library can run tests without inheriting from # BaseEstimator. From 1fd2e5700c036dff04e064ae2e8304e96e942c11 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 20:28:54 +0100 Subject: [PATCH 40/50] iter --- sklearn/model_selection/tests/test_search.py | 84 ++++++++------------ sklearn/tests/test_pipeline.py | 40 ++++++---- sklearn/utils/tests/test_tags.py | 17 ++-- 3 files changed, 61 insertions(+), 80 deletions(-) diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 7ab57149775d2..73c5965b879a4 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -13,7 +13,6 @@ import scipy.sparse as sp import pytest -from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils._testing import ( assert_raises, assert_warns, @@ -33,7 +32,7 @@ from scipy.stats import bernoulli, expon, uniform from sklearn.base import BaseEstimator, ClassifierMixin -from sklearn.base import clone +from sklearn.base import clone, is_classifier from sklearn.exceptions import NotFittedError from sklearn.datasets import make_classification from sklearn.datasets import make_blobs @@ -69,9 +68,11 @@ from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score from sklearn.metrics import confusion_matrix +from sklearn.metrics import r2_score +from sklearn.metrics import accuracy_score from sklearn.metrics.pairwise import euclidean_distances from sklearn.impute import SimpleImputer -from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.pipeline import Pipeline from sklearn.linear_model import Ridge, SGDClassifier, LinearRegression from sklearn.experimental import enable_hist_gradient_boosting # noqa from sklearn.ensemble import HistGradientBoostingClassifier @@ -2087,54 +2088,33 @@ def _fit_param_callable(): model.fit(X_train, y_train, **fit_params) -def _generate_search_cv_using_minimal_compatible_instances(): - """Generate instance containing estimators from minimal class compatible - implementation that should be supported by `SearhCV`.""" - for SearchCV, (Estimator, param_grid) in zip( - [GridSearchCV, RandomizedSearchCV], - [ - (MinimalRegressor, {"param": [1, 10]}), - (MinimalClassifier, {"param": [1, 10]}), - ], - ): - yield SearchCV(Estimator(), param_grid) - - for SearchCV, (Estimator, param_grid) in zip( - [GridSearchCV, RandomizedSearchCV], - [ - ( - MinimalRegressor, - { - "minimaltransformer__param": [1, 10], - "minimalregressor__param": [1, 10], - }, - ), - ( - MinimalClassifier, - { - "minimaltransformer__param": [1, 10], - "minimalclassifier__param": [1, 10], - }, - ), - ], - ): - yield SearchCV( - make_pipeline(MinimalTransformer(), Estimator()), param_grid - ).set_params(error_score="raise") - - -# FIXME: hopefully in 0.25 -@pytest.mark.xfail( - reason=( - "This test is currently failing because checks are granular " - "enough. Once checks are split with some kind of only API tests, " - "this test should enabled." - ) -) -@parametrize_with_checks( - list(_generate_search_cv_using_minimal_compatible_instances()) -) -def test_search_cv_using_minimal_compatible_estimator(estimator, check): +# FIXME: Replace this test with a full `check_estimator` once we have API only +# checks. +@pytest.mark.filterwarnings("ignore:The total space of parameters 4 is") +@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV]) +@pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier]) +def test_search_cv_using_minimal_compatible_estimator(SearchCV, Predictor): # Check that third-party library can run tests without inheriting from # BaseEstimator. - check(estimator) + rng = np.random.RandomState(0) + X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20) + + model = Pipeline([ + ("transformer", MinimalTransformer()), ("predictor", Predictor()) + ]) + + params = { + "transformer__param": [1, 10], "predictor__parama": [1, 10], + } + search = SearchCV(model, params, error_score="raise") + search.fit(X, y) + + assert search.best_params_.keys() == params.keys() + + y_pred = search.predict(X) + if is_classifier(search): + assert_array_equal(y_pred, 1) + assert search.score(X, y) == pytest.approx(accuracy_score(y, y_pred)) + else: + assert_allclose(y_pred, y.mean()) + assert search.score(X, y) == pytest.approx(r2_score(y, y_pred)) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index dc00ba04ec1f6..6ba83c1a975df 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -12,7 +12,6 @@ from scipy import sparse import joblib -from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.fixes import parse_version from sklearn.utils._testing import ( assert_raises, @@ -27,12 +26,13 @@ MinimalTransformer, ) -from sklearn.base import clone, BaseEstimator, TransformerMixin +from sklearn.base import clone, is_classifier, BaseEstimator, TransformerMixin from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union from sklearn.svm import SVC from sklearn.neighbors import LocalOutlierFactor from sklearn.linear_model import LogisticRegression, Lasso from sklearn.linear_model import LinearRegression +from sklearn.metrics import accuracy_score, r2_score from sklearn.cluster import KMeans from sklearn.feature_selection import SelectKBest, f_classif from sklearn.dummy import DummyRegressor @@ -1289,18 +1289,24 @@ def _generate_pipeline_using_minimal_compatible_instances(): yield make_pipeline(MinimalTransformer(), Predictor()) -# FIXME: hopefully in 0.25 -@pytest.mark.xfail( - reason=( - "This test is currently failing because checks are granular " - "enough. Once checks are split with some kind of only API tests, " - "this test should enabled." - ) -) -@parametrize_with_checks( - list(_generate_pipeline_using_minimal_compatible_instances()) -) -def test_pipeline_using_minimal_compatible_estimator(estimator, check): - # check that we can pass minimal estimator implementation within a - # pipeline. - check(estimator) +# FIXME: Replace this test with a full `check_estimator` once we have API only +# checks. +@pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier]) +def test_search_cv_using_minimal_compatible_estimator(Predictor): + # Check that third-party library can run tests without inheriting from + # BaseEstimator. + rng = np.random.RandomState(0) + X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20) + + model = Pipeline([ + ("transformer", MinimalTransformer()), ("predictor", Predictor()) + ]) + model.fit(X, y) + + y_pred = model.predict(X) + if is_classifier(model): + assert_array_equal(y_pred, 1) + assert model.score(X, y) == pytest.approx(accuracy_score(y, y_pred)) + else: + assert_allclose(y_pred, y.mean()) + assert model.score(X, y) == pytest.approx(r2_score(y, y_pred)) diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index fa802ca77f75a..9930b6000251d 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -7,20 +7,15 @@ ) -class EstimatorNoTags: +class NoTagsEstimator: pass -class EstimatorOwnTags: - def _get_tags(self): - return {} - - @pytest.mark.parametrize( "estimator, err_msg", [ (BaseEstimator(), "The key xxx is neither defined"), - (EstimatorNoTags(), "The key xxx is not a default tags defined"), + (NoTagsEstimator(), "The key xxx is not a default tags defined"), ], ) def test_safe_tags_error(estimator, err_msg): @@ -32,15 +27,15 @@ def test_safe_tags_error(estimator, err_msg): @pytest.mark.parametrize( "estimator, key, default, expected_results", [ - (EstimatorNoTags(), None, None, _DEFAULT_TAGS), - (EstimatorNoTags(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), + (NoTagsEstimator(), None, None, _DEFAULT_TAGS), + (NoTagsEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), ( - EstimatorNoTags(), + NoTagsEstimator(), "allow_nan", not _DEFAULT_TAGS["allow_nan"], not _DEFAULT_TAGS["allow_nan"], ), - (EstimatorNoTags(), "xxx", True, True), + (NoTagsEstimator(), "xxx", True, True), (BaseEstimator(), None, None, _DEFAULT_TAGS), (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), ( From eb9c41b4a6f4d62fdf8ca8e46ec81dc9ae295e6b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Nov 2020 20:34:26 +0100 Subject: [PATCH 41/50] PEP8 --- sklearn/model_selection/tests/test_search.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 73c5965b879a4..b1194600c530d 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -69,7 +69,6 @@ from sklearn.metrics import roc_auc_score from sklearn.metrics import confusion_matrix from sklearn.metrics import r2_score -from sklearn.metrics import accuracy_score from sklearn.metrics.pairwise import euclidean_distances from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline From 82270753f52d4a1a8be405cd978d5fef8c231d9b Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 1 Dec 2020 10:57:30 +0100 Subject: [PATCH 42/50] Rephrase test comment [ci skip] --- sklearn/tests/test_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 6ba83c1a975df..f60afce94f6c5 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1293,8 +1293,8 @@ def _generate_pipeline_using_minimal_compatible_instances(): # checks. @pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier]) def test_search_cv_using_minimal_compatible_estimator(Predictor): - # Check that third-party library can run tests without inheriting from - # BaseEstimator. + # Check that third-party library estimators can be part of a pipeline + # and tuned by grid-search without inheriting from BaseEstimator. rng = np.random.RandomState(0) X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20) From 07dace19312668f876c41680e030c0d4e4b1f7e2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 1 Dec 2020 16:49:01 +0100 Subject: [PATCH 43/50] iter --- sklearn/utils/_tags.py | 28 ++++++++++------------------ sklearn/utils/tests/test_tags.py | 28 ++++++++-------------------- 2 files changed, 18 insertions(+), 38 deletions(-) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index 37753123009c4..c713bdf0cf083 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -22,7 +22,7 @@ } -def _safe_tags(estimator, key=None, default=None): +def _safe_tags(estimator, key=None): """Safely get estimator tags. :class:`~sklearn.BaseEstimator` provides the estimator tags machinery. @@ -42,13 +42,6 @@ def _safe_tags(estimator, key=None, default=None): key : str, default=None Tag name to get. By default (`None`), all tags are returned. - default : list of {str, dtype} or bool, default=None - When `esimator.get_tags()` is not implemented, default` allows to - define the default value of a tag if it is not present in - `_DEFAULT_TAGS` or to overwrite the value in `_DEFAULT_TAGS` if it the - tag is defined. When `default is None`, no default values nor - overwriting will take place. - Returns ------- tags : dict or tag value @@ -68,16 +61,15 @@ def _safe_tags(estimator, key=None, default=None): return estimator._get_tags() else: if key is not None: - if default is None: - try: - default = _DEFAULT_TAGS[key] - except KeyError as exc: - raise ValueError( - f"The key {key} is not a default tags defined in " - f"_DEFAULT_TAGS and thus no default values are " - f"available. Use the parameter default if you want to " - f"define a default value." - ) from exc + try: + default = _DEFAULT_TAGS[key] + except KeyError as exc: + raise ValueError( + f"The key {key} is not a default tags defined in " + f"_DEFAULT_TAGS and thus no default values are " + f"available. Use the parameter default if you want to " + f"define a default value." + ) from exc return default else: return _DEFAULT_TAGS diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index 9930b6000251d..0007eed35b20c 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -25,28 +25,16 @@ def test_safe_tags_error(estimator, err_msg): @pytest.mark.parametrize( - "estimator, key, default, expected_results", + "estimator, key, expected_results", [ - (NoTagsEstimator(), None, None, _DEFAULT_TAGS), - (NoTagsEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - ( - NoTagsEstimator(), - "allow_nan", - not _DEFAULT_TAGS["allow_nan"], - not _DEFAULT_TAGS["allow_nan"], - ), - (NoTagsEstimator(), "xxx", True, True), - (BaseEstimator(), None, None, _DEFAULT_TAGS), - (BaseEstimator(), "allow_nan", None, _DEFAULT_TAGS["allow_nan"]), - ( - BaseEstimator(), - "allow_nan", - not _DEFAULT_TAGS["allow_nan"], - _DEFAULT_TAGS["allow_nan"], - ), + (NoTagsEstimator(), None, _DEFAULT_TAGS), + (NoTagsEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), + (BaseEstimator(), None, _DEFAULT_TAGS), + (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), + (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), ], ) -def test_safe_tags_no_get_tags(estimator, key, default, expected_results): +def test_safe_tags_no_get_tags(estimator, key, expected_results): # check the behaviour of _safe_tags when an estimator does not implement # _get_tags - assert _safe_tags(estimator, key=key, default=default) == expected_results + assert _safe_tags(estimator, key=key) == expected_results From d76b684de1154e91e367188b8ecb0a75c65dd107 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 1 Dec 2020 17:33:26 +0100 Subject: [PATCH 44/50] iter --- sklearn/feature_selection/_base.py | 5 +-- sklearn/feature_selection/_from_model.py | 3 +- sklearn/feature_selection/_rfe.py | 3 +- sklearn/feature_selection/_sequential.py | 3 +- sklearn/multiclass.py | 3 +- sklearn/utils/_tags.py | 40 ++++++++++-------------- sklearn/utils/tests/test_tags.py | 11 +++++-- 7 files changed, 30 insertions(+), 38 deletions(-) diff --git a/sklearn/feature_selection/_base.py b/sklearn/feature_selection/_base.py index 8993312bd2cec..45c6fc9454069 100644 --- a/sklearn/feature_selection/_base.py +++ b/sklearn/feature_selection/_base.py @@ -77,14 +77,11 @@ def transform(self, X): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ - # note: we use _safe_tags instead of _get_tags because this is a - # public Mixin. - force_all_finite = not _safe_tags(self, key="allow_nan", default=True) X = check_array( X, dtype=None, accept_sparse="csr", - force_all_finite=force_all_finite, + force_all_finite=not _safe_tags(self, key="allow_nan"), ) mask = self.get_support() if not mask.any(): diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index 730f6fac6833e..4b96804fbcc45 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -285,6 +285,5 @@ def n_features_in_(self): def _more_tags(self): return { - 'allow_nan': - _safe_tags(self.estimator, key="allow_nan", default=True) + 'allow_nan': _safe_tags(self.estimator, key="allow_nan") } diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index a76623659631f..16519dfba6761 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -375,8 +375,7 @@ def predict_log_proba(self, X): def _more_tags(self): return { 'poor_score': True, - 'allow_nan': - _safe_tags(self.estimator, key='allow_nan', default=True), + 'allow_nan': _safe_tags(self.estimator, key='allow_nan'), 'requires_y': True, } diff --git a/sklearn/feature_selection/_sequential.py b/sklearn/feature_selection/_sequential.py index 6edeee74238a5..271bc0062ef6b 100644 --- a/sklearn/feature_selection/_sequential.py +++ b/sklearn/feature_selection/_sequential.py @@ -208,7 +208,6 @@ def _get_support_mask(self): def _more_tags(self): return { - 'allow_nan': - _safe_tags(self.estimator, key="allow_nan", default=True), + 'allow_nan': _safe_tags(self.estimator, key="allow_nan"), 'requires_y': True, } diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 3209b0fc274a3..182a412f8313f 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -781,8 +781,7 @@ def _pairwise(self): def _more_tags(self): """Indicate if wrapped estimator is using a precomputed Gram matrix""" return { - 'pairwise': - _safe_tags(self.estimator, key="pairwise", default=True) + 'pairwise': _safe_tags(self.estimator, key="pairwise") } diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index c713bdf0cf083..e0baf6f23259d 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -48,28 +48,20 @@ def _safe_tags(estimator, key=None): The estimator tags. A single value is returned if `key` is not None. """ if hasattr(estimator, "_get_tags"): - if key is not None: - try: - return estimator._get_tags().get(key, _DEFAULT_TAGS[key]) - except KeyError as exc: - raise ValueError( - f"The key {key} is neither defined in _more_tags() in " - f"the class {repr(estimator)} nor a default tag key in " - f"_DEFAULT_TAGS." - ) from exc - else: - return estimator._get_tags() + tags_provider = "_get_tags()" + tags = estimator._get_tags() + elif hasattr(estimator, "_more_tags"): + tags_provider = "_more_tags()" + tags = {**_DEFAULT_TAGS, **estimator._more_tags()} else: - if key is not None: - try: - default = _DEFAULT_TAGS[key] - except KeyError as exc: - raise ValueError( - f"The key {key} is not a default tags defined in " - f"_DEFAULT_TAGS and thus no default values are " - f"available. Use the parameter default if you want to " - f"define a default value." - ) from exc - return default - else: - return _DEFAULT_TAGS + tags_provider = "_DEFAULT_TAGS" + tags = _DEFAULT_TAGS + + if key is not None: + if key not in tags: + raise ValueError( + f"The key {key} is not defined in {tags_provider} for the " + f"class {repr(estimator)}." + ) + return tags[key] + return tags diff --git a/sklearn/utils/tests/test_tags.py b/sklearn/utils/tests/test_tags.py index 0007eed35b20c..f96a4947164c3 100644 --- a/sklearn/utils/tests/test_tags.py +++ b/sklearn/utils/tests/test_tags.py @@ -11,11 +11,16 @@ class NoTagsEstimator: pass +class MoreTagsEstimator: + def _more_tags(self): + return {"allow_nan": True} + + @pytest.mark.parametrize( "estimator, err_msg", [ - (BaseEstimator(), "The key xxx is neither defined"), - (NoTagsEstimator(), "The key xxx is not a default tags defined"), + (BaseEstimator(), "The key xxx is not defined in _get_tags"), + (NoTagsEstimator(), "The key xxx is not defined in _DEFAULT_TAGS"), ], ) def test_safe_tags_error(estimator, err_msg): @@ -29,6 +34,8 @@ def test_safe_tags_error(estimator, err_msg): [ (NoTagsEstimator(), None, _DEFAULT_TAGS), (NoTagsEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), + (MoreTagsEstimator(), None, {**_DEFAULT_TAGS, **{"allow_nan": True}}), + (MoreTagsEstimator(), "allow_nan", True), (BaseEstimator(), None, _DEFAULT_TAGS), (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), (BaseEstimator(), "allow_nan", _DEFAULT_TAGS["allow_nan"]), From e8fa827a5718358ea39f87bb40d3b69cbc683747 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 1 Dec 2020 17:42:30 +0100 Subject: [PATCH 45/50] iter --- sklearn/feature_selection/_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/feature_selection/_base.py b/sklearn/feature_selection/_base.py index 45c6fc9454069..60f891b69e2b7 100644 --- a/sklearn/feature_selection/_base.py +++ b/sklearn/feature_selection/_base.py @@ -77,6 +77,8 @@ def transform(self, X): X_r : array of shape [n_samples, n_selected_features] The input samples with only the selected features. """ + # note: we use _safe_tags instead of _get_tags because this is a + # public Mixin. X = check_array( X, dtype=None, From ed269680beee2968193a81c3eb47e650aeedd320 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 1 Dec 2020 18:11:11 +0100 Subject: [PATCH 46/50] update doc --- doc/developers/develop.rst | 42 ++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index b3a0d0a96df38..24ec353085b4d 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -618,30 +618,24 @@ X_types (default=['2darray']) ``'categorical'`` data. For now, the test for sparse data do not make use of the ``'sparse'`` tag. -It is unlikely that the default values for each tag will suit the needs of -your specific estimator. There are two ways to override the defaults in your -own estimator: - -* If your estimator inherits from :class:`~sklearn.base.BaseEstimator`, which - is recommended, you can define a `_more_tags()` method which returns a dict - with the desired overridden tags. For example:: - - class MyMultiOutputEstimator(BaseEstimator): - - def _more_tags(self): - return {'multioutput_only': True, - 'non_deterministic': True} - - Any tag that is not in `_more_tags()` will just fall-back to the default - values documented above. - -* If your estimator does not inherit from :class:`~sklearn.base.BaseEstimator`, - you will need to implement a `_get_tags()` method which returns a dict that - should contains all the necessary tags for that estimator, including the - default tags typically defined in :class:`~sklearn.base.BaseEstimator` and - other scikit-learn mixin classes. Note however that **all tags must be - present in the dict**. If any of the keys documented above is not present in - the output of `_get_tags()`, an error might occur. +It is unlikely that the default values for each tag will suit the needs of your +specific estimator. Additional tags can be created or default tags can be +overridden by defining a `_more_tags()` method which returns a dict with the +desired overridden tags or new tags. For example:: + +class MyMultiOutputEstimator(BaseEstimator): + + def _more_tags(self): + return {'multioutput_only': True, + 'non_deterministic': True} + +Any tag that is not in `_more_tags()` will just fall-back to the default values +documented above. + +Even if it is not recommended, it is possible to override the method +`_get_tags()`. Note however that **all tags must be present in the dict**. If +any of the keys documented above is not present in the output of `_get_tags()`, +an error will occur. In addition to the tags, estimators also need to declare any non-optional parameters to ``__init__`` in the ``_required_parameters`` class attribute, From b8ecc41a5fe4d9fad44335b50c3dd2c155a20d3c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 1 Dec 2020 18:46:38 +0100 Subject: [PATCH 47/50] fix test --- sklearn/feature_selection/tests/test_rfe.py | 3 +++ .../model_selection/tests/test_validation.py | 12 ------------ sklearn/tests/test_base.py | 17 +---------------- 3 files changed, 4 insertions(+), 28 deletions(-) diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 553c709f8983a..9e6dfdbbd593a 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -56,6 +56,9 @@ def get_params(self, deep=True): def set_params(self, **params): return self + def _more_tags(self): + return {"allow_nan": True} + def test_rfe_features_importance(): generator = check_random_state(0) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index e8fdbc1dd32df..4437a7a4cb35c 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -1985,15 +1985,3 @@ def _more_tags(self): "Set the estimator tags of your estimator instead") with pytest.warns(FutureWarning, match=msg): cross_validate(svm, linear_kernel, y, cv=2) - - # the _pairwise attribute is present and set to True while the pairwise - # tag is not present - class NoEstimatorTagSVM(SVC): - def _get_tags(self): - tags = super()._get_tags() - del tags['pairwise'] - return tags - - svm = NoEstimatorTagSVM(kernel='precomputed') - with pytest.warns(FutureWarning, match=msg): - cross_validate(svm, linear_kernel, y, cv=2) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index b8d78a96d8e85..0c07db459d128 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -558,24 +558,9 @@ class IncorrectTagPCA(KernelPCA): with pytest.warns(FutureWarning, match=msg): assert not _is_pairwise(pca) - # the _pairwise attribute is present and set to False while the pairwise - # tag is not present - class FalsePairwise(BaseEstimator): - _pairwise = False - - def _get_tags(self): - tags = super()._get_tags() - del tags['pairwise'] - return tags - - false_pairwise = FalsePairwise() - with pytest.warns(None) as record: - assert not _is_pairwise(false_pairwise) - assert not record - # the _pairwise attribute is present and set to True while pairwise tag is # not present - class TruePairwise(FalsePairwise): + class TruePairwise(BaseEstimator): _pairwise = True true_pairwise = TruePairwise() From bb1079113a9c99a914db1b12e2bcefbb4513c2ee Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 1 Dec 2020 18:52:10 +0100 Subject: [PATCH 48/50] doc --- doc/developers/develop.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 24ec353085b4d..08ce24933dd8e 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -623,11 +623,11 @@ specific estimator. Additional tags can be created or default tags can be overridden by defining a `_more_tags()` method which returns a dict with the desired overridden tags or new tags. For example:: -class MyMultiOutputEstimator(BaseEstimator): + class MyMultiOutputEstimator(BaseEstimator): - def _more_tags(self): - return {'multioutput_only': True, - 'non_deterministic': True} + def _more_tags(self): + return {'multioutput_only': True, + 'non_deterministic': True} Any tag that is not in `_more_tags()` will just fall-back to the default values documented above. From b19137d6025b7582a658bdd0442912f8dca664cf Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 2 Dec 2020 09:16:15 +0100 Subject: [PATCH 49/50] answer ogrisel comments --- sklearn/tests/test_pipeline.py | 7 ------- sklearn/utils/_tags.py | 2 +- sklearn/utils/_testing.py | 26 ++------------------------ 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index f60afce94f6c5..7989394d0a65e 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1282,13 +1282,6 @@ def test_pipeline_get_tags_none(passthrough): assert not pipe._get_tags()['pairwise'] -def _generate_pipeline_using_minimal_compatible_instances(): - """Generate pipeline containing estimators from minimal class compatible - implementation.""" - for Predictor in [MinimalRegressor, MinimalClassifier]: - yield make_pipeline(MinimalTransformer(), Predictor()) - - # FIXME: Replace this test with a full `check_estimator` once we have API only # checks. @pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier]) diff --git a/sklearn/utils/_tags.py b/sklearn/utils/_tags.py index e0baf6f23259d..ac908ec63ce82 100644 --- a/sklearn/utils/_tags.py +++ b/sklearn/utils/_tags.py @@ -61,7 +61,7 @@ def _safe_tags(estimator, key=None): if key not in tags: raise ValueError( f"The key {key} is not defined in {tags_provider} for the " - f"class {repr(estimator)}." + f"class {estimator.__class__.__name__}." ) return tags[key] return tags diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 4e29f063b22e0..779e7b6574e3e 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -888,11 +888,6 @@ class MinimalClassifier: def __init__(self, param=None): self.param = param - def __repr__(self): - # Only required when using pytest-xdist to get an id not associated - # with the memory location - return self.__class__.__name__ - def get_params(self, deep=True): return {"param": self.param} @@ -904,7 +899,6 @@ def set_params(self, **params): def fit(self, X, y): X, y = check_X_y(X, y) check_classification_targets(y) - self.n_features_in_ = X.shape[1] self.classes_, counts = np.unique(y, return_counts=True) self._most_frequent_class_idx = counts.argmax() return self @@ -912,8 +906,6 @@ def fit(self, X, y): def predict_proba(self, X): check_is_fitted(self) X = check_array(X) - if X.shape[1] != self.n_features_in_: - raise ValueError proba_shape = (X.shape[0], self.classes_.size) y_proba = np.zeros(shape=proba_shape, dtype=np.float64) y_proba[:, self._most_frequent_class_idx] = 1.0 @@ -943,11 +935,6 @@ class MinimalRegressor: def __init__(self, param=None): self.param = param - def __repr__(self): - # Only required when using pytest-xdist to get an id not associated - # with the memory location - return self.__class__.__name__ - def get_params(self, deep=True): return {"param": self.param} @@ -958,15 +945,13 @@ def set_params(self, **params): def fit(self, X, y): X, y = check_X_y(X, y) - self.n_features_in_ = X.shape[1] + self.is_fitted_ = True self._mean = np.mean(y) return self def predict(self, X): check_is_fitted(self) X = check_array(X) - if X.shape[1] != self.n_features_in_: - raise ValueError return np.ones(shape=(X.shape[0],)) * self._mean def score(self, X, y): @@ -988,11 +973,6 @@ class MinimalTransformer: def __init__(self, param=None): self.param = param - def __repr__(self): - # Only required when using pytest-xdist to get an id not associated - # with the memory location - return self.__class__.__name__ - def get_params(self, deep=True): return {"param": self.param} @@ -1003,14 +983,12 @@ def set_params(self, **params): def fit(self, X, y=None): X = check_array(X) - self.n_features_in_ = X.shape[1] + self.is_fitted_ = True return self def transform(self, X, y=None): check_is_fitted(self) X = check_array(X) - if X.shape[1] != self.n_features_in_: - raise ValueError return X def fit_transform(self, X, y=None): From 754539f98478155a20d9cab06e55574474444ea7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 2 Dec 2020 09:22:07 +0100 Subject: [PATCH 50/50] more coverage --- sklearn/utils/tests/test_estimator_checks.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 34b35f4670915..8fabe5f91ea31 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -661,6 +661,12 @@ def test_check_estimator_get_tags_default_keys(): estimator, ) + # noop check when _get_tags is not available + estimator = MinimalTransformer() + check_estimator_get_tags_default_keys( + estimator.__class__.__name__, estimator + ) + def run_tests_without_pytest(): """Runs the tests in this file without using pytest.