Skip to content

[MRG] FIX corner cases with unique_labels #2015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 8, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,44 @@
from ..externals.six import string_types


def unique_labels(*lists_of_labels):
def _unique_multiclass(y):
if isinstance(y, np.ndarray):
return np.unique(y)
else:
return set(y)


def _unique_sequence_of_sequence(y):
return set(chain.from_iterable(y))


def _unique_indicator(y):
return np.arange(y.shape[1])


_FN_UNIQUE_LABELS = {
'binary': _unique_multiclass,
'multiclass': _unique_multiclass,
'multilabel-sequences': _unique_sequence_of_sequence,
'multilabel-indicator': _unique_indicator,
}


def unique_labels(*ys):
"""Extract an ordered array of unique labels

We don't allow:
- mix of multilabel and multiclass (single label) targets
- mix of label indicator matrix and anything else,
because there are no explicit labels)
- mix of label indicator matrices of different sizes
- mix of string and integer labels

At the moment, we also don't allow "mutliclass-multioutput" input type.

Parameters
----------
lists_of_labels : list of labels,
The supported "list of labels" are:
- a list / tuple / numpy array of int
- a list of lists / tuples of int;
- a binary indicator matrix (2D numpy array)
ys : array-likes,

Returns
-------
Expand All @@ -45,23 +73,37 @@ def unique_labels(*lists_of_labels):
array([1, 2, 3])

"""
def _unique_labels(y):
classes = None
if is_multilabel(y):
if is_label_indicator_matrix(y):
classes = np.arange(y.shape[1])
else:
classes = np.array(sorted(set(chain(*y))))
if not ys:
raise ValueError('No argument has been passed.')

# Check that we don't mix label format
ys_types = set(type_of_target(x) for x in ys)
if ys_types == set(["binary", "multiclass"]):
ys_types = set(["multiclass"])

if len(ys_types) > 1:
raise ValueError("Mix type of y not allowed, got types %s" % ys_types)

label_type = ys_types.pop()

# Check consistency for the indicator format
if (label_type == "multilabel-indicator" and
len(set(y.shape[1] for y in ys)) > 1):
raise ValueError("Multi-label binary indicator input with "
"different numbers of labels")

else:
classes = np.unique(y)
# Get the unique set of labels
_unique_labels = _FN_UNIQUE_LABELS.get(label_type, None)
if not _unique_labels:
raise ValueError("Unknown label type")

return classes
ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys))

if not lists_of_labels:
raise ValueError('No list of labels has been passed.')
# Check that we don't mix string type with number type
if (len(set(isinstance(label, string_types) for label in ys_labels)) > 1):
raise ValueError("Mix of label input types (string and number)")

return np.unique(np.hstack(_unique_labels(y) for y in lists_of_labels))
return np.array(sorted(ys_labels))


def _is_integral_float(y):
Expand Down
61 changes: 60 additions & 1 deletion sklearn/utils/tests/test_multiclass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np

from itertools import product
from sklearn.externals.six.moves import xrange
from sklearn.externals.six import iteritems

Expand Down Expand Up @@ -136,12 +136,71 @@ def test_unique_labels():
[0, 0, 0]])),
np.arange(3))

assert_array_equal(unique_labels(np.array([[0, 0, 1],
[0, 0, 0]])),
np.arange(3))

# Several arrays passed
assert_array_equal(unique_labels([4, 0, 2], xrange(5)),
np.arange(5))
assert_array_equal(unique_labels((0, 1, 2), (0,), (2, 1)),
np.arange(3))

# Border line case with binary indicator matrix
assert_raises(ValueError, unique_labels, [4, 0, 2], np.ones((5, 5)))
assert_raises(ValueError, unique_labels, np.ones((5, 4)), np.ones((5, 5)))
assert_array_equal(unique_labels(np.ones((4, 5)), np.ones((5, 5))),
np.arange(5))

# Some tests with strings input
assert_array_equal(unique_labels(["a", "b", "c"], ["d"]),
["a", "b", "c", "d"])
assert_array_equal(unique_labels([["a", "b"], ["c"]], [["d"]]),
["a", "b", "c", "d"])

# Smoke test for all supported format
for format in ["binary", "multiclass", "multilabel-sequences",
"multilabel-indicator"]:
for y in EXAMPLES[format]:
unique_labels(y)

# We don't support those format at the moment
for example in NON_ARRAY_LIKE_EXAMPLES:
assert_raises(ValueError, unique_labels, example)

for y_type in ["unknown", "continuous", 'continuous-multioutput',
'multiclass-multioutput']:
for example in EXAMPLES[y_type]:
assert_raises(ValueError, unique_labels, example)

#Mix of multilabel-indicator and multilabel-sequences
mix_multilabel_format = product(EXAMPLES["multilabel-indicator"],
EXAMPLES["multilabel-sequences"])
for y_multilabel, y_multiclass in mix_multilabel_format:
assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel)
assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass)

#Mix with binary or multiclass and multilabel
mix_clf_format = product(EXAMPLES["multilabel-indicator"] +
EXAMPLES["multilabel-sequences"],
EXAMPLES["multiclass"] +
EXAMPLES["binary"])

for y_multilabel, y_multiclass in mix_clf_format:
assert_raises(ValueError, unique_labels, y_multiclass, y_multilabel)
assert_raises(ValueError, unique_labels, y_multilabel, y_multiclass)

# Mix string and number input type
assert_raises(ValueError, unique_labels, [[1, 2], [3]],
[["a", "d"]])
assert_raises(ValueError, unique_labels, ["1", 2])
assert_raises(ValueError, unique_labels, [["1", 2], [3]])
assert_raises(ValueError, unique_labels, [["1", "2"], [3]])

assert_array_equal(unique_labels([(2,), (0, 2,)], [(), ()]), [0, 2])
assert_array_equal(unique_labels([("2",), ("0", "2",)], [(), ()]),
["0", "2"])


def test_is_multilabel():
for group, group_examples in iteritems(EXAMPLES):
Expand Down