-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX support scalar values in fit_params in SearchCV #15863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0492836
1ae9d03
341f8e0
53d7a91
ef64f0b
d5f0ba0
9567b44
f340ab6
ffb7ce5
2b5b1db
ffbac6f
c0216dc
4693729
d7a2c19
52ecee4
be69ce0
9b71a9c
46c4b9f
71fab3f
9a85162
f41c808
c989c70
570dfa8
444c947
9f47b58
75bd0a9
c24f39d
63679fd
849615b
7837cdf
b98e194
3127d2b
a096a7d
18b1207
74d70e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,6 +212,26 @@ def check_consistent_length(*arrays): | |
" samples: %r" % [int(l) for l in lengths]) | ||
|
||
|
||
def _make_indexable(iterable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can't have two functions Keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK but how do I reuse |
||
"""Ensure iterable supports indexing or convert to an indexable variant. | ||
|
||
Convert sparse matrices to csr and other non-indexable iterable to arrays. | ||
Let `None` and indexable objects (e.g. pandas dataframes) pass unchanged. | ||
|
||
Parameters | ||
---------- | ||
iterable : {list, dataframe, array, sparse} or None | ||
Object to be converted to an indexable iterable. | ||
""" | ||
if sp.issparse(iterable): | ||
return iterable.tocsr() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have to convert to csr? I think if the estimator needs to convert the param, they'll do it themselves. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For backward-compatibility only (we were doing it before). I assume that |
||
elif hasattr(iterable, "__getitem__") or hasattr(iterable, "iloc"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry if I've missed something, but if the point is not to pass anything which implements There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you mean something like: if sp.sparse(iterable):
# efficient indexing per rows
return iterable.csr()
elif hasattr(iterable, "iloc"):
# pandas series or dataframe
return iterable
elif hasattr(iterable, "__array_function__"):
# do not rely on array protocol
return np.asarray(iterable)
elif hasattr(iterable, "__getitem__"):
return iterable
return np.asarray(iterable) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah this looks better. I'd put the |
||
return iterable | ||
elif iterable is None: | ||
return iterable | ||
return np.array(iterable) | ||
|
||
|
||
def indexable(*iterables): | ||
"""Make arrays indexable for cross-validation. | ||
|
||
|
@@ -224,16 +244,7 @@ def indexable(*iterables): | |
*iterables : lists, dataframes, arrays, sparse matrices | ||
List of objects to ensure sliceability. | ||
""" | ||
result = [] | ||
for X in iterables: | ||
if sp.issparse(X): | ||
result.append(X.tocsr()) | ||
elif hasattr(X, "__getitem__") or hasattr(X, "iloc"): | ||
result.append(X) | ||
elif X is None: | ||
result.append(X) | ||
else: | ||
result.append(np.array(X)) | ||
result = [_make_indexable(X) for X in iterables] | ||
check_consistent_length(*result) | ||
return result | ||
|
||
|
@@ -1259,3 +1270,41 @@ def inner_f(*args, **kwargs): | |
kwargs.update({k: arg for k, arg in zip(all_args, args)}) | ||
return f(**kwargs) | ||
return inner_f | ||
|
||
|
||
def _check_fit_params(X, fit_params, indices=None): | ||
"""Check and validate the parameters passed during `fit`. | ||
|
||
Parameters | ||
---------- | ||
X : array-like of shape (n_samples, n_features) | ||
Data array. | ||
|
||
fit_params : dict | ||
Dictionary containing the parameters passed at fit. | ||
|
||
indices : array-like of shape (n_samples,), default=None | ||
Indices to be selected if the parameter has the same size as `X`. | ||
|
||
Returns | ||
------- | ||
fit_params_validated : dict | ||
Validated parameters. We ensure that the values support indexing. | ||
""" | ||
from . import _safe_indexing | ||
fit_params_validated = {} | ||
for param_key, param_value in fit_params.items(): | ||
if (not _is_arraylike(param_value) or | ||
_num_samples(param_value) != _num_samples(X)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we also support non-sample-aligned fit params? I though we wanted to support only the scalars for now. Also, shouldn't we pass all non-scalars to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
At least it seems we were supporting it. This is the code used within the cross-validation originally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No, we want to be backward compatible, which means supporting non-sample-aligned fit params There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair, but we should still make sure the non-sample-aligned params are also not implementing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let the downstream estimator deal with it for unaligned parameters I reckon. Pass them untouched There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this specific PR I think we should just restore compatibility with the prior behavior. For the future API, I am not so sure but I think I would also be in favor of passing the params untouched? |
||
# Non-indexable pass-through (for now for backward-compatibility). | ||
# https://github.com/scikit-learn/scikit-learn/issues/15805 | ||
fit_params_validated[param_key] = param_value | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
# Any other fit_params should support indexing | ||
# (e.g. for cross-validation). | ||
fit_params_validated[param_key] = _make_indexable(param_value) | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fit_params_validated[param_key] = _safe_indexing( | ||
fit_params_validated[param_key], indices | ||
) | ||
|
||
return fit_params_validated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do other things such as callable belong to "scalar values"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the change log entry understates the change, but that's okay, as this is relatively readable.