Skip to content

[WIP] RidgeGCV with sample weights is broken #4490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions sklearn/linear_model/ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,11 +797,21 @@ def __init__(self, alphas=(0.1, 1.0, 10.0),
self.gcv_mode = gcv_mode
self.store_cv_values = store_cv_values

def _pre_compute(self, X, y):
def _pre_compute(self, X, y, sample_weight=None):
# even if X is very sparse, K is usually very dense
K = safe_sparse_dot(X, X.T, dense_output=True)
if sample_weight is not None:
has_sw = True
sqrt_sw = np.sqrt(sample_weight).reshape(K.shape[0], 1)
K *= sqrt_sw
K *= sqrt_sw.T
else:
has_sw = False
v, Q = linalg.eigh(K)
QT_y = np.dot(Q.T, y)
if has_sw:
QT_y = np.dot(Q.T, sqrt_sw * y.reshape(y.shape[0], -1))
else:
QT_y = np.dot(Q.T, y)
return v, Q, QT_y

def _decomp_diag(self, v_prime, Q):
Expand All @@ -815,15 +825,28 @@ def _diag_dot(self, D, B):
D = D[(slice(None), ) + (np.newaxis, ) * (len(B.shape) - 1)]
return D * B

def _errors(self, alpha, y, v, Q, QT_y):
def _errors(self, alpha, y, v, Q, QT_y, sample_weight=None):
# don't construct matrix G, instead compute action on y & diagonal
w = 1.0 / (v + alpha)
c = np.dot(Q, self._diag_dot(w, QT_y))
c.shape = (c.shape[0], -1)
if sample_weight is not None:
has_sw = True
sqrt_sw = np.sqrt(sample_weight).reshape(y.shape[0], 1)
c *= sqrt_sw
else:
has_sw = False
G_diag = self._decomp_diag(w, Q)
# handle case where y is 2-d
if len(y.shape) != 1:
G_diag = G_diag[:, np.newaxis]
return (c / G_diag) ** 2, c
G_diag = G_diag[:, np.newaxis]

errors = c / G_diag
if has_sw:
errors /= sqrt_sw ** 2

if y.ndim == 1:
errors.shape = (-1,)
c.shape = (-1,)
return errors ** 2, c

def _values(self, alpha, y, v, Q, QT_y):
# don't construct matrix G, instead compute action on y & diagonal
Expand All @@ -835,15 +858,15 @@ def _values(self, alpha, y, v, Q, QT_y):
G_diag = G_diag[:, np.newaxis]
return y - (c / G_diag), c

def _pre_compute_svd(self, X, y):
def _pre_compute_svd(self, X, y, sample_weight=None):
if sparse.issparse(X):
raise TypeError("SVD not supported for sparse matrices")
U, s, _ = linalg.svd(X, full_matrices=0)
v = s ** 2
UT_y = np.dot(U.T, y)
return v, U, UT_y

def _errors_svd(self, alpha, y, v, U, UT_y):
def _errors_svd(self, alpha, y, v, U, UT_y, sample_weight=None):
w = ((v + alpha) ** -1) - (alpha ** -1)
c = np.dot(U, self._diag_dot(w, UT_y)) + (alpha ** -1) * y
G_diag = self._decomp_diag(w, U) + (alpha ** -1)
Expand Down Expand Up @@ -889,7 +912,7 @@ def fit(self, X, y, sample_weight=None):
sample_weight=sample_weight)

gcv_mode = self.gcv_mode
with_sw = len(np.shape(sample_weight))
with_sw = sample_weight is not None

if gcv_mode is None or gcv_mode == 'auto':
if sparse.issparse(X) or n_features > n_samples or with_sw:
Expand All @@ -914,7 +937,7 @@ def fit(self, X, y, sample_weight=None):
else:
raise ValueError('bad gcv_mode "%s"' % gcv_mode)

v, Q, QT_y = _pre_compute(X, y)
v, Q, QT_y = _pre_compute(X, y, sample_weight)
n_y = 1 if len(y.shape) == 1 else y.shape[1]
cv_values = np.zeros((n_samples * n_y, len(self.alphas)))
C = []
Expand All @@ -923,13 +946,10 @@ def fit(self, X, y, sample_weight=None):
error = scorer is None

for i, alpha in enumerate(self.alphas):
weighted_alpha = (sample_weight * alpha
if sample_weight is not None
else alpha)
if error:
out, c = _errors(weighted_alpha, y, v, Q, QT_y)
out, c = _errors(alpha, y, v, Q, QT_y, sample_weight)
else:
out, c = _values(weighted_alpha, y, v, Q, QT_y)
out, c = _values(alpha, y, v, Q, QT_y)
cv_values[:, i] = out.ravel()
C.append(c)

Expand Down
57 changes: 56 additions & 1 deletion sklearn/linear_model/tests/test_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

from sklearn.grid_search import GridSearchCV

from sklearn.cross_validation import KFold
from sklearn.cross_validation import KFold, LeaveOneOut
from sklearn.utils import check_random_state


diabetes = datasets.load_diabetes()
Expand Down Expand Up @@ -715,3 +716,57 @@ def test_ridge_fit_intercept_sparse():
assert_warns(UserWarning, sparse.fit, X_csr, y)
assert_almost_equal(dense.intercept_, sparse.intercept_)
assert_array_almost_equal(dense.coef_, sparse.coef_)


def make_noisy_forward_data(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment should be at the top and start with #

n_samples=100,
n_features=200,
n_targets=10,
train_frac=.8,
noise_levels=None,
random_state=42):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use a more classical formarting for the args. It's confusing this way.

For instance:

def make_noisy_forward_data(n_samples=100, n_features=200, n_targets=10,
                            train_frac=.8, noise_levels=None, random_state=42):
     ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebase must have gone wrong. I had removed this in favor of a make_regression

"""Creates a simple, dense, noisy forward linear model with multiple
output."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow PEP257 for docstring: one-liner.

    """Creates a dense, noisy forward linear model with multiple output."""

rng = check_random_state(random_state)
n_train = int(train_frac * n_samples)
train = slice(None, n_train)
test = slice(n_train, None)
X = rng.randn(n_samples, n_features)
W = rng.randn(n_features, n_targets)
Y_clean = X.dot(W)
if noise_levels is None:
noise_levels = rng.randn(n_targets) ** 2
noise_levels = np.atleast_1d(noise_levels) * np.ones(n_targets)
noise = rng.randn(*Y_clean.shape) * noise_levels * Y_clean.std(0)
Y = Y_clean + noise
return X, Y, W, train, test


def test_ridge_gcv_with_sample_weights():

n_samples, n_features, n_targets = 20, 5, 7
X, Y = datasets.make_regression(n_samples, n_features,
n_targets=n_targets)
alphas = np.logspace(-3, 3, 9)

rng = np.random.RandomState(42)
sample_weights = rng.randn(n_samples) ** 2
cv = LeaveOneOut(n_samples)
cv_predictions = np.array([[
Ridge(solver='cholesky', alpha=alpha, fit_intercept=False).fit(
X[train], Y[train], sample_weight=sample_weights[train]
).predict(X[test])
for train, test in cv] for alpha in alphas]).squeeze()

cv_errors = Y[np.newaxis] - cv_predictions.reshape(
len(alphas), n_samples, n_targets)

ridge_gcv = _RidgeGCV(alphas=alphas, store_cv_values=True,
gcv_mode='eigen', fit_intercept=False)
# emulate the sample weight stuff from _RidgeGCV
ridge_gcv.fit(X, Y, sample_weight=sample_weights)
loo_predictions = ridge_gcv.cv_values_

assert_array_almost_equal(cv_errors ** 2,
loo_predictions.transpose(2, 0, 1))