Closed
Description
Description
If an estimator has a scipy.sparse.dia_matrix as a parameter, the equality test in clone fails. It looks like this is because dia_matrix.data is a (1,n) ndarray, while most (all?) other sparse types are (n,) ndarrays.
Steps/Code to Reproduce
import sklearn.base
import scipy.sparse
import numpy as np
class DiaEstimator(sklearn.base.BaseEstimator):
def __init__(self, p=None):
self.p = p
def transform(self, X, y=None):
return X
def fit(self, X, y=None):
return self
M = scipy.sparse.csr_matrix(np.diag([1,2,3]))
e = DiaEstimator(p=M)
# works fine
sklearn.base.clone(e)
M = scipy.sparse.dia_matrix(np.diag([1,2,3]))
e = DiaEstimator(p=M)
# fails
sklearn.base.clone(e)
"""
Traceback (most recent call last):
File "diabug.py", line 23, in <module>
sklearn.base.clone(e)
File "/Users/elawrence/anaconda/lib/python2.7/site-packages/sklearn/base.py", line 90, in clone
and param1.data[0] == param2.data[0]
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
"""
Versions
In [7]: import platform; print(platform.platform())
Darwin-15.5.0-x86_64-i386-64bit
In [8]: import sys; print("Python", sys.version)
('Python', '2.7.11 |Anaconda custom (x86_64)| (default, Dec 6 2015, 18:57:58) \n[GCC 4.2.1 (Apple Inc. build 5577)]')
In [9]: import numpy; print("NumPy", numpy.__version__)
('NumPy', '1.10.4')
In [10]: import scipy; print("SciPy", scipy.__version__)
('SciPy', '0.17.0')
In [11]: import sklearn; print("Scikit-Learn", sklearn.__version__)
('Scikit-Learn', '0.17')
Metadata
Metadata
Assignees
Labels
No labels