Skip to content

MAINT move _estimator_has function to utils #29319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9b9da78
move _estimator_has to utils
nithish08 Jun 20, 2024
373ce79
Merge remote-tracking branch 'upstream/main' into nb/est_refactor
nithish08 Jun 20, 2024
bc45e95
Merge branch 'main' into nb/est_refactor
nithish08 Jun 20, 2024
b6cffa2
Merge branch 'main' into nb/est_refactor
nithish08 Jun 21, 2024
067d95b
update the _estimator_has function and add unittest
nithish08 Jun 23, 2024
bb5b951
Merge remote-tracking branch 'origin/nb/est_refactor' into nb/est_ref…
nithish08 Jun 23, 2024
55afdc0
Update sklearn/utils/validation.py
nithish08 Jun 23, 2024
27ba6b8
Update sklearn/utils/validation.py
nithish08 Jun 23, 2024
3dc544d
update _estimator_has pytests and refactor
nithish08 Jun 29, 2024
fdb02e0
Merge remote-tracking branch 'upstream/main' into nb/est_refactor
nithish08 Jun 29, 2024
70ce573
pull changes from main branch
nithish08 Jun 29, 2024
86fa443
Merge branch 'main' into nb/est_refactor
nithish08 Jul 1, 2024
2ee2615
Merge branch 'main' into nb/est_refactor
nithish08 Jul 2, 2024
3896d16
Merge branch 'main' into nb/est_refactor
nithish08 Jul 3, 2024
54a842a
Merge branch 'main' into nb/est_refactor
nithish08 Jul 4, 2024
ebf8dab
Update sklearn/utils/validation.py
nithish08 Jul 5, 2024
ef4b89a
Update sklearn/utils/validation.py
nithish08 Jul 5, 2024
d9c150e
Update sklearn/utils/validation.py
nithish08 Jul 5, 2024
88aedee
Merge branch 'main' into nb/est_refactor
nithish08 Jul 5, 2024
fa7b06e
Merge remote-tracking branch 'upstream/main' into nb/est_refactor
nithish08 Jul 6, 2024
4a75e99
Merge branch 'nb/est_refactor' of github.com:nithish08/scikit-learn i…
nithish08 Jul 6, 2024
a8ed3c8
parameterize test_estimator_has pytests
nithish08 Jul 7, 2024
5d938bb
add docstring and comments for _estimator_has function and testcases
nithish08 Jul 7, 2024
a50e3c7
add comments for _estimator_has pytest
nithish08 Jul 7, 2024
fc50c40
refactor _estimator_has function
nithish08 Jul 7, 2024
a054f92
refactor test_estimator_has function
nithish08 Jul 7, 2024
1c9627d
Update sklearn/utils/validation.py
nithish08 Jul 8, 2024
34883b1
Merge branch 'main' into nb/est_refactor
nithish08 Jul 8, 2024
b28c64a
refactor test_estimator_has function
nithish08 Jul 9, 2024
f914e58
Merge branch 'main' into nb/est_refactor
nithish08 Jul 10, 2024
0cef783
Merge branch 'main' into nb/est_refactor
nithish08 Jul 10, 2024
0d87c28
Merge branch 'main' into nb/est_refactor
nithish08 Jul 11, 2024
398f319
Guillaume's comments
adrinjalali Oct 31, 2024
c74d47f
Merge remote-tracking branch 'upstream/main' into nb/est_refactor
adrinjalali Oct 31, 2024
7c6f4f9
Merge branch 'main' into nb/est_refactor
nithish08 Nov 1, 2024
17bf5bc
Merge remote-tracking branch 'upstream/main' into nb/est_refactor
adrinjalali Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_check_method_params,
_check_sample_weight,
_deprecate_positional_args,
_estimator_has,
check_is_fitted,
has_fit_parameter,
validate_data,
Expand Down Expand Up @@ -269,22 +270,6 @@ def _parallel_predict_regression(estimators, estimators_features, X):
)


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the first fitted estimator if available, otherwise we
check the estimator attribute.
"""

def check(self):
if hasattr(self, "estimators_"):
return hasattr(self.estimators_[0], attr)
else: # self.estimator is not None
return hasattr(self.estimator, attr)

return check


class BaseBagging(BaseEnsemble, metaclass=ABCMeta):
"""Base class for Bagging meta-estimator.

Expand Down Expand Up @@ -1033,7 +1018,9 @@ def predict_log_proba(self, X):

return log_proba

@available_if(_estimator_has("decision_function"))
@available_if(
_estimator_has("decision_function", delegates=("estimators_", "estimator"))
)
def decision_function(self, X):
"""Average of the decision functions of the base classifiers.

Expand Down
44 changes: 20 additions & 24 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,13 @@
_check_feature_names_in,
_check_response_method,
_deprecate_positional_args,
_estimator_has,
check_is_fitted,
column_or_1d,
)
from ._base import _BaseHeterogeneousEnsemble, _fit_single_estimator


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the fitted `final_estimator_` if available, otherwise we check the
unfitted `final_estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "final_estimator_"):
getattr(self.final_estimator_, attr)
else:
getattr(self.final_estimator, attr)

return True

return check


class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta):
"""Base class for stacking method."""

Expand Down Expand Up @@ -364,7 +346,9 @@ def get_feature_names_out(self, input_features=None):

return np.asarray(meta_names, dtype=object)

@available_if(_estimator_has("predict"))
@available_if(
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
)
def predict(self, X, **predict_params):
"""Predict target for X.

Expand Down Expand Up @@ -732,7 +716,9 @@ def fit(self, X, y, *, sample_weight=None, **fit_params):
fit_params["sample_weight"] = sample_weight
return super().fit(X, y_encoded, **fit_params)

@available_if(_estimator_has("predict"))
@available_if(
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
)
def predict(self, X, **predict_params):
"""Predict target for X.

Expand Down Expand Up @@ -785,7 +771,11 @@ def predict(self, X, **predict_params):
y_pred = self._label_encoder.inverse_transform(y_pred)
return y_pred

@available_if(_estimator_has("predict_proba"))
@available_if(
_estimator_has(
"predict_proba", delegates=("final_estimator_", "final_estimator")
)
)
def predict_proba(self, X):
"""Predict class probabilities for `X` using the final estimator.

Expand All @@ -809,7 +799,11 @@ def predict_proba(self, X):
y_pred = np.array([preds[:, 0] for preds in y_pred]).T
return y_pred

@available_if(_estimator_has("decision_function"))
@available_if(
_estimator_has(
"decision_function", delegates=("final_estimator_", "final_estimator")
)
)
def decision_function(self, X):
"""Decision function for samples in `X` using the final estimator.

Expand Down Expand Up @@ -1125,7 +1119,9 @@ def fit_transform(self, X, y, *, sample_weight=None, **fit_params):
fit_params["sample_weight"] = sample_weight
return super().fit_transform(X, y, **fit_params)

@available_if(_estimator_has("predict"))
@available_if(
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
)
def predict(self, X, **predict_params):
"""Predict target for X.

Expand Down
20 changes: 1 addition & 19 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..utils.metaestimators import available_if
from ..utils.validation import (
_check_feature_names,
_estimator_has,
_num_features,
check_is_fitted,
check_scalar,
Expand Down Expand Up @@ -76,25 +77,6 @@ def _calculate_threshold(estimator, importances, threshold):
return threshold


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the fitted `estimator_` if available, otherwise we check the
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
"""Meta-transformer for selecting features based on importance weights.

Expand Down
20 changes: 1 addition & 19 deletions sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..utils.validation import (
_check_method_params,
_deprecate_positional_args,
_estimator_has,
check_is_fitted,
validate_data,
)
Expand Down Expand Up @@ -64,25 +65,6 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer, routed_params):
return rfe.step_scores_, rfe.step_n_features_


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the fitted `estimator_` if available, otherwise we check the
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
"""Feature ranking with recursive feature elimination.

Expand Down
18 changes: 1 addition & 17 deletions sklearn/model_selection/_classification_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..utils.parallel import Parallel, delayed
from ..utils.validation import (
_check_method_params,
_estimator_has,
_num_samples,
check_is_fitted,
indexable,
Expand All @@ -50,23 +51,6 @@ def _check_is_fitted(estimator):
check_is_fitted(estimator, "estimator_")


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the fitted estimator if available, otherwise we
check the unfitted estimator.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)
return True

return check


class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
"""Base class for binary classifiers that set a non-default decision threshold.

Expand Down
18 changes: 9 additions & 9 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _check_refit(search_cv, attr):
)


def _estimator_has(attr):
def _search_estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

Calling a prediction method will only be available if `refit=True`. In
Expand Down Expand Up @@ -555,7 +555,7 @@ def score(self, X, y=None, **params):
score = score[self.refit]
return score

@available_if(_estimator_has("score_samples"))
@available_if(_search_estimator_has("score_samples"))
def score_samples(self, X):
"""Call score_samples on the estimator with the best found parameters.

Expand All @@ -578,7 +578,7 @@ def score_samples(self, X):
check_is_fitted(self)
return self.best_estimator_.score_samples(X)

@available_if(_estimator_has("predict"))
@available_if(_search_estimator_has("predict"))
def predict(self, X):
"""Call predict on the estimator with the best found parameters.

Expand All @@ -600,7 +600,7 @@ def predict(self, X):
check_is_fitted(self)
return self.best_estimator_.predict(X)

@available_if(_estimator_has("predict_proba"))
@available_if(_search_estimator_has("predict_proba"))
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.

Expand All @@ -623,7 +623,7 @@ def predict_proba(self, X):
check_is_fitted(self)
return self.best_estimator_.predict_proba(X)

@available_if(_estimator_has("predict_log_proba"))
@available_if(_search_estimator_has("predict_log_proba"))
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.

Expand All @@ -646,7 +646,7 @@ def predict_log_proba(self, X):
check_is_fitted(self)
return self.best_estimator_.predict_log_proba(X)

@available_if(_estimator_has("decision_function"))
@available_if(_search_estimator_has("decision_function"))
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.

Expand All @@ -669,7 +669,7 @@ def decision_function(self, X):
check_is_fitted(self)
return self.best_estimator_.decision_function(X)

@available_if(_estimator_has("transform"))
@available_if(_search_estimator_has("transform"))
def transform(self, X):
"""Call transform on the estimator with the best found parameters.

Expand All @@ -691,7 +691,7 @@ def transform(self, X):
check_is_fitted(self)
return self.best_estimator_.transform(X)

@available_if(_estimator_has("inverse_transform"))
@available_if(_search_estimator_has("inverse_transform"))
def inverse_transform(self, X=None, Xt=None):
"""Call inverse_transform on the estimator with the best found params.

Expand Down Expand Up @@ -746,7 +746,7 @@ def classes_(self):

Only available when `refit=True` and the estimator is a classifier.
"""
_estimator_has("classes_")(self)
_search_estimator_has("classes_")(self)
return self.best_estimator_.classes_

def _run_search(self, evaluate_candidates):
Expand Down
21 changes: 1 addition & 20 deletions sklearn/semi_supervised/_self_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,14 @@
process_routing,
)
from ..utils.metaestimators import available_if
from ..utils.validation import check_is_fitted, validate_data
from ..utils.validation import _estimator_has, check_is_fitted, validate_data

__all__ = ["SelfTrainingClassifier"]

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.

First, we check the fitted `estimator_` if available, otherwise we check
the unfitted `estimator`. We raise the original `AttributeError` if
`attr` does not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class SelfTrainingClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
"""Self-training classifier.

Expand Down
Loading