|
2 | 2 | import gzip
|
3 | 3 | from io import BytesIO
|
4 | 4 | import numpy as np
|
| 5 | +import scipy.sparse as sp |
5 | 6 | import os
|
6 | 7 | import shutil
|
7 | 8 | from tempfile import NamedTemporaryFile
|
@@ -200,67 +201,84 @@ def test_invalid_filename():
|
200 | 201 |
|
201 | 202 |
|
202 | 203 | def test_dump():
|
203 |
| - Xs, y = load_svmlight_file(datafile) |
204 |
| - Xd = Xs.toarray() |
| 204 | + X_sparse, y_dense = load_svmlight_file(datafile) |
| 205 | + X_dense = X_sparse.toarray() |
| 206 | + y_sparse = sp.csr_matrix(y_dense) |
205 | 207 |
|
206 | 208 | # slicing a csr_matrix can unsort its .indices, so test that we sort
|
207 | 209 | # those correctly
|
208 |
| - Xsliced = Xs[np.arange(Xs.shape[0])] |
209 |
| - |
210 |
| - for X in (Xs, Xd, Xsliced): |
211 |
| - for zero_based in (True, False): |
212 |
| - for dtype in [np.float32, np.float64, np.int32]: |
213 |
| - f = BytesIO() |
214 |
| - # we need to pass a comment to get the version info in; |
215 |
| - # LibSVM doesn't grok comments so they're not put in by |
216 |
| - # default anymore. |
217 |
| - dump_svmlight_file(X.astype(dtype), y, f, comment="test", |
218 |
| - zero_based=zero_based) |
219 |
| - f.seek(0) |
220 |
| - |
221 |
| - comment = f.readline() |
222 |
| - try: |
223 |
| - comment = str(comment, "utf-8") |
224 |
| - except TypeError: # fails in Python 2.x |
225 |
| - pass |
226 |
| - |
227 |
| - assert_in("scikit-learn %s" % sklearn.__version__, comment) |
228 |
| - |
229 |
| - comment = f.readline() |
230 |
| - try: |
231 |
| - comment = str(comment, "utf-8") |
232 |
| - except TypeError: # fails in Python 2.x |
233 |
| - pass |
234 |
| - |
235 |
| - assert_in(["one", "zero"][zero_based] + "-based", comment) |
236 |
| - |
237 |
| - X2, y2 = load_svmlight_file(f, dtype=dtype, |
238 |
| - zero_based=zero_based) |
239 |
| - assert_equal(X2.dtype, dtype) |
240 |
| - assert_array_equal(X2.sorted_indices().indices, X2.indices) |
241 |
| - if dtype == np.float32: |
242 |
| - assert_array_almost_equal( |
| 210 | + X_sliced = X_sparse[np.arange(X_sparse.shape[0])] |
| 211 | + y_sliced = y_sparse[np.arange(y_sparse.shape[0])] |
| 212 | + |
| 213 | + for X in (X_sparse, X_dense, X_sliced): |
| 214 | + for y in (y_sparse, y_dense, y_sliced): |
| 215 | + for zero_based in (True, False): |
| 216 | + for dtype in [np.float32, np.float64, np.int32]: |
| 217 | + f = BytesIO() |
| 218 | + # we need to pass a comment to get the version info in; |
| 219 | + # LibSVM doesn't grok comments so they're not put in by |
| 220 | + # default anymore. |
| 221 | + |
| 222 | + if (sp.issparse(y) and y.shape[0] == 1): |
| 223 | + # make sure y's shape is: (n_samples, n_labels) |
| 224 | + # when it is sparse |
| 225 | + y = y.T |
| 226 | + |
| 227 | + dump_svmlight_file(X.astype(dtype), y, f, comment="test", |
| 228 | + zero_based=zero_based) |
| 229 | + f.seek(0) |
| 230 | + |
| 231 | + comment = f.readline() |
| 232 | + try: |
| 233 | + comment = str(comment, "utf-8") |
| 234 | + except TypeError: # fails in Python 2.x |
| 235 | + pass |
| 236 | + |
| 237 | + assert_in("scikit-learn %s" % sklearn.__version__, comment) |
| 238 | + |
| 239 | + comment = f.readline() |
| 240 | + try: |
| 241 | + comment = str(comment, "utf-8") |
| 242 | + except TypeError: # fails in Python 2.x |
| 243 | + pass |
| 244 | + |
| 245 | + assert_in(["one", "zero"][zero_based] + "-based", comment) |
| 246 | + |
| 247 | + X2, y2 = load_svmlight_file(f, dtype=dtype, |
| 248 | + zero_based=zero_based) |
| 249 | + assert_equal(X2.dtype, dtype) |
| 250 | + assert_array_equal(X2.sorted_indices().indices, X2.indices) |
| 251 | + |
| 252 | + X2_dense = X2.toarray() |
| 253 | + |
| 254 | + if dtype == np.float32: |
243 | 255 | # allow a rounding error at the last decimal place
|
244 |
| - Xd.astype(dtype), X2.toarray(), 4) |
245 |
| - else: |
246 |
| - assert_array_almost_equal( |
| 256 | + assert_array_almost_equal( |
| 257 | + X_dense.astype(dtype), X2_dense, 4) |
| 258 | + assert_array_almost_equal( |
| 259 | + y_dense.astype(dtype), y2, 4) |
| 260 | + else: |
247 | 261 | # allow a rounding error at the last decimal place
|
248 |
| - Xd.astype(dtype), X2.toarray(), 15) |
249 |
| - assert_array_equal(y, y2) |
| 262 | + assert_array_almost_equal( |
| 263 | + X_dense.astype(dtype), X2_dense, 15) |
| 264 | + assert_array_almost_equal( |
| 265 | + y_dense.astype(dtype), y2, 15) |
250 | 266 |
|
251 | 267 |
|
252 | 268 | def test_dump_multilabel():
|
253 | 269 | X = [[1, 0, 3, 0, 5],
|
254 | 270 | [0, 0, 0, 0, 0],
|
255 | 271 | [0, 5, 0, 1, 0]]
|
256 |
| - y = [[0, 1, 0], [1, 0, 1], [1, 1, 0]] |
257 |
| - f = BytesIO() |
258 |
| - dump_svmlight_file(X, y, f, multilabel=True) |
259 |
| - f.seek(0) |
260 |
| - # make sure it dumps multilabel correctly |
261 |
| - assert_equal(f.readline(), b("1 0:1 2:3 4:5\n")) |
262 |
| - assert_equal(f.readline(), b("0,2 \n")) |
263 |
| - assert_equal(f.readline(), b("0,1 1:5 3:1\n")) |
| 272 | + y_dense = [[0, 1, 0], [1, 0, 1], [1, 1, 0]] |
| 273 | + y_sparse = sp.csr_matrix(y_dense) |
| 274 | + for y in [y_dense, y_sparse]: |
| 275 | + f = BytesIO() |
| 276 | + dump_svmlight_file(X, y, f, multilabel=True) |
| 277 | + f.seek(0) |
| 278 | + # make sure it dumps multilabel correctly |
| 279 | + assert_equal(f.readline(), b("1 0:1 2:3 4:5\n")) |
| 280 | + assert_equal(f.readline(), b("0,2 \n")) |
| 281 | + assert_equal(f.readline(), b("0,1 1:5 3:1\n")) |
264 | 282 |
|
265 | 283 |
|
266 | 284 | def test_dump_concise():
|
|
0 commit comments