Skip to content

[MRG+1] Improve the error message for some metrics when the shape of sample_weight is inappropriate #9903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):

# Compute accuracy for each possible representation
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
if y_type.startswith('multilabel'):
differing_labels = count_nonzero(y_true - y_pred, axis=1)
score = differing_labels == 0
Expand Down Expand Up @@ -263,7 +264,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
else:
sample_weight = np.asarray(sample_weight)

check_consistent_length(sample_weight, y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)

n_labels = labels.size
label_to_ind = dict((y, x) for x, y in enumerate(labels))
Expand Down Expand Up @@ -444,6 +445,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,

# Compute accuracy for each possible representation
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
if y_type.startswith('multilabel'):
with np.errstate(divide='ignore', invalid='ignore'):
# oddly, we may get an "invalid" rather than a "divide" error here
Expand Down Expand Up @@ -519,6 +521,7 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None):
-0.33...
"""
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
if y_type not in {"binary", "multiclass"}:
raise ValueError("%s is not supported" % y_type)

Expand Down Expand Up @@ -1023,6 +1026,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
raise ValueError("beta should be >0 in the F-beta score")

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
present_labels = unique_labels(y_true, y_pred)

if average == 'binary':
Expand Down Expand Up @@ -1550,6 +1554,7 @@ def hamming_loss(y_true, y_pred, labels=None, sample_weight=None,
labels = classes

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)

if labels is None:
labels = unique_labels(y_true, y_pred)
Expand Down Expand Up @@ -1638,7 +1643,7 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None,
The logarithm used is the natural logarithm (base-e).
"""
y_pred = check_array(y_pred, ensure_2d=False)
check_consistent_length(y_pred, y_true)
check_consistent_length(y_pred, y_true, sample_weight)

lb = LabelBinarizer()

Expand Down Expand Up @@ -1911,6 +1916,7 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None):
y_prob = column_or_1d(y_prob)
assert_all_finite(y_true)
assert_all_finite(y_prob)
check_consistent_length(y_true, y_prob, sample_weight)

if pos_label is None:
pos_label = y_true.max()
Expand Down
5 changes: 5 additions & 0 deletions sklearn/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def mean_absolute_error(y_true, y_pred,
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)
output_errors = np.average(np.abs(y_pred - y_true),
weights=sample_weight, axis=0)
if isinstance(multioutput, string_types):
Expand Down Expand Up @@ -236,6 +237,7 @@ def mean_squared_error(y_true, y_pred,
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)
output_errors = np.average((y_true - y_pred) ** 2, axis=0,
weights=sample_weight)
if isinstance(multioutput, string_types):
Expand Down Expand Up @@ -306,6 +308,7 @@ def mean_squared_log_error(y_true, y_pred,
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)

if not (y_true >= 0).all() and not (y_pred >= 0).all():
raise ValueError("Mean Squared Logarithmic Error cannot be used when "
Expand Down Expand Up @@ -409,6 +412,7 @@ def explained_variance_score(y_true, y_pred,
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)

y_diff_avg = np.average(y_true - y_pred, weights=sample_weight, axis=0)
numerator = np.average((y_true - y_pred - y_diff_avg) ** 2,
Expand Down Expand Up @@ -528,6 +532,7 @@ def r2_score(y_true, y_pred, sample_weight=None,
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)

if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
Expand Down
14 changes: 10 additions & 4 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.datasets import make_multilabel_classification
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _num_samples
from sklearn.utils.validation import check_random_state
from sklearn.utils import shuffle

Expand Down Expand Up @@ -1005,10 +1006,15 @@ def check_sample_weight_invariance(name, metric, y1, y2):
err_msg="%s sample_weight is not invariant "
"under scaling" % name)

# Check that if sample_weight.shape[0] != y_true.shape[0], it raised an
# error
assert_raises(Exception, metric, y1, y2,
sample_weight=np.hstack([sample_weight, sample_weight]))
# Check that if number of samples in y_true and sample_weight are not
# equal, meaningful error is raised.
error_message = ("Found input variables with inconsistent numbers of "
"samples: [{}, {}, {}]".format(
_num_samples(y1), _num_samples(y2),
_num_samples(sample_weight) * 2))
assert_raise_message(ValueError, error_message, metric, y1, y2,
sample_weight=np.hstack([sample_weight,
sample_weight]))


def test_sample_weight_invariance(n_samples=50):
Expand Down