From 081e784a37694678239ce7686f8b65090c530c93 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 27 Nov 2017 10:08:33 +0100 Subject: [PATCH 01/13] Optionally use hashtable or pandas for LabelEncoder --- sklearn/preprocessing/label.py | 119 ++++++++++++++++++---- sklearn/preprocessing/tests/test_label.py | 35 +++++++ 2 files changed, 137 insertions(+), 17 deletions(-) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 61b8d4a21af30..ad3ea9b7e5e74 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -36,6 +36,91 @@ ] +_PANDAS_INSTALLED = None + + +def _check_pandas(): + # return False + global _PANDAS_INSTALLED + + if _PANDAS_INSTALLED is None: + try: + import pandas + _PANDAS_INSTALLED = True + except ImportError: + _PANDAS_INSTALLED = False + + return _PANDAS_INSTALLED + + +# def _encode_numpy(values, uniques=None, encode=True): +# if uniques is None: +# if encode: +# uniques, encoded = np.unique(values, return_inverse=True) +# return uniques, encoded +# else: +# # unique sorts +# return np.unique(values) +# if encode: +# encoded = np.searchsorted(uniques, values) +# return uniques, encoded +# else: +# return uniques + + +def _factorize_numpy(values): + # unique sorts + return np.unique(values), None + + +def _encode_numpy(values, uniques): + return np.searchsorted(uniques, values) + + +def _factorize_sets(values, uniques=None, encode=True): + uniques = sorted(set(values)) + uniques = np.array(uniques, dtype=values.dtype) + table = {val: i for i, val in enumerate(uniques)} + return uniques, table + + +def _encode_dict(values, uniques, table): + return np.array([table[v] for v in values]) + + +def _factorize_pandas(values): + import pandas as pd + _, categories = pd.factorize(values, sort=True) + return categories, None + + +def _encode_pandas(values, uniques): + import pandas as pd + return pd.Categorical(values, categories=uniques).codes + + +def _factorize(values): + has_pd = _check_pandas() + if has_pd: + return _factorize_pandas(values) + elif values.dtype == object: + return _factorize_sets(values) + else: + return _factorize_numpy(values) + + +def _encode(values, uniques, table): + has_pd = _check_pandas() + if has_pd: + return _encode_pandas(values, uniques) + elif values.dtype == object: + if table is None: + table = {val: i for i, val in enumerate(uniques)} + return _encode_dict(values, uniques, table) + else: + return _encode_numpy(values, uniques) + + class LabelEncoder(BaseEstimator, TransformerMixin): """Encode labels with value between 0 and n_classes-1. @@ -93,24 +178,24 @@ def fit(self, y): self : returns an instance of self. """ y = column_or_1d(y, warn=True) - self.classes_ = np.unique(y) + self.classes_, self._table = _factorize(y) return self - def fit_transform(self, y): - """Fit label encoder and return encoded labels - - Parameters - ---------- - y : array-like of shape [n_samples] - Target values. - - Returns - ------- - y : array-like of shape [n_samples] - """ - y = column_or_1d(y, warn=True) - self.classes_, y = np.unique(y, return_inverse=True) - return y + # def fit_transform(self, y): + # """Fit label encoder and return encoded labels + # + # Parameters + # ---------- + # y : array-like of shape [n_samples] + # Target values. + # + # Returns + # ------- + # y : array-like of shape [n_samples] + # """ + # y = column_or_1d(y, warn=True) + # self.classes_, y = np.unique(y, return_inverse=True) + # return y def transform(self, y): """Transform labels to normalized encoding. @@ -132,7 +217,7 @@ def transform(self, y): diff = np.setdiff1d(classes, self.classes_) raise ValueError( "y contains previously unseen labels: %s" % str(diff)) - return np.searchsorted(self.classes_, y) + return _encode(y, uniques=self.classes_, table=self._table) def inverse_transform(self, y): """Transform labels back to original encoding. diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index 4f64fc6b4638c..7243b4e5bea11 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -1,5 +1,7 @@ import numpy as np +import pytest + from scipy.sparse import issparse from scipy.sparse import coo_matrix from scipy.sparse import csc_matrix @@ -23,6 +25,9 @@ from sklearn.preprocessing.label import _inverse_binarize_thresholding from sklearn.preprocessing.label import _inverse_binarize_multiclass +from sklearn.preprocessing.label import ( + _encode_numpy, _encode_pandas, _encode_dict, _factorize_numpy, + _factorize_pandas, _factorize_sets) from sklearn import datasets @@ -513,3 +518,33 @@ def test_inverse_binarize_multiclass(): [0, 0, 0]]), np.arange(3)) assert_array_equal(got, np.array([1, 1, 0])) + + +@pytest.mark.parametrize('engine', ['numpy', 'python', 'pandas']) +@pytest.mark.parametrize( + "values, expected", + [(np.array([2, 1, 3, 1, 3], dtype='int64'), + np.array([1, 2, 3], dtype='int64')), + (np.array(['b', 'a', 'c', 'a', 'c'], dtype=object), + np.array(['a', 'b', 'c'], dtype=object)), + (np.array(['b', 'a', 'c', 'a', 'c']), + np.array(['a', 'b', 'c']))], + ids=['int64', 'object', 'str']) +def test_factorize_encode_utils(engine, values, expected): + # test that all different encoders are equivalent + + if engine == 'numpy': + factorize = lambda values: _factorize_numpy(values) + encode = lambda values, uniques, table: _encode_numpy(values, uniques) + elif engine == 'python': + factorize = lambda values: _factorize_sets(values) + encode = lambda values, uniques, table: _encode_dict(values, uniques, table) + elif engine == 'pandas': + pytest.importorskip('pandas') + factorize = lambda values: _factorize_pandas(values) + encode = lambda values, uniques, table: _encode_pandas(values, uniques) + + uniques, table = factorize(values) + assert_array_equal(uniques, expected) + encoded = encode(values, uniques, table) + assert_array_equal(encoded, np.array([1, 0, 2, 0, 2])) From 391f007d1e2aab7f1a3aa7b404058cb1b601b2c0 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 6 Jun 2018 14:29:36 +0200 Subject: [PATCH 02/13] clean-up --- sklearn/preprocessing/label.py | 26 +++++++++++------------ sklearn/preprocessing/tests/test_label.py | 16 +++++++------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index ad3ea9b7e5e74..aa1e007b133e3 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -45,7 +45,7 @@ def _check_pandas(): if _PANDAS_INSTALLED is None: try: - import pandas + import pandas # noqa _PANDAS_INSTALLED = True except ImportError: _PANDAS_INSTALLED = False @@ -70,28 +70,28 @@ def _check_pandas(): def _factorize_numpy(values): # unique sorts - return np.unique(values), None + return np.unique(values) def _encode_numpy(values, uniques): return np.searchsorted(uniques, values) -def _factorize_sets(values, uniques=None, encode=True): +def _factorize_python(values): uniques = sorted(set(values)) uniques = np.array(uniques, dtype=values.dtype) - table = {val: i for i, val in enumerate(uniques)} - return uniques, table + return uniques -def _encode_dict(values, uniques, table): +def _encode_python(values, uniques): + table = {val: i for i, val in enumerate(uniques)} return np.array([table[v] for v in values]) def _factorize_pandas(values): import pandas as pd _, categories = pd.factorize(values, sort=True) - return categories, None + return np.array(categories, dtype=values.dtype) def _encode_pandas(values, uniques): @@ -104,19 +104,17 @@ def _factorize(values): if has_pd: return _factorize_pandas(values) elif values.dtype == object: - return _factorize_sets(values) + return _factorize_python(values) else: return _factorize_numpy(values) -def _encode(values, uniques, table): +def _encode(values, uniques): has_pd = _check_pandas() if has_pd: return _encode_pandas(values, uniques) elif values.dtype == object: - if table is None: - table = {val: i for i, val in enumerate(uniques)} - return _encode_dict(values, uniques, table) + return _encode_python(values, uniques) else: return _encode_numpy(values, uniques) @@ -178,7 +176,7 @@ def fit(self, y): self : returns an instance of self. """ y = column_or_1d(y, warn=True) - self.classes_, self._table = _factorize(y) + self.classes_ = _factorize(y) return self # def fit_transform(self, y): @@ -217,7 +215,7 @@ def transform(self, y): diff = np.setdiff1d(classes, self.classes_) raise ValueError( "y contains previously unseen labels: %s" % str(diff)) - return _encode(y, uniques=self.classes_, table=self._table) + return _encode(y, uniques=self.classes_) def inverse_transform(self, y): """Transform labels back to original encoding. diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index 7243b4e5bea11..00f1dbba14306 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -26,8 +26,8 @@ from sklearn.preprocessing.label import _inverse_binarize_thresholding from sklearn.preprocessing.label import _inverse_binarize_multiclass from sklearn.preprocessing.label import ( - _encode_numpy, _encode_pandas, _encode_dict, _factorize_numpy, - _factorize_pandas, _factorize_sets) + _encode_numpy, _encode_pandas, _encode_python, _factorize_numpy, + _factorize_pandas, _factorize_python) from sklearn import datasets @@ -535,16 +535,16 @@ def test_factorize_encode_utils(engine, values, expected): if engine == 'numpy': factorize = lambda values: _factorize_numpy(values) - encode = lambda values, uniques, table: _encode_numpy(values, uniques) + encode = lambda values, uniques: _encode_numpy(values, uniques) elif engine == 'python': - factorize = lambda values: _factorize_sets(values) - encode = lambda values, uniques, table: _encode_dict(values, uniques, table) + factorize = lambda values: _factorize_python(values) + encode = lambda values, uniques: _encode_python(values, uniques) elif engine == 'pandas': pytest.importorskip('pandas') factorize = lambda values: _factorize_pandas(values) - encode = lambda values, uniques, table: _encode_pandas(values, uniques) + encode = lambda values, uniques: _encode_pandas(values, uniques) - uniques, table = factorize(values) + uniques = factorize(values) assert_array_equal(uniques, expected) - encoded = encode(values, uniques, table) + encoded = encode(values, uniques) assert_array_equal(encoded, np.array([1, 0, 2, 0, 2])) From 9131c46a463195e4433b181430cc83d95dab2c24 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 20 Jun 2018 18:25:13 +0200 Subject: [PATCH 03/13] remove pandas for now; simplify; document --- sklearn/preprocessing/label.py | 157 ++++++++++------------ sklearn/preprocessing/tests/test_label.py | 27 +--- 2 files changed, 75 insertions(+), 109 deletions(-) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 8c3c57c8a2e38..7cb330807bb75 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -37,87 +37,65 @@ ] -_PANDAS_INSTALLED = None - - -def _check_pandas(): - # return False - global _PANDAS_INSTALLED - - if _PANDAS_INSTALLED is None: - try: - import pandas # noqa - _PANDAS_INSTALLED = True - except ImportError: - _PANDAS_INSTALLED = False - - return _PANDAS_INSTALLED - - -# def _encode_numpy(values, uniques=None, encode=True): -# if uniques is None: -# if encode: -# uniques, encoded = np.unique(values, return_inverse=True) -# return uniques, encoded -# else: -# # unique sorts -# return np.unique(values) -# if encode: -# encoded = np.searchsorted(uniques, values) -# return uniques, encoded -# else: -# return uniques - - -def _factorize_numpy(values): - # unique sorts - return np.unique(values) - - -def _encode_numpy(values, uniques): - return np.searchsorted(uniques, values) - - -def _factorize_python(values): - uniques = sorted(set(values)) - uniques = np.array(uniques, dtype=values.dtype) - return uniques - - -def _encode_python(values, uniques): - table = {val: i for i, val in enumerate(uniques)} - return np.array([table[v] for v in values]) +def _encode_numpy(values, uniques=None, encode=False): + if uniques is None: + if encode: + uniques, encoded = np.unique(values, return_inverse=True) + return uniques, encoded + else: + # unique sorts + return np.unique(values) + if encode: + encoded = np.searchsorted(uniques, values) + return uniques, encoded + else: + return uniques -def _factorize_pandas(values): - import pandas as pd - _, categories = pd.factorize(values, sort=True) - return np.array(categories, dtype=values.dtype) +def _encode_python(values, uniques=None, encode=False): + if uniques is None: + uniques = sorted(set(values)) + uniques = np.array(uniques, dtype=values.dtype) + if encode: + table = {val: i for i, val in enumerate(uniques)} + encoded = np.array([table[v] for v in values]) + return uniques, encoded + else: + return uniques -def _encode_pandas(values, uniques): - import pandas as pd - return pd.Categorical(values, categories=uniques).codes +def _encode(values, uniques=None, encode=False): + """ + Helper function to factorize (find uniques) and encode values. + Uses pure python method for object dtype, and numpy method for + all other dtypes. + The numpy method has the limitation that the `uniques` need to + be sorted. -def _factorize(values): - has_pd = _check_pandas() - if has_pd: - return _factorize_pandas(values) - elif values.dtype == object: - return _factorize_python(values) - else: - return _factorize_numpy(values) + Parameters + ---------- + values : array + Values to factorize or encode. + uniques : array, optional + If passed, uniques are not determined from passed values (this + can be because the user specified categories, or because they + already have been determined in fit) + encode : bool, default False + If True, also encode the values into integer codes based on `uniques` + Returns + ------- + uniques + If decode=False + (uniques, encoded) + If decode=True -def _encode(values, uniques): - has_pd = _check_pandas() - if has_pd: - return _encode_pandas(values, uniques) - elif values.dtype == object: - return _encode_python(values, uniques) + """ + if values.dtype == object: + return _encode_python(values, uniques, encode) else: - return _encode_numpy(values, uniques) + return _encode_numpy(values, uniques, encode) class LabelEncoder(BaseEstimator, TransformerMixin): @@ -177,24 +155,24 @@ def fit(self, y): self : returns an instance of self. """ y = column_or_1d(y, warn=True) - self.classes_ = _factorize(y) + self.classes_ = _encode(y) return self - # def fit_transform(self, y): - # """Fit label encoder and return encoded labels - # - # Parameters - # ---------- - # y : array-like of shape [n_samples] - # Target values. - # - # Returns - # ------- - # y : array-like of shape [n_samples] - # """ - # y = column_or_1d(y, warn=True) - # self.classes_, y = np.unique(y, return_inverse=True) - # return y + def fit_transform(self, y): + """Fit label encoder and return encoded labels + + Parameters + ---------- + y : array-like of shape [n_samples] + Target values. + + Returns + ------- + y : array-like of shape [n_samples] + """ + y = column_or_1d(y, warn=True) + self.classes_, y = _encode(y, encode=True) + return y def transform(self, y): """Transform labels to normalized encoding. @@ -219,7 +197,8 @@ def transform(self, y): diff = np.setdiff1d(classes, self.classes_) raise ValueError( "y contains previously unseen labels: %s" % str(diff)) - return _encode(y, uniques=self.classes_) + _, y = _encode(y, uniques=self.classes_, encode=True) + return y def inverse_transform(self, y): """Transform labels back to original encoding. diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index cee5ed9adf4b6..9888f0cde4eab 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -26,9 +26,7 @@ from sklearn.preprocessing.label import _inverse_binarize_thresholding from sklearn.preprocessing.label import _inverse_binarize_multiclass -from sklearn.preprocessing.label import ( - _encode_numpy, _encode_pandas, _encode_python, _factorize_numpy, - _factorize_pandas, _factorize_python) +from sklearn.preprocessing.label import _encode from sklearn import datasets @@ -543,7 +541,6 @@ def test_inverse_binarize_multiclass(): assert_array_equal(got, np.array([1, 1, 0])) -@pytest.mark.parametrize('engine', ['numpy', 'python', 'pandas']) @pytest.mark.parametrize( "values, expected", [(np.array([2, 1, 3, 1, 3], dtype='int64'), @@ -553,21 +550,11 @@ def test_inverse_binarize_multiclass(): (np.array(['b', 'a', 'c', 'a', 'c']), np.array(['a', 'b', 'c']))], ids=['int64', 'object', 'str']) -def test_factorize_encode_utils(engine, values, expected): - # test that all different encoders are equivalent - - if engine == 'numpy': - factorize = lambda values: _factorize_numpy(values) - encode = lambda values, uniques: _encode_numpy(values, uniques) - elif engine == 'python': - factorize = lambda values: _factorize_python(values) - encode = lambda values, uniques: _encode_python(values, uniques) - elif engine == 'pandas': - pytest.importorskip('pandas') - factorize = lambda values: _factorize_pandas(values) - encode = lambda values, uniques: _encode_pandas(values, uniques) - - uniques = factorize(values) +def test_encode_util(values, expected): + uniques = _encode(values) assert_array_equal(uniques, expected) - encoded = encode(values, uniques) + uniques, encoded = _encode(values, encode=True) + assert_array_equal(uniques, expected) + assert_array_equal(encoded, np.array([1, 0, 2, 0, 2])) + _, encoded = _encode(values, uniques, encode=True) assert_array_equal(encoded, np.array([1, 0, 2, 0, 2])) From ed64b91afac6058d36fcf220cc5d65b884864e43 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 20 Jun 2018 19:08:31 +0200 Subject: [PATCH 04/13] move detection unseen labels into encode function --- sklearn/preprocessing/label.py | 16 ++++++++++------ sklearn/preprocessing/tests/test_label.py | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 7cb330807bb75..fa47b70648254 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -46,6 +46,11 @@ def _encode_numpy(values, uniques=None, encode=False): # unique sorts return np.unique(values) if encode: + uniques_values = np.unique(values) + if len(np.intersect1d(uniques_values, uniques)) < len(uniques_values): + diff = np.setdiff1d(uniques_values, uniques) + raise ValueError( + "y contains previously unseen labels: %s" % str(diff)) encoded = np.searchsorted(uniques, values) return uniques, encoded else: @@ -58,7 +63,11 @@ def _encode_python(values, uniques=None, encode=False): uniques = np.array(uniques, dtype=values.dtype) if encode: table = {val: i for i, val in enumerate(uniques)} - encoded = np.array([table[v] for v in values]) + try: + encoded = np.array([table[v] for v in values]) + except KeyError as e: + raise ValueError( + "y contains previously unseen labels: %s" % str(e)) return uniques, encoded else: return uniques @@ -192,11 +201,6 @@ def transform(self, y): if _num_samples(y) == 0: return np.array([]) - classes = np.unique(y) - if len(np.intersect1d(classes, self.classes_)) < len(classes): - diff = np.setdiff1d(classes, self.classes_) - raise ValueError( - "y contains previously unseen labels: %s" % str(diff)) _, y = _encode(y, uniques=self.classes_, encode=True) return y diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index 9888f0cde4eab..d4aa4169a8493 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -188,6 +188,27 @@ def test_label_encoder(): assert_raise_message(ValueError, msg, le.transform, "apple") +@pytest.mark.parametrize( + "values, classes, unknown", + [(np.array([2, 1, 3, 1, 3], dtype='int64'), + np.array([1, 2, 3], dtype='int64'), np.array([4], dtype='int64')), + (np.array(['b', 'a', 'c', 'a', 'c'], dtype=object), + np.array(['a', 'b', 'c'], dtype=object), + np.array(['d'], dtype=object)), + (np.array(['b', 'a', 'c', 'a', 'c']), + np.array(['a', 'b', 'c']), np.array(['d']))], + ids=['int64', 'object', 'str']) +def test_label_encoder_dtypes(values, classes, unknown): + # Test LabelEncoder's transform and inverse_transform methods + le = LabelEncoder() + le.fit(values) + assert_array_equal(le.classes_, classes) + assert_array_equal(le.transform(values), [1, 0, 2, 0, 2]) + assert_array_equal(le.inverse_transform([1, 0, 2, 0, 2]), values) + msg = "unseen labels" + assert_raise_message(ValueError, msg, le.transform, unknown) + + def test_label_encoder_fit_transform(): # Test fit_transform le = LabelEncoder() From 7268e6bc4ca1a7614e0e0e48229092fc78fb02e6 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 21 Jun 2018 09:21:49 +0200 Subject: [PATCH 05/13] use new _encode instead of LabelEncoder in CategoricalEncoder --- sklearn/preprocessing/data.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 4df7c295bd834..ea4b331b6a3b0 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -33,7 +33,7 @@ min_max_axis) from ..utils.validation import (check_is_fitted, check_random_state, FLOAT_DTYPES) -from .label import LabelEncoder +from .label import _encode BOUNDS_THRESHOLD = 1e-7 @@ -3029,13 +3029,12 @@ def fit(self, X, y=None): n_samples, n_features = X.shape - self._label_encoders_ = [LabelEncoder() for _ in range(n_features)] + self.categories_ = [] for i in range(n_features): - le = self._label_encoders_[i] Xi = X[:, i] if self.categories == 'auto': - le.fit(Xi) + cats = _encode(Xi) else: if self.handle_unknown == 'error': valid_mask = np.in1d(Xi, self.categories[i]) @@ -3044,9 +3043,8 @@ def fit(self, X, y=None): msg = ("Found unknown categories {0} in column {1}" " during fit".format(diff, i)) raise ValueError(msg) - le.classes_ = np.array(self.categories[i]) - - self.categories_ = [le.classes_ for le in self._label_encoders_] + cats = np.array(self.categories[i]) + self.categories_.append(cats) return self @@ -3091,7 +3089,8 @@ def transform(self, X): X_mask[:, i] = valid_mask Xi = Xi.copy() Xi[~valid_mask] = self.categories_[i][0] - X_int[:, i] = self._label_encoders_[i].transform(Xi) + _, encoded = _encode(Xi, self.categories_[i], encode=True) + X_int[:, i] = encoded if self.encoding == 'ordinal': return X_int.astype(self.dtype, copy=False) From 787d61759126a960970840cd7721fac40bf851c0 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 21 Jun 2018 09:37:40 +0200 Subject: [PATCH 06/13] parametrize more tests --- sklearn/preprocessing/tests/test_label.py | 55 ++++++++++++----------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index d4aa4169a8493..aa7bcd7a965fd 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -172,22 +172,6 @@ def test_label_binarizer_errors(): [1, 2, 3]) -def test_label_encoder(): - # Test LabelEncoder's transform and inverse_transform methods - le = LabelEncoder() - le.fit([1, 1, 4, 5, -1, 0]) - assert_array_equal(le.classes_, [-1, 0, 1, 4, 5]) - assert_array_equal(le.transform([0, 1, 4, 4, 5, -1, -1]), - [1, 2, 3, 3, 4, 0, 0]) - assert_array_equal(le.inverse_transform([1, 2, 3, 3, 4, 0, 0]), - [0, 1, 4, 4, 5, -1, -1]) - assert_raises(ValueError, le.transform, [0, 6]) - - le.fit(["apple", "orange"]) - msg = "bad input shape" - assert_raise_message(ValueError, msg, le.transform, "apple") - - @pytest.mark.parametrize( "values, classes, unknown", [(np.array([2, 1, 3, 1, 3], dtype='int64'), @@ -198,26 +182,39 @@ def test_label_encoder(): (np.array(['b', 'a', 'c', 'a', 'c']), np.array(['a', 'b', 'c']), np.array(['d']))], ids=['int64', 'object', 'str']) -def test_label_encoder_dtypes(values, classes, unknown): - # Test LabelEncoder's transform and inverse_transform methods +def test_label_encoder(values, classes, unknown): + # Test LabelEncoder's transform, fit_transform and + # inverse_transform methods le = LabelEncoder() le.fit(values) assert_array_equal(le.classes_, classes) assert_array_equal(le.transform(values), [1, 0, 2, 0, 2]) assert_array_equal(le.inverse_transform([1, 0, 2, 0, 2]), values) + le = LabelEncoder() + ret = le.fit_transform(values) + assert_array_equal(ret, [1, 0, 2, 0, 2]) + msg = "unseen labels" assert_raise_message(ValueError, msg, le.transform, unknown) -def test_label_encoder_fit_transform(): - # Test fit_transform +def test_label_encoder_negative_ints(): le = LabelEncoder() - ret = le.fit_transform([1, 1, 4, 5, -1, 0]) - assert_array_equal(ret, [2, 2, 3, 4, 0, 1]) + le.fit([1, 1, 4, 5, -1, 0]) + assert_array_equal(le.classes_, [-1, 0, 1, 4, 5]) + assert_array_equal(le.transform([0, 1, 4, 4, 5, -1, -1]), + [1, 2, 3, 3, 4, 0, 0]) + assert_array_equal(le.inverse_transform([1, 2, 3, 3, 4, 0, 0]), + [0, 1, 4, 4, 5, -1, -1]) + assert_raises(ValueError, le.transform, [0, 6]) + +@pytest.mark.parametrize("dtype", ['str', 'object']) +def test_label_encoder_str_bad_shape(dtype): le = LabelEncoder() - ret = le.fit_transform(["paris", "paris", "tokyo", "amsterdam"]) - assert_array_equal(ret, [1, 1, 2, 0]) + le.fit(np.array(["apple", "orange"], dtype=dtype)) + msg = "bad input shape" + assert_raise_message(ValueError, msg, le.transform, "apple") def test_label_encoder_errors(): @@ -238,9 +235,15 @@ def test_label_encoder_errors(): assert_raise_message(ValueError, msg, le.inverse_transform, "") -def test_label_encoder_empty_array(): +@pytest.mark.parametrize( + "values", + [np.array([2, 1, 3, 1, 3], dtype='int64'), + np.array(['b', 'a', 'c', 'a', 'c'], dtype=object), + np.array(['b', 'a', 'c', 'a', 'c'])], + ids=['int64', 'object', 'str']) +def test_label_encoder_empty_array(values): le = LabelEncoder() - le.fit(np.array(["1", "2", "1", "2", "2"])) + le.fit(values) # test empty transform transformed = le.transform([]) assert_array_equal(np.array([]), transformed) From 7502e5e96ed29db831fbf33c37c654140350cfc8 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 21 Jun 2018 10:34:35 +0200 Subject: [PATCH 07/13] also properly handle unknown categories in CategoricalEncoder --- sklearn/preprocessing/data.py | 13 ++++---- sklearn/preprocessing/label.py | 55 ++++++++++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index ea4b331b6a3b0..f871995863da2 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -33,7 +33,7 @@ min_max_axis) from ..utils.validation import (check_is_fitted, check_random_state, FLOAT_DTYPES) -from .label import _encode +from .label import _encode, _encode_check_unknown BOUNDS_THRESHOLD = 1e-7 @@ -3036,14 +3036,13 @@ def fit(self, X, y=None): if self.categories == 'auto': cats = _encode(Xi) else: + cats = np.array(self.categories[i]) if self.handle_unknown == 'error': - valid_mask = np.in1d(Xi, self.categories[i]) - if not np.all(valid_mask): - diff = np.unique(Xi[~valid_mask]) + diff = _encode_check_unknown(Xi, cats) + if diff: msg = ("Found unknown categories {0} in column {1}" " during fit".format(diff, i)) raise ValueError(msg) - cats = np.array(self.categories[i]) self.categories_.append(cats) return self @@ -3074,11 +3073,11 @@ def transform(self, X): for i in range(n_features): Xi = X[:, i] - valid_mask = np.in1d(Xi, self.categories_[i]) + diff, valid_mask = _encode_check_unknown(Xi, self.categories_[i], + return_mask=True) if not np.all(valid_mask): if self.handle_unknown == 'error': - diff = np.unique(X[~valid_mask, i]) msg = ("Found unknown categories {0} in column {1}" " during transform".format(diff, i)) raise ValueError(msg) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index fa47b70648254..eb9c3ddbf4fc6 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -96,9 +96,9 @@ def _encode(values, uniques=None, encode=False): Returns ------- uniques - If decode=False + If encode=False (uniques, encoded) - If decode=True + If encode=True """ if values.dtype == object: @@ -107,6 +107,57 @@ def _encode(values, uniques=None, encode=False): return _encode_numpy(values, uniques, encode) +def _encode_check_unknown(values, uniques, return_mask=False): + """ + Helper function to check for unknowns in values to be encoded. + + Uses pure python method for object dtype, and numpy method for + all other dtypes. + + Parameters + ---------- + values : array + Values to check for unknowns. + uniques : array + Allowed uniques values. + return_mask : bool, default False + If True, return a mask of the same shape as `values` indicating + the valid values. + + Returns + ------- + diff : list + The unique values present in `values` and not in `uniques` (the + unknown values).If encode=False + valid_mask : boolean array + If return_mask=True + + """ + if values.dtype == object: + unique_values = set(values) + diff = list(unique_values - set(uniques)) + if return_mask: + if diff: + uniques_set = set(uniques) + valid_mask = np.array([val in uniques_set for val in values]) + else: + valid_mask = np.ones(len(values), dtype=bool) + return diff, valid_mask + else: + return diff + else: + unique_values = np.unique(values) + diff = list(np.setdiff1d(unique_values, uniques)) + if return_mask: + if diff: + valid_mask = np.in1d(values, uniques) + else: + valid_mask = np.ones(len(values), dtype=bool) + return diff, valid_mask + else: + return diff + + class LabelEncoder(BaseEstimator, TransformerMixin): """Encode labels with value between 0 and n_classes-1. From f7e195ab08e1ad32bbf274dbcc10c7ddfd269a6a Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 21 Jun 2018 10:40:32 +0200 Subject: [PATCH 08/13] reuse check function for LabelEncoder as well --- sklearn/preprocessing/label.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index eb9c3ddbf4fc6..20902c01072fd 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -46,11 +46,10 @@ def _encode_numpy(values, uniques=None, encode=False): # unique sorts return np.unique(values) if encode: - uniques_values = np.unique(values) - if len(np.intersect1d(uniques_values, uniques)) < len(uniques_values): - diff = np.setdiff1d(uniques_values, uniques) + diff = _encode_check_unknown(values, uniques) + if diff: raise ValueError( - "y contains previously unseen labels: %s" % str(diff)) + "y contains previously unseen labels: %s" % str(diff)) encoded = np.searchsorted(uniques, values) return uniques, encoded else: From 3d1281c34403d12b06c5dc2283bd4123f821df32 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 21 Jun 2018 14:35:22 +0200 Subject: [PATCH 09/13] allow unsorted categories passed by user for object dtype --- sklearn/preprocessing/_encoders.py | 17 ++++++++------- sklearn/preprocessing/tests/test_encoders.py | 23 +++++++++++++++----- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index 75a453207f141..96311627bfeae 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -104,10 +104,11 @@ def _fit(self, X, handle_unknown='error'): n_samples, n_features = X.shape if self._categories != 'auto': - for cats in self._categories: - if not np.all(np.sort(cats) == np.array(cats)): - raise ValueError("Unsorted categories are not yet " - "supported") + if X.dtype != object: + for cats in self._categories: + if not np.all(np.sort(cats) == np.array(cats)): + raise ValueError("Unsorted categories are not " + "supported for numerical categories") if len(self._categories) != n_features: raise ValueError("Shape mismatch: if n_values is an array," " it has to be of shape (n_features,).") @@ -193,8 +194,8 @@ class OneHotEncoder(_BaseEncoder): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith - column. The passed categories must be sorted and should not mix - strings and numeric values. + column. The passed categories should not mix strings and numeric + values, and should be sorted in case of numeric values. The used categories can be found in the ``categories_`` attribute. @@ -711,8 +712,8 @@ class OrdinalEncoder(_BaseEncoder): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith - column. The passed categories must be sorted and should not mix - strings and numeric values. + column. The passed categories should not mix strings and numeric + values, and should be sorted in case of numeric values. The used categories can be found in the ``categories_`` attribute. diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index e9abce28c8639..97596be130a0b 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -429,11 +429,6 @@ def test_one_hot_encoder_specified_categories(): assert enc.categories_[0].tolist() == ['a', 'b', 'c'] assert np.issubdtype(enc.categories_[0].dtype, np.str_) - # unsorted passed categories raises for now - enc = OneHotEncoder(categories=[['c', 'b', 'a']]) - msg = re.escape('Unsorted categories are not yet supported') - assert_raises_regex(ValueError, msg, enc.fit_transform, X) - # multiple columns X = np.array([['a', 'b'], [0, 2]], dtype=object).T enc = OneHotEncoder(categories=[['a', 'b', 'c'], [0, 1, 2]]) @@ -455,6 +450,24 @@ def test_one_hot_encoder_specified_categories(): assert_array_equal(enc.fit(X).transform(X).toarray(), exp) +def test_one_hot_encoder_unsorted_categories(): + X = np.array([['a', 'b']], dtype=object).T + + enc = OneHotEncoder(categories=[['b', 'a', 'c']]) + exp = np.array([[0., 1., 0.], + [1., 0., 0.]]) + assert_array_equal(enc.fit(X).transform(X).toarray(), exp) + assert_array_equal(enc.fit_transform(X).toarray(), exp) + assert enc.categories_[0].tolist() == ['b', 'a', 'c'] + assert np.issubdtype(enc.categories_[0].dtype, np.str_) + + # unsorted passed categories still raise for numerical values + X = np.array([[1, 2]]).T + enc = OneHotEncoder(categories=[[2, 1, 3]]) + msg = re.escape('Unsorted categories are not supported') + assert_raises_regex(ValueError, msg, enc.fit_transform, X) + + def test_one_hot_encoder_pandas(): pd = pytest.importorskip('pandas') From 626d217cea2ccffca250055c26c1be7dbe8cf5c5 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 22 Jun 2018 14:13:30 +0200 Subject: [PATCH 10/13] parametrize some OneHotEncoder tests for object/int dtypes --- sklearn/preprocessing/tests/test_encoders.py | 68 +++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index 97596be130a0b..020729f4099bb 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -351,7 +351,12 @@ def check_categorical_onehot(X): return Xtr1.toarray() -def test_one_hot_encoder(): +@pytest.mark.parametrize("X", [ + [['abc', 1, 55], ['def', 2, 55]], + np.array([[10, 1, 55], [5, 2, 55]]), + np.array([['b', 'A', 'cat'], ['a', 'B', 'cat']], dtype=object) + ], ids=['mixed', 'numeric', 'object']) +def test_one_hot_encoder(X): X = [['abc', 1, 55], ['def', 2, 55]] Xtr = check_categorical_onehot(np.array(X)[:, [0]]) @@ -404,31 +409,50 @@ def test_one_hot_encoder_inverse(): assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) -def test_one_hot_encoder_categories(): - X = [['abc', 1, 55], ['def', 2, 55]] - +@pytest.mark.parametrize("X, cat_exp", [ + ([['abc', 55], ['def', 55]], [['abc', 'def'], [55]]), + (np.array([[1, 2], [3, 2]]), [[1, 3], [2]]), + (np.array([['A', 'cat'], ['B', 'cat']], dtype=object), + [['A', 'B'], ['cat']]) + ], ids=['mixed', 'numeric', 'object']) +def test_one_hot_encoder_categories(X, cat_exp): # order of categories should not depend on order of samples for Xi in [X, X[::-1]]: - enc = OneHotEncoder() + enc = OneHotEncoder(categories='auto') enc.fit(Xi) # assert enc.categories == 'auto' assert isinstance(enc.categories_, list) - cat_exp = [['abc', 'def'], [1, 2], [55]] for res, exp in zip(enc.categories_, cat_exp): assert res.tolist() == exp -def test_one_hot_encoder_specified_categories(): - X = np.array([['a', 'b']], dtype=object).T - - enc = OneHotEncoder(categories=[['a', 'b', 'c']]) +@pytest.mark.parametrize("X, X2, cats, cat_dtype", [ + (np.array([['a', 'b']], dtype=object).T, + np.array([['a', 'd']], dtype=object).T, + [['a', 'b', 'c']], np.str_), + (np.array([[1, 2]], dtype='int64').T, + np.array([[1, 4]], dtype='int64').T, + [[1, 2, 3]], np.integer), + ], ids=['object', 'numeric']) +def test_one_hot_encoder_specified_categories(X, X2, cats, cat_dtype): + enc = OneHotEncoder(categories=cats) exp = np.array([[1., 0., 0.], [0., 1., 0.]]) assert_array_equal(enc.fit_transform(X).toarray(), exp) - assert enc.categories[0] == ['a', 'b', 'c'] - assert enc.categories_[0].tolist() == ['a', 'b', 'c'] - assert np.issubdtype(enc.categories_[0].dtype, np.str_) + assert enc.categories[0] == cats[0] + assert enc.categories_[0].tolist() == cats[0] + assert np.issubdtype(enc.categories_[0].dtype, cat_dtype) + + # when specifying categories manually, unknown categories should already + # raise when fitting + enc = OneHotEncoder(categories=cats) + assert_raises(ValueError, enc.fit, X2) + enc = OneHotEncoder(categories=cats, handle_unknown='ignore') + exp = np.array([[1., 0., 0.], [0., 0., 0.]]) + assert_array_equal(enc.fit(X2).transform(X2).toarray(), exp) + +def test_one_hot_encoder_specified_categories_mixed_columns(): # multiple columns X = np.array([['a', 'b'], [0, 2]], dtype=object).T enc = OneHotEncoder(categories=[['a', 'b', 'c'], [0, 1, 2]]) @@ -440,15 +464,6 @@ def test_one_hot_encoder_specified_categories(): assert enc.categories_[1].tolist() == [0, 1, 2] assert np.issubdtype(enc.categories_[1].dtype, np.integer) - # when specifying categories manually, unknown categories should already - # raise when fitting - X = np.array([['a', 'b', 'c']]).T - enc = OneHotEncoder(categories=[['a', 'b']]) - assert_raises(ValueError, enc.fit, X) - enc = OneHotEncoder(categories=[['a', 'b']], handle_unknown='ignore') - exp = np.array([[1., 0.], [0., 1.], [0., 0.]]) - assert_array_equal(enc.fit(X).transform(X).toarray(), exp) - def test_one_hot_encoder_unsorted_categories(): X = np.array([['a', 'b']], dtype=object).T @@ -477,9 +492,12 @@ def test_one_hot_encoder_pandas(): assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) -def test_ordinal_encoder(): - X = [['abc', 2, 55], ['def', 1, 55]] - +@pytest.mark.parametrize("X", [ + [['abc', 2, 55], ['def', 1, 55]], + np.array([[10, 2, 55], [20, 1, 55]]), + np.array([['a', 'B', 'cat'], ['b', 'A', 'cat']], dtype=object) + ], ids=['mixed', 'numeric', 'object']) +def test_ordinal_encoder(X): enc = OrdinalEncoder() exp = np.array([[0, 1, 0], [1, 0, 0]], dtype='int64') From 662da670616526b417307e4e4f04880c71a9524f Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 27 Jun 2018 10:06:57 +0200 Subject: [PATCH 11/13] feedback guillaume --- sklearn/preprocessing/label.py | 17 +++++++------- sklearn/preprocessing/tests/test_encoders.py | 24 ++++++++++---------- sklearn/preprocessing/tests/test_label.py | 4 ++-- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index e805dd0aa000b..921fa877e60ea 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -88,16 +88,16 @@ def _encode(values, uniques=None, encode=False): uniques : array, optional If passed, uniques are not determined from passed values (this can be because the user specified categories, or because they - already have been determined in fit) + already have been determined in fit). encode : bool, default False - If True, also encode the values into integer codes based on `uniques` + If True, also encode the values into integer codes based on `uniques`. Returns ------- uniques - If encode=False + If ``encode=False``. (uniques, encoded) - If encode=True + If ``encode=True``. """ if values.dtype == object: @@ -127,17 +127,16 @@ def _encode_check_unknown(values, uniques, return_mask=False): ------- diff : list The unique values present in `values` and not in `uniques` (the - unknown values).If encode=False + unknown values). valid_mask : boolean array - If return_mask=True + Additionally returned if ``return_mask=True``. """ if values.dtype == object: - unique_values = set(values) - diff = list(unique_values - set(uniques)) + uniques_set = set(uniques) + diff = list(set(values) - uniques_set) if return_mask: if diff: - uniques_set = set(uniques) valid_mask = np.array([val in uniques_set for val in values]) else: valid_mask = np.ones(len(values), dtype=bool) diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index 020729f4099bb..d7419cb1c9186 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -339,10 +339,10 @@ def test_one_hot_encoder_set_params(): def check_categorical_onehot(X): - enc = OneHotEncoder() + enc = OneHotEncoder(categories='auto') Xtr1 = enc.fit_transform(X) - enc = OneHotEncoder(sparse=False) + enc = OneHotEncoder(categories='auto', sparse=False) Xtr2 = enc.fit_transform(X) assert_allclose(Xtr1.toarray(), Xtr2) @@ -352,21 +352,19 @@ def check_categorical_onehot(X): @pytest.mark.parametrize("X", [ - [['abc', 1, 55], ['def', 2, 55]], + [['def', 1, 55], ['abc', 2, 55]], np.array([[10, 1, 55], [5, 2, 55]]), np.array([['b', 'A', 'cat'], ['a', 'B', 'cat']], dtype=object) ], ids=['mixed', 'numeric', 'object']) def test_one_hot_encoder(X): - X = [['abc', 1, 55], ['def', 2, 55]] - Xtr = check_categorical_onehot(np.array(X)[:, [0]]) - assert_allclose(Xtr, [[1, 0], [0, 1]]) + assert_allclose(Xtr, [[0, 1], [1, 0]]) Xtr = check_categorical_onehot(np.array(X)[:, [0, 1]]) - assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) + assert_allclose(Xtr, [[0, 1, 1, 0], [1, 0, 0, 1]]) - Xtr = OneHotEncoder().fit_transform(X) - assert_allclose(Xtr.toarray(), [[1, 0, 1, 0, 1], [0, 1, 0, 1, 1]]) + Xtr = OneHotEncoder(categories='auto').fit_transform(X) + assert_allclose(Xtr.toarray(), [[0, 1, 1, 0, 1], [1, 0, 0, 1, 1]]) def test_one_hot_encoder_inverse(): @@ -446,7 +444,8 @@ def test_one_hot_encoder_specified_categories(X, X2, cats, cat_dtype): # when specifying categories manually, unknown categories should already # raise when fitting enc = OneHotEncoder(categories=cats) - assert_raises(ValueError, enc.fit, X2) + with pytest.raises(ValueError, match="Found unknown categories"): + enc.fit(X2) enc = OneHotEncoder(categories=cats, handle_unknown='ignore') exp = np.array([[1., 0., 0.], [0., 0., 0.]]) assert_array_equal(enc.fit(X2).transform(X2).toarray(), exp) @@ -479,8 +478,9 @@ def test_one_hot_encoder_unsorted_categories(): # unsorted passed categories still raise for numerical values X = np.array([[1, 2]]).T enc = OneHotEncoder(categories=[[2, 1, 3]]) - msg = re.escape('Unsorted categories are not supported') - assert_raises_regex(ValueError, msg, enc.fit_transform, X) + msg = 'Unsorted categories are not supported' + with pytest.raises(ValueError, match=msg): + enc.fit_transform(X) def test_one_hot_encoder_pandas(): diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index eaf72f0b53c8f..f8f4ee4870acf 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -194,8 +194,8 @@ def test_label_encoder(values, classes, unknown): ret = le.fit_transform(values) assert_array_equal(ret, [1, 0, 2, 0, 2]) - msg = "unseen labels" - assert_raise_message(ValueError, msg, le.transform, unknown) + with pytest.raises(ValueError, match="unseen labels"): + le.transform(unknown) def test_label_encoder_negative_ints(): From 0bbbe828b9a2ea777687a06b9c8e58ee0c45b849 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Sun, 1 Jul 2018 11:14:34 +0200 Subject: [PATCH 12/13] fixup merge master --- sklearn/preprocessing/tests/test_encoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index f21ea21fbd8eb..d4f8aaefc34af 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -468,7 +468,7 @@ def test_one_hot_encoder_unsorted_categories(): assert_array_equal(enc.fit(X).transform(X).toarray(), exp) assert_array_equal(enc.fit_transform(X).toarray(), exp) assert enc.categories_[0].tolist() == ['b', 'a', 'c'] - assert np.issubdtype(enc.categories_[0].dtype, np.str_) + assert np.issubdtype(enc.categories_[0].dtype, np.object_) # unsorted passed categories still raise for numerical values X = np.array([[1, 2]]).T From 43655363583ad92b0eed4557c7dae23df6f1ac2f Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 2 Jul 2018 08:49:49 +0200 Subject: [PATCH 13/13] feedback Joel --- sklearn/preprocessing/_encoders.py | 3 ++- sklearn/preprocessing/label.py | 22 +++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index 1d468c703d48c..7516b2af9ec82 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -195,7 +195,8 @@ class OneHotEncoder(_BaseEncoder): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith column. The passed categories should not mix strings and numeric - values, and should be sorted in case of numeric values. + values within a single feature, and should be sorted in case of + numeric values. The used categories can be found in the ``categories_`` attribute. diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 921fa877e60ea..51faccf1a30a1 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -38,6 +38,7 @@ def _encode_numpy(values, uniques=None, encode=False): + # only used in _encode below, see docstring there for details if uniques is None: if encode: uniques, encoded = np.unique(values, return_inverse=True) @@ -48,8 +49,8 @@ def _encode_numpy(values, uniques=None, encode=False): if encode: diff = _encode_check_unknown(values, uniques) if diff: - raise ValueError( - "y contains previously unseen labels: %s" % str(diff)) + raise ValueError("y contains previously unseen labels: %s" + % str(diff)) encoded = np.searchsorted(uniques, values) return uniques, encoded else: @@ -57,6 +58,7 @@ def _encode_numpy(values, uniques=None, encode=False): def _encode_python(values, uniques=None, encode=False): + # only used in _encode below, see docstring there for details if uniques is None: uniques = sorted(set(values)) uniques = np.array(uniques, dtype=values.dtype) @@ -65,21 +67,22 @@ def _encode_python(values, uniques=None, encode=False): try: encoded = np.array([table[v] for v in values]) except KeyError as e: - raise ValueError( - "y contains previously unseen labels: %s" % str(e)) + raise ValueError("y contains previously unseen labels: %s" + % str(e)) return uniques, encoded else: return uniques def _encode(values, uniques=None, encode=False): - """ - Helper function to factorize (find uniques) and encode values. + """Helper function to factorize (find uniques) and encode values. Uses pure python method for object dtype, and numpy method for all other dtypes. The numpy method has the limitation that the `uniques` need to - be sorted. + be sorted. Importantly, this is not checked but assumed to already be + the case. The calling method needs to ensure this for all non-object + values. Parameters ---------- @@ -95,7 +98,8 @@ def _encode(values, uniques=None, encode=False): Returns ------- uniques - If ``encode=False``. + If ``encode=False``. The unique values are sorted if the `uniques` + parameter was None (and thus inferred from the data). (uniques, encoded) If ``encode=True``. @@ -145,7 +149,7 @@ def _encode_check_unknown(values, uniques, return_mask=False): return diff else: unique_values = np.unique(values) - diff = list(np.setdiff1d(unique_values, uniques)) + diff = list(np.setdiff1d(unique_values, uniques, assume_unique=True)) if return_mask: if diff: valid_mask = np.in1d(values, uniques)