|
4 | 4 | # Andreas Mueller <amueller@ais.uni-bonn.de>
|
5 | 5 | # Joel Nothman <joel.nothman@gmail.com>
|
6 | 6 | # Hamzeh Alsalhi <ha258@cornell.edu>
|
| 7 | +# Charles Moyes <me@charlesmoyes.com> |
7 | 8 | # License: BSD 3 clause
|
8 | 9 |
|
9 | 10 | from collections import defaultdict
|
|
30 | 31 | 'LabelBinarizer',
|
31 | 32 | 'LabelEncoder',
|
32 | 33 | 'MultiLabelBinarizer',
|
| 34 | + 'MultiLabelHistogram', |
33 | 35 | ]
|
34 | 36 |
|
35 | 37 |
|
@@ -1026,3 +1028,229 @@ def inverse_transform(self, yt):
|
1026 | 1028 |
|
1027 | 1029 | def _more_tags(self):
|
1028 | 1030 | return {'X_types': ['2dlabels']}
|
| 1031 | + |
| 1032 | + |
| 1033 | +class MultiLabelHistogram(TransformerMixin, BaseEstimator): |
| 1034 | + """Transform between iterable of dicts and a multilabel format |
| 1035 | +
|
| 1036 | + This transformer converts between a class mapping with numerical values |
| 1037 | + (such as a histogram) to the multilabel format: a (samples x classes) |
| 1038 | + histogram matrix indicating the presence of a class label with an |
| 1039 | + associated value. |
| 1040 | +
|
| 1041 | + Parameters |
| 1042 | + ---------- |
| 1043 | + classes : array-like of shape [n_classes] (optional) |
| 1044 | + Indicates an ordering for the class labels. |
| 1045 | + All entries should be unique (cannot contain duplicate classes). |
| 1046 | +
|
| 1047 | + sparse_output : boolean (default: False), |
| 1048 | + Set to true if output binary array is desired in CSR sparse format |
| 1049 | +
|
| 1050 | + Attributes |
| 1051 | + ---------- |
| 1052 | + classes_ : array of labels |
| 1053 | + A copy of the `classes` parameter where provided, |
| 1054 | + or otherwise, the sorted set of classes found when fitting. |
| 1055 | +
|
| 1056 | + Examples |
| 1057 | + -------- |
| 1058 | + >>> from sklearn.preprocessing import MultiLabelHistogram |
| 1059 | + >>> mlh = MultiLabelHistogram() |
| 1060 | + >>> mlh.fit_transform([{1: 5.5, 2: -3.0}, {3: 999}]) |
| 1061 | + array([[ 5.5 -3. 0. ], |
| 1062 | + [ 0. 0. 999. ]]) |
| 1063 | + >>> mlh.classes_ |
| 1064 | + array([1, 2, 3]) |
| 1065 | +
|
| 1066 | + >>> mlh.fit_transform([{'sci-fi': -2.0, 'thriller': 5.0}, {'comedy': 0.1}]) |
| 1067 | + array([[ 0. -2. 5. ], |
| 1068 | + [ 0.1 0. 0. ]]) |
| 1069 | + >>> list(mlh.classes_) |
| 1070 | + ['comedy', 'sci-fi', 'thriller'] |
| 1071 | +
|
| 1072 | + See also |
| 1073 | + -------- |
| 1074 | + MultiLabelBinarizer : Transform between iterable |
| 1075 | + of iterables and a multilabel format |
| 1076 | + """ |
| 1077 | + |
| 1078 | + def __init__(self, classes=None, sparse_output=False): |
| 1079 | + self.classes = classes |
| 1080 | + self.sparse_output = sparse_output |
| 1081 | + |
| 1082 | + def fit(self, y): |
| 1083 | + """Fit the label sets binarizer, storing :term:`classes_` |
| 1084 | +
|
| 1085 | + Parameters |
| 1086 | + ---------- |
| 1087 | + y : iterable of iterables |
| 1088 | + A set of labels (any orderable and hashable object) for each |
| 1089 | + sample. If the `classes` parameter is set, `y` will not be |
| 1090 | + iterated. |
| 1091 | +
|
| 1092 | + Returns |
| 1093 | + ------- |
| 1094 | + self : returns this MultiLabelHistogram instance |
| 1095 | + """ |
| 1096 | + self._cached_dict = None |
| 1097 | + if self.classes is None: |
| 1098 | + classes = sorted(set(itertools.chain.from_iterable(y))) |
| 1099 | + elif len(set(self.classes)) < len(self.classes): |
| 1100 | + raise ValueError("The classes argument contains duplicate " |
| 1101 | + "classes. Remove these duplicates before passing " |
| 1102 | + "them to MultiLabelHistogram.") |
| 1103 | + else: |
| 1104 | + classes = self.classes |
| 1105 | + dtype = np.int if all(isinstance(c, int) for c in classes) else object |
| 1106 | + self.classes_ = np.empty(len(classes), dtype=dtype) |
| 1107 | + self.classes_[:] = classes |
| 1108 | + return self |
| 1109 | + |
| 1110 | + def fit_transform(self, y): |
| 1111 | + """Fit the label sets binarizer and transform the given label sets |
| 1112 | +
|
| 1113 | + Parameters |
| 1114 | + ---------- |
| 1115 | + y : iterable of iterables |
| 1116 | + A set of labels (any orderable and hashable object) for each |
| 1117 | + sample. If the `classes` parameter is set, `y` will not be |
| 1118 | + iterated. |
| 1119 | +
|
| 1120 | + Returns |
| 1121 | + ------- |
| 1122 | + y_indicator : array or CSR matrix, shape (n_samples, n_classes) |
| 1123 | + A matrix such that `y_indicator[i, j] = (C != 0)` iff |
| 1124 | + `classes_[j]` is in `y[i]` with associated value C, and 0 |
| 1125 | + otherwise. |
| 1126 | + """ |
| 1127 | + self._cached_dict = None |
| 1128 | + |
| 1129 | + if self.classes is not None: |
| 1130 | + return self.fit(y).transform(y) |
| 1131 | + |
| 1132 | + # Automatically increment on new class |
| 1133 | + class_mapping = defaultdict(int) |
| 1134 | + class_mapping.default_factory = class_mapping.__len__ |
| 1135 | + yt = self._transform(y, class_mapping) |
| 1136 | + |
| 1137 | + # sort classes and reorder columns |
| 1138 | + tmp = sorted(class_mapping, key=class_mapping.get) |
| 1139 | + |
| 1140 | + # (make safe for tuples) |
| 1141 | + dtype = np.int if all(isinstance(c, int) for c in tmp) else object |
| 1142 | + class_mapping = np.empty(len(tmp), dtype=dtype) |
| 1143 | + class_mapping[:] = tmp |
| 1144 | + self.classes_, inverse = np.unique(class_mapping, return_inverse=True) |
| 1145 | + # ensure yt.indices keeps its current dtype |
| 1146 | + yt.indices = np.array(inverse[yt.indices], dtype=yt.indices.dtype, |
| 1147 | + copy=False) |
| 1148 | + |
| 1149 | + if not self.sparse_output: |
| 1150 | + yt = yt.toarray() |
| 1151 | + |
| 1152 | + return yt |
| 1153 | + |
| 1154 | + def transform(self, y): |
| 1155 | + """Transform the given label sets |
| 1156 | +
|
| 1157 | + Parameters |
| 1158 | + ---------- |
| 1159 | + y : iterable of dicts |
| 1160 | + A dict of labels for each sample, containing associated |
| 1161 | + values for each class. If the `classes` parameter is set, |
| 1162 | + `y` will not be iterated. |
| 1163 | +
|
| 1164 | + Returns |
| 1165 | + ------- |
| 1166 | + y_indicator : array or CSR matrix, shape (n_samples, n_classes) |
| 1167 | + A matrix such that `y_indicator[i, j] = (C ! = 0)` iff |
| 1168 | + `classes_[j]` is in `y[i]` with associated value C, and 0 |
| 1169 | + otherwise. |
| 1170 | + """ |
| 1171 | + check_is_fitted(self) |
| 1172 | + |
| 1173 | + class_to_index = self._build_cache() |
| 1174 | + yt = self._transform(y, class_to_index) |
| 1175 | + |
| 1176 | + if not self.sparse_output: |
| 1177 | + yt = yt.toarray() |
| 1178 | + |
| 1179 | + return yt |
| 1180 | + |
| 1181 | + def _build_cache(self): |
| 1182 | + if self._cached_dict is None: |
| 1183 | + self._cached_dict = dict(zip(self.classes_, |
| 1184 | + range(len(self.classes_)))) |
| 1185 | + |
| 1186 | + return self._cached_dict |
| 1187 | + |
| 1188 | + def _transform(self, y, class_mapping): |
| 1189 | + """Transforms the label sets with a given mapping |
| 1190 | +
|
| 1191 | + Parameters |
| 1192 | + ---------- |
| 1193 | + y : iterable of iterables |
| 1194 | + class_mapping : Mapping |
| 1195 | + Maps from label to column index in label value matrix |
| 1196 | +
|
| 1197 | + Returns |
| 1198 | + ------- |
| 1199 | + y_indicator : sparse CSR matrix, shape (n_samples, n_classes) |
| 1200 | + Label indicator matrix |
| 1201 | + """ |
| 1202 | + data = array.array('f') |
| 1203 | + indices = array.array('i') |
| 1204 | + indptr = array.array('i', [0]) |
| 1205 | + unknown = set() |
| 1206 | + for labels in y: |
| 1207 | + index = set() |
| 1208 | + for label, value in labels.items(): |
| 1209 | + try: |
| 1210 | + index.add(class_mapping[label]) |
| 1211 | + except KeyError: |
| 1212 | + unknown.add(label) |
| 1213 | + |
| 1214 | + data.append(value) |
| 1215 | + |
| 1216 | + indices.extend(index) |
| 1217 | + indptr.append(len(indices)) |
| 1218 | + if unknown: |
| 1219 | + warnings.warn('unknown class(es) {0} will be ignored' |
| 1220 | + .format(sorted(unknown, key=str))) |
| 1221 | + |
| 1222 | + return sp.csr_matrix((data, indices, indptr), |
| 1223 | + shape=(len(indptr) - 1, len(class_mapping))) |
| 1224 | + |
| 1225 | + def inverse_transform(self, yt): |
| 1226 | + """Transform the given indicator matrix into label sets |
| 1227 | +
|
| 1228 | + Parameters |
| 1229 | + ---------- |
| 1230 | + yt : array or sparse matrix of shape (n_samples, n_classes) |
| 1231 | + A matrix containing associated values for each class in each |
| 1232 | + sample. |
| 1233 | +
|
| 1234 | + Returns |
| 1235 | + ------- |
| 1236 | + y : list of tuples |
| 1237 | + The set of labels for each sample such that `y[i]` consists of |
| 1238 | + `classes_[j]` for each `yt[i, j] == (C != 0)`. |
| 1239 | + """ |
| 1240 | + check_is_fitted(self) |
| 1241 | + |
| 1242 | + if yt.shape[1] != len(self.classes_): |
| 1243 | + raise ValueError('Expected indicator for {0} classes, but got {1}' |
| 1244 | + .format(len(self.classes_), yt.shape[1])) |
| 1245 | + |
| 1246 | + if sp.issparse(yt): |
| 1247 | + yt = yt.tocsr() |
| 1248 | + return [dict(zip(self.classes_.take(yt.indices[start:end]), yt.data[start:end])) |
| 1249 | + for start, end in zip(yt.indptr[:-1], yt.indptr[1:])] |
| 1250 | + else: |
| 1251 | + print([ind for ind in yt]) |
| 1252 | + return [dict(zip(self.classes_.compress(indicators), filter(None, indicators))) |
| 1253 | + for indicators in yt] |
| 1254 | + |
| 1255 | + def _more_tags(self): |
| 1256 | + return {'X_types': ['2dlabels']} |
0 commit comments