-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MRG added classes
parameter to LabelBinarizer
#1643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9beffbd
c6bac77
5cf9219
36e3b57
092c56d
3d15d1a
9690888
eafd52d
a7816b1
1f19bea
384df60
983bcd9
07b5ed1
a18ab06
2f5a5b1
cd09bee
1f70d84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from .utils import check_arrays, array2d, atleast2d_or_csr, safe_asarray | ||
from .utils import warn_if_not_float | ||
from .utils.fixes import unique | ||
from .utils import deprecated | ||
|
||
from .utils.sparsefuncs import inplace_csr_row_normalize_l1 | ||
from .utils.sparsefuncs import inplace_csr_row_normalize_l2 | ||
|
@@ -622,6 +623,18 @@ def _is_multilabel(y): | |
_is_label_indicator_matrix(y)) | ||
|
||
|
||
def _get_label_type(y): | ||
multilabel = _is_multilabel(y) | ||
if multilabel: | ||
if _is_label_indicator_matrix(y): | ||
label_type = "multilabel-indicator" | ||
else: | ||
label_type = "multilabel-list" | ||
else: | ||
label_type = "multiclass" | ||
return label_type | ||
|
||
|
||
class OneHotEncoder(BaseEstimator, TransformerMixin): | ||
"""Encode categorical integer features using a one-hot aka one-of-K scheme. | ||
|
||
|
@@ -919,24 +932,41 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): | |
|
||
Parameters | ||
---------- | ||
|
||
neg_label: int (default: 0) | ||
neg_label : int (default: 0) | ||
Value with which negative labels must be encoded. | ||
|
||
pos_label: int (default: 1) | ||
pos_label : int (default: 1) | ||
Value with which positive labels must be encoded. | ||
|
||
classes : ndarray of int or None (default) | ||
Array of possible classes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The purpose of the Are there cases where finding an absent class means we should raise an error, but others where it makes sense to remove the entry, or substitute a different class? Should we have a There is really no reason the same functionality shouldn't apply to |
||
|
||
label_type : string, default="auto" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we call this |
||
Expected type of y. | ||
Possible values are: | ||
- "multiclass", y is an array-like of ints | ||
- "multilabel-indicator", y is an indicator matrix of classes | ||
- "multiclass-list", y is a list of lists of labels | ||
- "auto", the form of y is determined during 'fit'. If 'fit' is not | ||
called, multiclass is assumed. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When you wrote |
||
|
||
Attributes | ||
---------- | ||
`classes_`: array of shape [n_class] | ||
`classes_` : array of shape [n_class] | ||
Holds the label for each class. | ||
|
||
`label_type_` : string | ||
The type of label used. Inferred from training data if | ||
``label_type="auto"``, otherwise identical to the ``label_type`` | ||
parameter. | ||
|
||
|
||
Examples | ||
-------- | ||
>>> from sklearn import preprocessing | ||
>>> lb = preprocessing.LabelBinarizer() | ||
>>> lb.fit([1, 2, 6, 4, 2]) | ||
LabelBinarizer(neg_label=0, pos_label=1) | ||
LabelBinarizer(classes=None, label_type='auto', neg_label=0, pos_label=1) | ||
>>> lb.classes_ | ||
array([1, 2, 4, 6]) | ||
>>> lb.transform([1, 6]) | ||
|
@@ -950,19 +980,29 @@ class LabelBinarizer(BaseEstimator, TransformerMixin): | |
array([1, 2, 3]) | ||
""" | ||
|
||
def __init__(self, neg_label=0, pos_label=1): | ||
if neg_label >= pos_label: | ||
raise ValueError("neg_label must be strictly less than pos_label.") | ||
def __init__(self, neg_label=0, pos_label=1, classes=None, | ||
label_type='auto'): | ||
|
||
self.neg_label = neg_label | ||
self.pos_label = pos_label | ||
self.classes = classes | ||
self.label_type = label_type | ||
|
||
def _check_fitted(self): | ||
if not hasattr(self, "classes_"): | ||
raise ValueError("LabelBinarizer was not fitted yet.") | ||
if self.classes is not None: | ||
self.classes_ = np.unique(self.classes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The order of the classes won't be preserve. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know. I was wondering about that. I think I'd like to keep it like this for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only place, where I have seen such things is in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "such things" meaning preserving the class order? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the class order must be known to output the precision, recall, fscore in the right order (see averaging type is equal to None). |
||
# default to not doing multi-label things | ||
self.label_type_ = (self.label_type | ||
if self.label_type != "auto" | ||
else "multiclass") | ||
else: | ||
raise ValueError("LabelBinarizer was not fitted yet.") | ||
|
||
def fit(self, y): | ||
"""Fit label binarizer | ||
"""Fit label binarizer. | ||
|
||
No-op if parameter ``classes`` was specified. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -973,16 +1013,36 @@ def fit(self, y): | |
Returns | ||
------- | ||
self : returns an instance of self. | ||
|
||
""" | ||
self.multilabel = _is_multilabel(y) | ||
if self.multilabel: | ||
self.indicator_matrix_ = _is_label_indicator_matrix(y) | ||
if self.indicator_matrix_: | ||
self.classes_ = np.arange(y.shape[1]) | ||
else: | ||
self.classes_ = np.array(sorted(set.union(*map(set, y)))) | ||
if self.neg_label >= self.pos_label: | ||
raise ValueError("neg_label must be strictly less than pos_label.") | ||
|
||
label_type = _get_label_type(y) | ||
|
||
if self.label_type not in ["auto", label_type]: | ||
raise ValueError("label_type was set to %s, but got y of type %s." | ||
% (self.label_type, label_type)) | ||
|
||
self.label_type_ = label_type | ||
|
||
if label_type == "multilabel-indicator": | ||
classes = np.arange(y.shape[1]) | ||
elif label_type == "multilabel-list": | ||
classes = np.array(sorted(set.union(*map(set, y)))) | ||
else: | ||
classes = np.unique(y) | ||
|
||
if self.classes is not None: | ||
classes_set = set(classes) | ||
if not set.issubset(classes_set, self.classes): | ||
difference = set.difference(classes_set, self.classes) | ||
warnings.warn("Found class(es) %s, which was not contained " | ||
"in parameter ``classes`` and will be ignored." | ||
% str(list(difference))) | ||
self.classes_ = np.unique(self.classes) | ||
else: | ||
self.classes_ = np.unique(y) | ||
self.classes_ = classes | ||
return self | ||
|
||
def transform(self, y): | ||
|
@@ -1000,31 +1060,24 @@ def transform(self, y): | |
Returns | ||
------- | ||
Y : numpy array of shape [n_samples, n_classes] | ||
|
||
""" | ||
self._check_fitted() | ||
|
||
if self.multilabel or len(self.classes_) > 2: | ||
if _is_label_indicator_matrix(y): | ||
# nothing to do as y is already a label indicator matrix | ||
return y | ||
|
||
label_type = _get_label_type(y) | ||
if label_type != self.label_type_: | ||
raise ValueError("label_type was set to %s, but got y of type %s." | ||
% (self.label_type_, label_type)) | ||
if label_type == "multilabel-indicator": | ||
# nothing to do as y is already a label indicator matrix | ||
return y | ||
elif label_type == "multilabel-list" or len(self.classes_) > 2: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because for two classes, the output is 1d. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks ! |
||
Y = np.zeros((len(y), len(self.classes_)), dtype=np.int) | ||
else: | ||
Y = np.zeros((len(y), 1), dtype=np.int) | ||
|
||
Y += self.neg_label | ||
|
||
y_is_multilabel = _is_multilabel(y) | ||
|
||
if y_is_multilabel and not self.multilabel: | ||
raise ValueError("The object was not fitted with multilabel" | ||
" input!") | ||
|
||
elif self.multilabel: | ||
if not _is_multilabel(y): | ||
raise ValueError("y should be a list of label lists/tuples," | ||
"got %r" % (y,)) | ||
|
||
if label_type == "multilabel-list": | ||
# inverse map: label => column index | ||
imap = dict((v, k) for k, v in enumerate(self.classes_)) | ||
|
||
|
@@ -1089,10 +1142,10 @@ def inverse_transform(self, Y, threshold=None): | |
half = (self.pos_label - self.neg_label) / 2.0 | ||
threshold = self.neg_label + half | ||
|
||
if self.multilabel: | ||
if self.multilabel_: | ||
Y = np.array(Y > threshold, dtype=int) | ||
# Return the predictions in the same format as in fit | ||
if self.indicator_matrix_: | ||
if self.label_type_ == "multilabel-indicator": | ||
# Label indicator matrix format | ||
return Y | ||
else: | ||
|
@@ -1108,6 +1161,15 @@ def inverse_transform(self, Y, threshold=None): | |
|
||
return self.classes_[y] | ||
|
||
@property | ||
def multilabel_(self): | ||
return self.label_type_ in ["multilabel-list", "multilabel-indicator"] | ||
|
||
@property | ||
@deprecated("it will be removed in 0.15. Use ``label_type_`` instead.") | ||
def label_indicator_(self): | ||
return self.label_type_ == "multilabel-indicator" | ||
|
||
|
||
class KernelCenterer(BaseEstimator, TransformerMixin): | ||
"""Center a kernel matrix | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from sklearn.utils.testing import assert_array_equal | ||
from sklearn.utils.testing import assert_equal | ||
from sklearn.utils.testing import assert_raises | ||
from sklearn.utils.testing import assert_raise_message | ||
from sklearn.utils.testing import assert_true | ||
from sklearn.utils.testing import assert_false | ||
|
||
|
@@ -640,6 +641,67 @@ def test_label_binarizer_iris(): | |
assert_almost_equal(accuracy, accuracy2) | ||
|
||
|
||
def test_label_binarizer_classes(): | ||
# check that explictly giving classes works | ||
lb = LabelBinarizer(classes=np.arange(3)) | ||
y = np.ones(10) | ||
# if classes is specified, we don't need to fit | ||
assert_equal(lb.transform(y).shape, (10, 3)) | ||
assert_array_equal(y, np.argmax(lb.transform(y), axis=1)) | ||
|
||
# check that fitting doesn't change the shape | ||
assert_equal(lb.fit_transform(y).shape, (10, 3)) | ||
|
||
# also works with weird classes: | ||
lb = LabelBinarizer(classes=['a', 'b', 'see']) | ||
transformed = lb.transform(['see', 'see']) | ||
assert_equal(transformed.shape, (2, 3)) | ||
assert_array_equal(np.argmax(transformed, axis=1), [2, 2]) | ||
# test inverse transform | ||
assert_array_equal(['see', 'see'], lb.inverse_transform(transformed)) | ||
|
||
# also works with multilabel data if we say so: | ||
lb = LabelBinarizer(classes=np.arange(1, 3), | ||
label_type="multilabel-list") | ||
y = [(1, 2), (1,), ()] | ||
Y = np.array([[1, 1], | ||
[1, 0], | ||
[0, 0]]) | ||
assert_array_equal(lb.transform(y), Y) | ||
assert_array_equal(lb.fit_transform(y), Y) | ||
# inverse transform of label indicator matrix to label | ||
assert_array_equal(lb.inverse_transform(Y), y) | ||
|
||
# inverse transform with indicator_matrix=True | ||
lb = LabelBinarizer(classes=np.arange(1, 3), | ||
label_type="multilabel-indicator") | ||
assert_array_equal(lb.inverse_transform(Y), Y) | ||
|
||
lb = LabelBinarizer(classes=np.arange(1, 3), label_type="multiclass") | ||
assert_raise_message(ValueError, "label_type was set to multiclass, " | ||
"but got y of type multilabel-list.", | ||
lb.fit, y) | ||
lb = LabelBinarizer(classes=np.arange(1, 3)) | ||
assert_raise_message(ValueError, "label_type was set to multiclass," | ||
" but got y of type multilabel-list.", | ||
lb.transform, y) | ||
|
||
# check that labels present at fit time that are not in 'classes' | ||
# will be ignored but a warning will be shown | ||
lb = LabelBinarizer(classes=[1, 2]) | ||
with warnings.catch_warnings(record=True) as w: | ||
transformed = lb.fit_transform([0, 1, 2]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for testing this :) |
||
|
||
# warning was raised: | ||
assert_equal(len(w), 1) | ||
assert_true("not contained in parameter ``classes`` and will be ignored." | ||
in str(w[0])) | ||
|
||
# result is as for binary case | ||
assert_equal(transformed.shape, (3, 1)) | ||
assert_array_equal(transformed.ravel(), [0, 0, 1]) | ||
|
||
|
||
def test_label_binarizer_multilabel_unlabeled(): | ||
"""Check that LabelBinarizer can handle an unlabeled sample""" | ||
lb = LabelBinarizer() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, you can put this in
utils.multiclass
. :-)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose we replace
is_multilabel
with something like this, that moreover may handle multipley
s: #1985