-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG-0] Make LabelEncoder more friendly to new labels #3483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7fabf54
fac95e1
0d3851f
4ac58af
e314ed6
751b585
866e939
da4cafb
0f3e3d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,18 @@ | ||
# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr> | ||
# Mathieu Blondel <mathieu@mblondel.org> | ||
# Olivier Grisel <olivier.grisel@ensta.org> | ||
# Mathieu Blondel <mathieu@mblondel.org> | ||
# Olivier Grisel <olivier.grisel@ensta.org> | ||
# Andreas Mueller <amueller@ais.uni-bonn.de> | ||
# Joel Nothman <joel.nothman@gmail.com> | ||
# Hamzeh Alsalhi <ha258@cornell.edu> | ||
# Michael Bommarito <michael@bommaritollc.com> | ||
# License: BSD 3 clause | ||
|
||
from collections import defaultdict | ||
import itertools | ||
import array | ||
import warnings | ||
|
||
import operator | ||
import numpy as np | ||
import scipy.sparse as sp | ||
|
||
|
@@ -53,10 +55,34 @@ def _check_numpy_unicode_bug(labels): | |
class LabelEncoder(BaseEstimator, TransformerMixin): | ||
"""Encode labels with value between 0 and n_classes-1. | ||
|
||
Parameters | ||
---------- | ||
|
||
new_labels : string, optional (default: "raise") | ||
Determines how to handle new labels, i.e., data | ||
not seen in the training domain. | ||
|
||
- If ``"raise"``, then raise ValueError. | ||
- If ``"update"``, then re-map the new labels to | ||
classes ``[N, ..., N+m-1]``, where ``m`` is the number of new labels. | ||
- If an integer value is passed, then re-label with this value. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it work with string label (string)? |
||
N.B. that default values are in [0, 1, ...], so caution should be | ||
taken if a non-negative value is passed to not accidentally | ||
intersect. Additionally, ``inverse_transform`` will fail for a | ||
value that does not intersect with the ``fit``-time label set. | ||
|
||
Attributes | ||
---------- | ||
`classes_` : array of shape (n_class,) | ||
Holds the label for each class. | ||
Property that holds the label for each class that were seen at fit. | ||
See ``get_classes()`` and ``set_classes()`` to retrieve all | ||
view getter and setter for observed labels. | ||
|
||
`fit_labels` : array of shape (n_class,) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a duplicate of classes_. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re: classes_, classes_ alone as an array cannot handle all new_labels I won't have time to make this volume of changes for a few weeks. Thanks, On Mon, Aug 11, 2014 at 3:47 AM, Arnaud Joly notifications@github.com
|
||
Stores the labels seen at ``fit``-time. | ||
|
||
`new_label_mapping_` : dictionary | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not follow our convention and consider that the order in |
||
Stores the mapping for classes not seen during original ``fit``. | ||
|
||
Examples | ||
-------- | ||
|
@@ -65,7 +91,7 @@ class LabelEncoder(BaseEstimator, TransformerMixin): | |
>>> from sklearn import preprocessing | ||
>>> le = preprocessing.LabelEncoder() | ||
>>> le.fit([1, 2, 2, 6]) | ||
LabelEncoder() | ||
LabelEncoder(new_labels='raise') | ||
>>> le.classes_ | ||
array([1, 2, 6]) | ||
>>> le.transform([1, 1, 2, 6]) #doctest: +ELLIPSIS | ||
|
@@ -78,7 +104,7 @@ class LabelEncoder(BaseEstimator, TransformerMixin): | |
|
||
>>> le = preprocessing.LabelEncoder() | ||
>>> le.fit(["paris", "paris", "tokyo", "amsterdam"]) | ||
LabelEncoder() | ||
LabelEncoder(new_labels='raise') | ||
>>> list(le.classes_) | ||
['amsterdam', 'paris', 'tokyo'] | ||
>>> le.transform(["tokyo", "tokyo", "paris"]) #doctest: +ELLIPSIS | ||
|
@@ -88,10 +114,42 @@ class LabelEncoder(BaseEstimator, TransformerMixin): | |
|
||
""" | ||
|
||
def __init__(self, new_labels="raise"): | ||
"""Constructor""" | ||
self.new_labels = new_labels | ||
self.new_label_mapping_ = {} | ||
self.fit_labels_ = [] | ||
|
||
def _check_fitted(self): | ||
if not hasattr(self, "classes_"): | ||
if len(self.fit_labels_) == 0: | ||
raise ValueError("LabelEncoder was not fitted yet.") | ||
|
||
def get_classes(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make this methods private as it doesn't really follow the api? Why not updating the self.classes_ attribute? |
||
"""Get classes that have been observed by the encoder. Note that this | ||
method returns classes seen both at original ``fit`` time (i.e., | ||
``self.classes_``) and classes seen after ``fit`` (i.e., | ||
``self.new_label_mapping_.keys()``) for applicable values of | ||
``new_labels``. | ||
|
||
Returns | ||
------- | ||
classes : array-like of shape [n_classes] | ||
""" | ||
# If we've seen updates, include them in the order they were added. | ||
if len(self.new_label_mapping_) > 0: | ||
# Sort the post-fit time labels to return into the class array. | ||
sorted_new, _ = zip(*sorted(self.new_label_mapping_.items(), | ||
key=operator.itemgetter(1))) | ||
return np.append(self.fit_labels_, sorted_new) | ||
else: | ||
return self.fit_labels_ | ||
|
||
def set_classes(self, classes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function add new method that are not part of the api. Furthermore, it's juste one lines so I would inline it if needed. |
||
"""Set the classes via property.""" | ||
self.fit_labels_ = classes | ||
|
||
classes_ = property(get_classes, set_classes) | ||
|
||
def fit(self, y): | ||
"""Fit label encoder | ||
|
||
|
@@ -104,9 +162,17 @@ def fit(self, y): | |
------- | ||
self : returns an instance of self. | ||
""" | ||
# Check new_labels parameter | ||
if self.new_labels not in ["update", "raise"] and \ | ||
type(self.new_labels) not in [int]: | ||
# Raise on invalid argument. | ||
raise ValueError("Value of argument `new_labels`={0} " | ||
"is unknown and not integer." | ||
.format(self.new_labels)) | ||
|
||
y = column_or_1d(y, warn=True) | ||
_check_numpy_unicode_bug(y) | ||
self.classes_ = np.unique(y) | ||
self.fit_labels_ = np.unique(y) | ||
return self | ||
|
||
def fit_transform(self, y): | ||
|
@@ -121,9 +187,17 @@ def fit_transform(self, y): | |
------- | ||
y : array-like of shape [n_samples] | ||
""" | ||
# Check new_labels parameter | ||
if self.new_labels not in ["update", "raise"] and \ | ||
type(self.new_labels) not in [int]: | ||
# Raise on invalid argument. | ||
raise ValueError("Value of argument `new_labels`={0} " | ||
"is unknown and not integer." | ||
.format(self.new_labels)) | ||
|
||
y = column_or_1d(y, warn=True) | ||
_check_numpy_unicode_bug(y) | ||
self.classes_, y = np.unique(y, return_inverse=True) | ||
self.fit_labels_, y = np.unique(y, return_inverse=True) | ||
return y | ||
|
||
def transform(self, y): | ||
|
@@ -142,10 +216,53 @@ def transform(self, y): | |
|
||
classes = np.unique(y) | ||
_check_numpy_unicode_bug(classes) | ||
if len(np.intersect1d(classes, self.classes_)) < len(classes): | ||
diff = np.setdiff1d(classes, self.classes_) | ||
raise ValueError("y contains new labels: %s" % str(diff)) | ||
return np.searchsorted(self.classes_, y) | ||
if len(np.intersect1d(classes, self.get_classes())) < len(classes): | ||
# Get the new classes | ||
diff_fit = np.setdiff1d(classes, self.fit_labels_) | ||
diff_new = np.setdiff1d(classes, self.get_classes()) | ||
|
||
# Create copy of array and return | ||
y = np.array(y) | ||
|
||
# If we are mapping new labels, get "new" ID and change in copy. | ||
if self.new_labels == "update": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a partial_fit method. Should we add a partial fit to the label encoder? |
||
# Update the new label mapping | ||
next_label = len(self.get_classes()) | ||
self.new_label_mapping_.update(dict(zip(diff_new, | ||
range(next_label, | ||
next_label + | ||
len(diff_new))))) | ||
|
||
# Find entries with new labels | ||
missing_mask = np.in1d(y, diff_fit) | ||
|
||
# Populate return array properly by mask and return | ||
out = np.searchsorted(self.fit_labels_, y) | ||
out[missing_mask] = [self.new_label_mapping_[value] | ||
for value in y[missing_mask]] | ||
return out | ||
elif type(self.new_labels) in [int]: | ||
# Update the new label mapping | ||
self.new_label_mapping_.update(dict(zip(diff_new, | ||
[self.new_labels] | ||
* len(diff_new)))) | ||
|
||
# Find entries with new labels | ||
missing_mask = np.in1d(y, diff_fit) | ||
|
||
# Populate return array properly by mask and return | ||
out = np.searchsorted(self.fit_labels_, y) | ||
out[missing_mask] = self.new_labels | ||
return out | ||
elif self.new_labels == "raise": | ||
# Return ValueError, original behavior. | ||
raise ValueError("y contains new labels: %s" % str(diff_fit)) | ||
else: | ||
# Raise on invalid argument. | ||
raise ValueError("Value of argument `new_labels`={0} " | ||
"is unknown.".format(self.new_labels)) | ||
|
||
return np.searchsorted(self.fit_labels_, y) | ||
|
||
def inverse_transform(self, y): | ||
"""Transform labels back to original encoding. | ||
|
@@ -161,8 +278,24 @@ def inverse_transform(self, y): | |
""" | ||
self._check_fitted() | ||
|
||
if type(self.new_labels) in [int]: | ||
warnings.warn('When ``new_labels`` uses an integer ' | ||
're-labeling strategy, the ``inverse_transform`` ' | ||
'is not necessarily one-to-one mapping; any ' | ||
'labels not present during initial ``fit`` will ' | ||
'not be mapped.', | ||
UserWarning) | ||
|
||
y = np.asarray(y) | ||
return self.classes_[y] | ||
try: | ||
return self.get_classes()[y] | ||
except IndexError: | ||
# Raise exception | ||
num_classes = len(self.get_classes()) | ||
raise ValueError("Classes were passed to ``inverse_transform`` " | ||
"with integer new_labels strategy ``fit``-time: " | ||
"{0}" | ||
.format(np.setdiff1d(y, range(num_classes)))) | ||
|
||
|
||
class LabelBinarizer(BaseEstimator, TransformerMixin): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you undo those changes?