Skip to content

Commit 9486a7a

Browse files
committed
Merge pull request #5161 from beepee14/sparse_prediction_check
[MRG+1] Add check for sparse prediction in cross_val_predict (fixes #5132)
2 parents 42b9590 + b68f922 commit 9486a7a

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

sklearn/cross_validation.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1042,14 +1042,20 @@ def cross_val_predict(estimator, X, y=None, cv=None, n_jobs=1,
10421042
train, test, verbose,
10431043
fit_params)
10441044
for train, test in cv)
1045-
p = np.concatenate([p for p, _ in preds_blocks])
1045+
1046+
preds = [p for p, _ in preds_blocks]
10461047
locs = np.concatenate([loc for _, loc in preds_blocks])
10471048
if not _check_is_partition(locs, _num_samples(X)):
10481049
raise ValueError('cross_val_predict only works for partitions')
1049-
preds = p.copy()
1050-
preds[locs] = p
1051-
return preds
1052-
1050+
inv_locs = np.empty(len(locs), dtype=int)
1051+
inv_locs[locs] = np.arange(len(locs))
1052+
1053+
# Check for sparse predictions
1054+
if sp.issparse(preds[0]):
1055+
preds = sp.vstack(preds, format=preds[0].format)
1056+
else :
1057+
preds = np.concatenate(preds)
1058+
return preds[inv_locs]
10531059

10541060
def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params):
10551061
"""Fit estimator and predict values for a given dataset split.

sklearn/tests/test_cross_validation.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
from scipy.sparse import coo_matrix
7+
from scipy.sparse import csr_matrix
78
from scipy import stats
89

910
from sklearn.utils.testing import assert_true
@@ -25,14 +26,15 @@
2526
from sklearn.datasets import load_boston
2627
from sklearn.datasets import load_digits
2728
from sklearn.datasets import load_iris
29+
from sklearn.datasets import make_multilabel_classification
2830
from sklearn.metrics import explained_variance_score
2931
from sklearn.metrics import make_scorer
3032
from sklearn.metrics import precision_score
31-
3233
from sklearn.externals import six
3334
from sklearn.externals.six.moves import zip
3435

3536
from sklearn.linear_model import Ridge
37+
from sklearn.multiclass import OneVsRestClassifier
3638
from sklearn.neighbors import KNeighborsClassifier
3739
from sklearn.svm import SVC
3840
from sklearn.cluster import KMeans
@@ -1094,3 +1096,18 @@ def test_check_is_partition():
10941096

10951097
p[0] = 23
10961098
assert_false(cval._check_is_partition(p, 100))
1099+
1100+
def test_cross_val_predict_sparse_prediction():
1101+
# check that cross_val_predict gives same result for sparse and dense input
1102+
X, y = make_multilabel_classification(n_classes=2, n_labels=1,
1103+
allow_unlabeled=False,
1104+
return_indicator=True,
1105+
random_state=1)
1106+
X_sparse = csr_matrix(X)
1107+
y_sparse = csr_matrix(y)
1108+
classif = OneVsRestClassifier(SVC(kernel='linear'))
1109+
preds = cval.cross_val_predict(classif, X, y, cv=10)
1110+
preds_sparse = cval.cross_val_predict(classif, X_sparse,y_sparse, cv=10)
1111+
preds_sparse = preds_sparse.toarray()
1112+
assert_array_almost_equal(preds_sparse, preds)
1113+

0 commit comments

Comments
 (0)