diff --git a/sklearn/datasets/kddcup99.py b/sklearn/datasets/kddcup99.py index 589749851ea6a..2b9686b4db2e0 100644 --- a/sklearn/datasets/kddcup99.py +++ b/sklearn/datasets/kddcup99.py @@ -24,7 +24,7 @@ from .base import get_data_home from .base import Bunch -from ..externals import joblib +from ..externals import joblib, six from ..utils import check_random_state from ..utils import shuffle as shuffle_method @@ -40,8 +40,8 @@ def fetch_kddcup99(subset=None, shuffle=False, random_state=None, - percent10=False): - """Load and return the kddcup 99 dataset (regression). + percent10=True, download_if_missing=True): + """Load and return the kddcup 99 dataset (classification). The KDD Cup '99 dataset was created by processing the tcpdump portions of the 1998 DARPA Intrusion Detection System (IDS) Evaluation dataset, @@ -93,7 +93,7 @@ def fetch_kddcup99(subset=None, shuffle=False, random_state=None, ================ ========================================== Samples total 699691 - Dimensionality 40 + Dimensionality 4 Features discrete (int) or continuous (float) Targets str, 'normal.' or name of the anomaly type ================ ========================================== @@ -102,7 +102,7 @@ def fetch_kddcup99(subset=None, shuffle=False, random_state=None, ================ ========================================== Samples total 619052 - Dimensionality 39 + Dimensionality 3 Features discrete (int) or continuous (float) Targets str, 'normal.' or name of the anomaly type ================ ========================================== @@ -111,7 +111,7 @@ def fetch_kddcup99(subset=None, shuffle=False, random_state=None, ================ ========================================== Samples total 95373 - Dimensionality 39 + Dimensionality 3 Features discrete (int) or continuous (float) Targets str, 'normal.' or name of the anomaly type ================ ========================================== @@ -135,6 +135,10 @@ def fetch_kddcup99(subset=None, shuffle=False, random_state=None, percent10 : bool, default=False Whether to load only 10 percent of the data. + download_if_missing : bool, default=True + If False, raise a IOError if the data is not locally available + instead of trying to download the data from the source site. + Returns ------- data : Bunch @@ -153,13 +157,14 @@ def fetch_kddcup99(subset=None, shuffle=False, random_state=None, Intrusions in Unlabeled Data (2002) by Eleazar Eskin, Andrew Arnold, Michael Prerau, Leonid Portnoy, Sal Stolfo """ - kddcup99 = _fetch_brute_kddcup99(shuffle=shuffle, percent10=percent10) + kddcup99 = _fetch_brute_kddcup99(shuffle=shuffle, percent10=percent10, + download_if_missing=download_if_missing) data = kddcup99.data target = kddcup99.target if subset == 'SA': - s = target == 'normal.' + s = target == b'normal.' t = np.logical_not(s) normal_samples = data[s, :] normal_targets = target[s] @@ -187,13 +192,13 @@ def fetch_kddcup99(subset=None, shuffle=False, random_state=None, data[:, 5] = np.log((data[:, 5] + 0.1).astype(float)) if subset == 'http': - s = data[:, 2] == 'http' + s = data[:, 2] == b'http' data = data[s] target = target[s] data = np.c_[data[:, 0], data[:, 4], data[:, 5]] if subset == 'smtp': - s = data[:, 2] == 'smtp' + s = data[:, 2] == b'smtp' data = data[s] target = target[s] data = np.c_[data[:, 0], data[:, 4], data[:, 5]] @@ -320,6 +325,8 @@ def _fetch_brute_kddcup99(subset=None, data_home=None, file_ = GzipFile(fileobj=f, mode='r') Xy = [] for line in file_.readlines(): + if six.PY3: + line = line.decode() Xy.append(line.replace('\n', '').split(',')) file_.close() print('extraction done') diff --git a/sklearn/datasets/tests/test_kddcup99.py b/sklearn/datasets/tests/test_kddcup99.py new file mode 100644 index 0000000000000..414c89763c1e8 --- /dev/null +++ b/sklearn/datasets/tests/test_kddcup99.py @@ -0,0 +1,41 @@ +"""Test kddcup99 loader. Only 'percent10' mode is tested, as the full data +is too big to use in unit-testing. + +The test is skipped if the data wasn't previously fetched and saved to +scikit-learn data folder. +""" + +import errno +from sklearn.datasets import fetch_kddcup99 +from sklearn.utils.testing import assert_equal, SkipTest + + +def test_percent10(): + try: + data = fetch_kddcup99(download_if_missing=False) + except IOError as e: + if e.errno == errno.ENOENT: + raise SkipTest("kddcup99 dataset can not be loaded.") + + assert_equal(data.data.shape, (494021, 41)) + assert_equal(data.target.shape, (494021,)) + + data_shuffled = fetch_kddcup99(shuffle=True, random_state=0) + assert_equal(data.data.shape, data_shuffled.data.shape) + assert_equal(data.target.shape, data_shuffled.target.shape) + + data = fetch_kddcup99('SA') + assert_equal(data.data.shape, (100655, 41)) + assert_equal(data.target.shape, (100655,)) + + data = fetch_kddcup99('SF') + assert_equal(data.data.shape, (73237, 4)) + assert_equal(data.target.shape, (73237,)) + + data = fetch_kddcup99('http') + assert_equal(data.data.shape, (58725, 3)) + assert_equal(data.target.shape, (58725,)) + + data = fetch_kddcup99('smtp') + assert_equal(data.data.shape, (9571, 3)) + assert_equal(data.target.shape, (9571,))