-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Callbacks API #16925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Callbacks API #16925
Changes from all commits
b98f4c1
1fcc07e
7365435
0c5cfaa
7732ece
f236274
69e7255
0e78233
40135f1
9cf1272
3ce8771
6012f6d
1c5cd13
6ba0fe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# License: BSD 3 clause | ||
from typing import List, Callable, Optional | ||
from abc import ABC, abstractmethod | ||
|
||
import numpy as np | ||
|
||
CALLBACK_PARAM_TYPES = { | ||
"n_iter": int, | ||
"max_iter": int, | ||
"loss": (float, dict), | ||
"score": (float, dict), | ||
"validation_loss": (float, dict), | ||
"validation_score": (float, dict), | ||
"coef": np.ndarray, | ||
"intercept": (np.ndarray, float), | ||
} | ||
|
||
|
||
def _check_callback_params(**kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we let each callback independently validate its data? My question might not make sense but I don't see this being used anywhere except in the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, absolutely each callback validates its data. But we also need to enforce that callbacks do follow the documented API in tests. For instance, that no undocumented parameters are passed etc which requires this function. Third party callbacks could also use this validations function, similarly to how we expose |
||
invalid_params = [] | ||
invalid_types = [] | ||
for key, val in kwargs.items(): | ||
if key not in CALLBACK_PARAM_TYPES: | ||
invalid_params.append(key) | ||
else: | ||
val_types = CALLBACK_PARAM_TYPES[key] | ||
if not isinstance(val, val_types): | ||
invalid_types.append(f"{key}={val} is not of type {val_types}") | ||
msg = "" | ||
if invalid_params: | ||
msg += ("Invalid callback parameters: {}, must be one of {}. ").format( | ||
", ".join(invalid_params), ", ".join(CALLBACK_PARAM_TYPES.keys()) | ||
) | ||
if invalid_types: | ||
msg += "Invalid callback parameters: " + ", ".join(invalid_types) | ||
if msg: | ||
raise ValueError(msg) | ||
|
||
|
||
def _eval_callbacks( | ||
callbacks: Optional[List[Callable]], method="on_iter_end", **kwargs | ||
) -> None: | ||
if callbacks is None: | ||
return | ||
Comment on lines
+43
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we rely on Or maybe you are anticipating a future where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we switched to the case when |
||
|
||
for callback in callbacks: | ||
getattr(callback, method)(**kwargs) | ||
|
||
|
||
class BaseCallback(ABC): | ||
@abstractmethod | ||
def on_fit_begin(self, estimator, X, y) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def on_iter_end(self, **kwargs) -> None: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
from . import __version__ | ||
from ._config import get_config | ||
from sklearn._callbacks import BaseCallback | ||
from .utils import _IS_32BIT | ||
from .utils.validation import check_X_y | ||
from .utils.validation import check_array | ||
|
@@ -84,6 +85,11 @@ def clone(estimator, *, safe=True): | |
new_object = klass(**new_object_params) | ||
params_set = new_object.get_params(deep=False) | ||
|
||
# copy callbacks | ||
if hasattr(estimator, "_callbacks"): | ||
# TODO: do we need to use the recusive setter here? | ||
new_object._callbacks = estimator._callbacks | ||
|
||
# quick sanity check of the parameters of the clone | ||
for name in new_object_params: | ||
param1 = new_object_params[name] | ||
|
@@ -406,6 +412,7 @@ def _validate_data(self, X, y=None, reset=True, | |
out : {ndarray, sparse matrix} or tuple of these | ||
The validated input. A tuple is returned if `y` is not None. | ||
""" | ||
self._eval_callbacks('on_fit_begin', X=X, y=y) | ||
|
||
if y is None: | ||
if self._get_tags()['requires_y']: | ||
|
@@ -433,6 +440,46 @@ def _validate_data(self, X, y=None, reset=True, | |
|
||
return out | ||
|
||
def _set_callbacks(self, callbacks, deep: bool = True): | ||
"""Set callbacks for the estimator. | ||
|
||
Parameters | ||
---------- | ||
callbacks : callback or list of callbacks | ||
the callbacks to set. | ||
|
||
deep: bool=True | ||
If True, in the case of meta-estmators, callbacks are also set | ||
recursively for all child estimators. | ||
""" | ||
if isinstance(callbacks, BaseCallback): | ||
self._callbacks = [callbacks] | ||
else: | ||
self._callbacks = callbacks | ||
|
||
if not deep: | ||
return | ||
|
||
# set callbacks recursively | ||
for attr_name in getattr(self, "_required_parameters", []): | ||
# likely a meta-estimator | ||
if attr_name in ['steps', 'transformers']: | ||
for attr in getattr(self, attr_name): | ||
if isinstance(attr, BaseEstimator): | ||
attr._set_callbacks(callbacks) | ||
elif (hasattr(attr, '__len__') | ||
and len(attr) >= 2 | ||
and isinstance(attr[1], BaseEstimator)): | ||
attr[1]._set_callbacks(callbacks) | ||
|
||
def _eval_callbacks(self, method='on_iter_end', **kwargs): | ||
"""Call callbacks, e.g. in each iteration of an iterative solver""" | ||
from ._callbacks import _eval_callbacks | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why lazy import? |
||
|
||
callbacks = getattr(self, '_callbacks', None) | ||
|
||
_eval_callbacks(callbacks, method=method, estimator=self, **kwargs) | ||
|
||
@property | ||
def _repr_html_(self): | ||
"""HTML representation of estimator. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sneaky :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't see the point in manually formatting code anymore for new files. It shouldn't hurt even if we are not using everywhere..