Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions sklearn/linear_model/tests/test_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def sag(X, y, step_size, alpha, n_iter=1, dloss=None, sparse=False,

def sag_sparse(X, y, step_size, alpha, n_iter=1,
dloss=None, sample_weight=None, sparse=False,
fit_intercept=True, saga=False):
fit_intercept=True, saga=False, random_state=0):
if step_size * alpha == 1.:
raise ZeroDivisionError("Sparse sag does not handle the case "
"step_size * alpha == 1")
Expand All @@ -130,7 +130,7 @@ def sag_sparse(X, y, step_size, alpha, n_iter=1,
sum_gradient = np.zeros(n_features)
last_updated = np.zeros(n_features, dtype=np.int)
gradient_memory = np.zeros(n_samples)
rng = np.random.RandomState(77)
rng = check_random_state(random_state)
intercept = 0.0
intercept_sum_gradient = 0.0
wscale = 1.0
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_sag_regressor_computed_correctly():
alpha = .1
n_features = 10
n_samples = 40
max_iter = 50
max_iter = 100
tol = .000001
fit_intercept = True
rng = np.random.RandomState(0)
Expand All @@ -378,7 +378,8 @@ def test_sag_regressor_computed_correctly():
step_size = get_step_size(X, alpha, fit_intercept, classification=False)

clf1 = Ridge(fit_intercept=fit_intercept, tol=tol, solver='sag',
alpha=alpha * n_samples, max_iter=max_iter)
alpha=alpha * n_samples, max_iter=max_iter,
random_state=rng)
clf2 = clone(clf1)

clf1.fit(X, y)
Expand All @@ -387,12 +388,14 @@ def test_sag_regressor_computed_correctly():
spweights1, spintercept1 = sag_sparse(X, y, step_size, alpha,
n_iter=max_iter,
dloss=squared_dloss,
fit_intercept=fit_intercept)
fit_intercept=fit_intercept,
random_state=rng)

spweights2, spintercept2 = sag_sparse(X, y, step_size, alpha,
n_iter=max_iter,
dloss=squared_dloss, sparse=True,
fit_intercept=fit_intercept)
fit_intercept=fit_intercept,
random_state=rng)

assert_array_almost_equal(clf1.coef_.ravel(),
spweights1.ravel(),
Expand Down