Skip to content

Commit d28d434

Browse files
committed
Add MultiLabelHistogram
1 parent 308a54e commit d28d434

File tree

2 files changed

+243
-0
lines changed

2 files changed

+243
-0
lines changed

sklearn/preprocessing/_label.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Andreas Mueller <amueller@ais.uni-bonn.de>
55
# Joel Nothman <joel.nothman@gmail.com>
66
# Hamzeh Alsalhi <ha258@cornell.edu>
7+
# Charles Moyes <me@charlesmoyes.com>
78
# License: BSD 3 clause
89

910
from collections import defaultdict
@@ -30,6 +31,7 @@
3031
'LabelBinarizer',
3132
'LabelEncoder',
3233
'MultiLabelBinarizer',
34+
'MultiLabelHistogram',
3335
]
3436

3537

@@ -1026,3 +1028,229 @@ def inverse_transform(self, yt):
10261028

10271029
def _more_tags(self):
10281030
return {'X_types': ['2dlabels']}
1031+
1032+
1033+
class MultiLabelHistogram(TransformerMixin, BaseEstimator):
1034+
"""Transform between iterable of dicts and a multilabel format
1035+
1036+
This transformer converts between a class mapping with numerical values
1037+
(such as a histogram) to the multilabel format: a (samples x classes)
1038+
histogram matrix indicating the presence of a class label with an
1039+
associated value.
1040+
1041+
Parameters
1042+
----------
1043+
classes : array-like of shape [n_classes] (optional)
1044+
Indicates an ordering for the class labels.
1045+
All entries should be unique (cannot contain duplicate classes).
1046+
1047+
sparse_output : boolean (default: False),
1048+
Set to true if output binary array is desired in CSR sparse format
1049+
1050+
Attributes
1051+
----------
1052+
classes_ : array of labels
1053+
A copy of the `classes` parameter where provided,
1054+
or otherwise, the sorted set of classes found when fitting.
1055+
1056+
Examples
1057+
--------
1058+
>>> from sklearn.preprocessing import MultiLabelHistogram
1059+
>>> mlh = MultiLabelHistogram()
1060+
>>> mlh.fit_transform([{1: 5.5, 2: -3.0}, {3: 999}])
1061+
array([[ 5.5 -3. 0. ],
1062+
[ 0. 0. 999. ]])
1063+
>>> mlh.classes_
1064+
array([1, 2, 3])
1065+
1066+
>>> mlh.fit_transform([{'sci-fi': -2.0, 'thriller': 5.0}, {'comedy': 0.1}])
1067+
array([[ 0. -2. 5. ],
1068+
[ 0.1 0. 0. ]])
1069+
>>> list(mlh.classes_)
1070+
['comedy', 'sci-fi', 'thriller']
1071+
1072+
See also
1073+
--------
1074+
MultiLabelBinarizer : Transform between iterable
1075+
of iterables and a multilabel format
1076+
"""
1077+
1078+
def __init__(self, classes=None, sparse_output=False):
1079+
self.classes = classes
1080+
self.sparse_output = sparse_output
1081+
1082+
def fit(self, y):
1083+
"""Fit the label sets binarizer, storing :term:`classes_`
1084+
1085+
Parameters
1086+
----------
1087+
y : iterable of iterables
1088+
A set of labels (any orderable and hashable object) for each
1089+
sample. If the `classes` parameter is set, `y` will not be
1090+
iterated.
1091+
1092+
Returns
1093+
-------
1094+
self : returns this MultiLabelHistogram instance
1095+
"""
1096+
self._cached_dict = None
1097+
if self.classes is None:
1098+
classes = sorted(set(itertools.chain.from_iterable(y)))
1099+
elif len(set(self.classes)) < len(self.classes):
1100+
raise ValueError("The classes argument contains duplicate "
1101+
"classes. Remove these duplicates before passing "
1102+
"them to MultiLabelHistogram.")
1103+
else:
1104+
classes = self.classes
1105+
dtype = np.int if all(isinstance(c, int) for c in classes) else object
1106+
self.classes_ = np.empty(len(classes), dtype=dtype)
1107+
self.classes_[:] = classes
1108+
return self
1109+
1110+
def fit_transform(self, y):
1111+
"""Fit the label sets binarizer and transform the given label sets
1112+
1113+
Parameters
1114+
----------
1115+
y : iterable of iterables
1116+
A set of labels (any orderable and hashable object) for each
1117+
sample. If the `classes` parameter is set, `y` will not be
1118+
iterated.
1119+
1120+
Returns
1121+
-------
1122+
y_indicator : array or CSR matrix, shape (n_samples, n_classes)
1123+
A matrix such that `y_indicator[i, j] = (C != 0)` iff
1124+
`classes_[j]` is in `y[i]` with associated value C, and 0
1125+
otherwise.
1126+
"""
1127+
self._cached_dict = None
1128+
1129+
if self.classes is not None:
1130+
return self.fit(y).transform(y)
1131+
1132+
# Automatically increment on new class
1133+
class_mapping = defaultdict(int)
1134+
class_mapping.default_factory = class_mapping.__len__
1135+
yt = self._transform(y, class_mapping)
1136+
1137+
# sort classes and reorder columns
1138+
tmp = sorted(class_mapping, key=class_mapping.get)
1139+
1140+
# (make safe for tuples)
1141+
dtype = np.int if all(isinstance(c, int) for c in tmp) else object
1142+
class_mapping = np.empty(len(tmp), dtype=dtype)
1143+
class_mapping[:] = tmp
1144+
self.classes_, inverse = np.unique(class_mapping, return_inverse=True)
1145+
# ensure yt.indices keeps its current dtype
1146+
yt.indices = np.array(inverse[yt.indices], dtype=yt.indices.dtype,
1147+
copy=False)
1148+
1149+
if not self.sparse_output:
1150+
yt = yt.toarray()
1151+
1152+
return yt
1153+
1154+
def transform(self, y):
1155+
"""Transform the given label sets
1156+
1157+
Parameters
1158+
----------
1159+
y : iterable of dicts
1160+
A dict of labels for each sample, containing associated
1161+
values for each class. If the `classes` parameter is set,
1162+
`y` will not be iterated.
1163+
1164+
Returns
1165+
-------
1166+
y_indicator : array or CSR matrix, shape (n_samples, n_classes)
1167+
A matrix such that `y_indicator[i, j] = (C ! = 0)` iff
1168+
`classes_[j]` is in `y[i]` with associated value C, and 0
1169+
otherwise.
1170+
"""
1171+
check_is_fitted(self)
1172+
1173+
class_to_index = self._build_cache()
1174+
yt = self._transform(y, class_to_index)
1175+
1176+
if not self.sparse_output:
1177+
yt = yt.toarray()
1178+
1179+
return yt
1180+
1181+
def _build_cache(self):
1182+
if self._cached_dict is None:
1183+
self._cached_dict = dict(zip(self.classes_,
1184+
range(len(self.classes_))))
1185+
1186+
return self._cached_dict
1187+
1188+
def _transform(self, y, class_mapping):
1189+
"""Transforms the label sets with a given mapping
1190+
1191+
Parameters
1192+
----------
1193+
y : iterable of iterables
1194+
class_mapping : Mapping
1195+
Maps from label to column index in label value matrix
1196+
1197+
Returns
1198+
-------
1199+
y_indicator : sparse CSR matrix, shape (n_samples, n_classes)
1200+
Label indicator matrix
1201+
"""
1202+
data = array.array('f')
1203+
indices = array.array('i')
1204+
indptr = array.array('i', [0])
1205+
unknown = set()
1206+
for labels in y:
1207+
index = set()
1208+
for label, value in labels.items():
1209+
try:
1210+
index.add(class_mapping[label])
1211+
except KeyError:
1212+
unknown.add(label)
1213+
1214+
data.append(value)
1215+
1216+
indices.extend(index)
1217+
indptr.append(len(indices))
1218+
if unknown:
1219+
warnings.warn('unknown class(es) {0} will be ignored'
1220+
.format(sorted(unknown, key=str)))
1221+
1222+
return sp.csr_matrix((data, indices, indptr),
1223+
shape=(len(indptr) - 1, len(class_mapping)))
1224+
1225+
def inverse_transform(self, yt):
1226+
"""Transform the given indicator matrix into label sets
1227+
1228+
Parameters
1229+
----------
1230+
yt : array or sparse matrix of shape (n_samples, n_classes)
1231+
A matrix containing associated values for each class in each
1232+
sample.
1233+
1234+
Returns
1235+
-------
1236+
y : list of tuples
1237+
The set of labels for each sample such that `y[i]` consists of
1238+
`classes_[j]` for each `yt[i, j] == (C != 0)`.
1239+
"""
1240+
check_is_fitted(self)
1241+
1242+
if yt.shape[1] != len(self.classes_):
1243+
raise ValueError('Expected indicator for {0} classes, but got {1}'
1244+
.format(len(self.classes_), yt.shape[1]))
1245+
1246+
if sp.issparse(yt):
1247+
yt = yt.tocsr()
1248+
return [dict(zip(self.classes_.take(yt.indices[start:end]), yt.data[start:end]))
1249+
for start, end in zip(yt.indptr[:-1], yt.indptr[1:])]
1250+
else:
1251+
print([ind for ind in yt])
1252+
return [dict(zip(self.classes_.compress(indicators), filter(None, indicators)))
1253+
for indicators in yt]
1254+
1255+
def _more_tags(self):
1256+
return {'X_types': ['2dlabels']}

sklearn/preprocessing/tests/test_label.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from sklearn.preprocessing._label import LabelBinarizer
1919
from sklearn.preprocessing._label import MultiLabelBinarizer
20+
from sklearn.preprocessing._label import MultiLabelHistogram
2021
from sklearn.preprocessing._label import LabelEncoder
2122
from sklearn.preprocessing._label import label_binarize
2223

@@ -654,3 +655,17 @@ def test_encode_check_unknown():
654655
with pytest.raises(ValueError,
655656
match='y contains previously unseen labels'):
656657
_encode(values, uniques, encode=True, check_unknown=False)
658+
659+
660+
def test_multilabel_histogram():
661+
mlh = MultiLabelHistogram()
662+
y = [{1: 5.5, 2: -3.0}, {3: 999}]
663+
Y = np.array([[ 5.5 -3. 0. ],
664+
[ 0. 0. 999. ]])
665+
assert_array_equal(mlh.fit_transform(y), Y)
666+
667+
mlh = MultiLabelHistogram()
668+
y = [{'sci-fi': -2.0, 'thriller': 5.0}, {'comedy': 0.1}]
669+
Y = np.array([[ 0. -2. 5. ],
670+
[ 0.1 0. 0. ]])
671+
assert_array_equal(mlh.fit_transform(y), Y)

0 commit comments

Comments
 (0)