diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 2200458cc5d32..aa5ade8969003 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -267,7 +267,7 @@ Changelog meaning that it is note required to call `fit` before calling `transform`. Parameter validation only happens at `fit` time. :pr:`24230` by :user:`Guillaume Lemaitre `. - + :mod:`sklearn.feature_selection` ................................ @@ -294,6 +294,11 @@ Changelog :mod:`sklearn.linear_model` ........................... +- |Efficiency| Avoid data scaling when `sample_weight=None` and other + unnecessary data copies and unexpected dense to sparse data conversion in + :class:`linear_model.LinearRegression`. + :pr:`26207` by :user:`Olivier Grisel `. + - |Enhancement| :class:`linear_model.SGDClassifier`, :class:`linear_model.SGDRegressor` and :class:`linear_model.SGDOneClassSVM` now preserve dtype for `numpy.float32`. @@ -309,7 +314,7 @@ Changelog :class:`linear_model.ARDRegression` to expose the actual number of iterations required to reach the stopping criterion. :pr:`25697` by :user:`John Pangas `. - + - |Fix| Use a more robust criterion to detect convergence of :class:`linear_model.LogisticRegression(penalty="l1", solver="liblinear")` on linearly separable problems. diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index e8e907906efc3..d05713a34a139 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -187,6 +187,7 @@ def _preprocess_data( fit_intercept, normalize=False, copy=True, + copy_y=True, sample_weight=None, check_input=True, ): @@ -230,13 +231,14 @@ def _preprocess_data( if check_input: X = check_array(X, copy=copy, accept_sparse=["csr", "csc"], dtype=FLOAT_DTYPES) - elif copy: - if sp.issparse(X): - X = X.copy() - else: - X = X.copy(order="K") - - y = np.asarray(y, dtype=X.dtype) + y = check_array(y, dtype=X.dtype, copy=copy_y, ensure_2d=False) + else: + y = y.astype(X.dtype, copy=copy_y) + if copy: + if sp.issparse(X): + X = X.copy() + else: + X = X.copy(order="K") if fit_intercept: if sp.issparse(X): @@ -276,7 +278,7 @@ def _preprocess_data( X_scale = np.ones(X.shape[1], dtype=X.dtype) y_offset = np.average(y, axis=0, weights=sample_weight) - y = y - y_offset + y -= y_offset else: X_offset = np.zeros(X.shape[1], dtype=X.dtype) X_scale = np.ones(X.shape[1], dtype=X.dtype) @@ -293,7 +295,7 @@ def _preprocess_data( # sample_weight makes the refactoring tricky. -def _rescale_data(X, y, sample_weight): +def _rescale_data(X, y, sample_weight, inplace=False): """Rescale data sample-wise by square root of sample_weight. For many linear models, this enables easy support for sample_weight because @@ -315,14 +317,37 @@ def _rescale_data(X, y, sample_weight): y_rescaled : {array-like, sparse matrix} """ + # Assume that _validate_data and _check_sample_weight have been called by + # the caller. n_samples = X.shape[0] - sample_weight = np.asarray(sample_weight) - if sample_weight.ndim == 0: - sample_weight = np.full(n_samples, sample_weight, dtype=sample_weight.dtype) sample_weight_sqrt = np.sqrt(sample_weight) - sw_matrix = sparse.dia_matrix((sample_weight_sqrt, 0), shape=(n_samples, n_samples)) - X = safe_sparse_dot(sw_matrix, X) - y = safe_sparse_dot(sw_matrix, y) + + if sp.issparse(X) or sp.issparse(y): + sw_matrix = sparse.dia_matrix( + (sample_weight_sqrt, 0), shape=(n_samples, n_samples) + ) + + if sp.issparse(X): + X = safe_sparse_dot(sw_matrix, X) + else: + if inplace: + X *= sample_weight_sqrt[:, np.newaxis] + else: + X = X * sample_weight_sqrt[:, np.newaxis] + + if sp.issparse(y): + y = safe_sparse_dot(sw_matrix, y) + else: + if inplace: + if y.ndim == 1: + y *= sample_weight_sqrt + else: + y *= sample_weight_sqrt[:, np.newaxis] + else: + if y.ndim == 1: + y = y * sample_weight_sqrt + else: + y = y * sample_weight_sqrt[:, np.newaxis] return X, y, sample_weight_sqrt @@ -651,20 +676,32 @@ def fit(self, X, y, sample_weight=None): X, y, accept_sparse=accept_sparse, y_numeric=True, multi_output=True ) - sample_weight = _check_sample_weight( - sample_weight, X, dtype=X.dtype, only_non_negative=True - ) + has_sw = sample_weight is not None + if has_sw: + sample_weight = _check_sample_weight( + sample_weight, X, dtype=X.dtype, only_non_negative=True + ) + + # Note that neither _rescale_data nor the rest of the fit method of + # LinearRegression can benefit from in-place operations when X is a + # sparse matrix. Therefore, let's not copy X when it is sparse. + copy_X_in_preprocess_data = self.copy_X and not sp.issparse(X) X, y, X_offset, y_offset, X_scale = _preprocess_data( X, y, fit_intercept=self.fit_intercept, - copy=self.copy_X, + copy=copy_X_in_preprocess_data, sample_weight=sample_weight, ) - # Sample weight can be implemented via a simple rescaling. - X, y, sample_weight_sqrt = _rescale_data(X, y, sample_weight) + if has_sw: + # Sample weight can be implemented via a simple rescaling. Note + # that we safely do inplace rescaling when _preprocess_data has + # already made a copy if requested. + X, y, sample_weight_sqrt = _rescale_data( + X, y, sample_weight, inplace=copy_X_in_preprocess_data + ) if self.positive: if y.ndim < 2: @@ -678,11 +715,21 @@ def fit(self, X, y, sample_weight=None): elif sp.issparse(X): X_offset_scale = X_offset / X_scale - def matvec(b): - return X.dot(b) - sample_weight_sqrt * b.dot(X_offset_scale) + if has_sw: + + def matvec(b): + return X.dot(b) - sample_weight_sqrt * b.dot(X_offset_scale) + + def rmatvec(b): + return X.T.dot(b) - X_offset_scale * b.dot(sample_weight_sqrt) + + else: + + def matvec(b): + return X.dot(b) - b.dot(X_offset_scale) - def rmatvec(b): - return X.T.dot(b) - X_offset_scale * b.dot(sample_weight_sqrt) + def rmatvec(b): + return X.T.dot(b) - X_offset_scale * b.sum() X_centered = sparse.linalg.LinearOperator( shape=X.shape, matvec=matvec, rmatvec=rmatvec diff --git a/sklearn/linear_model/tests/test_base.py b/sklearn/linear_model/tests/test_base.py index f8b1bd7b7b65c..112583bae593a 100644 --- a/sklearn/linear_model/tests/test_base.py +++ b/sklearn/linear_model/tests/test_base.py @@ -315,6 +315,62 @@ def test_linear_regression_positive_vs_nonpositive_when_positive(global_random_s assert np.mean((reg.coef_ - regn.coef_) ** 2) < 1e-6 +@pytest.mark.parametrize("sparse_X", [True, False]) +@pytest.mark.parametrize("use_sw", [True, False]) +def test_inplace_data_preprocessing(sparse_X, use_sw, global_random_seed): + # Check that the data is not modified inplace by the linear regression + # estimator. + rng = np.random.RandomState(global_random_seed) + original_X_data = rng.randn(10, 12) + original_y_data = rng.randn(10, 2) + orginal_sw_data = rng.rand(10) + + if sparse_X: + X = sparse.csr_matrix(original_X_data) + else: + X = original_X_data.copy() + y = original_y_data.copy() + # XXX: Note hat y_sparse is not supported (broken?) in the current + # implementation of LinearRegression. + + if use_sw: + sample_weight = orginal_sw_data.copy() + else: + sample_weight = None + + # Do not allow inplace preprocessing of X and y: + reg = LinearRegression() + reg.fit(X, y, sample_weight=sample_weight) + if sparse_X: + assert_allclose(X.toarray(), original_X_data) + else: + assert_allclose(X, original_X_data) + assert_allclose(y, original_y_data) + + if use_sw: + assert_allclose(sample_weight, orginal_sw_data) + + # Allow inplace preprocessing of X and y + reg = LinearRegression(copy_X=False) + reg.fit(X, y, sample_weight=sample_weight) + if sparse_X: + # No optimization relying on the inplace modification of sparse input + # data has been implemented at this time. + assert_allclose(X.toarray(), original_X_data) + else: + # X has been offset (and optionally rescaled by sample weights) + # inplace. The 0.42 threshold is arbitrary and has been found to be + # robust to any random seed in the admissible range. + assert np.linalg.norm(X - original_X_data) > 0.42 + + # y should not have been modified inplace by LinearRegression.fit. + assert_allclose(y, original_y_data) + + if use_sw: + # Sample weights have no reason to ever be modified inplace. + assert_allclose(sample_weight, orginal_sw_data) + + def test_linear_regression_pd_sparse_dataframe_warning(): pd = pytest.importorskip("pandas") @@ -661,7 +717,8 @@ def test_dtype_preprocess_data(global_random_seed): @pytest.mark.parametrize("n_targets", [None, 2]) -def test_rescale_data_dense(n_targets, global_random_seed): +@pytest.mark.parametrize("sparse_data", [True, False]) +def test_rescale_data(n_targets, sparse_data, global_random_seed): rng = np.random.RandomState(global_random_seed) n_samples = 200 n_features = 2 @@ -672,14 +729,34 @@ def test_rescale_data_dense(n_targets, global_random_seed): y = rng.rand(n_samples) else: y = rng.rand(n_samples, n_targets) - rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight) - rescaled_X2 = X * sqrt_sw[:, np.newaxis] + + expected_sqrt_sw = np.sqrt(sample_weight) + expected_rescaled_X = X * expected_sqrt_sw[:, np.newaxis] + if n_targets is None: - rescaled_y2 = y * sqrt_sw + expected_rescaled_y = y * expected_sqrt_sw else: - rescaled_y2 = y * sqrt_sw[:, np.newaxis] - assert_array_almost_equal(rescaled_X, rescaled_X2) - assert_array_almost_equal(rescaled_y, rescaled_y2) + expected_rescaled_y = y * expected_sqrt_sw[:, np.newaxis] + + if sparse_data: + X = sparse.csr_matrix(X) + if n_targets is None: + y = sparse.csr_matrix(y.reshape(-1, 1)) + else: + y = sparse.csr_matrix(y) + + rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight) + + assert_allclose(sqrt_sw, expected_sqrt_sw) + + if sparse_data: + rescaled_X = rescaled_X.toarray() + rescaled_y = rescaled_y.toarray() + if n_targets is None: + rescaled_y = rescaled_y.ravel() + + assert_allclose(rescaled_X, expected_rescaled_X) + assert_allclose(rescaled_y, expected_rescaled_y) def test_fused_types_make_dataset():