Skip to content

Commit 8c44bb6

Browse files
dohmatobagramfort
authored andcommitted
REFACTOR: fixed huge code duplication in _RidgeGCV
1 parent 9fe25df commit 8c44bb6

File tree

2 files changed

+65
-16
lines changed

2 files changed

+65
-16
lines changed

sklearn/linear_model/ridge.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -817,24 +817,28 @@ def _diag_dot(self, D, B):
817817
D = D[(slice(None), ) + (np.newaxis, ) * (len(B.shape) - 1)]
818818
return D * B
819819

820-
def _errors(self, alpha, y, v, Q, QT_y):
821-
# don't construct matrix G, instead compute action on y & diagonal
820+
def _errors_and_values_helper(self, alpha, y, v, Q, QT_y):
821+
"""Helper function to avoid code duplication between self._errors and
822+
self._values.
823+
824+
Notes
825+
-----
826+
We don't construct matrix G, instead compute action on y & diagonal.
827+
"""
822828
w = 1.0 / (v + alpha)
823829
c = np.dot(Q, self._diag_dot(w, QT_y))
824830
G_diag = self._decomp_diag(w, Q)
825831
# handle case where y is 2-d
826832
if len(y.shape) != 1:
827833
G_diag = G_diag[:, np.newaxis]
834+
return G_diag, c
835+
836+
def _errors(self, alpha, y, v, Q, QT_y):
837+
G_diag, c = self._errors_and_values_helper(alpha, y, v, Q, QT_y)
828838
return (c / G_diag) ** 2, c
829839

830840
def _values(self, alpha, y, v, Q, QT_y):
831-
# don't construct matrix G, instead compute action on y & diagonal
832-
w = 1.0 / (v + alpha)
833-
c = np.dot(Q, self._diag_dot(w, QT_y))
834-
G_diag = self._decomp_diag(w, Q)
835-
# handle case where y is 2-d
836-
if len(y.shape) != 1:
837-
G_diag = G_diag[:, np.newaxis]
841+
G_diag, c = self._errors_and_values_helper(alpha, y, v, Q, QT_y)
838842
return y - (c / G_diag), c
839843

840844
def _pre_compute_svd(self, X, y):
@@ -845,22 +849,24 @@ def _pre_compute_svd(self, X, y):
845849
UT_y = np.dot(U.T, y)
846850
return v, U, UT_y
847851

848-
def _errors_svd(self, alpha, y, v, U, UT_y):
852+
def _errors_and_values_svd_helper(self, alpha, y, v, U, UT_y):
853+
"""Helper function to avoid code duplication between self._errors_svd
854+
and self._values_svd.
855+
"""
849856
w = ((v + alpha) ** -1) - (alpha ** -1)
850857
c = np.dot(U, self._diag_dot(w, UT_y)) + (alpha ** -1) * y
851858
G_diag = self._decomp_diag(w, U) + (alpha ** -1)
852859
if len(y.shape) != 1:
853860
# handle case where y is 2-d
854861
G_diag = G_diag[:, np.newaxis]
862+
return G_diag, c
863+
864+
def _errors_svd(self, alpha, y, v, U, UT_y):
865+
G_diag, c = self._errors_and_values_svd_helper(alpha, y, v, U, UT_y)
855866
return (c / G_diag) ** 2, c
856867

857868
def _values_svd(self, alpha, y, v, U, UT_y):
858-
w = ((v + alpha) ** -1) - (alpha ** -1)
859-
c = np.dot(U, self._diag_dot(w, UT_y)) + (alpha ** -1) * y
860-
G_diag = self._decomp_diag(w, U) + (alpha ** -1)
861-
if len(y.shape) != 1:
862-
# handle case when y is 2-d
863-
G_diag = G_diag[:, np.newaxis]
869+
G_diag, c = self._errors_and_values_svd_helper(alpha, y, v, U, UT_y)
864870
return y - (c / G_diag), c
865871

866872
def fit(self, X, y, sample_weight=None):

sklearn/linear_model/tests/test_ridge.py

+43
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sklearn.model_selection import GridSearchCV
3333
from sklearn.model_selection import KFold
3434

35+
from sklearn.utils import check_random_state
3536

3637
diabetes = datasets.load_diabetes()
3738
X_diabetes, y_diabetes = diabetes.data, diabetes.target
@@ -711,3 +712,45 @@ def test_ridge_fit_intercept_sparse():
711712
assert_warns(UserWarning, sparse.fit, X_csr, y)
712713
assert_almost_equal(dense.intercept_, sparse.intercept_)
713714
assert_array_almost_equal(dense.coef_, sparse.coef_)
715+
716+
717+
def test_errors_and_values_helper():
718+
ridgecv = _RidgeGCV()
719+
rng = check_random_state(42)
720+
alpha = 1.
721+
n = 5
722+
y = rng.randn(n)
723+
v = rng.randn(n)
724+
Q = rng.randn(len(v), len(v))
725+
QT_y = Q.T.dot(y)
726+
G_diag, c = ridgecv._errors_and_values_helper(alpha, y, v, Q, QT_y)
727+
728+
# test that helper function behaves as expected
729+
out, c_ = ridgecv._errors(alpha, y, v, Q, QT_y)
730+
np.testing.assert_array_equal(out, (c / G_diag) ** 2)
731+
np.testing.assert_array_equal(c, c)
732+
733+
out, c_ = ridgecv._values(alpha, y, v, Q, QT_y)
734+
np.testing.assert_array_equal(out, y - (c / G_diag))
735+
np.testing.assert_array_equal(c_, c)
736+
737+
738+
def test_errors_and_values_svd_helper():
739+
ridgecv = _RidgeGCV()
740+
rng = check_random_state(42)
741+
alpha = 1.
742+
for n, p in zip((5, 10), (12, 6)):
743+
y = rng.randn(n)
744+
v = rng.randn(p)
745+
U = rng.randn(n, p)
746+
UT_y = U.T.dot(y)
747+
G_diag, c = ridgecv._errors_and_values_svd_helper(alpha, y, v, U, UT_y)
748+
749+
# test that helper function behaves as expected
750+
out, c_ = ridgecv._errors_svd(alpha, y, v, U, UT_y)
751+
np.testing.assert_array_equal(out, (c / G_diag) ** 2)
752+
np.testing.assert_array_equal(c, c)
753+
754+
out, c_ = ridgecv._values_svd(alpha, y, v, U, UT_y)
755+
np.testing.assert_array_equal(out, y - (c / G_diag))
756+
np.testing.assert_array_equal(c_, c)

0 commit comments

Comments
 (0)