Skip to content

Commit a39c8ab

Browse files
authored
ENH svmlight chunk loader (scikit-learn#935)
1 parent 7238b46 commit a39c8ab

File tree

4 files changed

+181
-17
lines changed

4 files changed

+181
-17
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,16 @@ Enhancements
202202

203203
- Prevent cast from float32 to float64 in
204204
:class:`linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers
205+
:class:`sklearn.linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers
205206
by :user:`Joan Massich <massich>`, :user:`Nicolas Cordier <ncordier>`
206207

207208
- Add ``max_train_size`` parameter to :class:`model_selection.TimeSeriesSplit`
208209
:issue:`8282` by :user:`Aman Dalmia <dalmia>`.
209210

211+
- Make it possible to load a chunk of an svmlight formatted file by
212+
passing a range of bytes to :func:`datasets.load_svmlight_file`.
213+
:issue:`935` by :user:`Olivier Grisel <ogrisel>`.
214+
210215
Bug fixes
211216
.........
212217

sklearn/datasets/_svmlight_format.pyx

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ cdef bytes COLON = u':'.encode('ascii')
2626
@cython.boundscheck(False)
2727
@cython.wraparound(False)
2828
def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
29-
bint query_id):
29+
bint query_id, long long offset, long long length):
3030
cdef array.array data, indices, indptr
3131
cdef bytes line
3232
cdef char *hash_ptr
@@ -35,6 +35,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
3535
cdef Py_ssize_t i
3636
cdef bytes qid_prefix = b('qid')
3737
cdef Py_ssize_t n_features
38+
cdef long long offset_max = offset + length if length > 0 else -1
3839

3940
# Special-case float32 but use float64 for everything else;
4041
# the Python code will do further conversions.
@@ -52,6 +53,12 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
5253
else:
5354
labels = array.array("d")
5455

56+
if offset > 0:
57+
f.seek(offset)
58+
# drop the current line that might be truncated and is to be
59+
# fetched by another call
60+
f.readline()
61+
5562
for line in f:
5663
# skip comments
5764
line_cstr = line
@@ -90,7 +97,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
9097
idx = int(idx_s)
9198
if idx < 0 or not zero_based and idx == 0:
9299
raise ValueError(
93-
"Invalid index %d in SVMlight/LibSVM data file." % idx)
100+
"Invalid index %d in SVMlight/LibSVM data file." % idx)
94101
if idx <= prev_idx:
95102
raise ValueError("Feature indices in SVMlight/LibSVM data "
96103
"file should be sorted and unique.")
@@ -106,4 +113,8 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
106113
array.resize_smart(indptr, len(indptr) + 1)
107114
indptr[len(indptr) - 1] = len(data)
108115

116+
if offset_max != -1 and f.tell() > offset_max:
117+
# Stop here and let another call deal with the following.
118+
break
119+
109120
return (dtype, data, indices, indptr, labels, query)

sklearn/datasets/svmlight_format.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232

3333
def load_svmlight_file(f, n_features=None, dtype=np.float64,
34-
multilabel=False, zero_based="auto", query_id=False):
34+
multilabel=False, zero_based="auto", query_id=False,
35+
offset=0, length=-1):
3536
"""Load datasets in the svmlight / libsvm format into sparse CSR matrix
3637
3738
This format is a text-based format, with one sample per line. It does
@@ -76,6 +77,8 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
7677
bigger sliced dataset: each subset might not have examples of
7778
every feature, hence the inferred shape might vary from one
7879
slice to another.
80+
n_features is only required if ``offset`` or ``length`` are passed a
81+
non-default value.
7982
8083
multilabel : boolean, optional, default False
8184
Samples may have several labels each (see
@@ -88,7 +91,10 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
8891
If set to "auto", a heuristic check is applied to determine this from
8992
the file contents. Both kinds of files occur "in the wild", but they
9093
are unfortunately not self-identifying. Using "auto" or True should
91-
always be safe.
94+
always be safe when no ``offset`` or ``length`` is passed.
95+
If ``offset`` or ``length`` are passed, the "auto" mode falls back
96+
to ``zero_based=True`` to avoid having the heuristic check yield
97+
inconsistent results on different segments of the file.
9298
9399
query_id : boolean, default False
94100
If True, will return the query_id array for each file.
@@ -97,6 +103,15 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
97103
Data type of dataset to be loaded. This will be the data type of the
98104
output numpy arrays ``X`` and ``y``.
99105
106+
offset : integer, optional, default 0
107+
Ignore the offset first bytes by seeking forward, then
108+
discarding the following bytes up until the next new line
109+
character.
110+
111+
length : integer, optional, default -1
112+
If strictly positive, stop reading any new line of data once the
113+
position in the file has reached the (offset + length) bytes threshold.
114+
100115
Returns
101116
-------
102117
X : scipy.sparse matrix of shape (n_samples, n_features)
@@ -129,7 +144,7 @@ def get_data():
129144
X, y = get_data()
130145
"""
131146
return tuple(load_svmlight_files([f], n_features, dtype, multilabel,
132-
zero_based, query_id))
147+
zero_based, query_id, offset, length))
133148

134149

135150
def _gen_open(f):
@@ -149,15 +164,18 @@ def _gen_open(f):
149164
return open(f, "rb")
150165

151166

152-
def _open_and_load(f, dtype, multilabel, zero_based, query_id):
167+
def _open_and_load(f, dtype, multilabel, zero_based, query_id,
168+
offset=0, length=-1):
153169
if hasattr(f, "read"):
154170
actual_dtype, data, ind, indptr, labels, query = \
155-
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id)
171+
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id,
172+
offset, length)
156173
# XXX remove closing when Python 2.7+/3.1+ required
157174
else:
158175
with closing(_gen_open(f)) as f:
159176
actual_dtype, data, ind, indptr, labels, query = \
160-
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id)
177+
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id,
178+
offset, length)
161179

162180
# convert from array.array, give data the right dtype
163181
if not multilabel:
@@ -172,7 +190,8 @@ def _open_and_load(f, dtype, multilabel, zero_based, query_id):
172190

173191

174192
def load_svmlight_files(files, n_features=None, dtype=np.float64,
175-
multilabel=False, zero_based="auto", query_id=False):
193+
multilabel=False, zero_based="auto", query_id=False,
194+
offset=0, length=-1):
176195
"""Load dataset from multiple files in SVMlight format
177196
178197
This function is equivalent to mapping load_svmlight_file over a list of
@@ -216,7 +235,10 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
216235
If set to "auto", a heuristic check is applied to determine this from
217236
the file contents. Both kinds of files occur "in the wild", but they
218237
are unfortunately not self-identifying. Using "auto" or True should
219-
always be safe.
238+
always be safe when no offset or length is passed.
239+
If offset or length are passed, the "auto" mode falls back
240+
to zero_based=True to avoid having the heuristic check yield
241+
inconsistent results on different segments of the file.
220242
221243
query_id : boolean, defaults to False
222244
If True, will return the query_id array for each file.
@@ -225,6 +247,15 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
225247
Data type of dataset to be loaded. This will be the data type of the
226248
output numpy arrays ``X`` and ``y``.
227249
250+
offset : integer, optional, default 0
251+
Ignore the offset first bytes by seeking forward, then
252+
discarding the following bytes up until the next new line
253+
character.
254+
255+
length : integer, optional, default -1
256+
If strictly positive, stop reading any new line of data once the
257+
position in the file has reached the (offset + length) bytes threshold.
258+
228259
Returns
229260
-------
230261
[X1, y1, ..., Xn, yn]
@@ -245,16 +276,27 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
245276
--------
246277
load_svmlight_file
247278
"""
248-
r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id))
279+
if (offset != 0 or length > 0) and zero_based == "auto":
280+
# disable heuristic search to avoid getting inconsistent results on
281+
# different segments of the file
282+
zero_based = True
283+
284+
if (offset != 0 or length > 0) and n_features is None:
285+
raise ValueError(
286+
"n_features is required when offset or length is specified.")
287+
288+
r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id),
289+
offset=offset, length=length)
249290
for f in files]
250291

251-
if (zero_based is False
252-
or zero_based == "auto" and all(np.min(tmp[1]) > 0 for tmp in r)):
253-
for ind in r:
254-
indices = ind[1]
292+
if (zero_based is False or
293+
zero_based == "auto" and all(len(tmp[1]) and np.min(tmp[1]) > 0
294+
for tmp in r)):
295+
for _, indices, _, _, _ in r:
255296
indices -= 1
256297

257-
n_f = max(ind[1].max() for ind in r) + 1
298+
n_f = max(ind[1].max() if len(ind[1]) else 0 for ind in r) + 1
299+
258300
if n_features is None:
259301
n_features = n_f
260302
elif n_features < n_f:

sklearn/datasets/tests/test_svmlight_format.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import division
12
from bz2 import BZ2File
23
import gzip
34
from io import BytesIO
@@ -13,8 +14,10 @@
1314
from sklearn.utils.testing import assert_array_equal
1415
from sklearn.utils.testing import assert_array_almost_equal
1516
from sklearn.utils.testing import assert_raises
17+
from sklearn.utils.testing import assert_raises_regex
1618
from sklearn.utils.testing import raises
1719
from sklearn.utils.testing import assert_in
20+
from sklearn.utils.fixes import sp_version
1821

1922
import sklearn
2023
from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
@@ -401,4 +404,107 @@ def test_load_with_long_qid():
401404
f.seek(0)
402405
X, y = load_svmlight_file(f, query_id=False, zero_based=True)
403406
assert_array_equal(y, true_y)
404-
assert_array_equal(X.toarray(), true_X)
407+
assert_array_equal(X.toarray(), true_X)
408+
409+
410+
def test_load_zeros():
411+
f = BytesIO()
412+
true_X = sp.csr_matrix(np.zeros(shape=(3, 4)))
413+
true_y = np.array([0, 1, 0])
414+
dump_svmlight_file(true_X, true_y, f)
415+
416+
for zero_based in ['auto', True, False]:
417+
f.seek(0)
418+
X, y = load_svmlight_file(f, n_features=4, zero_based=zero_based)
419+
assert_array_equal(y, true_y)
420+
assert_array_equal(X.toarray(), true_X.toarray())
421+
422+
423+
def test_load_with_offsets():
424+
def check_load_with_offsets(sparsity, n_samples, n_features):
425+
rng = np.random.RandomState(0)
426+
X = rng.uniform(low=0.0, high=1.0, size=(n_samples, n_features))
427+
if sparsity:
428+
X[X < sparsity] = 0.0
429+
X = sp.csr_matrix(X)
430+
y = rng.randint(low=0, high=2, size=n_samples)
431+
432+
f = BytesIO()
433+
dump_svmlight_file(X, y, f)
434+
f.seek(0)
435+
436+
size = len(f.getvalue())
437+
438+
# put some marks that are likely to happen anywhere in a row
439+
mark_0 = 0
440+
mark_1 = size // 3
441+
length_0 = mark_1 - mark_0
442+
mark_2 = 4 * size // 5
443+
length_1 = mark_2 - mark_1
444+
445+
# load the original sparse matrix into 3 independant CSR matrices
446+
X_0, y_0 = load_svmlight_file(f, n_features=n_features,
447+
offset=mark_0, length=length_0)
448+
X_1, y_1 = load_svmlight_file(f, n_features=n_features,
449+
offset=mark_1, length=length_1)
450+
X_2, y_2 = load_svmlight_file(f, n_features=n_features,
451+
offset=mark_2)
452+
453+
y_concat = np.concatenate([y_0, y_1, y_2])
454+
X_concat = sp.vstack([X_0, X_1, X_2])
455+
assert_array_equal(y, y_concat)
456+
assert_array_almost_equal(X.toarray(), X_concat.toarray())
457+
458+
# Generate a uniformly random sparse matrix
459+
for sparsity in [0, 0.1, .5, 0.99, 1]:
460+
for n_samples in [13, 101]:
461+
for n_features in [2, 7, 41]:
462+
yield check_load_with_offsets, sparsity, n_samples, n_features
463+
464+
465+
def test_load_offset_exhaustive_splits():
466+
rng = np.random.RandomState(0)
467+
X = np.array([
468+
[0, 0, 0, 0, 0, 0],
469+
[1, 2, 3, 4, 0, 6],
470+
[1, 2, 3, 4, 0, 6],
471+
[0, 0, 0, 0, 0, 0],
472+
[1, 0, 3, 0, 0, 0],
473+
[0, 0, 0, 0, 0, 1],
474+
[1, 0, 0, 0, 0, 0],
475+
])
476+
X = sp.csr_matrix(X)
477+
n_samples, n_features = X.shape
478+
y = rng.randint(low=0, high=2, size=n_samples)
479+
query_id = np.arange(n_samples) // 2
480+
481+
f = BytesIO()
482+
dump_svmlight_file(X, y, f, query_id=query_id)
483+
f.seek(0)
484+
485+
size = len(f.getvalue())
486+
487+
# load the same data in 2 parts with all the possible byte offsets to
488+
# locate the split so has to test for particular boundary cases
489+
for mark in range(size):
490+
if sp_version < (0, 14) and (mark == 0 or mark > size - 100):
491+
# old scipy does not support sparse matrices with 0 rows.
492+
continue
493+
f.seek(0)
494+
X_0, y_0, q_0 = load_svmlight_file(f, n_features=n_features,
495+
query_id=True, offset=0,
496+
length=mark)
497+
X_1, y_1, q_1 = load_svmlight_file(f, n_features=n_features,
498+
query_id=True, offset=mark,
499+
length=-1)
500+
q_concat = np.concatenate([q_0, q_1])
501+
y_concat = np.concatenate([y_0, y_1])
502+
X_concat = sp.vstack([X_0, X_1])
503+
assert_array_equal(y, y_concat)
504+
assert_array_equal(query_id, q_concat)
505+
assert_array_almost_equal(X.toarray(), X_concat.toarray())
506+
507+
508+
def test_load_with_offsets_error():
509+
assert_raises_regex(ValueError, "n_features is required",
510+
load_svmlight_file, datafile, offset=3, length=3)

0 commit comments

Comments
 (0)