Skip to content

FIX support scalar values in fit_params in SearchCV #15863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Dec 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0492836
support a scalar fit param
adrinjalali Dec 11, 2019
1ae9d03
pep8
adrinjalali Dec 11, 2019
341f8e0
TST add test for desired behavior
glemaitre Dec 23, 2019
53d7a91
FIX introduce _check_fit_params to validate parameters
glemaitre Dec 23, 2019
ef64f0b
DOC update whats new
glemaitre Dec 23, 2019
d5f0ba0
Merge remote-tracking branch 'origin/master' into pr/adrinjalali/15863
glemaitre Dec 23, 2019
9567b44
TST tests both grid-search and randomize-search
glemaitre Dec 23, 2019
f340ab6
PEP8
glemaitre Dec 23, 2019
ffb7ce5
DOC revert unecessary change
glemaitre Dec 23, 2019
2b5b1db
TST add test for _check_fit_params
glemaitre Dec 23, 2019
ffbac6f
olivier comments
glemaitre Dec 23, 2019
c0216dc
TST fixes
glemaitre Dec 23, 2019
4693729
DOC whats new
glemaitre Dec 23, 2019
d7a2c19
DOC whats new
glemaitre Dec 23, 2019
52ecee4
TST revert type of error
glemaitre Dec 23, 2019
be69ce0
add olivier suggestions
glemaitre Dec 23, 2019
9b71a9c
address olivier comments
glemaitre Dec 23, 2019
46c4b9f
address thomas comments
glemaitre Dec 23, 2019
71fab3f
PEP8
glemaitre Dec 23, 2019
9a85162
comments olivier
glemaitre Dec 23, 2019
f41c808
TST fix test by passing X
glemaitre Dec 23, 2019
c989c70
avoid to call twice tocsr
glemaitre Dec 23, 2019
570dfa8
add case column/row sparse in check_fit_param
glemaitre Dec 23, 2019
444c947
provide optional indices
glemaitre Dec 23, 2019
9f47b58
TST check content when indexing params
glemaitre Dec 23, 2019
75bd0a9
PEP8
glemaitre Dec 23, 2019
c24f39d
TST update tests to check identity
glemaitre Dec 23, 2019
63679fd
stupid fix
glemaitre Dec 23, 2019
849615b
use a distribution in RandomizedSearchCV
glemaitre Dec 24, 2019
7837cdf
MNT add lightgbm to one of the CI build
glemaitre Dec 24, 2019
b98e194
move to another build
glemaitre Dec 24, 2019
3127d2b
do not install dependencies lightgbm
glemaitre Dec 24, 2019
a096a7d
MNT comments on the CI setup
glemaitre Dec 24, 2019
18b1207
address some comments
glemaitre Dec 27, 2019
74d70e7
Test fit_params compat without dependency on lightgbm
ogrisel Dec 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ jobs:
JOBLIB_VERSION: '0.12.3'
COVERAGE: 'true'
# Linux environment to test the latest available dependencies and MKL.
# It runs tests requiring pandas and PyAMG.
# It runs tests requiring lightgbm, pandas and PyAMG.
pylatest_pip_openblas_pandas:
DISTRIB: 'conda-pip-latest'
# FIXME: pinned until SciPy wheels are available for Python 3.8
PYTHON_VERSION: '3.8'
PYTEST_VERSION: '4.6.2'
COVERAGE: 'true'
Expand Down
2 changes: 2 additions & 0 deletions build_tools/azure/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ elif [[ "$DISTRIB" == "conda-pip-latest" ]]; then
python -m pip install numpy scipy cython joblib
python -m pip install pytest==$PYTEST_VERSION pytest-cov pytest-xdist
python -m pip install pandas matplotlib pyamg
# do not install dependencies for lightgbm since it requires scikit-learn
python -m pip install lightgbm --no-deps
fi

if [[ "$COVERAGE" == "true" ]]; then
Expand Down
9 changes: 9 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ Changelog
value of the ``zero_division`` keyword argument. :pr:`15879`
by :user:`Bibhash Chandra Mitra <Bibyutatsu>`.

:mod:`sklearn.model_selection`
..............................

- |Fix| :class:`model_selection.GridSearchCV` and
:class:`model_selection.RandomizedSearchCV` accept scalar values provided in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do other things such as callable belong to "scalar values"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the change log entry understates the change, but that's okay, as this is relatively readable.

`fit_params`. Change in 0.22 was breaking backward compatibility.
:pr:`15863` by :user:`Adrin Jalali <adrinjalali>` and
:user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.utils`
....................

Expand Down
6 changes: 2 additions & 4 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..utils import check_random_state
from ..utils.fixes import MaskedArray
from ..utils.random import sample_without_replacement
from ..utils.validation import indexable, check_is_fitted
from ..utils.validation import indexable, check_is_fitted, _check_fit_params
from ..utils.metaestimators import if_delegate_has_method
from ..metrics._scorer import _check_multimetric_scoring
from ..metrics import check_scoring
Expand Down Expand Up @@ -648,9 +648,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
refit_metric = 'score'

X, y, groups = indexable(X, y, groups)
# make sure fit_params are sliceable
fit_params_values = indexable(*fit_params.values())
fit_params = dict(zip(fit_params.keys(), fit_params_values))
fit_params = _check_fit_params(X, fit_params)

n_splits = cv.get_n_splits(X, y, groups)

Expand Down
17 changes: 3 additions & 14 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..base import is_classifier, clone
from ..utils import (indexable, check_random_state, _safe_indexing,
_message_with_time)
from ..utils.validation import _check_fit_params
from ..utils.validation import _is_arraylike, _num_samples
from ..utils.metaestimators import _safe_split
from ..metrics import check_scoring
Expand Down Expand Up @@ -489,8 +490,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,

# Adjust length of sample weights
fit_params = fit_params if fit_params is not None else {}
fit_params = {k: _index_param_value(X, v, train)
for k, v in fit_params.items()}
fit_params = _check_fit_params(X, fit_params, train)

train_scores = {}
if parameters is not None:
Expand Down Expand Up @@ -830,8 +830,7 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
"""
# Adjust length of sample weights
fit_params = fit_params if fit_params is not None else {}
fit_params = {k: _index_param_value(X, v, train)
for k, v in fit_params.items()}
fit_params = _check_fit_params(X, fit_params, train)

X_train, y_train = _safe_split(estimator, X, y, train)
X_test, _ = _safe_split(estimator, X, y, test, train)
Expand Down Expand Up @@ -937,16 +936,6 @@ def _check_is_permutation(indices, n_samples):
return True


def _index_param_value(X, v, indices):
"""Private helper function for parameter value indexing."""
if not _is_arraylike(v) or _num_samples(v) != _num_samples(X):
# pass through: skip indexing
return v
if sp.issparse(v):
v = v.tocsr()
return _safe_indexing(v, indices)


def permutation_test_score(estimator, X, y, groups=None, cv=None,
n_permutations=100, n_jobs=None, random_state=0,
verbose=0, scoring=None):
Expand Down
110 changes: 89 additions & 21 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from scipy.stats import bernoulli, expon, uniform

from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.base import clone
from sklearn.exceptions import NotFittedError
from sklearn.datasets import make_classification
Expand All @@ -36,6 +36,7 @@

from sklearn.model_selection import fit_grid_point
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedShuffleSplit
Expand Down Expand Up @@ -218,33 +219,25 @@ def test_grid_search_pipeline_steps():
assert not hasattr(param_grid['regressor'][1], 'coef_')


def check_hyperparameter_searcher_with_fit_params(klass, **klass_kwargs):
@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
def test_SearchCV_with_fit_params(SearchCV):
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
clf = CheckingClassifier(expected_fit_params=['spam', 'eggs'])
searcher = klass(clf, {'foo_param': [1, 2, 3]}, cv=2, **klass_kwargs)
searcher = SearchCV(
clf, {'foo_param': [1, 2, 3]}, cv=2, error_score="raise"
)

# The CheckingClassifier generates an assertion error if
# a parameter is missing or has length != len(X).
assert_raise_message(AssertionError,
"Expected fit parameter(s) ['eggs'] not seen.",
searcher.fit, X, y, spam=np.ones(10))
assert_raise_message(
ValueError,
"Found input variables with inconsistent numbers of samples: [",
searcher.fit, X, y, spam=np.ones(1),
eggs=np.zeros(10))
searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))
err_msg = r"Expected fit parameter\(s\) \['eggs'\] not seen."
with pytest.raises(AssertionError, match=err_msg):
searcher.fit(X, y, spam=np.ones(10))


def test_grid_search_with_fit_params():
check_hyperparameter_searcher_with_fit_params(GridSearchCV,
error_score='raise')


def test_random_search_with_fit_params():
check_hyperparameter_searcher_with_fit_params(RandomizedSearchCV, n_iter=1,
error_score='raise')
err_msg = "Fit parameter spam has length 1; expected"
with pytest.raises(AssertionError, match=err_msg):
searcher.fit(X, y, spam=np.ones(1), eggs=np.zeros(10))
searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))


@ignore_warnings
Expand Down Expand Up @@ -1846,3 +1839,78 @@ def test_search_cv__pairwise_property_equivalence_of_precomputed():

attr_message = "GridSearchCV not identical with precomputed metric"
assert (preds_original == preds_precomputed).all(), attr_message


@pytest.mark.parametrize(
"SearchCV, param_search",
[(GridSearchCV, {'a': [0.1, 0.01]}),
(RandomizedSearchCV, {'a': uniform(1, 3)})]
)
def test_scalar_fit_param(SearchCV, param_search):
# unofficially sanctioned tolerance for scalar values in fit_params
# non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15805
class TestEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, a=None):
self.a = a

def fit(self, X, y, r=None):
self.r_ = r

def predict(self, X):
return np.zeros(shape=(len(X)))

model = SearchCV(TestEstimator(), param_search)
X, y = make_classification(random_state=42)
model.fit(X, y, r=42)
assert model.best_estimator_.r_ == 42


@pytest.mark.parametrize(
"SearchCV, param_search",
[(GridSearchCV, {'alpha': [0.1, 0.01]}),
(RandomizedSearchCV, {'alpha': uniform(0.01, 0.1)})]
)
def test_scalar_fit_param_compat(SearchCV, param_search):
# check support for scalar values in fit_params, for instance in LightGBM
# that do not exactly respect the scikit-learn API contract but that we do
# not want to break without an explicit deprecation cycle and API
# recommendations for implementing early stopping with a user provided
# validation set. non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15805
X_train, X_valid, y_train, y_valid = train_test_split(
*make_classification(random_state=42), random_state=42
)

class _FitParamClassifier(SGDClassifier):

def fit(self, X, y, sample_weight=None, tuple_of_arrays=None,
scalar_param=None, callable_param=None):
super().fit(X, y, sample_weight=sample_weight)
assert scalar_param > 0
assert callable(callable_param)

# The tuple of arrays should be preserved as tuple.
assert isinstance(tuple_of_arrays, tuple)
assert tuple_of_arrays[0].ndim == 2
assert tuple_of_arrays[1].ndim == 1
return self

def _fit_param_callable():
pass

model = SearchCV(
_FitParamClassifier(), param_search
)

# NOTE: `fit_params` should be data dependent (e.g. `sample_weight`) which
# is not the case for the following parameters. But this abuse is common in
# popular third-party libraries and we should tolerate this behavior for
# now and be careful not to break support for those without following
# proper deprecation cycle.
fit_params = {
'tuple_of_arrays': (X_valid, y_valid),
'callable_param': _fit_param_callable,
'scalar_param': 42,
}
model.fit(X_train, y_train, **fit_params)
30 changes: 30 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.datasets import make_blobs
from sklearn.utils import _safe_indexing
from sklearn.utils.validation import (
has_fit_parameter,
check_is_fitted,
Expand All @@ -46,6 +47,7 @@
_check_sample_weight,
_allclose_dense_sparse,
FLOAT_DTYPES)
from sklearn.utils.validation import _check_fit_params

import sklearn

Expand Down Expand Up @@ -1053,3 +1055,31 @@ def __init__(self, a=1, b=1, *, c=1, d=1):
with pytest.warns(FutureWarning,
match=r"Pass c=3, d=4 as keyword args"):
A2(1, 2, 3, 4)


@pytest.mark.parametrize("indices", [None, [1, 3]])
def test_check_fit_params(indices):
X = np.random.randn(4, 2)
fit_params = {
'list': [1, 2, 3, 4],
'array': np.array([1, 2, 3, 4]),
'sparse-col': sp.csc_matrix([1, 2, 3, 4]).T,
'sparse-row': sp.csc_matrix([1, 2, 3, 4]),
'scalar-int': 1,
'scalar-str': 'xxx',
'None': None,
}
result = _check_fit_params(X, fit_params, indices)
indices_ = indices if indices is not None else list(range(X.shape[0]))

for key in ['sparse-row', 'scalar-int', 'scalar-str', 'None']:
assert result[key] is fit_params[key]

assert result['list'] == _safe_indexing(fit_params['list'], indices_)
assert_array_equal(
result['array'], _safe_indexing(fit_params['array'], indices_)
)
assert_allclose_dense_sparse(
result['sparse-col'],
_safe_indexing(fit_params['sparse-col'], indices_)
)
69 changes: 59 additions & 10 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,26 @@ def check_consistent_length(*arrays):
" samples: %r" % [int(l) for l in lengths])


def _make_indexable(iterable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't have two functions indexable and _make_indexable, especially not if both are used outside of this module.

Keep indexable as the only API and make this local to indexable, e.g. def _check(iterable):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK but how do I reuse _check in _check_fit_params. The idea was to avoid code duplication in indexable and _check_fit_params.

"""Ensure iterable supports indexing or convert to an indexable variant.

Convert sparse matrices to csr and other non-indexable iterable to arrays.
Let `None` and indexable objects (e.g. pandas dataframes) pass unchanged.

Parameters
----------
iterable : {list, dataframe, array, sparse} or None
Object to be converted to an indexable iterable.
"""
if sp.issparse(iterable):
return iterable.tocsr()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to convert to csr? I think if the estimator needs to convert the param, they'll do it themselves.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward-compatibility only (we were doing it before). I assume that csr would be a good default since we are converting to csr when the number of samples in the arrays are the same than in X meaning that we should be efficient taking rows.

elif hasattr(iterable, "__getitem__") or hasattr(iterable, "iloc"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I've missed something, but if the point is not to pass anything which implements __array_function__, shouldn't we test for that instead? An object may implement that protocol and implement __getitem__, can it not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you mean something like:

if sp.sparse(iterable):
    # efficient indexing per rows
    return iterable.csr()
elif hasattr(iterable, "iloc"):
    # pandas series or dataframe
    return iterable
elif hasattr(iterable, "__array_function__"):
    # do not rely on array protocol
    return np.asarray(iterable)
elif hasattr(iterable, "__getitem__"):
    return iterable
return np.asarray(iterable)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this looks better. I'd put the __array_function__ condition on top or right after sp.issparse though.

return iterable
elif iterable is None:
return iterable
return np.array(iterable)


def indexable(*iterables):
"""Make arrays indexable for cross-validation.

Expand All @@ -224,16 +244,7 @@ def indexable(*iterables):
*iterables : lists, dataframes, arrays, sparse matrices
List of objects to ensure sliceability.
"""
result = []
for X in iterables:
if sp.issparse(X):
result.append(X.tocsr())
elif hasattr(X, "__getitem__") or hasattr(X, "iloc"):
result.append(X)
elif X is None:
result.append(X)
else:
result.append(np.array(X))
result = [_make_indexable(X) for X in iterables]
check_consistent_length(*result)
return result

Expand Down Expand Up @@ -1259,3 +1270,41 @@ def inner_f(*args, **kwargs):
kwargs.update({k: arg for k, arg in zip(all_args, args)})
return f(**kwargs)
return inner_f


def _check_fit_params(X, fit_params, indices=None):
"""Check and validate the parameters passed during `fit`.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Data array.

fit_params : dict
Dictionary containing the parameters passed at fit.

indices : array-like of shape (n_samples,), default=None
Indices to be selected if the parameter has the same size as `X`.

Returns
-------
fit_params_validated : dict
Validated parameters. We ensure that the values support indexing.
"""
from . import _safe_indexing
fit_params_validated = {}
for param_key, param_value in fit_params.items():
if (not _is_arraylike(param_value) or
_num_samples(param_value) != _num_samples(X)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we also support non-sample-aligned fit params? I though we wanted to support only the scalars for now.

Also, shouldn't we pass all non-scalars to _make_indexable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we also support non-sample-aligned fit params

At least it seems we were supporting it. This is the code used within the cross-validation originally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we also support non-sample-aligned fit params? I though we wanted to support only the scalars for now.

No, we want to be backward compatible, which means supporting non-sample-aligned fit params

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, but we should still make sure the non-sample-aligned params are also not implementing __array_function__, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let the downstream estimator deal with it for unaligned parameters I reckon. Pass them untouched

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this specific PR I think we should just restore compatibility with the prior behavior.

For the future API, I am not so sure but I think I would also be in favor of passing the params untouched?

# Non-indexable pass-through (for now for backward-compatibility).
# https://github.com/scikit-learn/scikit-learn/issues/15805
fit_params_validated[param_key] = param_value
else:
# Any other fit_params should support indexing
# (e.g. for cross-validation).
fit_params_validated[param_key] = _make_indexable(param_value)
fit_params_validated[param_key] = _safe_indexing(
fit_params_validated[param_key], indices
)

return fit_params_validated