diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index be79114e3dfda..11289c433bf2d 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -14,16 +14,44 @@ from ..externals.six import string_types -def unique_labels(*lists_of_labels): +def _unique_multiclass(y): + if isinstance(y, np.ndarray): + return np.unique(y) + else: + return set(y) + + +def _unique_sequence_of_sequence(y): + return set(chain.from_iterable(y)) + + +def _unique_indicator(y): + return np.arange(y.shape[1]) + + +_FN_UNIQUE_LABELS = { + 'binary': _unique_multiclass, + 'multiclass': _unique_multiclass, + 'multilabel-sequences': _unique_sequence_of_sequence, + 'multilabel-indicator': _unique_indicator, +} + + +def unique_labels(*ys): """Extract an ordered array of unique labels + We don't allow: + - mix of multilabel and multiclass (single label) targets + - mix of label indicator matrix and anything else, + because there are no explicit labels) + - mix of label indicator matrices of different sizes + - mix of string and integer labels + + At the moment, we also don't allow "mutliclass-multioutput" input type. + Parameters ---------- - lists_of_labels : list of labels, - The supported "list of labels" are: - - a list / tuple / numpy array of int - - a list of lists / tuples of int; - - a binary indicator matrix (2D numpy array) + ys : array-likes, Returns ------- @@ -45,23 +73,37 @@ def unique_labels(*lists_of_labels): array([1, 2, 3]) """ - def _unique_labels(y): - classes = None - if is_multilabel(y): - if is_label_indicator_matrix(y): - classes = np.arange(y.shape[1]) - else: - classes = np.array(sorted(set(chain(*y)))) + if not ys: + raise ValueError('No argument has been passed.') + + # Check that we don't mix label format + ys_types = set(type_of_target(x) for x in ys) + if ys_types == set(["binary", "multiclass"]): + ys_types = set(["multiclass"]) + + if len(ys_types) > 1: + raise ValueError("Mix type of y not allowed, got types %s" % ys_types) + + label_type = ys_types.pop() + + # Check consistency for the indicator format + if (label_type == "multilabel-indicator" and + len(set(y.shape[1] for y in ys)) > 1): + raise ValueError("Multi-label binary indicator input with " + "different numbers of labels") - else: - classes = np.unique(y) + # Get the unique set of labels + _unique_labels = _FN_UNIQUE_LABELS.get(label_type, None) + if not _unique_labels: + raise ValueError("Unknown label type") - return classes + ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys)) - if not lists_of_labels: - raise ValueError('No list of labels has been passed.') + # Check that we don't mix string type with number type + if (len(set(isinstance(label, string_types) for label in ys_labels)) > 1): + raise ValueError("Mix of label input types (string and number)") - return np.unique(np.hstack(_unique_labels(y) for y in lists_of_labels)) + return np.array(sorted(ys_labels)) def _is_integral_float(y): diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py index e8ad3b77905d2..775a3a8a2248e 100644 --- a/sklearn/utils/tests/test_multiclass.py +++ b/sklearn/utils/tests/test_multiclass.py @@ -1,5 +1,5 @@ import numpy as np - +from itertools import product from sklearn.externals.six.moves import xrange from sklearn.externals.six import iteritems @@ -136,12 +136,71 @@ def test_unique_labels(): [0, 0, 0]])), np.arange(3)) + assert_array_equal(unique_labels(np.array([[0, 0, 1], + [0, 0, 0]])), + np.arange(3)) + # Several arrays passed assert_array_equal(unique_labels([4, 0, 2], xrange(5)), np.arange(5)) assert_array_equal(unique_labels((0, 1, 2), (0,), (2, 1)), np.arange(3)) + # Border line case with binary indicator matrix + assert_raises(ValueError, unique_labels, [4, 0, 2], np.ones((5, 5))) + assert_raises(ValueError, unique_labels, np.ones((5, 4)), np.ones((5, 5))) + assert_array_equal(unique_labels(np.ones((4, 5)), np.ones((5, 5))), + np.arange(5)) + + # Some tests with strings input + assert_array_equal(unique_labels(["a", "b", "c"], ["d"]), + ["a", "b", "c", "d"]) + assert_array_equal(unique_labels([["a", "b"], ["c"]], [["d"]]), + ["a", "b", "c", "d"]) + + # Smoke test for all supported format + for format in ["binary", "multiclass", "multilabel-sequences", + "multilabel-indicator"]: + for y in EXAMPLES[format]: + unique_labels(y) + + # We don't support those format at the moment + for example in NON_ARRAY_LIKE_EXAMPLES: + assert_raises(ValueError, unique_labels, example) + + for y_type in ["unknown", "continuous", 'continuous-multioutput', + 'multiclass-multioutput']: + for example in EXAMPLES[y_type]: + assert_raises(ValueError, unique_labels, example) + + #Mix of multilabel-indicator and multilabel-sequences + mix_multilabel_format = product(EXAMPLES["multilabel-indicator"], + EXAMPLES["multilabel-sequences"]) + for y_multilabel, y_multiclass in mix_multilabel_format: + assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel) + assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass) + + #Mix with binary or multiclass and multilabel + mix_clf_format = product(EXAMPLES["multilabel-indicator"] + + EXAMPLES["multilabel-sequences"], + EXAMPLES["multiclass"] + + EXAMPLES["binary"]) + + for y_multilabel, y_multiclass in mix_clf_format: + assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel) + assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass) + + # Mix string and number input type + assert_raises(ValueError, unique_labels, [[1, 2], [3]], + [["a", "d"]]) + assert_raises(ValueError, unique_labels, ["1", 2]) + assert_raises(ValueError, unique_labels, [["1", 2], [3]]) + assert_raises(ValueError, unique_labels, [["1", "2"], [3]]) + + assert_array_equal(unique_labels([(2,), (0, 2,)], [(), ()]), [0, 2]) + assert_array_equal(unique_labels([("2",), ("0", "2",)], [(), ()]), + ["0", "2"]) + def test_is_multilabel(): for group, group_examples in iteritems(EXAMPLES):