Skip to content

[MRG + 1] FIX for LassoLarsCV on with readonly folds #4684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions sklearn/linear_model/least_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,12 @@ def __init__(self, alpha=1.0, fit_intercept=True, verbose=False,
###############################################################################
# Cross-validated estimator classes

def _check_copy_and_writeable(array, copy=False):
if copy or not array.flags.writeable:
return array.copy()
return array


def _lars_path_residues(X_train, y_train, X_test, y_test, Gram=None,
copy=True, method='lars', verbose=False,
fit_intercept=True, normalize=True, max_iter=500,
Expand Down Expand Up @@ -842,11 +848,10 @@ def _lars_path_residues(X_train, y_train, X_test, y_test, Gram=None,
residues : array, shape (n_alphas, n_samples)
Residues of the prediction on the test data
"""
if copy:
X_train = X_train.copy()
y_train = y_train.copy()
X_test = X_test.copy()
y_test = y_test.copy()
X_train = _check_copy_and_writeable(X_train, copy)
y_train = _check_copy_and_writeable(y_train, copy)
X_test = _check_copy_and_writeable(X_test, copy)
y_test = _check_copy_and_writeable(y_test, copy)

if fit_intercept:
X_mean = X_train.mean(axis=0)
Expand Down
26 changes: 25 additions & 1 deletion sklearn/linear_model/tests/test_least_angle.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import tempfile
import shutil
import os.path as op
from nose.tools import assert_equal

import numpy as np
from scipy import linalg

from sklearn.cross_validation import train_test_split
from sklearn.externals import joblib
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_less
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import ignore_warnings, assert_warns_message
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.testing import assert_no_warnings, assert_warns
from sklearn.utils import ConvergenceWarning
from sklearn import linear_model, datasets
from sklearn.linear_model.least_angle import _lars_path_residues

diabetes = datasets.load_diabetes()
X, y = diabetes.data, diabetes.target
Expand Down Expand Up @@ -428,6 +434,24 @@ def test_no_warning_for_zero_mse():
assert_true(np.any(np.isinf(lars.criterion_)))


def test_lars_path_readonly_data():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could probably do a more light-weight test by using stringIO and np.memmap? I don't like IO in tests.... but seems good enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmap is a system call, you cannot do it on a Python StringIO instance.

# When using automated memory mapping on large input, the
# fold data is in read-only mode
# This is a non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/4597
splitted_data = train_test_split(X, y, random_state=42)
temp_folder = tempfile.mkdtemp()
try:
fpath = op.join(temp_folder, 'data.pkl')
joblib.dump(splitted_data, fpath)
X_train, X_test, y_train, y_test = joblib.load(fpath, mmap_mode='r')

# The following should not fail despite copy=False
_lars_path_residues(X_train, y_train, X_test, y_test, copy=False)
finally:
shutil.rmtree(temp_folder)


if __name__ == '__main__':
import nose
nose.runmodule()