Skip to content

Commit 958b647

Browse files
committed
ENH: adds the ability load datasets from OpenML containing string
attributes by providing the option to ignore said attributes. Right now, an error is raised when a dataset containing string attributes (e.g., the Titanic dataset) is fetched from OpenML. This commit allows users to specify whether or not they are okay loading only a subset of the data. Closes #11819.
1 parent b8d1226 commit 958b647

8 files changed

+95
-19
lines changed

sklearn/datasets/openml.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def _convert_arff_data(arff_data, col_slice_x, col_slice_y):
235235
y : np.array
236236
"""
237237
if isinstance(arff_data, list):
238-
data = np.array(arff_data, dtype=np.float64)
238+
data = np.array(arff_data)
239239
X = np.array(data[:, col_slice_x], dtype=np.float64)
240240
y = np.array(data[:, col_slice_y], dtype=np.float64)
241241
return X, y
@@ -278,7 +278,7 @@ def _get_data_info_by_name(name, version, data_home):
278278
Returns
279279
-------
280280
first_dataset : json
281-
json representation of the first dataset object that adhired to the
281+
json representation of the first dataset object that adhered to the
282282
search criteria
283283
284284
"""
@@ -399,7 +399,8 @@ def _valid_data_column_names(features_list, target_columns):
399399

400400

401401
def fetch_openml(name=None, version='active', data_id=None, data_home=None,
402-
target_column='default-target', cache=True, return_X_y=False):
402+
ignore_strings=False, target_column='default-target',
403+
cache=True, return_X_y=False):
403404
"""Fetch dataset from openml by name or dataset id.
404405
405406
Datasets are uniquely identified by either an integer ID or by a
@@ -438,6 +439,9 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
438439
Specify another download and cache folder for the data sets. By default
439440
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
440441
442+
ignore_strings : boolean, default=True
443+
Whether to ignore string attributes when loading a dataset.
444+
441445
target_column : string, list or None, default 'default-target'
442446
Specify the column name in the data to use as target. If
443447
'default-target', the standard target column a stored on the server
@@ -536,11 +540,27 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
536540
# download data features, meta-info about column types
537541
features_list = _get_data_features(data_id, data_home)
538542

543+
if ignore_strings:
544+
string_features = list(filter(lambda f: f['data_type'] == 'string',
545+
features_list))
546+
if string_features:
547+
string_feature_names = list(map(lambda f: f['name'],
548+
string_features))
549+
warn("STRING attributes which are not yet supported. "
550+
"Therefore, the following column(s) will not be returned: {}"
551+
.format(",".join(string_feature_names)))
552+
features_list = list(filter(lambda f: f['name'] not
553+
in string_feature_names,
554+
features_list))
555+
539556
for feature in features_list:
540557
if 'true' in (feature['is_ignore'], feature['is_row_identifier']):
541558
continue
542-
if feature['data_type'] == 'string':
543-
raise ValueError('STRING attributes are not yet supported')
559+
if feature['data_type'] == 'string' and not ignore_strings:
560+
raise ValueError('STRING attributes are not yet supported.'
561+
'If you would like to return the data '
562+
'without STRING attributes, try using '
563+
'ignore_strings')
544564

545565
if target_column == "default-target":
546566
# determines the default target based on the data feature results
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

sklearn/datasets/tests/test_openml.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def decode_column(data_bunch, col_idx):
6565

6666

6767
def _fetch_dataset_from_openml(data_id, data_name, data_version,
68-
target_column,
68+
ignore_strings, target_column,
6969
expected_observations, expected_features,
7070
expected_missing,
7171
expected_data_dtype, expected_target_dtype,
@@ -75,17 +75,18 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
7575
# result. Note that this function can be mocked (by invoking
7676
# _monkey_patch_webbased_functions before invoking this function)
7777
data_by_name_id = fetch_openml(name=data_name, version=data_version,
78-
cache=False)
78+
ignore_strings=ignore_strings, cache=False)
7979
assert int(data_by_name_id.details['id']) == data_id
8080

8181
# Please note that cache=False is crucial, as the monkey patched files are
8282
# not consistent with reality
83-
fetch_openml(name=data_name, cache=False)
83+
fetch_openml(name=data_name, ignore_strings=ignore_strings, cache=False)
8484
# without specifying the version, there is no guarantee that the data id
8585
# will be the same
8686

8787
# fetch with dataset id
8888
data_by_id = fetch_openml(data_id=data_id, cache=False,
89+
ignore_strings=ignore_strings,
8990
target_column=target_column)
9091
assert data_by_id.details['name'] == data_name
9192
assert data_by_id.data.shape == (expected_observations, expected_features)
@@ -111,7 +112,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
111112

112113
if compare_default_target:
113114
# check whether the data by id and data by id target are equal
114-
data_by_id_default = fetch_openml(data_id=data_id, cache=False)
115+
data_by_id_default = fetch_openml(data_id=data_id,
116+
ignore_strings=ignore_strings,
117+
cache=False)
115118
if data_by_id.data.dtype == np.float64:
116119
np.testing.assert_allclose(data_by_id.data,
117120
data_by_id_default.data)
@@ -132,8 +135,9 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
132135
expected_missing)
133136

134137
# test return_X_y option
135-
fetch_func = partial(fetch_openml, data_id=data_id, cache=False,
136-
target_column=target_column)
138+
fetch_func = partial(fetch_openml, data_id=data_id,
139+
ignore_strings=ignore_strings,
140+
cache=False, target_column=target_column)
137141
check_return_X_y(data_by_id, fetch_func)
138142
return data_by_id
139143

@@ -260,6 +264,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
260264
data_id = 61
261265
data_name = 'iris'
262266
data_version = 1
267+
ignore_strings = False
263268
target_column = 'class'
264269
expected_observations = 150
265270
expected_features = 4
@@ -274,6 +279,7 @@ def test_fetch_openml_iris(monkeypatch, gzip_response):
274279
_fetch_dataset_from_openml,
275280
**{'data_id': data_id, 'data_name': data_name,
276281
'data_version': data_version,
282+
'ignore_strings': ignore_strings,
277283
'target_column': target_column,
278284
'expected_observations': expected_observations,
279285
'expected_features': expected_features,
@@ -297,13 +303,15 @@ def test_fetch_openml_iris_multitarget(monkeypatch, gzip_response):
297303
data_id = 61
298304
data_name = 'iris'
299305
data_version = 1
306+
ignore_strings = False
300307
target_column = ['sepallength', 'sepalwidth']
301308
expected_observations = 150
302309
expected_features = 3
303310
expected_missing = 0
304311

305312
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
306-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
313+
_fetch_dataset_from_openml(data_id, data_name, data_version,
314+
ignore_strings, target_column,
307315
expected_observations, expected_features,
308316
expected_missing,
309317
object, np.float64, expect_sparse=False,
@@ -316,13 +324,15 @@ def test_fetch_openml_anneal(monkeypatch, gzip_response):
316324
data_id = 2
317325
data_name = 'anneal'
318326
data_version = 1
327+
ignore_strings = False
319328
target_column = 'class'
320329
# Not all original instances included for space reasons
321330
expected_observations = 11
322331
expected_features = 38
323332
expected_missing = 267
324333
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
325-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
334+
_fetch_dataset_from_openml(data_id, data_name, data_version,
335+
ignore_strings, target_column,
326336
expected_observations, expected_features,
327337
expected_missing,
328338
object, object, expect_sparse=False,
@@ -341,13 +351,15 @@ def test_fetch_openml_anneal_multitarget(monkeypatch, gzip_response):
341351
data_id = 2
342352
data_name = 'anneal'
343353
data_version = 1
354+
ignore_strings = False
344355
target_column = ['class', 'product-type', 'shape']
345356
# Not all original instances included for space reasons
346357
expected_observations = 11
347358
expected_features = 36
348359
expected_missing = 267
349360
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
350-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
361+
_fetch_dataset_from_openml(data_id, data_name, data_version,
362+
ignore_strings, target_column,
351363
expected_observations, expected_features,
352364
expected_missing,
353365
object, object, expect_sparse=False,
@@ -360,12 +372,14 @@ def test_fetch_openml_cpu(monkeypatch, gzip_response):
360372
data_id = 561
361373
data_name = 'cpu'
362374
data_version = 1
375+
ignore_strings = False
363376
target_column = 'class'
364377
expected_observations = 209
365378
expected_features = 7
366379
expected_missing = 0
367380
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
368-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
381+
_fetch_dataset_from_openml(data_id, data_name, data_version,
382+
ignore_strings, target_column,
369383
expected_observations, expected_features,
370384
expected_missing,
371385
object, np.float64, expect_sparse=False,
@@ -387,6 +401,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
387401
data_id = 292
388402
data_name = 'Australian'
389403
data_version = 1
404+
ignore_strings = False
390405
target_column = 'Y'
391406
# Not all original instances included for space reasons
392407
expected_observations = 85
@@ -399,6 +414,7 @@ def test_fetch_openml_australian(monkeypatch, gzip_response):
399414
_fetch_dataset_from_openml,
400415
**{'data_id': data_id, 'data_name': data_name,
401416
'data_version': data_version,
417+
'ignore_strings': ignore_strings,
402418
'target_column': target_column,
403419
'expected_observations': expected_observations,
404420
'expected_features': expected_features,
@@ -416,13 +432,15 @@ def test_fetch_openml_adultcensus(monkeypatch, gzip_response):
416432
data_id = 1119
417433
data_name = 'adult-census'
418434
data_version = 1
435+
ignore_strings = False
419436
target_column = 'class'
420437
# Not all original instances included for space reasons
421438
expected_observations = 10
422439
expected_features = 14
423440
expected_missing = 0
424441
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
425-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
442+
_fetch_dataset_from_openml(data_id, data_name, data_version,
443+
ignore_strings, target_column,
426444
expected_observations, expected_features,
427445
expected_missing,
428446
np.float64, object, expect_sparse=False,
@@ -438,13 +456,15 @@ def test_fetch_openml_miceprotein(monkeypatch, gzip_response):
438456
data_id = 40966
439457
data_name = 'MiceProtein'
440458
data_version = 4
459+
ignore_strings = False
441460
target_column = 'class'
442461
# Not all original instances included for space reasons
443462
expected_observations = 7
444463
expected_features = 77
445464
expected_missing = 7
446465
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
447-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
466+
_fetch_dataset_from_openml(data_id, data_name, data_version,
467+
ignore_strings, target_column,
448468
expected_observations, expected_features,
449469
expected_missing,
450470
np.float64, object, expect_sparse=False,
@@ -457,14 +477,16 @@ def test_fetch_openml_emotions(monkeypatch, gzip_response):
457477
data_id = 40589
458478
data_name = 'emotions'
459479
data_version = 3
480+
ignore_strings = False
460481
target_column = ['amazed.suprised', 'happy.pleased', 'relaxing.calm',
461482
'quiet.still', 'sad.lonely', 'angry.aggresive']
462483
expected_observations = 13
463484
expected_features = 72
464485
expected_missing = 0
465486
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
466487

467-
_fetch_dataset_from_openml(data_id, data_name, data_version, target_column,
488+
_fetch_dataset_from_openml(data_id, data_name, data_version,
489+
ignore_strings, target_column,
468490
expected_observations, expected_features,
469491
expected_missing,
470492
np.float64, object, expect_sparse=False,
@@ -477,6 +499,27 @@ def test_decode_emotions(monkeypatch):
477499
_test_features_list(data_id)
478500

479501

502+
@pytest.mark.parametrize('gzip_response', [True, False])
503+
def test_fetch_titanic(monkeypatch, gzip_response):
504+
# check because of the string attributes
505+
data_id = 40945
506+
data_name = 'Titanic'
507+
data_version = 1
508+
ignore_strings = True
509+
target_column = 'survived'
510+
# Not all original features included because five are strings
511+
expected_observations = 1309
512+
expected_features = 8
513+
expected_missing = 1454
514+
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
515+
_fetch_dataset_from_openml(data_id, data_name, data_version,
516+
ignore_strings, target_column,
517+
expected_observations, expected_features,
518+
expected_missing,
519+
np.float64, object, expect_sparse=False,
520+
compare_default_target=True)
521+
522+
480523
@pytest.mark.parametrize('gzip_response', [True, False])
481524
def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir):
482525
data_id = 61
@@ -659,14 +702,27 @@ def test_warn_ignore_attribute(monkeypatch, gzip_response):
659702
cache=False)
660703

661704

705+
@pytest.mark.parametrize('gzip_response', [True, False])
706+
def test_ignore_strings(monkeypatch, gzip_response):
707+
data_id = 40945
708+
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
709+
assert_warns_message(
710+
UserWarning,
711+
"STRING attributes which are not yet supported. "
712+
"Therefore, the following column(s) will not be returned:",
713+
fetch_openml, data_id=data_id, ignore_strings=True, cache=False
714+
)
715+
716+
662717
@pytest.mark.parametrize('gzip_response', [True, False])
663718
def test_string_attribute(monkeypatch, gzip_response):
664719
data_id = 40945
665720
_monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
666721
# single column test
667722
assert_raise_message(ValueError,
668723
'STRING attributes are not yet supported',
669-
fetch_openml, data_id=data_id, cache=False)
724+
fetch_openml, data_id=data_id, ignore_strings=False,
725+
cache=False)
670726

671727

672728
@pytest.mark.parametrize('gzip_response', [True, False])

0 commit comments

Comments
 (0)