Skip to content

[MRG] ENH remove mix of multilabel input format #2024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions sklearn/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,11 +930,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True):

# Handle mix representation
if type(y_true) != type(y_pred):
labels = unique_labels(y_true, y_pred)
lb = LabelBinarizer()
lb.fit([labels.tolist()])
y_true = lb.transform(y_true)
y_pred = lb.transform(y_pred)
raise ValueError("Mix of multilabel type input")

if is_label_indicator_matrix(y_true):
try:
Expand Down Expand Up @@ -1052,11 +1048,8 @@ def accuracy_score(y_true, y_pred, normalize=True):

# Handle mix representation
if type(y_true) != type(y_pred):
labels = unique_labels(y_true, y_pred)
lb = LabelBinarizer()
lb.fit([labels.tolist()])
y_true = lb.transform(y_true)
y_pred = lb.transform(y_pred)
raise ValueError("Mix of multilabel type input")


if is_label_indicator_matrix(y_true):
score = (y_pred != y_true).sum(axis=1) == 0
Expand Down Expand Up @@ -1416,11 +1409,8 @@ def _tp_tn_fp_fn(y_true, y_pred, labels=None):
if is_multilabel(y_true):
# Handle mix representation
if type(y_true) != type(y_pred):
labels = unique_labels(y_true, y_pred)
lb = LabelBinarizer()
lb.fit([labels.tolist()])
y_true = lb.transform(y_true)
y_pred = lb.transform(y_pred)
raise ValueError("Mix of multilabel type input")


if is_label_indicator_matrix(y_true):
true_pos = np.sum(np.logical_and(y_true == 1,
Expand Down Expand Up @@ -1644,11 +1634,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
if is_multilabel(y_true):
# Handle mix representation
if type(y_true) != type(y_pred):
labels = unique_labels(y_true, y_pred)
lb = LabelBinarizer()
lb.fit([labels.tolist()])
y_true = lb.transform(y_true)
y_pred = lb.transform(y_pred)
raise ValueError("Mix of multilabel type input")

if is_label_indicator_matrix(y_true):
y_true_pos_label = y_true == 1
Expand Down Expand Up @@ -2210,8 +2196,7 @@ def hamming_loss(y_true, y_pred, classes=None):
lb.fit([classes.tolist()])

if type(y_true) != type(y_pred):
y_true = lb.transform(y_true)
y_pred = lb.transform(y_pred)
raise ValueError("Mix of multilabel type input")

if is_label_indicator_matrix(y_true):
return np.mean(y_true != y_pred)
Expand Down
24 changes: 7 additions & 17 deletions sklearn/metrics/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,12 +1046,12 @@ def test_multilabel_representation_invariance():
# allows to return the shuffled tuple.
rng = check_random_state(42)
shuffled = lambda x: sorted(x, key=lambda *args: rng.rand())
y1_shuffle = [shuffled(x) for x in y1]
y2_shuffle = [shuffled(x) for x in y2]
y1_shuffle = tuple(shuffled(x) for x in y1)
y2_shuffle = tuple(shuffled(x) for x in y2)

# Let's have redundant labels
y1_redundant = [x * rng.randint(1, 4) for x in y1]
y2_redundant = [x * rng.randint(1, 4) for x in y2]
y1_redundant = tuple(x * rng.randint(1, 4) for x in y1)
y2_redundant = tuple(x * rng.randint(1, 4) for x in y2)

# Binary indicator matrix format
lb = LabelBinarizer().fit([range(n_classes)])
Expand Down Expand Up @@ -1099,19 +1099,9 @@ def test_multilabel_representation_invariance():
% name)

# Check invariance with mix input representation
assert_almost_equal(metric(y1, y2_binary_indicator), measure,
err_msg="%s failed mix input representation "
"invariance: y_true in list of list of "
"labels format and y_pred in dense binary "
"indicator format"
% name)

assert_almost_equal(metric(y1_binary_indicator, y2), measure,
err_msg="%s failed mix input representation "
"invariance: y_true in dense binary "
"indicator format and y_pred in list of "
"list of labels format."
% name)
print name
assert_raises(ValueError, metric, y1, y2_binary_indicator)
assert_raises(ValueError, metric, y1_binary_indicator, y2)


def test_multilabel_zero_one_loss_subset():
Expand Down