Skip to content

MAINT Param validation: decorate all estimators with _fit_context #26473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .utils.validation import _num_features
from .utils.validation import _check_feature_names_in
from .utils.validation import _generate_get_feature_names_out
from .utils.validation import check_is_fitted
from .utils.validation import _is_fitted, check_is_fitted
from .utils._metadata_requests import _MetadataRequester
from .utils.validation import _get_feature_names
from .utils._estimator_html_repr import estimator_html_repr
Expand Down Expand Up @@ -1131,7 +1131,13 @@ def decorator(fit_method):
@functools.wraps(fit_method)
def wrapper(estimator, *args, **kwargs):
global_skip_validation = get_config()["skip_parameter_validation"]
if not global_skip_validation:

# we don't want to validate again for each call to partial_fit
partial_fit_and_fitted = (
fit_method.__name__ == "partial_fit" and _is_fitted(estimator)
)

if not global_skip_validation and not partial_fit_and_fitted:
Comment on lines +1135 to +1140
Copy link
Contributor

@avm19 avm19 Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about computing partial_fit_and_fitted inside if not global_skip_validation:? I don't know if _is_fitted(estimator) is always fast, because it may check all parameter names of the estimator or call estimator.__sklearn_is_fitted__() if present. It might be a good idea to avoid it when not necessary.

estimator._validate_params()

with config_context(
Expand Down
7 changes: 5 additions & 2 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
RegressorMixin,
clone,
MetaEstimatorMixin,
_fit_context,
)
from .preprocessing import label_binarize, LabelEncoder
from .utils import (
Expand Down Expand Up @@ -318,6 +319,10 @@ def _get_estimator(self):

return estimator

@_fit_context(
# CalibratedClassifierCV.estimator is not validated yet
prefer_skip_nested_validation=False
)
def fit(self, X, y, sample_weight=None, **fit_params):
"""Fit the calibrated model.

Expand All @@ -341,8 +346,6 @@ def fit(self, X, y, sample_weight=None, **fit_params):
self : object
Returns an instance of self.
"""
self._validate_params()

check_classification_targets(y)
X, y = indexable(X, y)
if sample_weight is not None:
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..exceptions import ConvergenceWarning
from ..base import BaseEstimator, ClusterMixin
from ..base import _fit_context
from ..utils import check_random_state
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.validation import check_is_fitted
Expand Down Expand Up @@ -469,6 +470,7 @@ def __init__(
def _more_tags(self):
return {"pairwise": self.affinity == "precomputed"}

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Fit the clustering from features, or affinity matrix.

Expand All @@ -488,8 +490,6 @@ def fit(self, X, y=None):
self
Returns the instance itself.
"""
self._validate_params()

if self.affinity == "precomputed":
accept_sparse = False
else:
Expand Down
5 changes: 3 additions & 2 deletions sklearn/cluster/_agglomerative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from scipy.sparse.csgraph import connected_components

from ..base import BaseEstimator, ClusterMixin, ClassNamePrefixFeaturesOutMixin
from ..base import _fit_context
from ..metrics.pairwise import paired_distances
from ..metrics.pairwise import _VALID_METRICS
from ..metrics import DistanceMetric
Expand Down Expand Up @@ -950,6 +951,7 @@ def __init__(
self.metric = metric
self.compute_distances = compute_distances

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Fit the hierarchical clustering from features, or distance matrix.

Expand All @@ -968,7 +970,6 @@ def fit(self, X, y=None):
self : object
Returns the fitted instance.
"""
self._validate_params()
X = self._validate_data(X, ensure_min_samples=2)
return self._fit(X)

Expand Down Expand Up @@ -1324,6 +1325,7 @@ def __init__(
)
self.pooling_func = pooling_func

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Fit the hierarchical clustering on the data.

Expand All @@ -1340,7 +1342,6 @@ def fit(self, X, y=None):
self : object
Returns the transformer.
"""
self._validate_params()
X = self._validate_data(X, ensure_min_features=2)
super()._fit(X.T)
self._n_features_out = self.n_clusters_
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from . import KMeans, MiniBatchKMeans
from ..base import BaseEstimator, BiclusterMixin
from ..base import _fit_context
from ..utils import check_random_state
from ..utils import check_scalar

Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
def _check_parameters(self, n_samples):
"""Validate parameters depending on the input data."""

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Create a biclustering for X.

Expand All @@ -134,8 +136,6 @@ def fit(self, X, y=None):
self : object
SpectralBiclustering instance.
"""
self._validate_params()

X = self._validate_data(X, accept_sparse="csr", dtype=np.float64)
self._check_parameters(X.shape[0])
self._fit(X)
Expand Down
8 changes: 3 additions & 5 deletions sklearn/cluster/_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ClusterMixin,
BaseEstimator,
ClassNamePrefixFeaturesOutMixin,
_fit_context,
)
from ..utils.extmath import row_norms
from ..utils._param_validation import Interval
Expand Down Expand Up @@ -501,6 +502,7 @@ def __init__(
self.compute_labels = compute_labels
self.copy = copy

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""
Build a CF Tree for the input data.
Expand All @@ -518,9 +520,6 @@ def fit(self, X, y=None):
self
Fitted estimator.
"""

self._validate_params()

return self._fit(X, partial=False)

def _fit(self, X, partial):
Expand Down Expand Up @@ -610,6 +609,7 @@ def _get_leaves(self):
leaf_ptr = leaf_ptr.next_leaf_
return leaves

@_fit_context(prefer_skip_nested_validation=True)
def partial_fit(self, X=None, y=None):
"""
Online learning. Prevents rebuilding of CFTree from scratch.
Expand All @@ -629,8 +629,6 @@ def partial_fit(self, X=None, y=None):
self
Fitted estimator.
"""
self._validate_params()

if X is None:
# Perform just the final global clustering step.
self._global_clustering()
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_bisect_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import scipy.sparse as sp

from ..base import _fit_context
from ._kmeans import _BaseKMeans
from ._kmeans import _kmeans_single_elkan
from ._kmeans import _kmeans_single_lloyd
Expand Down Expand Up @@ -347,6 +348,7 @@ def _bisect(self, X, x_squared_norms, sample_weight, cluster_to_bisect):

cluster_to_bisect.split(best_labels, best_centers, scores)

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None, sample_weight=None):
"""Compute bisecting k-means clustering.

Expand All @@ -373,8 +375,6 @@ def fit(self, X, y=None, sample_weight=None):
self
Fitted estimator.
"""
self._validate_params()

X = self._validate_data(
X,
accept_sparse="csr",
Expand Down
7 changes: 5 additions & 2 deletions sklearn/cluster/_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ..metrics.pairwise import _VALID_METRICS
from ..base import BaseEstimator, ClusterMixin
from ..base import _fit_context
from ..utils.validation import _check_sample_weight
from ..utils._param_validation import Interval, StrOptions
from ..neighbors import NearestNeighbors
Expand Down Expand Up @@ -338,6 +339,10 @@ def __init__(
self.p = p
self.n_jobs = n_jobs

@_fit_context(
# DBSCAN.metric is not validated yet
prefer_skip_nested_validation=False
)
def fit(self, X, y=None, sample_weight=None):
"""Perform DBSCAN clustering from features, or distance matrix.

Expand All @@ -363,8 +368,6 @@ def fit(self, X, y=None, sample_weight=None):
self : object
Returns a fitted instance of self.
"""
self._validate_params()

X = self._validate_data(X, accept_sparse="csr")

if sample_weight is not None:
Expand Down
11 changes: 4 additions & 7 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ClusterMixin,
TransformerMixin,
ClassNamePrefixFeaturesOutMixin,
_fit_context,
)
from ..metrics.pairwise import euclidean_distances
from ..metrics.pairwise import _euclidean_distances
Expand Down Expand Up @@ -1448,6 +1449,7 @@ def _warn_mkl_vcomp(self, n_active_threads):
f" variable OMP_NUM_THREADS={n_active_threads}."
)

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None, sample_weight=None):
"""Compute k-means clustering.

Expand Down Expand Up @@ -1475,8 +1477,6 @@ def fit(self, X, y=None, sample_weight=None):
self : object
Fitted estimator.
"""
self._validate_params()

X = self._validate_data(
X,
accept_sparse="csr",
Expand Down Expand Up @@ -2057,6 +2057,7 @@ def _random_reassign(self):
return True
return False

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None, sample_weight=None):
"""Compute the centroids on X by chunking it into mini-batches.

Expand Down Expand Up @@ -2084,8 +2085,6 @@ def fit(self, X, y=None, sample_weight=None):
self : object
Fitted estimator.
"""
self._validate_params()

X = self._validate_data(
X,
accept_sparse="csr",
Expand Down Expand Up @@ -2214,6 +2213,7 @@ def fit(self, X, y=None, sample_weight=None):

return self

@_fit_context(prefer_skip_nested_validation=True)
def partial_fit(self, X, y=None, sample_weight=None):
"""Update k means estimate on a single mini-batch X.

Expand Down Expand Up @@ -2241,9 +2241,6 @@ def partial_fit(self, X, y=None, sample_weight=None):
"""
has_centers = hasattr(self, "cluster_centers_")

if not has_centers:
self._validate_params()

X = self._validate_data(
X,
accept_sparse="csr",
Expand Down
3 changes: 2 additions & 1 deletion sklearn/cluster/_mean_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..utils.parallel import delayed, Parallel
from ..utils import check_random_state, gen_batches, check_array
from ..base import BaseEstimator, ClusterMixin
from ..base import _fit_context
from ..neighbors import NearestNeighbors
from ..metrics.pairwise import pairwise_distances_argmin
from .._config import config_context
Expand Down Expand Up @@ -435,6 +436,7 @@ def __init__(
self.n_jobs = n_jobs
self.max_iter = max_iter

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Perform clustering.

Expand All @@ -451,7 +453,6 @@ def fit(self, X, y=None):
self : object
Fitted instance.
"""
self._validate_params()
X = self._validate_data(X)
bandwidth = self.bandwidth
if bandwidth is None:
Expand Down
7 changes: 5 additions & 2 deletions sklearn/cluster/_optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..utils.validation import check_memory
from ..neighbors import NearestNeighbors
from ..base import BaseEstimator, ClusterMixin
from ..base import _fit_context
from ..metrics import pairwise_distances
from scipy.sparse import issparse, SparseEfficiencyWarning

Expand Down Expand Up @@ -288,6 +289,10 @@ def __init__(
self.memory = memory
self.n_jobs = n_jobs

@_fit_context(
# Optics.metric is not validated yet
prefer_skip_nested_validation=False
)
def fit(self, X, y=None):
"""Perform OPTICS clustering.

Expand All @@ -311,8 +316,6 @@ def fit(self, X, y=None):
self : object
Returns a fitted instance of self.
"""
self._validate_params()

dtype = bool if self.metric in PAIRWISE_BOOLEAN_FUNCTIONS else float
if dtype == bool and X.dtype != bool:
msg = (
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from scipy.sparse import csc_matrix

from ..base import BaseEstimator, ClusterMixin
from ..base import _fit_context
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils import check_random_state, as_float_array
from ..metrics.pairwise import pairwise_kernels, KERNEL_PARAMS
Expand Down Expand Up @@ -649,6 +650,7 @@ def __init__(
self.n_jobs = n_jobs
self.verbose = verbose

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Perform spectral clustering from features, or affinity matrix.

Expand All @@ -671,8 +673,6 @@ def fit(self, X, y=None):
self : object
A fitted instance of the estimator.
"""
self._validate_params()

X = self._validate_data(
X,
accept_sparse=["csr", "csc", "coo"],
Expand Down
Loading