-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[WIP] RidgeGCV with sample weights is broken #4490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,8 @@ | |
|
||
from sklearn.grid_search import GridSearchCV | ||
|
||
from sklearn.cross_validation import KFold | ||
from sklearn.cross_validation import KFold, LeaveOneOut | ||
from sklearn.utils import check_random_state | ||
|
||
|
||
diabetes = datasets.load_diabetes() | ||
|
@@ -715,3 +716,57 @@ def test_ridge_fit_intercept_sparse(): | |
assert_warns(UserWarning, sparse.fit, X_csr, y) | ||
assert_almost_equal(dense.intercept_, sparse.intercept_) | ||
assert_array_almost_equal(dense.coef_, sparse.coef_) | ||
|
||
|
||
def make_noisy_forward_data( | ||
n_samples=100, | ||
n_features=200, | ||
n_targets=10, | ||
train_frac=.8, | ||
noise_levels=None, | ||
random_state=42): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use a more classical formarting for the args. It's confusing this way. For instance: def make_noisy_forward_data(n_samples=100, n_features=200, n_targets=10,
train_frac=.8, noise_levels=None, random_state=42):
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rebase must have gone wrong. I had removed this in favor of a |
||
"""Creates a simple, dense, noisy forward linear model with multiple | ||
output.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow PEP257 for docstring: one-liner.
|
||
rng = check_random_state(random_state) | ||
n_train = int(train_frac * n_samples) | ||
train = slice(None, n_train) | ||
test = slice(n_train, None) | ||
X = rng.randn(n_samples, n_features) | ||
W = rng.randn(n_features, n_targets) | ||
Y_clean = X.dot(W) | ||
if noise_levels is None: | ||
noise_levels = rng.randn(n_targets) ** 2 | ||
noise_levels = np.atleast_1d(noise_levels) * np.ones(n_targets) | ||
noise = rng.randn(*Y_clean.shape) * noise_levels * Y_clean.std(0) | ||
Y = Y_clean + noise | ||
return X, Y, W, train, test | ||
|
||
|
||
def test_ridge_gcv_with_sample_weights(): | ||
|
||
n_samples, n_features, n_targets = 20, 5, 7 | ||
X, Y = datasets.make_regression(n_samples, n_features, | ||
n_targets=n_targets) | ||
alphas = np.logspace(-3, 3, 9) | ||
|
||
rng = np.random.RandomState(42) | ||
sample_weights = rng.randn(n_samples) ** 2 | ||
cv = LeaveOneOut(n_samples) | ||
cv_predictions = np.array([[ | ||
Ridge(solver='cholesky', alpha=alpha, fit_intercept=False).fit( | ||
X[train], Y[train], sample_weight=sample_weights[train] | ||
).predict(X[test]) | ||
for train, test in cv] for alpha in alphas]).squeeze() | ||
|
||
cv_errors = Y[np.newaxis] - cv_predictions.reshape( | ||
len(alphas), n_samples, n_targets) | ||
|
||
ridge_gcv = _RidgeGCV(alphas=alphas, store_cv_values=True, | ||
gcv_mode='eigen', fit_intercept=False) | ||
# emulate the sample weight stuff from _RidgeGCV | ||
ridge_gcv.fit(X, Y, sample_weight=sample_weights) | ||
loo_predictions = ridge_gcv.cv_values_ | ||
|
||
assert_array_almost_equal(cv_errors ** 2, | ||
loo_predictions.transpose(2, 0, 1)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the comment should be at the top and start with
#