From 7a6743a25999ae5cb88f3b71bf47230e9afbf032 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 9 Apr 2022 17:41:54 +0200 Subject: [PATCH 1/2] ENH save memory with LinearLoss --- sklearn/linear_model/_glm/glm.py | 18 ++++- sklearn/linear_model/_linear_loss.py | 109 ++++++++++++++++++++++++++- sklearn/linear_model/_logistic.py | 71 +++++++++++++---- 3 files changed, 178 insertions(+), 20 deletions(-) diff --git a/sklearn/linear_model/_glm/glm.py b/sklearn/linear_model/_glm/glm.py index 68aa4ea0df22c..6b1b0182cb7a5 100644 --- a/sklearn/linear_model/_glm/glm.py +++ b/sklearn/linear_model/_glm/glm.py @@ -271,6 +271,12 @@ def fit(self, X, y, sample_weight=None): else: coef = np.zeros(n_features, dtype=loss_dtype) + # To save some memory, we preallocate a ndarray used as per row loss and + # gradient inside od LinearLoss, e.g. by LinearLoss.base_loss.gradient (and + # others). + per_sample_loss_out = np.empty_like(y) + per_sample_gradient_out = np.empty_like(y) + # Algorithms for optimization: # Note again that our losses implement 1/2 * deviance. if solver == "lbfgs": @@ -289,7 +295,17 @@ def fit(self, X, y, sample_weight=None): "gtol": self.tol, "ftol": 1e3 * np.finfo(float).eps, }, - args=(X, y, sample_weight, l2_reg_strength, n_threads), + args=( + X, + y, + sample_weight, + l2_reg_strength, + n_threads, + { + "per_sample_loss_out": per_sample_loss_out, + "per_sample_gradient_out": per_sample_gradient_out, + }, + ), ) self.n_iter_ = _check_optimize_result("lbfgs", opt_res) coef = opt_res.x diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 64a99325dcd7a..55f2b4bc71273 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -112,7 +112,16 @@ def _w_intercept_raw(self, coef, X): return weights, intercept, raw_prediction - def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1): + def loss( + self, + coef, + X, + y, + sample_weight=None, + l2_reg_strength=0.0, + n_threads=1, + temporary_array_dict=None, + ): """Compute the loss as sum over point-wise losses. Parameters @@ -132,6 +141,12 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. + temporary_array_dict: None or dict, default=None + Providing such an array can save a little memory. Possible entry: + + "per_sample_loss_out" : C-contiguous array of shape (n_samples,) + A location into which the per sample loss is stored. If None, a new + array might be created. Returns ------- @@ -139,11 +154,16 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) Sum of losses per sample plus penalty. """ weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + if temporary_array_dict is not None: + per_sample_loss_out = temporary_array_dict.get("per_sample_loss_out", None) + else: + per_sample_loss_out = None loss = self.base_loss.loss( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, + loss_out=per_sample_loss_out, n_threads=n_threads, ) loss = loss.sum() @@ -152,7 +172,14 @@ def loss(self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1) return loss + 0.5 * l2_reg_strength * norm2_w def loss_gradient( - self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + self, + coef, + X, + y, + sample_weight=None, + l2_reg_strength=0.0, + n_threads=1, + temporary_array_dict=None, ): """Computes the sum of loss and gradient w.r.t. coef. @@ -173,6 +200,17 @@ def loss_gradient( L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. + temporary_array_dict: None or dict, default=None + Providing such arrays can save a little memory. Possible entry: + + "per_sample_loss_out" : C-contiguous array of shape (n_samples,) + A location into which the per sample loss is stored. If None, a new + array might be created. + + "per_sample_gradient_out" : None or C-contiguous array of shape + (n_samples,) or array of shape (n_samples, n_classes) + A location into which the per sample gradient is stored. If None, a + new array might be created. Returns ------- @@ -185,11 +223,21 @@ def loss_gradient( n_features, n_classes = X.shape[1], self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + if temporary_array_dict is not None: + per_sample_loss_out = temporary_array_dict.get("per_sample_loss_out", None) + per_sample_gradient_out = temporary_array_dict.get( + "per_sample_gradient_out", None + ) + else: + per_sample_loss_out = None + per_sample_gradient_out = None loss, grad_per_sample = self.base_loss.loss_gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, + loss_out=per_sample_loss_out, + gradient_out=per_sample_gradient_out, n_threads=n_threads, ) loss = loss.sum() @@ -213,7 +261,14 @@ def loss_gradient( return loss, grad def gradient( - self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + self, + coef, + X, + y, + sample_weight=None, + l2_reg_strength=0.0, + n_threads=1, + temporary_array_dict=None, ): """Computes the gradient w.r.t. coef. @@ -234,6 +289,13 @@ def gradient( L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. + temporary_array_dict: None or dict, default=None + Providing such an array can save a little memory. Possible entry: + + "per_sample_gradient_out" : None or C-contiguous array of shape + (n_samples,) or array of shape (n_samples, n_classes) + A location into which the per sample gradient is stored. If None, a new + array might be created. Returns ------- @@ -243,11 +305,18 @@ def gradient( n_features, n_classes = X.shape[1], self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + if temporary_array_dict is not None: + per_sample_gradient_out = temporary_array_dict.get( + "per_sample_gradient_out", None + ) + else: + per_sample_gradient_out = None grad_per_sample = self.base_loss.gradient( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, + gradient_out=per_sample_gradient_out, n_threads=n_threads, ) @@ -269,7 +338,14 @@ def gradient( return grad def gradient_hessian_product( - self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1 + self, + coef, + X, + y, + sample_weight=None, + l2_reg_strength=0.0, + n_threads=1, + temporary_array_dict=None, ): """Computes gradient and hessp (hessian product function) w.r.t. coef. @@ -290,6 +366,17 @@ def gradient_hessian_product( L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. + temporary_array_dict: None or dict, default=None + Providing such an array can save a little memory. Possible entry: + + "per_sample_gradient_out" : None or C-contiguous array of shape + (n_samples,) or array of shape (n_samples, n_classes) + A location into which the per sample gradient is stored. If None, a new + array might be created. + "per_sample_hessian_out" : None or C-contiguous array of shape (n_samples,) + or array of shape (n_samples, n_classes) + A location into which the per sample hessian is stored. If None, a new + array might be created. Returns ------- @@ -303,12 +390,24 @@ def gradient_hessian_product( (n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) + if temporary_array_dict is not None: + per_sample_gradient_out = temporary_array_dict.get( + "per_sample_gradient_out", None + ) + per_sample_hessian_out = temporary_array_dict.get( + "per_sample_hessian_out", None + ) + else: + per_sample_gradient_out = None + per_sample_hessian_out = None if not self.base_loss.is_multiclass: gradient, hessian = self.base_loss.gradient_hessian( y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, + gradient_out=per_sample_gradient_out, + hessian_out=per_sample_hessian_out, n_threads=n_threads, ) grad = np.empty_like(coef, dtype=weights.dtype) @@ -356,6 +455,8 @@ def hessp(s): y_true=y, raw_prediction=raw_prediction, sample_weight=sample_weight, + gradient_out=per_sample_gradient_out, + proba_out=per_sample_hessian_out, n_threads=n_threads, ) grad = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F") diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 658aa12fcd535..469ff158ae1de 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -397,34 +397,54 @@ def _logistic_regression_path( # reconstructs the 2d-array via w0.reshape((n_classes, -1), order="F"). # As w0 is F-contiguous, ravel(order="F") also avoids a copy. w0 = w0.ravel(order="F") - loss = LinearModelLoss( + linear_loss = LinearModelLoss( base_loss=HalfMultinomialLoss(n_classes=classes.size), fit_intercept=fit_intercept, ) target = Y_multi - if solver in "lbfgs": - func = loss.loss_gradient - elif solver == "newton-cg": - func = loss.loss - grad = loss.gradient - hess = loss.gradient_hessian_product # hess = [gradient, hessp] warm_start_sag = {"coef": w0.T} else: target = y_bin if solver == "lbfgs": - loss = LinearModelLoss( + linear_loss = LinearModelLoss( base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept ) - func = loss.loss_gradient elif solver == "newton-cg": - loss = LinearModelLoss( + linear_loss = LinearModelLoss( base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept ) - func = loss.loss - grad = loss.gradient - hess = loss.gradient_hessian_product # hess = [gradient, hessp] warm_start_sag = {"coef": np.expand_dims(w0, axis=1)} + if solver == "lbfgs": + func = linear_loss.loss_gradient + + # To save some memory, we preallocate a ndarray used as per row loss and + # gradient inside od LinearLoss, e.g. by LinearLoss.base_loss.gradient (and + # others). + per_sample_loss_out = np.empty_like(target) + if linear_loss.base_loss.is_multiclass: + per_sample_gradient_out = np.empty( + shape=(X.shape[0], classes.size), dtype=X.dtype, order="C" + ) + else: + per_sample_gradient_out = np.empty_like(target, order="C") + elif solver == "newton-cg": + func = linear_loss.loss + grad = linear_loss.gradient + hess = linear_loss.gradient_hessian_product # hess = [gradient, hessp] + + # To save some memory, we preallocate a ndarray used as per row loss and + # gradient inside od LinearLoss, e.g. by LinearLoss.base_loss.gradient (and + # others). + per_sample_loss_out = np.empty_like(target) + if linear_loss.base_loss.is_multiclass: + per_sample_gradient_out = np.empty( + shape=(X.shape[0], classes.size), dtype=X.dtype, order="C" + ) + else: + per_sample_gradient_out = np.empty_like(target, order="C") + per_sample_hessian_out = np.empty_like(per_sample_gradient_out) + coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) for i, C in enumerate(Cs): @@ -438,7 +458,17 @@ def _logistic_regression_path( w0, method="L-BFGS-B", jac=True, - args=(X, target, sample_weight, l2_reg_strength, n_threads), + args=( + X, + target, + sample_weight, + l2_reg_strength, + n_threads, + { + "per_sample_loss_out": per_sample_loss_out, + "per_sample_gradient_out": per_sample_gradient_out, + }, + ), options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}, ) n_iter_i = _check_optimize_result( @@ -450,7 +480,18 @@ def _logistic_regression_path( w0, loss = opt_res.x, opt_res.fun elif solver == "newton-cg": l2_reg_strength = 1.0 / C - args = (X, target, sample_weight, l2_reg_strength, n_threads) + args = ( + X, + target, + sample_weight, + l2_reg_strength, + n_threads, + { + "per_sample_loss_out": per_sample_loss_out, + "per_sample_gradient_out": per_sample_gradient_out, + "per_sample_hessian_out": per_sample_hessian_out, + }, + ) w0, n_iter_i = _newton_cg( hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol ) From eac91e1577f97d32f809a957796b475dfd61fef2 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 10 Apr 2022 10:41:17 +0200 Subject: [PATCH 2/2] CLN use functools --- sklearn/linear_model/_glm/glm.py | 8 +-- sklearn/linear_model/_linear_loss.py | 97 +++++++++------------------- sklearn/linear_model/_logistic.py | 51 +++++++-------- 3 files changed, 55 insertions(+), 101 deletions(-) diff --git a/sklearn/linear_model/_glm/glm.py b/sklearn/linear_model/_glm/glm.py index 6b1b0182cb7a5..2ff296be3da49 100644 --- a/sklearn/linear_model/_glm/glm.py +++ b/sklearn/linear_model/_glm/glm.py @@ -272,7 +272,7 @@ def fit(self, X, y, sample_weight=None): coef = np.zeros(n_features, dtype=loss_dtype) # To save some memory, we preallocate a ndarray used as per row loss and - # gradient inside od LinearLoss, e.g. by LinearLoss.base_loss.gradient (and + # gradient inside of LinearLoss, e.g. by LinearLoss.base_loss.gradient (and # others). per_sample_loss_out = np.empty_like(y) per_sample_gradient_out = np.empty_like(y) @@ -301,10 +301,8 @@ def fit(self, X, y, sample_weight=None): sample_weight, l2_reg_strength, n_threads, - { - "per_sample_loss_out": per_sample_loss_out, - "per_sample_gradient_out": per_sample_gradient_out, - }, + per_sample_loss_out, + per_sample_gradient_out, ), ) self.n_iter_ = _check_optimize_result("lbfgs", opt_res) diff --git a/sklearn/linear_model/_linear_loss.py b/sklearn/linear_model/_linear_loss.py index 55f2b4bc71273..9fcdc7cd396da 100644 --- a/sklearn/linear_model/_linear_loss.py +++ b/sklearn/linear_model/_linear_loss.py @@ -120,7 +120,7 @@ def loss( sample_weight=None, l2_reg_strength=0.0, n_threads=1, - temporary_array_dict=None, + per_sample_loss_out=None, ): """Compute the loss as sum over point-wise losses. @@ -141,12 +141,10 @@ def loss( L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. - temporary_array_dict: None or dict, default=None - Providing such an array can save a little memory. Possible entry: - - "per_sample_loss_out" : C-contiguous array of shape (n_samples,) - A location into which the per sample loss is stored. If None, a new - array might be created. + per_sample_loss_out : None or C-contiguous array of shape (n_samples,), \ + default=None + A location into which the per sample loss is stored. If None, a new array + might be created. Providing such an array can save a little memory. Returns ------- @@ -154,10 +152,6 @@ def loss( Sum of losses per sample plus penalty. """ weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - if temporary_array_dict is not None: - per_sample_loss_out = temporary_array_dict.get("per_sample_loss_out", None) - else: - per_sample_loss_out = None loss = self.base_loss.loss( y_true=y, @@ -179,7 +173,8 @@ def loss_gradient( sample_weight=None, l2_reg_strength=0.0, n_threads=1, - temporary_array_dict=None, + per_sample_loss_out=None, + per_sample_gradient_out=None, ): """Computes the sum of loss and gradient w.r.t. coef. @@ -200,17 +195,14 @@ def loss_gradient( L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. - temporary_array_dict: None or dict, default=None - Providing such arrays can save a little memory. Possible entry: - - "per_sample_loss_out" : C-contiguous array of shape (n_samples,) - A location into which the per sample loss is stored. If None, a new - array might be created. - - "per_sample_gradient_out" : None or C-contiguous array of shape - (n_samples,) or array of shape (n_samples, n_classes) - A location into which the per sample gradient is stored. If None, a - new array might be created. + per_sample_loss_out : None or C-contiguous array of shape (n_samples,), \ + default=None + A location into which the per sample loss is stored. If None, a new array + might be created. Providing such an array can save a little memory. + per_sample_gradient_out : None or C-contiguous array of shape (n_samples,) or + array of shape (n_samples, n_classes), default=None + A location into which the per sample gradient is stored. If None, a new + array might be created. Providing such an array can save a little memory. Returns ------- @@ -223,14 +215,6 @@ def loss_gradient( n_features, n_classes = X.shape[1], self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - if temporary_array_dict is not None: - per_sample_loss_out = temporary_array_dict.get("per_sample_loss_out", None) - per_sample_gradient_out = temporary_array_dict.get( - "per_sample_gradient_out", None - ) - else: - per_sample_loss_out = None - per_sample_gradient_out = None loss, grad_per_sample = self.base_loss.loss_gradient( y_true=y, @@ -268,7 +252,7 @@ def gradient( sample_weight=None, l2_reg_strength=0.0, n_threads=1, - temporary_array_dict=None, + per_sample_gradient_out=None, ): """Computes the gradient w.r.t. coef. @@ -289,13 +273,10 @@ def gradient( L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. - temporary_array_dict: None or dict, default=None - Providing such an array can save a little memory. Possible entry: - - "per_sample_gradient_out" : None or C-contiguous array of shape - (n_samples,) or array of shape (n_samples, n_classes) - A location into which the per sample gradient is stored. If None, a new - array might be created. + per_sample_gradient_out : None or C-contiguous array of shape (n_samples,) or + array of shape (n_samples, n_classes), default=None + A location into which the per sample gradient is stored. If None, a new + array might be created. Providing such an array can save a little memory. Returns ------- @@ -305,12 +286,6 @@ def gradient( n_features, n_classes = X.shape[1], self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - if temporary_array_dict is not None: - per_sample_gradient_out = temporary_array_dict.get( - "per_sample_gradient_out", None - ) - else: - per_sample_gradient_out = None grad_per_sample = self.base_loss.gradient( y_true=y, @@ -345,7 +320,8 @@ def gradient_hessian_product( sample_weight=None, l2_reg_strength=0.0, n_threads=1, - temporary_array_dict=None, + per_sample_gradient_out=None, + per_sample_hessian_out=None, ): """Computes gradient and hessp (hessian product function) w.r.t. coef. @@ -366,17 +342,14 @@ def gradient_hessian_product( L2 regularization strength n_threads : int, default=1 Number of OpenMP threads to use. - temporary_array_dict: None or dict, default=None - Providing such an array can save a little memory. Possible entry: - - "per_sample_gradient_out" : None or C-contiguous array of shape - (n_samples,) or array of shape (n_samples, n_classes) - A location into which the per sample gradient is stored. If None, a new - array might be created. - "per_sample_hessian_out" : None or C-contiguous array of shape (n_samples,) - or array of shape (n_samples, n_classes) - A location into which the per sample hessian is stored. If None, a new - array might be created. + per_sample_gradient_out : None or C-contiguous array of shape (n_samples,) or + array of shape (n_samples, n_classes), default=None + A location into which the per sample gradient is stored. If None, a new + array might be created. Providing such an array can save a little memory. + per_sample_hessian_out : None or C-contiguous array of shape (n_samples,) or + array of shape (n_samples, n_classes), default=None + A location into which the per sample hessian is stored. If None, a new + array might be created. Providing such an array can save a little memory. Returns ------- @@ -390,16 +363,6 @@ def gradient_hessian_product( (n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes n_dof = n_features + int(self.fit_intercept) weights, intercept, raw_prediction = self._w_intercept_raw(coef, X) - if temporary_array_dict is not None: - per_sample_gradient_out = temporary_array_dict.get( - "per_sample_gradient_out", None - ) - per_sample_hessian_out = temporary_array_dict.get( - "per_sample_hessian_out", None - ) - else: - per_sample_gradient_out = None - per_sample_hessian_out = None if not self.base_loss.is_multiclass: gradient, hessian = self.base_loss.gradient_hessian( diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 469ff158ae1de..87aaf2afc3eaf 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -10,6 +10,7 @@ # Simon Wu # Arthur Mensch