Skip to content

base.clone fails if estimator has dia_matrix as a parameter #6855

Closed
@nileracecrew

Description

@nileracecrew

Description

If an estimator has a scipy.sparse.dia_matrix as a parameter, the equality test in clone fails. It looks like this is because dia_matrix.data is a (1,n) ndarray, while most (all?) other sparse types are (n,) ndarrays.

Steps/Code to Reproduce

import sklearn.base
import scipy.sparse
import numpy as np

class DiaEstimator(sklearn.base.BaseEstimator):
    def __init__(self, p=None):
        self.p = p

    def transform(self, X, y=None):
        return X

    def fit(self, X, y=None):
        return self

M = scipy.sparse.csr_matrix(np.diag([1,2,3]))
e = DiaEstimator(p=M)
# works fine
sklearn.base.clone(e)

M = scipy.sparse.dia_matrix(np.diag([1,2,3]))
e = DiaEstimator(p=M)
# fails
sklearn.base.clone(e)

"""
Traceback (most recent call last):
  File "diabug.py", line 23, in <module>
    sklearn.base.clone(e)
  File "/Users/elawrence/anaconda/lib/python2.7/site-packages/sklearn/base.py", line 90, in clone
    and param1.data[0] == param2.data[0]
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
"""

Versions

In [7]: import platform; print(platform.platform())
Darwin-15.5.0-x86_64-i386-64bit

In [8]: import sys; print("Python", sys.version)
('Python', '2.7.11 |Anaconda custom (x86_64)| (default, Dec  6 2015, 18:57:58) \n[GCC 4.2.1 (Apple Inc. build 5577)]')

In [9]: import numpy; print("NumPy", numpy.__version__)
('NumPy', '1.10.4')

In [10]: import scipy; print("SciPy", scipy.__version__)
('SciPy', '0.17.0')

In [11]: import sklearn; print("Scikit-Learn", sklearn.__version__)
('Scikit-Learn', '0.17')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions