Skip to content

ENH Generally avoid nested param validation #25815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.

:mod:`sklearn`
..............

- |Feature| Added a new option `skip_parameter_validation`, to the function
:func:`sklearn.set_config` and context manager :func:`sklearn.config_context`, that
allows to skip the validation of the parameters passed to the estimators and public
functions. This can be useful to speed up the code but should be used with care
because it can lead to unexpected behaviors or raise obscure error messages when
setting invalid parameters.
:pr:`25815` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.base`
...................
Expand Down
28 changes: 28 additions & 0 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"enable_cython_pairwise_dist": True,
"array_api_dispatch": False,
"transform_output": "default",
"skip_parameter_validation": False,
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -54,6 +55,7 @@ def set_config(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
skip_parameter_validation=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -134,6 +136,17 @@ def set_config(

.. versionadded:: 1.2

skip_parameter_validation : bool, default=None
If `True`, disable the validation of the hyper-parameters' types and values in
the fit method of estimators and for arguments passed to public helper
functions. It can save time in some situations but can lead to low level
crashes and exceptions with confusing error messages.

Note that for data parameters, such as `X` and `y`, only type validation is
skipped but validation with `check_array` will continue to run.

.. versionadded:: 1.3

See Also
--------
config_context : Context manager for global scikit-learn configuration.
Expand All @@ -160,6 +173,8 @@ def set_config(
local_config["array_api_dispatch"] = array_api_dispatch
if transform_output is not None:
local_config["transform_output"] = transform_output
if skip_parameter_validation is not None:
local_config["skip_parameter_validation"] = skip_parameter_validation


@contextmanager
Expand All @@ -173,6 +188,7 @@ def config_context(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
skip_parameter_validation=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -252,6 +268,17 @@ def config_context(

.. versionadded:: 1.2

skip_parameter_validation : bool, default=None
If `True`, disable the validation of the hyper-parameters' types and values in
the fit method of estimators and for arguments passed to public helper
functions. It can save time in some situations but can lead to low level
crashes and exceptions with confusing error messages.

Note that for data parameters, such as `X` and `y`, only type validation is
skipped but validation with `check_array` will continue to run.

.. versionadded:: 1.3

Yields
------
None.
Expand Down Expand Up @@ -289,6 +316,7 @@ def config_context(
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
array_api_dispatch=array_api_dispatch,
transform_output=transform_output,
skip_parameter_validation=skip_parameter_validation,
)

try:
Expand Down
46 changes: 45 additions & 1 deletion sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# License: BSD 3 clause

import copy
import functools
import warnings
from collections import defaultdict
import platform
Expand All @@ -13,7 +14,7 @@
import numpy as np

from . import __version__
from ._config import get_config
from ._config import get_config, config_context
from .utils import _IS_32BIT
from .utils._set_output import _SetOutputMixin
from .utils._tags import (
Expand Down Expand Up @@ -1089,3 +1090,46 @@ def is_outlier_detector(estimator):
True if estimator is an outlier detector and False otherwise.
"""
return getattr(estimator, "_estimator_type", None) == "outlier_detector"


def _fit_context(*, prefer_skip_nested_validation):
"""Decorator to run the fit methods of estimators within context managers.

Parameters
----------
prefer_skip_nested_validation : bool
If True, the validation of parameters of inner estimators or functions
called during fit will be skipped.

This is useful to avoid validating many times the parameters passed by the
user from the public facing API. It's also useful to avoid validating
parameters that we pass internally to inner functions that are guaranteed to
be valid by the test suite.

It should be set to True for most estimators, except for those that receive
non-validated objects as parameters, such as meta-estimators that are given
estimator objects.

Returns
-------
decorated_fit : method
The decorated fit method.
"""

def decorator(fit_method):
@functools.wraps(fit_method)
def wrapper(estimator, *args, **kwargs):
global_skip_validation = get_config()["skip_parameter_validation"]
if not global_skip_validation:
estimator._validate_params()

with config_context(
skip_parameter_validation=(
prefer_skip_nested_validation or global_skip_validation
)
):
return fit_method(estimator, *args, **kwargs)

return wrapper

return decorator
4 changes: 2 additions & 2 deletions sklearn/decomposition/_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from joblib import effective_n_jobs

from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
from ..base import _fit_context
from ..utils import check_array, check_random_state, gen_even_slices, gen_batches
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._param_validation import validate_params
Expand Down Expand Up @@ -2318,6 +2319,7 @@ def _check_convergence(

return False

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Fit the model from data in X.

Expand All @@ -2335,8 +2337,6 @@ def fit(self, X, y=None):
self : object
Returns the instance itself.
"""
self._validate_params()

X = self._validate_data(
X, dtype=[np.float64, np.float32], order="C", copy=False
)
Expand Down
3 changes: 2 additions & 1 deletion sklearn/decomposition/_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,8 @@ def _fit_multiplicative_update(
"W": ["array-like", None],
"H": ["array-like", None],
"update_H": ["boolean"],
}
},
prefer_skip_nested_validation=False,
)
def non_negative_factorization(
X,
Expand Down
3 changes: 3 additions & 0 deletions sklearn/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"skip_parameter_validation": False,
}

# Not using as a context manager affects nothing
Expand All @@ -35,6 +36,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"skip_parameter_validation": False,
}
assert get_config()["assume_finite"] is False

Expand Down Expand Up @@ -68,6 +70,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"skip_parameter_validation": False,
}

# No positional arguments
Expand Down
27 changes: 25 additions & 2 deletions sklearn/utils/_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy.sparse import issparse
from scipy.sparse import csr_matrix

from .._config import get_config, config_context
from .validation import _is_arraylike_not_scalar


Expand Down Expand Up @@ -142,7 +143,7 @@ def make_constraint(constraint):
raise ValueError(f"Unknown constraint type: {constraint}")


def validate_params(parameter_constraints):
def validate_params(parameter_constraints, *, prefer_skip_nested_validation=False):
"""Decorator to validate types and values of functions and methods.

Parameters
Expand All @@ -154,6 +155,19 @@ def validate_params(parameter_constraints):
Note that the *args and **kwargs parameters are not validated and must not be
present in the parameter_constraints dictionary.

prefer_skip_nested_validation : bool, default=False
If True, the validation of parameters of inner estimators or functions
called by the decorated function will be skipped.

This is useful to avoid validating many times the parameters passed by the
user from the public facing API. It's also useful to avoid validating
parameters that we pass internally to inner functions that are guaranteed to
be valid by the test suite.

It should be set to True for most functions, except for those that receive
non-validated objects as parameters or that are just wrappers around classes
because they only perform a partial validation.

Returns
-------
decorated_function : function or method
Expand All @@ -168,6 +182,10 @@ def decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
global_skip_validation = get_config()["skip_parameter_validation"]
if global_skip_validation:
return func(*args, **kwargs)

func_sig = signature(func)

# Map *args/**kwargs to the function signature
Expand All @@ -188,7 +206,12 @@ def wrapper(*args, **kwargs):
)

try:
return func(*args, **kwargs)
with config_context(
skip_parameter_validation=(
prefer_skip_nested_validation or global_skip_validation
)
):
return func(*args, **kwargs)
except InvalidParameterError as e:
# When the function is just a wrapper around an estimator, we allow
# the function to delegate validation to the estimator, but we replace
Expand Down
69 changes: 69 additions & 0 deletions sklearn/utils/tests/test_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy.sparse import csr_matrix
import pytest

from sklearn._config import config_context, get_config
from sklearn.base import BaseEstimator
from sklearn.model_selection import LeaveOneOut
from sklearn.utils import deprecated
Expand Down Expand Up @@ -672,3 +673,71 @@ def test_real_not_int():
assert not isinstance(1, RealNotInt)
assert isinstance(np.float64(1), RealNotInt)
assert not isinstance(np.int64(1), RealNotInt)


def test_skip_param_validation():
"""Check that param validation can be skipped using config_context."""

@validate_params({"a": [int]})
def f(a):
pass

with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
f(a="1")

# does not raise
with config_context(skip_parameter_validation=True):
f(a="1")


@pytest.mark.parametrize("prefer_skip_nested_validation", [True, False])
def test_skip_nested_validation(prefer_skip_nested_validation):
"""Check that nested validation can be skipped."""

@validate_params({"a": [int]})
def f(a):
pass

@validate_params(
{"b": [int]},
prefer_skip_nested_validation=prefer_skip_nested_validation,
)
def g(b):
# calls f with a bad parameter type
return f(a="invalid_param_value")

# Validation for g is never skipped.
with pytest.raises(InvalidParameterError, match="The 'b' parameter"):
g(b="invalid_param_value")

if prefer_skip_nested_validation:
g(b=1) # does not raise because inner f is not validated
else:
with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
g(b=1)


@pytest.mark.parametrize(
"skip_parameter_validation, prefer_skip_nested_validation, expected_skipped",
[
(True, True, True),
(True, False, True),
(False, True, True),
(False, False, False),
],
)
def test_skip_nested_validation_and_config_context(
skip_parameter_validation, prefer_skip_nested_validation, expected_skipped
):
"""Check interaction between global skip and local skip."""

@validate_params(
{"a": [int]}, prefer_skip_nested_validation=prefer_skip_nested_validation
)
def g(a):
return get_config()["skip_parameter_validation"]

with config_context(skip_parameter_validation=skip_parameter_validation):
actual_skipped = g(1)

assert actual_skipped == expected_skipped