diff --git a/sklearn/base.py b/sklearn/base.py index 8c3a9a8eba4da..f98f4bf7e8cc5 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -21,6 +21,18 @@ class ChangedBehaviorWarning(_ChangedBehaviorWarning): ############################################################################## +def _first_and_last_element(arr): + """Returns first and last element of numpy array or sparse matrix.""" + if isinstance(arr, np.ndarray) or hasattr(arr, 'data'): + # numpy array or sparse matrix with .data attribute + data = arr.data if sparse.issparse(arr) else arr + return data.flat[0], data.flat[-1] + else: + # Sparse matrices without .data attribute. Only dok_matrix at + # the time of writing, in this case indexing is fast + return arr[0, 0], arr[-1, -1] + + def clone(estimator, safe=True): """Constructs a new estimator with the same parameters. @@ -73,9 +85,8 @@ def clone(estimator, safe=True): equality_test = ( param1.shape == param2.shape and param1.dtype == param2.dtype - # We have to use '.flat' for 2D arrays - and param1.flat[0] == param2.flat[0] - and param1.flat[-1] == param2.flat[-1] + and (_first_and_last_element(param1) == + _first_and_last_element(param2)) ) else: equality_test = np.all(param1 == param2) @@ -92,8 +103,8 @@ def clone(estimator, safe=True): else: equality_test = ( param1.__class__ == param2.__class__ - and param1.data[0] == param2.data[0] - and param1.data[-1] == param2.data[-1] + and (_first_and_last_element(param1) == + _first_and_last_element(param2)) and param1.nnz == param2.nnz and param1.shape == param2.shape ) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 873808ff914af..6f4be0dcc8ab7 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -1,6 +1,8 @@ # Author: Gael Varoquaux # License: BSD 3 clause +import sys + import numpy as np import scipy.sparse as sp @@ -143,6 +145,24 @@ def test_clone_nan(): assert_true(clf.empty is clf2.empty) +def test_clone_sparse_matrices(): + sparse_matrix_classes = [ + getattr(sp, name) + for name in dir(sp) if name.endswith('_matrix')] + + PY26 = sys.version_info[:2] == (2, 6) + if PY26: + # sp.dok_matrix can not be deepcopied in Python 2.6 + sparse_matrix_classes.remove(sp.dok_matrix) + + for cls in sparse_matrix_classes: + sparse_matrix = cls(np.eye(5)) + clf = MyEstimator(empty=sparse_matrix) + clf_cloned = clone(clf) + assert_true(clf.empty.__class__ is clf_cloned.empty.__class__) + assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray()) + + def test_repr(): # Smoke test the repr of the base estimator. my_estimator = MyEstimator()