Skip to content

[MRG] Fix sklearn.base.clone when estimator has any kind of sparse matrix as attribute #6910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class ChangedBehaviorWarning(_ChangedBehaviorWarning):


##############################################################################
def _first_and_last_element(arr):
"""Returns first and last element of numpy array or sparse matrix."""
if isinstance(arr, np.ndarray) or hasattr(arr, 'data'):
# numpy array or sparse matrix with .data attribute
data = arr.data if sparse.issparse(arr) else arr
return data.flat[0], data.flat[-1]
else:
# Sparse matrices without .data attribute. Only dok_matrix at
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a strong test, but okay.

# the time of writing, in this case indexing is fast
return arr[0, 0], arr[-1, -1]


def clone(estimator, safe=True):
"""Constructs a new estimator with the same parameters.

Expand Down Expand Up @@ -73,9 +85,8 @@ def clone(estimator, safe=True):
equality_test = (
param1.shape == param2.shape
and param1.dtype == param2.dtype
# We have to use '.flat' for 2D arrays
and param1.flat[0] == param2.flat[0]
and param1.flat[-1] == param2.flat[-1]
and (_first_and_last_element(param1) ==
_first_and_last_element(param2))
)
else:
equality_test = np.all(param1 == param2)
Expand All @@ -92,8 +103,8 @@ def clone(estimator, safe=True):
else:
equality_test = (
param1.__class__ == param2.__class__
and param1.data[0] == param2.data[0]
and param1.data[-1] == param2.data[-1]
and (_first_and_last_element(param1) ==
_first_and_last_element(param2))
and param1.nnz == param2.nnz
and param1.shape == param2.shape
)
Expand Down
20 changes: 20 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Author: Gael Varoquaux
# License: BSD 3 clause

import sys

import numpy as np
import scipy.sparse as sp

Expand Down Expand Up @@ -143,6 +145,24 @@ def test_clone_nan():
assert_true(clf.empty is clf2.empty)


def test_clone_sparse_matrices():
sparse_matrix_classes = [
getattr(sp, name)
for name in dir(sp) if name.endswith('_matrix')]

PY26 = sys.version_info[:2] == (2, 6)
if PY26:
# sp.dok_matrix can not be deepcopied in Python 2.6
Copy link
Member Author

@lesteve lesteve Jun 21, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record:

from copy import deepcopy

import numpy as np

from scipy import sparse as sp

m = sp.dok_matrix(np.eye(5))

deepcopy(m)

fails on Python 2.6 with the error:

AttributeError: shape not found

For some reason it looks like the reconstructed matrix is missing some attributes. I think it's fine not testing dok_matrix for Python 2.6.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's because Py2.6 has a specialised reconstruction routine for dicts and fails to handle the subclassing correctly... I'm happy with this solution.

sparse_matrix_classes.remove(sp.dok_matrix)

for cls in sparse_matrix_classes:
sparse_matrix = cls(np.eye(5))
clf = MyEstimator(empty=sparse_matrix)
clf_cloned = clone(clf)
assert_true(clf.empty.__class__ is clf_cloned.empty.__class__)
assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray())


def test_repr():
# Smoke test the repr of the base estimator.
my_estimator = MyEstimator()
Expand Down