diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 186040b32ebd8..4fb185cf8a846 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -5,9 +5,9 @@ Developing scikit-learn estimators ================================== Whether you are proposing an estimator for inclusion in scikit-learn, -developing a separate package compatible with scikit-learn, or -implementing custom components for your own projects, this chapter -details how to develop objects that safely interact with scikit-learn +developing a separate package compatible with scikit-learn, or +implementing custom components for your own projects, this chapter +details how to develop objects that safely interact with scikit-learn Pipelines and model selection tools. .. currentmodule:: sklearn @@ -576,10 +576,10 @@ closed-form solutions. Coding guidelines ================= -The following are some guidelines on how new code should be written for -inclusion in scikit-learn, and which may be appropriate to adopt in external -projects. Of course, there are special cases and there will be exceptions to -these rules. However, following these rules when submitting new code makes +The following are some guidelines on how new code should be written for +inclusion in scikit-learn, and which may be appropriate to adopt in external +projects. Of course, there are special cases and there will be exceptions to +these rules. However, following these rules when submitting new code makes the review easier so new code can be integrated in less time. Uniformly formatted code makes it easier to share code ownership. The @@ -709,3 +709,95 @@ The reason for this setup is reproducibility: when an estimator is ``fit`` twice to the same data, it should produce an identical model both times, hence the validation in ``fit``, not ``__init__``. + +Estimator callbacks +=================== + +.. note:: Callbacks are currently experimental and the API can change without + notice. + +To add (optional) support of callbacks, for instance to support progress +bars or monitoring convergence, the estimator must implement the following +points: + +- At the beginning of ``fit`` either explicitly call + ``self._eval_callbacks(method='on_fit_begin', X=X, y=y)`` + or use ``self._validate_data(X, y)`` which makes this call internally. +- For iterative solvers call ``self._eval_callbacks(n_iter=.., **kwargs)`` at + each iteration, where ``kwargs`` keys must be part of supported callback + arguments (cf. list below). + +User defined callbacks must extend the ``sklearn._callbacks.BaseCallback`` +abstract base class and define the following methods, + +.. code:: python + + from sklearn._callbacks import BaseCallback + + class MyCallback(BaseCallback): + + def on_fit_begin(self, estimator, X, y): + ... + + def on_iter_end(self, **kwargs): + ... + + +For more details, see the Callback API documentation below. + +Callback API +------------ + +on_fit_begin method +^^^^^^^^^^^^^^^^^^^ + +Parameters +"""""""""" + +X: array_like + Training data + +y: array_like + Target values + +Returns +""""""" + +Any. The return value is ignored. + +on_iter_end method +^^^^^^^^^^^^^^^^^^ + +Parameters +"""""""""" + +n_iter: int + current iteration number for iterative solvers. + +max_iter: int + maximum number of iterations for iterative solvers. If the estimator + has a ``max_iter`` init parameter, this will be inferred. + +loss: float or ordered dict + cost function value or error at a given iteration. When ordered dict, + multiple loss functions can given, with the default loss being the first + element. Lower is better. + +score: float or ordered dict + same as ``loss`` parameter, but for evaluation metrics. Higher is better. + +validation_loss: float or ordered dict + cost function value or error at a given iteration, evaluated on the + validation set. + +validation_score: float or ordered dict + same as ``validation_loss`` parameter, but for evaluation metrics. Higher is + better. + +coef: ndarray + coefficients of linear models. + +Returns +""""""" + +Any. The return value is ignored. diff --git a/pyproject.toml b/pyproject.toml index c302626101c71..3520b3fe0f3c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,3 +13,6 @@ requires = [ "numpy==1.17.3; python_version>='3.8' and platform_system=='AIX'", "scipy>=0.19.1", ] + +[tool.black] +line-length = 79 diff --git a/sklearn/_callbacks.py b/sklearn/_callbacks.py new file mode 100644 index 0000000000000..b3687124e9b7e --- /dev/null +++ b/sklearn/_callbacks.py @@ -0,0 +1,57 @@ +# License: BSD 3 clause +from typing import List, Callable, Optional +from abc import ABC, abstractmethod + +import numpy as np + +CALLBACK_PARAM_TYPES = { + "n_iter": int, + "max_iter": int, + "loss": (float, dict), + "score": (float, dict), + "validation_loss": (float, dict), + "validation_score": (float, dict), + "coef": np.ndarray, + "intercept": (np.ndarray, float), +} + + +def _check_callback_params(**kwargs): + invalid_params = [] + invalid_types = [] + for key, val in kwargs.items(): + if key not in CALLBACK_PARAM_TYPES: + invalid_params.append(key) + else: + val_types = CALLBACK_PARAM_TYPES[key] + if not isinstance(val, val_types): + invalid_types.append(f"{key}={val} is not of type {val_types}") + msg = "" + if invalid_params: + msg += ("Invalid callback parameters: {}, must be one of {}. ").format( + ", ".join(invalid_params), ", ".join(CALLBACK_PARAM_TYPES.keys()) + ) + if invalid_types: + msg += "Invalid callback parameters: " + ", ".join(invalid_types) + if msg: + raise ValueError(msg) + + +def _eval_callbacks( + callbacks: Optional[List[Callable]], method="on_iter_end", **kwargs +) -> None: + if callbacks is None: + return + + for callback in callbacks: + getattr(callback, method)(**kwargs) + + +class BaseCallback(ABC): + @abstractmethod + def on_fit_begin(self, estimator, X, y) -> None: + pass + + @abstractmethod + def on_iter_end(self, **kwargs) -> None: + pass diff --git a/sklearn/base.py b/sklearn/base.py index 46398baabfd3a..380c0358fe17a 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -14,6 +14,7 @@ from . import __version__ from ._config import get_config +from sklearn._callbacks import BaseCallback from .utils import _IS_32BIT from .utils.validation import check_X_y from .utils.validation import check_array @@ -84,6 +85,11 @@ def clone(estimator, *, safe=True): new_object = klass(**new_object_params) params_set = new_object.get_params(deep=False) + # copy callbacks + if hasattr(estimator, "_callbacks"): + # TODO: do we need to use the recusive setter here? + new_object._callbacks = estimator._callbacks + # quick sanity check of the parameters of the clone for name in new_object_params: param1 = new_object_params[name] @@ -406,6 +412,7 @@ def _validate_data(self, X, y=None, reset=True, out : {ndarray, sparse matrix} or tuple of these The validated input. A tuple is returned if `y` is not None. """ + self._eval_callbacks('on_fit_begin', X=X, y=y) if y is None: if self._get_tags()['requires_y']: @@ -433,6 +440,46 @@ def _validate_data(self, X, y=None, reset=True, return out + def _set_callbacks(self, callbacks, deep: bool = True): + """Set callbacks for the estimator. + + Parameters + ---------- + callbacks : callback or list of callbacks + the callbacks to set. + + deep: bool=True + If True, in the case of meta-estmators, callbacks are also set + recursively for all child estimators. + """ + if isinstance(callbacks, BaseCallback): + self._callbacks = [callbacks] + else: + self._callbacks = callbacks + + if not deep: + return + + # set callbacks recursively + for attr_name in getattr(self, "_required_parameters", []): + # likely a meta-estimator + if attr_name in ['steps', 'transformers']: + for attr in getattr(self, attr_name): + if isinstance(attr, BaseEstimator): + attr._set_callbacks(callbacks) + elif (hasattr(attr, '__len__') + and len(attr) >= 2 + and isinstance(attr[1], BaseEstimator)): + attr[1]._set_callbacks(callbacks) + + def _eval_callbacks(self, method='on_iter_end', **kwargs): + """Call callbacks, e.g. in each iteration of an iterative solver""" + from ._callbacks import _eval_callbacks + + callbacks = getattr(self, '_callbacks', None) + + _eval_callbacks(callbacks, method=method, estimator=self, **kwargs) + @property def _repr_html_(self): """HTML representation of estimator. diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index f148633021a97..7f4a8f713a5c6 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -516,6 +516,7 @@ def fit_transform(self, X, y=None): sparse matrices. """ + self._eval_callbacks(method='on_fit_begin', X=X, y=y) # TODO: this should be `feature_names_in_` when we start having it if hasattr(X, "columns"): self._feature_names_in = np.asarray(X.columns) diff --git a/sklearn/decomposition/_factor_analysis.py b/sklearn/decomposition/_factor_analysis.py index cc0178b70e447..e202dde6f2aba 100644 --- a/sklearn/decomposition/_factor_analysis.py +++ b/sklearn/decomposition/_factor_analysis.py @@ -236,6 +236,7 @@ def my_svd(X): old_ll = ll psi = np.maximum(var - np.sum(W ** 2, axis=0), SMALL) + self._eval_callbacks(n_iter=i) else: warnings.warn('FactorAnalysis did not converge.' + ' You might want' + diff --git a/sklearn/decomposition/_incremental_pca.py b/sklearn/decomposition/_incremental_pca.py index c544145a7922d..940be32d209d0 100644 --- a/sklearn/decomposition/_incremental_pca.py +++ b/sklearn/decomposition/_incremental_pca.py @@ -206,12 +206,14 @@ def fit(self, X, y=None): else: self.batch_size_ = self.batch_size - for batch in gen_batches(n_samples, self.batch_size_, - min_batch_size=self.n_components or 0): + for n_batch, batch in enumerate( + gen_batches(n_samples, self.batch_size_, + min_batch_size=self.n_components or 0)): X_batch = X[batch] if sparse.issparse(X_batch): X_batch = X_batch.toarray() self.partial_fit(X_batch, check_input=False) + self._eval_callbacks(n_iter=n_batch) return self diff --git a/sklearn/decomposition/_lda.py b/sklearn/decomposition/_lda.py index a6e253aab1e6e..84dadc93c7eac 100644 --- a/sklearn/decomposition/_lda.py +++ b/sklearn/decomposition/_lda.py @@ -464,6 +464,7 @@ def _em_step(self, X, total_samples, batch_update, parallel=None): self.exp_dirichlet_component_ = np.exp( _dirichlet_expectation_2d(self.components_)) self.n_batch_iter_ += 1 + self._eval_callbacks() return def _more_tags(self): diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index c2405a7dea5fc..2f72efcac043a 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -19,6 +19,7 @@ from ..utils import check_random_state, check_array from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm from ..utils.validation import check_is_fitted, check_non_negative +from .._callbacks import _eval_callbacks from ..utils.validation import _deprecate_positional_args EPSILON = np.finfo(np.float32).eps @@ -426,7 +427,8 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle, def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0, l1_reg_H=0, l2_reg_W=0, l2_reg_H=0, update_H=True, - verbose=0, shuffle=False, random_state=None): + verbose=0, shuffle=False, random_state=None, + callbacks=None): """Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent The objective function is minimized with an alternating minimization of W @@ -522,6 +524,10 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0, if verbose: print("violation:", violation / violation_init) + _eval_callbacks(callbacks, n_iter=n_iter, + tol=violation/violation_init, + error=violation) + if violation / violation_init <= tol: if verbose: print("Converged at iteration", n_iter + 1) @@ -710,7 +716,7 @@ def _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H, l2_reg_H, gamma): def _fit_multiplicative_update(X, W, H, beta_loss='frobenius', max_iter=200, tol=1e-4, l1_reg_W=0, l1_reg_H=0, l2_reg_W=0, l2_reg_H=0, - update_H=True, verbose=0): + update_H=True, verbose=0, callbacks=None): """Compute Non-negative Matrix Factorization with Multiplicative Update The objective function is _beta_divergence(X, WH) and is minimized with an @@ -828,6 +834,9 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius', print("Epoch %02d reached after %.3f seconds, error: %f" % (n_iter, iter_time - start_time, error)) + _eval_callbacks(callbacks, n_iter=n_iter, error=error, + tol=(previous_error - error) / error_at_init) + if (previous_error - error) / error_at_init < tol: break previous_error = error @@ -847,7 +856,7 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *, beta_loss='frobenius', tol=1e-4, max_iter=200, alpha=0., l1_ratio=0., regularization=None, random_state=None, - verbose=0, shuffle=False): + verbose=0, shuffle=False, callbacks=None): r"""Compute Non-negative Matrix Factorization (NMF) Find two non-negative matrices (W, H) whose product approximates the non- @@ -1062,12 +1071,13 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *, update_H=update_H, verbose=verbose, shuffle=shuffle, - random_state=random_state) + random_state=random_state, + callbacks=callbacks) elif solver == 'mu': W, H, n_iter = _fit_multiplicative_update(X, W, H, beta_loss, max_iter, tol, l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H, update_H, - verbose) + verbose, callbacks=callbacks) else: raise ValueError("Invalid solver parameter '%s'." % solver) @@ -1286,7 +1296,7 @@ def fit_transform(self, X, y=None, W=None, H=None): tol=self.tol, max_iter=self.max_iter, alpha=self.alpha, l1_ratio=self.l1_ratio, regularization='both', random_state=self.random_state, verbose=self.verbose, - shuffle=self.shuffle) + shuffle=self.shuffle, callbacks=getattr(self, "_callbacks", None)) self.reconstruction_err_ = _beta_divergence(X, W, H, self.beta_loss, square_root=True) @@ -1335,7 +1345,7 @@ def transform(self, X): beta_loss=self.beta_loss, tol=self.tol, max_iter=self.max_iter, alpha=self.alpha, l1_ratio=self.l1_ratio, regularization='both', random_state=self.random_state, verbose=self.verbose, - shuffle=self.shuffle) + shuffle=self.shuffle, callbacks=getattr(self, '_callbacks', None)) return W diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index fa305d273e857..458010c1973e5 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -491,6 +491,8 @@ def _fit_full(self, X, n_components): explained_variance_ratio_[:n_components] self.singular_values_ = singular_values_[:n_components] + self._eval_callbacks() + return U, S, Vt def _fit_truncated(self, X, n_components, svd_solver): @@ -537,12 +539,17 @@ def _fit_truncated(self, X, n_components, svd_solver): # flip eigenvectors' sign to enforce deterministic output U, Vt = svd_flip(U[:, ::-1], Vt[::-1]) + self._eval_callbacks() + elif svd_solver == 'randomized': # sign flipping is done inside - U, S, Vt = randomized_svd(X, n_components=n_components, - n_iter=self.iterated_power, - flip_sign=True, - random_state=random_state) + U, S, Vt = randomized_svd( + X, n_components=n_components, + n_iter=self.iterated_power, + flip_sign=True, + random_state=random_state, + callbacks=getattr(self, '_callbacks', []) + ) self.n_samples_, self.n_features_ = n_samples, n_features self.components_ = Vt diff --git a/sklearn/decomposition/_truncated_svd.py b/sklearn/decomposition/_truncated_svd.py index ef723ac25d4d3..50207cd19652b 100644 --- a/sklearn/decomposition/_truncated_svd.py +++ b/sklearn/decomposition/_truncated_svd.py @@ -170,6 +170,7 @@ def fit_transform(self, X, y=None): # conventions, so reverse its outputs. Sigma = Sigma[::-1] U, VT = svd_flip(U[:, ::-1], VT[::-1]) + self._eval_callbacks() elif self.algorithm == "randomized": k = self.n_components @@ -179,7 +180,9 @@ def fit_transform(self, X, y=None): " got %d >= %d" % (k, n_features)) U, Sigma, VT = randomized_svd(X, self.n_components, n_iter=self.n_iter, - random_state=random_state) + random_state=random_state, + callbacks=getattr( + self, '_callbacks', [])) else: raise ValueError("unknown algorithm %r" % self.algorithm) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index db03b06328799..f132e0d7e53d0 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -34,6 +34,7 @@ from ..utils.fixes import _joblib_parallel_args from ..model_selection import check_cv from ..metrics import get_scorer +from .._callbacks import _eval_callbacks _LOGISTIC_SOLVER_CONVERGENCE_MSG = ( @@ -481,7 +482,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, intercept_scaling=1., multi_class='auto', random_state=None, check_input=True, max_squared_sum=None, sample_weight=None, - l1_ratio=None): + l1_ratio=None, callbacks=None): """Compute a Logistic Regression model for a list of regularization parameters. @@ -757,7 +758,8 @@ def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1] opt_res = optimize.minimize( func, w0, method="L-BFGS-B", jac=True, args=(X, target, 1. / C, sample_weight), - options={"iprint": iprint, "gtol": tol, "maxiter": max_iter} + options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}, + callback=lambda xk: _eval_callbacks(callbacks, coef=xk) ) n_iter_i = _check_optimize_result( solver, opt_res, max_iter, @@ -766,12 +768,14 @@ def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1] elif solver == 'newton-cg': args = (X, target, 1. / C, sample_weight) w0, n_iter_i = _newton_cg(hess, func, grad, w0, args=args, - maxiter=max_iter, tol=tol) + maxiter=max_iter, tol=tol, + callbacks=callbacks) elif solver == 'liblinear': coef_, intercept_, n_iter_i, = _fit_liblinear( X, target, C, fit_intercept, intercept_scaling, None, penalty, dual, verbose, max_iter, tol, random_state, sample_weight=sample_weight) + _eval_callbacks(callbacks) if fit_intercept: w0 = np.concatenate([coef_.ravel(), intercept_]) else: @@ -798,7 +802,8 @@ def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1] X, target, sample_weight, loss, alpha, beta, max_iter, tol, verbose, random_state, False, max_squared_sum, warm_start_sag, - is_saga=(solver == 'saga')) + is_saga=(solver == 'saga'), + callbacks=callbacks) else: raise ValueError("solver must be one of {'liblinear', 'lbfgs', " @@ -826,7 +831,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, dual=False, intercept_scaling=1., multi_class='auto', random_state=None, max_squared_sum=None, sample_weight=None, - l1_ratio=None): + l1_ratio=None, callbacks=None): """Computes scores across logistic_regression_path Parameters @@ -969,7 +974,8 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, multi_class=multi_class, tol=tol, verbose=verbose, dual=dual, penalty=penalty, intercept_scaling=intercept_scaling, random_state=random_state, check_input=False, - max_squared_sum=max_squared_sum, sample_weight=sample_weight) + max_squared_sum=max_squared_sum, sample_weight=sample_weight, + callbacks=callbacks) log_reg = LogisticRegression(solver=solver, multi_class=multi_class) @@ -1413,7 +1419,8 @@ def fit(self, X, y, sample_weight=None): class_weight=self.class_weight, check_input=False, random_state=self.random_state, coef=warm_start_coef_, penalty=penalty, max_squared_sum=max_squared_sum, - sample_weight=sample_weight) + sample_weight=sample_weight, + callbacks=getattr(self, '_callbacks', None)) for class_, warm_start_coef_ in zip(classes_, warm_start_coef)) fold_coefs_, _, n_iter_ = zip(*fold_coefs_) @@ -1894,7 +1901,8 @@ def fit(self, X, y, sample_weight=None): random_state=self.random_state, max_squared_sum=max_squared_sum, sample_weight=sample_weight, - l1_ratio=l1_ratio + l1_ratio=l1_ratio, + callbacks=getattr(self, '_callbacks', None) ) for label in iter_encoded_labels for train, test in folds diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index ca3fba196d6d3..62cbd7231c97c 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -20,6 +20,7 @@ from ._base import LinearClassifierMixin, LinearModel, _rescale_data from ._sag import sag_solver from ..base import RegressorMixin, MultiOutputMixin, is_classifier +from .._callbacks import _eval_callbacks from ..utils.extmath import safe_sparse_dot from ..utils.extmath import row_norms from ..utils import check_array @@ -36,7 +37,7 @@ def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3, verbose=0, - X_offset=None, X_scale=None): + X_offset=None, X_scale=None, callbacks=None): def _get_rescaled_operator(X): @@ -103,6 +104,7 @@ def _mv(x): # old scipy coefs[i], info = sp_linalg.cg(C, y_column, maxiter=max_iter, tol=tol) + _eval_callbacks(callbacks, n_iter=i, coef=coefs) if info < 0: raise ValueError("Failed with error code %d" % info) @@ -380,7 +382,8 @@ def ridge_regression(X, y, alpha, *, sample_weight=None, solver='auto', def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto', max_iter=None, tol=1e-3, verbose=0, random_state=None, return_n_iter=False, return_intercept=False, - X_scale=None, X_offset=None, check_input=True): + X_scale=None, X_offset=None, check_input=True, + callbacks=None): has_sw = sample_weight is not None @@ -451,7 +454,8 @@ def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto', tol=tol, verbose=verbose, X_offset=X_offset, - X_scale=X_scale) + X_scale=X_scale, + callbacks=callbacks) elif solver == 'lsqr': coef, n_iter = _solve_lsqr(X, y, alpha, max_iter, tol) @@ -486,7 +490,8 @@ def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto', coef_, n_iter_, _ = sag_solver( X, target.ravel(), sample_weight, 'squared', alpha_i, 0, max_iter, tol, verbose, random_state, False, max_squared_sum, - init, is_saga=solver == 'saga') + init, is_saga=solver == 'saga', + callbacks=callbacks) if return_intercept: coef[i] = coef_[:-1] intercept[i] = coef_[-1] @@ -578,7 +583,8 @@ def fit(self, X, y, sample_weight=None): X, y, alpha=self.alpha, sample_weight=sample_weight, max_iter=self.max_iter, tol=self.tol, solver='sag', random_state=self.random_state, return_n_iter=True, - return_intercept=True, check_input=False) + return_intercept=True, check_input=False, + callbacks=getattr(self, '_callbacks', None)) # add the offset which was subtracted by _preprocess_data self.intercept_ += y_offset @@ -594,7 +600,8 @@ def fit(self, X, y, sample_weight=None): X, y, alpha=self.alpha, sample_weight=sample_weight, max_iter=self.max_iter, tol=self.tol, solver=solver, random_state=self.random_state, return_n_iter=True, - return_intercept=False, check_input=False, **params) + return_intercept=False, check_input=False, + callbacks=getattr(self, '_callbacks', None), **params) self._set_intercept(X_offset, y_offset, X_scale) return self diff --git a/sklearn/linear_model/_sag.py b/sklearn/linear_model/_sag.py index caa9b2d133003..27120f805a59c 100644 --- a/sklearn/linear_model/_sag.py +++ b/sklearn/linear_model/_sag.py @@ -89,8 +89,7 @@ def get_auto_step_size(max_squared_sum, alpha_scaled, loss, fit_intercept, def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0., max_iter=1000, tol=0.001, verbose=0, random_state=None, check_input=True, max_squared_sum=None, - warm_start_mem=None, - is_saga=False): + warm_start_mem=None, is_saga=False, callbacks=None): """SAG solver for Ridge and LogisticRegression SAG stands for Stochastic Average Gradient: the gradient of the loss is @@ -323,7 +322,8 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0., intercept_sum_gradient, intercept_decay, is_saga, - verbose) + verbose, + callbacks) if n_iter_ == max_iter: warnings.warn("The max_iter was reached which means " diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 8508340e3b329..8be35c9cb9030 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -53,6 +53,7 @@ from libc.time cimport time, time_t from ._sgd_fast cimport LossFunction from ._sgd_fast cimport Log, SquaredLoss +from .._callbacks import _eval_callbacks from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 @@ -239,7 +240,8 @@ def sag{{name}}(SequentialDataset{{name}} dataset, np.ndarray[{{c_type}}, ndim=1, mode='c'] intercept_sum_gradient_init, double intercept_decay, bint saga, - bint verbose): + bint verbose, + callbacks=None): """Stochastic Average Gradient (SAG) and SAGA solvers. Used in Ridge and LogisticRegression. @@ -534,6 +536,10 @@ def sag{{name}}(SequentialDataset{{name}} dataset, elif verbose: printf('Epoch %d, change: %.8f\n', n_iter + 1, max_change / max_weight) + with gil: + if callbacks is not None: + _eval_callbacks(callbacks, n_iter=n_iter, coef=weights_array, + intercept=intercept_array) n_iter += 1 # We do the error treatment here based on error code in status to avoid # re-acquiring the GIL within the cython code, which slows the computation diff --git a/sklearn/linear_model/_sgd_fast.pyx b/sklearn/linear_model/_sgd_fast.pyx index ab1a274d37c8f..e8ad86f5017f3 100644 --- a/sklearn/linear_model/_sgd_fast.pyx +++ b/sklearn/linear_model/_sgd_fast.pyx @@ -23,6 +23,7 @@ cdef extern from "_sgd_fast_helpers.h": from ..utils._weight_vector cimport WeightVector from ..utils._seq_dataset cimport SequentialDataset64 as SequentialDataset +from .._callbacks import _eval_callbacks np.import_array() @@ -353,7 +354,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, double power_t, double t=1.0, double intercept_decay=1.0, - int average=0): + int average=0, + callbacks=None): """SGD for generic loss functions and penalties with optional averaging Parameters @@ -609,6 +611,10 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, if early_stopping: with gil: score = validation_score_cb(weights, intercept) + _eval_callbacks(callbacks, + n_iter=epoch, + loss=sumloss, + validation_score=score) if tol > -INFINITY and score < best_score + tol: no_improvement_count += 1 else: @@ -624,6 +630,9 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, if sumloss < best_loss: best_loss = sumloss + with gil: + _eval_callbacks(callbacks, n_iter=epoch, loss=sumloss) + # if there is no improvement several times in a row if no_improvement_count >= n_iter_no_change: if learning_rate == ADAPTIVE and eta > 1e-6: diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index 9b35377cbbb72..539a4c5fbc500 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -353,7 +353,7 @@ def _prepare_fit_binary(est, y, i): def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter, pos_weight, neg_weight, sample_weight, validation_mask=None, - random_state=None): + random_state=None, callbacks=None): """Fit a single binary classifier. The i'th class is considered the "positive" class. @@ -435,7 +435,8 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter, est.early_stopping, validation_score_cb, int(est.n_iter_no_change), max_iter, tol, int(est.fit_intercept), int(est.verbose), int(est.shuffle), seed, pos_weight, neg_weight, learning_rate_type, - est.eta0, est.power_t, est.t_, intercept_decay, est.average) + est.eta0, est.power_t, est.t_, intercept_decay, est.average, + callbacks) if est.average: if len(est.classes_) == 2: @@ -576,12 +577,15 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None, def _fit_binary(self, X, y, alpha, C, sample_weight, learning_rate, max_iter): """Fit a binary classifier on X and y. """ - coef, intercept, n_iter_ = fit_binary(self, 1, X, y, alpha, C, - learning_rate, max_iter, - self._expanded_class_weight[1], - self._expanded_class_weight[0], - sample_weight, - random_state=self.random_state) + coef, intercept, n_iter_ = fit_binary( + self, 1, X, y, alpha, C, + learning_rate, max_iter, + self._expanded_class_weight[1], + self._expanded_class_weight[0], + sample_weight, + random_state=self.random_state, + callbacks=getattr(self, '_callbacks', None) + ) self.t_ += n_iter_ * X.shape[0] self.n_iter_ = n_iter_ @@ -623,7 +627,8 @@ def _fit_multiclass(self, X, y, alpha, C, learning_rate, max_iter, self._expanded_class_weight[i], 1., sample_weight, validation_mask=validation_mask, - random_state=seed) + random_state=seed, + callbacks=getattr(self, '_callbacks', None)) for i, seed in enumerate(seeds)) # take the maximum of n_iter_ over every binary fit @@ -1331,7 +1336,8 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate, 1.0, 1.0, learning_rate_type, self.eta0, self.power_t, self.t_, - intercept_decay, self.average) + intercept_decay, self.average, + getattr(self, '_callbacks', None)) self.t_ += self.n_iter_ * X.shape[0] diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index ca9546696e94d..7918165d29d8b 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -326,6 +326,7 @@ def fit(self, X, y=None, **fit_params): This estimator """ fit_params_steps = self._check_fit_params(**fit_params) + self._eval_callbacks(method='on_fit_begin', X=X, y=y) Xt = self._fit(X, y, **fit_params_steps) with _print_elapsed_time('Pipeline', self._log_message(len(self.steps) - 1)): diff --git a/sklearn/tests/test_callbacks.py b/sklearn/tests/test_callbacks.py new file mode 100644 index 0000000000000..5f8908766b3ee --- /dev/null +++ b/sklearn/tests/test_callbacks.py @@ -0,0 +1,149 @@ +# License: BSD 3 clause + +import warnings + +import pytest + +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import StandardScaler, MinMaxScaler +from sklearn.pipeline import make_pipeline +from sklearn.compose import make_column_transformer +from sklearn.exceptions import ConvergenceWarning +from sklearn.base import is_classifier, is_regressor, ClusterMixin, clone +from sklearn._callbacks import BaseCallback, _check_callback_params +from sklearn.utils import all_estimators +from sklearn.utils._testing import set_random_state +from sklearn.utils.estimator_checks import ( + _construct_instance, + _enforce_estimator_tags_y, +) + + +@pytest.fixture(scope="module") +def iris(): + X, y = load_iris(return_X_y=True) + return X, y.reshape(-1, 1) + + +def test_check_callback_params(): + _check_callback_params(n_iter=2, max_iter=10, loss=0.1) + + msg = "Invalid callback parameters: a, must be one of n_iter.*" + with pytest.raises(ValueError, match=msg): + _check_callback_params(a=0) + + msg = "Invalid callback parameters: max_iter=1.0 is not of type .*int" + with pytest.raises(ValueError, match=msg): + _check_callback_params(max_iter=1.0) + + +def _supported_estimators(): + for (name, Estimator) in all_estimators(): + if name.startswith("_"): + continue + # need to make appropriate 1D test data + if name == "IsotonicRegression": + continue + + if ( + is_classifier(Estimator) + or is_regressor(Estimator) + or issubclass(Estimator, ClusterMixin) + ): + yield name, Estimator + + +class CheckCallback(BaseCallback): + def __init__(self): + self.n_calls = 0 + self.n_fit_calls = 0 + + def on_fit_begin(self, estimator, X, y): + self.n_fit_calls += 1 + + def on_iter_end(self, **kwargs): + self.n_calls += 1 + _check_callback_params(**kwargs) + + +@pytest.mark.parametrize("name, Estimator", _supported_estimators()) +def test_callback(name, Estimator, iris): + estimator = _construct_instance(Estimator) + + tags = estimator._get_tags() + + callback = CheckCallback() + estimator._set_callbacks([callback]) + if tags.get("X_types", []) == ["string"]: + X = ["some document", "another document"] + y = None + else: + X, y = iris + y = _enforce_estimator_tags_y(estimator, y) + set_random_state(estimator, 0) + + assert callback.n_calls == 0 + estimator.fit(X, y) + if callback.n_fit_calls == 0: + pytest.skip("callbacks not implemented") + else: + assert callback.n_fit_calls >= 1 + + +def check_has_callback(est, callback): + assert hasattr(est, "_callbacks") and est._callbacks is not None + assert est._callbacks[0] is callback + return True + + +def test_set_callbacks_clone(): + # Check that clone preserves callbacks + est = StandardScaler() + callback = CheckCallback() + est._set_callbacks(callback) + check_has_callback(est, callback) + + est2 = clone(est) + check_has_callback(est2, callback) + + +def test_set_callbacks(): + # Check that callbacks are set recursively for meta-estimators + + X, y = load_iris(return_X_y=True) + + # check simple pipeline (recursive) + callback = CheckCallback() + pipe = make_pipeline(StandardScaler(), LogisticRegression(max_iter=3)) + pipe._set_callbacks(callback) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=ConvergenceWarning) + pipe.fit(X, y) + check_has_callback(pipe, callback) + check_has_callback(pipe.named_steps["standardscaler"], callback) + check_has_callback(pipe.named_steps["logisticregression"], callback) + + # check simple pipeline (non recursive) + callback = CheckCallback() + pipe = make_pipeline(StandardScaler()) + pipe._set_callbacks(callback, deep=False) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=ConvergenceWarning) + pipe.fit(X, y) + check_has_callback(pipe, callback) + assert not hasattr(pipe.named_steps["standardscaler"], "_callbacks") + + # check column transformer + callback = CheckCallback() + pipe = make_column_transformer( + (StandardScaler(), [0, 1]), (MinMaxScaler(), [2, 3]), + ) + + pipe._set_callbacks(callback) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=ConvergenceWarning) + pipe.fit(X, y) + check_has_callback(pipe, callback) + check_has_callback(pipe.named_transformers_["standardscaler"], callback) + check_has_callback(pipe.named_transformers_["minmaxscaler"], callback) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 9a92387034cbc..e467ebcc4300b 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -20,6 +20,7 @@ from ._logistic_sigmoid import _log_logistic_sigmoid from .sparsefuncs_fast import csr_row_norms from .validation import check_array +from .._callbacks import _eval_callbacks from .validation import _deprecate_positional_args @@ -160,7 +161,7 @@ def safe_sparse_dot(a, b, *, dense_output=False): @_deprecate_positional_args def randomized_range_finder(A, *, size, n_iter, power_iteration_normalizer='auto', - random_state=None): + random_state=None, callbacks=None): """Computes an orthonormal matrix whose range approximates the range of A. Parameters @@ -190,6 +191,9 @@ def randomized_range_finder(A, *, size, n_iter, Pass an int for reproducible results across multiple function calls. See :term:`Glossary `. + callbacks : list of BaseCallback, default=None + optional callbacks + Returns ------- Q : 2D array @@ -236,6 +240,8 @@ def randomized_range_finder(A, *, size, n_iter, Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic') Q, _ = linalg.qr(safe_sparse_dot(A.T, Q), mode='economic') + _eval_callbacks(callbacks, n_iter=i) + # Sample the range of A using by linear projection of Q # Extract an orthonormal basis Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic') @@ -245,7 +251,7 @@ def randomized_range_finder(A, *, size, n_iter, @_deprecate_positional_args def randomized_svd(M, n_components, *, n_oversamples=10, n_iter='auto', power_iteration_normalizer='auto', transpose='auto', - flip_sign=True, random_state=0): + flip_sign=True, random_state=0, callbacks=None): """Computes a truncated randomized SVD Parameters @@ -302,6 +308,9 @@ def randomized_svd(M, n_components, *, n_oversamples=10, n_iter='auto', Pass an int for reproducible results across multiple function calls. See :term:`Glossary `. + callbacks : list of BaseCallback, default=None + optional callbacks + Notes ----- This algorithm finds a (usually very good) approximate truncated @@ -348,7 +357,7 @@ def randomized_svd(M, n_components, *, n_oversamples=10, n_iter='auto', Q = randomized_range_finder( M, size=n_random, n_iter=n_iter, power_iteration_normalizer=power_iteration_normalizer, - random_state=random_state) + random_state=random_state, callbacks=callbacks) # project M to the (k + p) dimensional space using the basis vectors B = safe_sparse_dot(Q.T, M) diff --git a/sklearn/utils/optimize.py b/sklearn/utils/optimize.py index 9be6a787dd7ff..77a9e7bc428ca 100644 --- a/sklearn/utils/optimize.py +++ b/sklearn/utils/optimize.py @@ -18,6 +18,7 @@ from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1 from ..exceptions import ConvergenceWarning +from .._callbacks import _eval_callbacks class _LineSearchError(RuntimeError): @@ -112,7 +113,8 @@ def _cg(fhess_p, fgrad, maxiter, tol): def _newton_cg(grad_hess, func, grad, x0, args=(), tol=1e-4, - maxiter=100, maxinner=200, line_search=True, warn=True): + maxiter=100, maxinner=200, line_search=True, warn=True, + callbacks=None): """ Minimization of scalar function of one or more variables using the Newton-CG algorithm. @@ -196,6 +198,8 @@ def _newton_cg(grad_hess, func, grad, x0, args=(), tol=1e-4, break xk = xk + alphak * xsupi # upcast if necessary + if callbacks is not None: + _eval_callbacks(callbacks, n_iter=k, coef=xk, tol=np.max(absgrad)) k += 1 if warn and k >= maxiter: