Skip to content

Commit eed5fc5

Browse files
committed
Merge pull request #6395 from yenchenlin1994/make-dump_svmlight_file-support-sparse-y
[MRG+1] Make dump_svmlight_file support sparse y (fixes #6301)
2 parents 56d625f + 22d7cd5 commit eed5fc5

File tree

2 files changed

+101
-63
lines changed

2 files changed

+101
-63
lines changed

sklearn/datasets/svmlight_format.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
276276

277277

278278
def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
279-
is_sp = int(hasattr(X, "tocsr"))
279+
X_is_sp = int(hasattr(X, "tocsr"))
280+
y_is_sp = int(hasattr(y, "tocsr"))
280281
if X.dtype.kind == 'i':
281282
value_pattern = u("%d:%d")
282283
else:
@@ -302,7 +303,7 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
302303
f.writelines(b("# %s\n" % line) for line in comment.splitlines())
303304

304305
for i in range(X.shape[0]):
305-
if is_sp:
306+
if X_is_sp:
306307
span = slice(X.indptr[i], X.indptr[i + 1])
307308
row = zip(X.indices[span], X.data[span])
308309
else:
@@ -312,10 +313,16 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
312313
s = " ".join(value_pattern % (j + one_based, x) for j, x in row)
313314

314315
if multilabel:
315-
nz_labels = np.where(y[i] != 0)[0]
316+
if y_is_sp:
317+
nz_labels = y[i].nonzero()[1]
318+
else:
319+
nz_labels = np.where(y[i] != 0)[0]
316320
labels_str = ",".join(label_pattern % j for j in nz_labels)
317321
else:
318-
labels_str = label_pattern % y[i]
322+
if y_is_sp:
323+
labels_str = label_pattern % y.data[i]
324+
else:
325+
labels_str = label_pattern % y[i]
319326

320327
if query_id is not None:
321328
feat = (labels_str, query_id[i], s)
@@ -341,9 +348,10 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None,
341348
Training vectors, where n_samples is the number of samples and
342349
n_features is the number of features.
343350
344-
y : array-like, shape = [n_samples] or [n_samples, n_labels]
345-
Target values. Class labels must be an integer or float, or array-like
346-
objects of integer or float for multilabel classifications.
351+
y : {array-like, sparse matrix}, shape = [n_samples (, n_labels)]
352+
Target values. Class labels must be an
353+
integer or float, or array-like objects of integer or float for
354+
multilabel classifications.
347355
348356
f : string or file-like in binary mode
349357
If string, specifies the path that will contain the data.
@@ -385,19 +393,31 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None,
385393
if six.b("\0") in comment:
386394
raise ValueError("comment string contains NUL byte")
387395

388-
y = np.asarray(y)
389-
if y.ndim != 1 and not multilabel:
390-
raise ValueError("expected y of shape (n_samples,), got %r"
391-
% (y.shape,))
396+
yval = check_array(y, accept_sparse='csr', ensure_2d=False)
397+
if sp.issparse(yval):
398+
if yval.shape[1] != 1 and not multilabel:
399+
raise ValueError("expected y of shape (n_samples, 1),"
400+
" got %r" % (yval.shape,))
401+
else:
402+
if yval.ndim != 1 and not multilabel:
403+
raise ValueError("expected y of shape (n_samples,), got %r"
404+
% (yval.shape,))
392405

393406
Xval = check_array(X, accept_sparse='csr')
394-
if Xval.shape[0] != y.shape[0]:
407+
if Xval.shape[0] != yval.shape[0]:
395408
raise ValueError("X.shape[0] and y.shape[0] should be the same, got"
396-
" %r and %r instead." % (Xval.shape[0], y.shape[0]))
409+
" %r and %r instead." % (Xval.shape[0], yval.shape[0]))
397410

398411
# We had some issues with CSR matrices with unsorted indices (e.g. #1501),
399412
# so sort them here, but first make sure we don't modify the user's X.
400413
# TODO We can do this cheaper; sorted_indices copies the whole matrix.
414+
if yval is y and hasattr(yval, "sorted_indices"):
415+
y = yval.sorted_indices()
416+
else:
417+
y = yval
418+
if hasattr(y, "sort_indices"):
419+
y.sort_indices()
420+
401421
if Xval is X and hasattr(Xval, "sorted_indices"):
402422
X = Xval.sorted_indices()
403423
else:

sklearn/datasets/tests/test_svmlight_format.py

Lines changed: 68 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import gzip
33
from io import BytesIO
44
import numpy as np
5+
import scipy.sparse as sp
56
import os
67
import shutil
78
from tempfile import NamedTemporaryFile
@@ -200,67 +201,84 @@ def test_invalid_filename():
200201

201202

202203
def test_dump():
203-
Xs, y = load_svmlight_file(datafile)
204-
Xd = Xs.toarray()
204+
X_sparse, y_dense = load_svmlight_file(datafile)
205+
X_dense = X_sparse.toarray()
206+
y_sparse = sp.csr_matrix(y_dense)
205207

206208
# slicing a csr_matrix can unsort its .indices, so test that we sort
207209
# those correctly
208-
Xsliced = Xs[np.arange(Xs.shape[0])]
209-
210-
for X in (Xs, Xd, Xsliced):
211-
for zero_based in (True, False):
212-
for dtype in [np.float32, np.float64, np.int32]:
213-
f = BytesIO()
214-
# we need to pass a comment to get the version info in;
215-
# LibSVM doesn't grok comments so they're not put in by
216-
# default anymore.
217-
dump_svmlight_file(X.astype(dtype), y, f, comment="test",
218-
zero_based=zero_based)
219-
f.seek(0)
220-
221-
comment = f.readline()
222-
try:
223-
comment = str(comment, "utf-8")
224-
except TypeError: # fails in Python 2.x
225-
pass
226-
227-
assert_in("scikit-learn %s" % sklearn.__version__, comment)
228-
229-
comment = f.readline()
230-
try:
231-
comment = str(comment, "utf-8")
232-
except TypeError: # fails in Python 2.x
233-
pass
234-
235-
assert_in(["one", "zero"][zero_based] + "-based", comment)
236-
237-
X2, y2 = load_svmlight_file(f, dtype=dtype,
238-
zero_based=zero_based)
239-
assert_equal(X2.dtype, dtype)
240-
assert_array_equal(X2.sorted_indices().indices, X2.indices)
241-
if dtype == np.float32:
242-
assert_array_almost_equal(
210+
X_sliced = X_sparse[np.arange(X_sparse.shape[0])]
211+
y_sliced = y_sparse[np.arange(y_sparse.shape[0])]
212+
213+
for X in (X_sparse, X_dense, X_sliced):
214+
for y in (y_sparse, y_dense, y_sliced):
215+
for zero_based in (True, False):
216+
for dtype in [np.float32, np.float64, np.int32]:
217+
f = BytesIO()
218+
# we need to pass a comment to get the version info in;
219+
# LibSVM doesn't grok comments so they're not put in by
220+
# default anymore.
221+
222+
if (sp.issparse(y) and y.shape[0] == 1):
223+
# make sure y's shape is: (n_samples, n_labels)
224+
# when it is sparse
225+
y = y.T
226+
227+
dump_svmlight_file(X.astype(dtype), y, f, comment="test",
228+
zero_based=zero_based)
229+
f.seek(0)
230+
231+
comment = f.readline()
232+
try:
233+
comment = str(comment, "utf-8")
234+
except TypeError: # fails in Python 2.x
235+
pass
236+
237+
assert_in("scikit-learn %s" % sklearn.__version__, comment)
238+
239+
comment = f.readline()
240+
try:
241+
comment = str(comment, "utf-8")
242+
except TypeError: # fails in Python 2.x
243+
pass
244+
245+
assert_in(["one", "zero"][zero_based] + "-based", comment)
246+
247+
X2, y2 = load_svmlight_file(f, dtype=dtype,
248+
zero_based=zero_based)
249+
assert_equal(X2.dtype, dtype)
250+
assert_array_equal(X2.sorted_indices().indices, X2.indices)
251+
252+
X2_dense = X2.toarray()
253+
254+
if dtype == np.float32:
243255
# allow a rounding error at the last decimal place
244-
Xd.astype(dtype), X2.toarray(), 4)
245-
else:
246-
assert_array_almost_equal(
256+
assert_array_almost_equal(
257+
X_dense.astype(dtype), X2_dense, 4)
258+
assert_array_almost_equal(
259+
y_dense.astype(dtype), y2, 4)
260+
else:
247261
# allow a rounding error at the last decimal place
248-
Xd.astype(dtype), X2.toarray(), 15)
249-
assert_array_equal(y, y2)
262+
assert_array_almost_equal(
263+
X_dense.astype(dtype), X2_dense, 15)
264+
assert_array_almost_equal(
265+
y_dense.astype(dtype), y2, 15)
250266

251267

252268
def test_dump_multilabel():
253269
X = [[1, 0, 3, 0, 5],
254270
[0, 0, 0, 0, 0],
255271
[0, 5, 0, 1, 0]]
256-
y = [[0, 1, 0], [1, 0, 1], [1, 1, 0]]
257-
f = BytesIO()
258-
dump_svmlight_file(X, y, f, multilabel=True)
259-
f.seek(0)
260-
# make sure it dumps multilabel correctly
261-
assert_equal(f.readline(), b("1 0:1 2:3 4:5\n"))
262-
assert_equal(f.readline(), b("0,2 \n"))
263-
assert_equal(f.readline(), b("0,1 1:5 3:1\n"))
272+
y_dense = [[0, 1, 0], [1, 0, 1], [1, 1, 0]]
273+
y_sparse = sp.csr_matrix(y_dense)
274+
for y in [y_dense, y_sparse]:
275+
f = BytesIO()
276+
dump_svmlight_file(X, y, f, multilabel=True)
277+
f.seek(0)
278+
# make sure it dumps multilabel correctly
279+
assert_equal(f.readline(), b("1 0:1 2:3 4:5\n"))
280+
assert_equal(f.readline(), b("0,2 \n"))
281+
assert_equal(f.readline(), b("0,1 1:5 3:1\n"))
264282

265283

266284
def test_dump_concise():

0 commit comments

Comments
 (0)