@@ -65,7 +65,7 @@ def decode_column(data_bunch, col_idx):
65
65
66
66
67
67
def _fetch_dataset_from_openml (data_id , data_name , data_version ,
68
- target_column ,
68
+ ignore_strings , target_column ,
69
69
expected_observations , expected_features ,
70
70
expected_missing ,
71
71
expected_data_dtype , expected_target_dtype ,
@@ -75,17 +75,18 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
75
75
# result. Note that this function can be mocked (by invoking
76
76
# _monkey_patch_webbased_functions before invoking this function)
77
77
data_by_name_id = fetch_openml (name = data_name , version = data_version ,
78
- cache = False )
78
+ ignore_strings = ignore_strings , cache = False )
79
79
assert int (data_by_name_id .details ['id' ]) == data_id
80
80
81
81
# Please note that cache=False is crucial, as the monkey patched files are
82
82
# not consistent with reality
83
- fetch_openml (name = data_name , cache = False )
83
+ fetch_openml (name = data_name , ignore_strings = ignore_strings , cache = False )
84
84
# without specifying the version, there is no guarantee that the data id
85
85
# will be the same
86
86
87
87
# fetch with dataset id
88
88
data_by_id = fetch_openml (data_id = data_id , cache = False ,
89
+ ignore_strings = ignore_strings ,
89
90
target_column = target_column )
90
91
assert data_by_id .details ['name' ] == data_name
91
92
assert data_by_id .data .shape == (expected_observations , expected_features )
@@ -111,7 +112,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
111
112
112
113
if compare_default_target :
113
114
# check whether the data by id and data by id target are equal
114
- data_by_id_default = fetch_openml (data_id = data_id , cache = False )
115
+ data_by_id_default = fetch_openml (data_id = data_id ,
116
+ ignore_strings = ignore_strings ,
117
+ cache = False )
115
118
if data_by_id .data .dtype == np .float64 :
116
119
np .testing .assert_allclose (data_by_id .data ,
117
120
data_by_id_default .data )
@@ -132,8 +135,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
132
135
expected_missing )
133
136
134
137
# test return_X_y option
135
- fetch_func = partial (fetch_openml , data_id = data_id , cache = False ,
136
- target_column = target_column )
138
+ fetch_func = partial (fetch_openml , data_id = data_id ,
139
+ ignore_strings = ignore_strings ,
140
+ cache = False , target_column = target_column )
137
141
check_return_X_y (data_by_id , fetch_func )
138
142
return data_by_id
139
143
@@ -260,6 +264,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
260
264
data_id = 61
261
265
data_name = 'iris'
262
266
data_version = 1
267
+ ignore_strings = False
263
268
target_column = 'class'
264
269
expected_observations = 150
265
270
expected_features = 4
@@ -274,6 +279,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
274
279
_fetch_dataset_from_openml ,
275
280
** {'data_id' : data_id , 'data_name' : data_name ,
276
281
'data_version' : data_version ,
282
+ 'ignore_strings' : ignore_strings ,
277
283
'target_column' : target_column ,
278
284
'expected_observations' : expected_observations ,
279
285
'expected_features' : expected_features ,
@@ -297,13 +303,15 @@ def test_fetch_openml_iris_multitarget(monkeypatch, gzip_response):
297
303
data_id = 61
298
304
data_name = 'iris'
299
305
data_version = 1
306
+ ignore_strings = False
300
307
target_column = ['sepallength' , 'sepalwidth' ]
301
308
expected_observations = 150
302
309
expected_features = 3
303
310
expected_missing = 0
304
311
305
312
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
306
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
313
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
314
+ ignore_strings , target_column ,
307
315
expected_observations , expected_features ,
308
316
expected_missing ,
309
317
object , np .float64 , expect_sparse = False ,
@@ -316,13 +324,15 @@ def test_fetch_openml_anneal(monkeypatch, gzip_response):
316
324
data_id = 2
317
325
data_name = 'anneal'
318
326
data_version = 1
327
+ ignore_strings = False
319
328
target_column = 'class'
320
329
# Not all original instances included for space reasons
321
330
expected_observations = 11
322
331
expected_features = 38
323
332
expected_missing = 267
324
333
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
325
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
334
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
335
+ ignore_strings , target_column ,
326
336
expected_observations , expected_features ,
327
337
expected_missing ,
328
338
object , object , expect_sparse = False ,
@@ -341,13 +351,15 @@ def test_fetch_openml_anneal_multitarget(monkeypatch, gzip_response):
341
351
data_id = 2
342
352
data_name = 'anneal'
343
353
data_version = 1
354
+ ignore_strings = False
344
355
target_column = ['class' , 'product-type' , 'shape' ]
345
356
# Not all original instances included for space reasons
346
357
expected_observations = 11
347
358
expected_features = 36
348
359
expected_missing = 267
349
360
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
350
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
361
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
362
+ ignore_strings , target_column ,
351
363
expected_observations , expected_features ,
352
364
expected_missing ,
353
365
object , object , expect_sparse = False ,
@@ -360,12 +372,14 @@ def test_fetch_openml_cpu(monkeypatch, gzip_response):
360
372
data_id = 561
361
373
data_name = 'cpu'
362
374
data_version = 1
375
+ ignore_strings = False
363
376
target_column = 'class'
364
377
expected_observations = 209
365
378
expected_features = 7
366
379
expected_missing = 0
367
380
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
368
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
381
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
382
+ ignore_strings , target_column ,
369
383
expected_observations , expected_features ,
370
384
expected_missing ,
371
385
object , np .float64 , expect_sparse = False ,
@@ -387,6 +401,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
387
401
data_id = 292
388
402
data_name = 'Australian'
389
403
data_version = 1
404
+ ignore_strings = False
390
405
target_column = 'Y'
391
406
# Not all original instances included for space reasons
392
407
expected_observations = 85
@@ -399,6 +414,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
399
414
_fetch_dataset_from_openml ,
400
415
** {'data_id' : data_id , 'data_name' : data_name ,
401
416
'data_version' : data_version ,
417
+ 'ignore_strings' : ignore_strings ,
402
418
'target_column' : target_column ,
403
419
'expected_observations' : expected_observations ,
404
420
'expected_features' : expected_features ,
@@ -416,13 +432,15 @@ def test_fetch_openml_adultcensus(monkeypatch, gzip_response):
416
432
data_id = 1119
417
433
data_name = 'adult-census'
418
434
data_version = 1
435
+ ignore_strings = False
419
436
target_column = 'class'
420
437
# Not all original instances included for space reasons
421
438
expected_observations = 10
422
439
expected_features = 14
423
440
expected_missing = 0
424
441
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
425
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
442
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
443
+ ignore_strings , target_column ,
426
444
expected_observations , expected_features ,
427
445
expected_missing ,
428
446
np .float64 , object , expect_sparse = False ,
@@ -438,13 +456,15 @@ def test_fetch_openml_miceprotein(monkeypatch, gzip_response):
438
456
data_id = 40966
439
457
data_name = 'MiceProtein'
440
458
data_version = 4
459
+ ignore_strings = False
441
460
target_column = 'class'
442
461
# Not all original instances included for space reasons
443
462
expected_observations = 7
444
463
expected_features = 77
445
464
expected_missing = 7
446
465
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
447
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
466
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
467
+ ignore_strings , target_column ,
448
468
expected_observations , expected_features ,
449
469
expected_missing ,
450
470
np .float64 , object , expect_sparse = False ,
@@ -457,14 +477,16 @@ def test_fetch_openml_emotions(monkeypatch, gzip_response):
457
477
data_id = 40589
458
478
data_name = 'emotions'
459
479
data_version = 3
480
+ ignore_strings = False
460
481
target_column = ['amazed.suprised' , 'happy.pleased' , 'relaxing.calm' ,
461
482
'quiet.still' , 'sad.lonely' , 'angry.aggresive' ]
462
483
expected_observations = 13
463
484
expected_features = 72
464
485
expected_missing = 0
465
486
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
466
487
467
- _fetch_dataset_from_openml (data_id , data_name , data_version , target_column ,
488
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
489
+ ignore_strings , target_column ,
468
490
expected_observations , expected_features ,
469
491
expected_missing ,
470
492
np .float64 , object , expect_sparse = False ,
@@ -477,6 +499,27 @@ def test_decode_emotions(monkeypatch):
477
499
_test_features_list (data_id )
478
500
479
501
502
+ @pytest .mark .parametrize ('gzip_response' , [True , False ])
503
+ def test_fetch_titanic (monkeypatch , gzip_response ):
504
+ # check because of the string attributes
505
+ data_id = 40945
506
+ data_name = 'Titanic'
507
+ data_version = 1
508
+ ignore_strings = True
509
+ target_column = 'survived'
510
+ # Not all original features included because five are strings
511
+ expected_observations = 1309
512
+ expected_features = 8
513
+ expected_missing = 1454
514
+ _monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
515
+ _fetch_dataset_from_openml (data_id , data_name , data_version ,
516
+ ignore_strings , target_column ,
517
+ expected_observations , expected_features ,
518
+ expected_missing ,
519
+ np .float64 , object , expect_sparse = False ,
520
+ compare_default_target = True )
521
+
522
+
480
523
@pytest .mark .parametrize ('gzip_response' , [True , False ])
481
524
def test_open_openml_url_cache (monkeypatch , gzip_response , tmpdir ):
482
525
data_id = 61
@@ -659,14 +702,27 @@ def test_warn_ignore_attribute(monkeypatch, gzip_response):
659
702
cache = False )
660
703
661
704
705
+ @pytest .mark .parametrize ('gzip_response' , [True , False ])
706
+ def test_ignore_strings (monkeypatch , gzip_response ):
707
+ data_id = 40945
708
+ _monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
709
+ assert_warns_message (
710
+ UserWarning ,
711
+ "STRING attributes which are not yet supported. "
712
+ "Therefore, the following column(s) will not be returned:" ,
713
+ fetch_openml , data_id = data_id , ignore_strings = True , cache = False
714
+ )
715
+
716
+
662
717
@pytest .mark .parametrize ('gzip_response' , [True , False ])
663
718
def test_string_attribute (monkeypatch , gzip_response ):
664
719
data_id = 40945
665
720
_monkey_patch_webbased_functions (monkeypatch , data_id , gzip_response )
666
721
# single column test
667
722
assert_raise_message (ValueError ,
668
723
'STRING attributes are not yet supported' ,
669
- fetch_openml , data_id = data_id , cache = False )
724
+ fetch_openml , data_id = data_id , ignore_strings = False ,
725
+ cache = False )
670
726
671
727
672
728
@pytest .mark .parametrize ('gzip_response' , [True , False ])
0 commit comments