diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 5a97f503682e3..667679918bf82 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -910,11 +910,9 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True): # Compute accuracy for each possible representation y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) if y_type == 'multilabel-indicator': - try: + with np.errstate(divide='ignore', invalid='ignore'): # oddly, we may get an "invalid" rather than a "divide" # error here - old_err_settings = np.seterr(divide='ignore', - invalid='ignore') y_pred_pos_label = y_pred == 1 y_true_pos_label = y_true == 1 pred_inter_true = np.sum(np.logical_and(y_pred_pos_label, @@ -929,8 +927,6 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True): # the jaccard to 1: lim_{x->0} x/x = 1 # Note with py2.6 and np 1.3: we can't check safely for nan. score[pred_union_true == 0.0] = 1.0 - finally: - np.seterr(**old_err_settings) elif y_type == 'multilabel-sequences': score = np.empty(len(y_true), dtype=np.float) @@ -1448,24 +1444,37 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, size_true[i] = len(true_set) else: raise ValueError("Example-based precision, recall, fscore is " - "not meaning full outside multilabe" - "classification. See the accuracy_score instead.") + "not meaningful outside of multilabel" + "classification. Use accuracy_score instead.") - try: + warning_msg = "" + if np.any(size_pred == 0): + warning_msg += ("Sample-based precision is undefined for some " + "samples. ") + + if np.any(size_true == 0): + warning_msg += ("Sample-based recall is undefined for some " + "samples. ") + + if np.any((beta2 * size_true + size_pred) == 0): + warning_msg += ("Sample-based f_score is undefined for some " + "samples. ") + + if warning_msg: + warnings.warn(warning_msg) + + with np.errstate(divide="ignore", invalid="ignore"): # oddly, we may get an "invalid" rather than a "divide" error # here - old_err_settings = np.seterr(divide='ignore', invalid='ignore') - - precision = size_inter / size_true - recall = size_inter / size_pred - f_score = ((1 + beta2 ** 2) * size_inter / - (beta2 * size_pred + size_true)) - finally: - np.seterr(**old_err_settings) + precision = divide(size_inter, size_pred, dtype=np.double) + recall = divide(size_inter, size_true, dtype=np.double) + f_score = divide((1 + beta2) * size_inter, + (beta2 * size_true + size_pred), + dtype=np.double) - precision[size_true == 0] = 1.0 - recall[size_pred == 0] = 1.0 - f_score[(beta2 * size_pred + size_true) == 0] = 1.0 + precision[size_pred == 0] = 0.0 + recall[size_true == 0] = 0.0 + f_score[(beta2 * size_true + size_pred) == 0] = 0.0 precision = np.mean(precision) recall = np.mean(recall) @@ -1476,26 +1485,50 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, true_pos, _, false_pos, false_neg = _tp_tn_fp_fn(y_true, y_pred, labels) support = true_pos + false_neg - try: + with np.errstate(divide='ignore', invalid='ignore'): # oddly, we may get an "invalid" rather than a "divide" error here - old_err_settings = np.seterr(divide='ignore', invalid='ignore') # precision and recall precision = divide(true_pos.astype(np.float), true_pos + false_pos) recall = divide(true_pos.astype(np.float), true_pos + false_neg) + idx_ill_defined_precision = (true_pos + false_pos) == 0 + idx_ill_defined_recall = (true_pos + false_neg) == 0 + # handle division by 0 in precision and recall - precision[(true_pos + false_pos) == 0] = 0.0 - recall[(true_pos + false_neg) == 0] = 0.0 + precision[idx_ill_defined_precision] = 0.0 + recall[idx_ill_defined_recall] = 0.0 # fbeta score fscore = divide((1 + beta2) * precision * recall, beta2 * precision + recall) # handle division by 0 in fscore - fscore[(beta2 * precision + recall) == 0] = 0.0 - finally: - np.seterr(**old_err_settings) + idx_ill_defined_fbeta_score = (beta2 * precision + recall) == 0 + fscore[idx_ill_defined_fbeta_score] = 0.0 + + if average in (None, "macro", "weighted"): + warning_msg = "" + if np.any(idx_ill_defined_precision): + warning_msg += ("The sum of true positives and false positives " + "are equal to zero for some labels. Precision is " + "ill defined for those labels %s. " + % labels[idx_ill_defined_precision]) + + if np.any(idx_ill_defined_recall): + warning_msg += ("The sum of true positives and false negatives " + "are equal to zero for some labels. Recall is ill " + "defined for those labels %s. " + % labels[idx_ill_defined_recall]) + + if np.any(idx_ill_defined_fbeta_score): + warning_msg += ("The precision and recall are equal to zero for " + "some labels. fbeta_score is ill defined for " + "those labels %s. " + % labels[idx_ill_defined_fbeta_score]) + + if warning_msg: + warnings.warn(warning_msg, stacklevel=2) if not average: return precision, recall, fscore, support @@ -1513,24 +1546,40 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, else: average_options = (None, 'micro', 'macro', 'weighted', 'samples') if average == 'micro': - avg_precision = divide(true_pos.sum(), - true_pos.sum() + false_pos.sum(), - dtype=np.double) - avg_recall = divide(true_pos.sum(), - true_pos.sum() + false_neg.sum(), - dtype=np.double) - avg_fscore = divide((1 + beta2) * (avg_precision * avg_recall), - beta2 * avg_precision + avg_recall, - dtype=np.double) - - if np.isnan(avg_precision): + with np.errstate(divide='ignore', invalid='ignore'): + # oddly, we may get an "invalid" rather than a "divide" error + # here + + tp_sum = true_pos.sum() + fp_sum = false_pos.sum() + fn_sum = false_neg.sum() + avg_precision = divide(tp_sum, tp_sum + fp_sum, + dtype=np.double) + avg_recall = divide(tp_sum, tp_sum + fn_sum, dtype=np.double) + avg_fscore = divide((1 + beta2) * (avg_precision * avg_recall), + beta2 * avg_precision + avg_recall, + dtype=np.double) + + warning_msg = "" + if tp_sum + fp_sum == 0: avg_precision = 0. + warning_msg += ("The sum of true positives and false " + "positives are equal to zero. Micro-precision" + " is ill defined. ") - if np.isnan(avg_recall): + if tp_sum + fn_sum == 0: avg_recall = 0. + warning_msg += ("The sum of true positives and false " + "negatives are equal to zero. Micro-recall " + "is ill defined. ") - if np.isnan(avg_fscore): + if beta2 * avg_precision + avg_recall == 0: avg_fscore = 0. + warning_msg += ("Micro-precision and micro-recall are equal " + "to zero. Micro-fbeta_score is ill defined.") + + if warning_msg: + warnings.warn(warning_msg, stacklevel=2) elif average == 'macro': avg_precision = np.mean(precision) @@ -1542,6 +1591,11 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, avg_precision = 0. avg_recall = 0. avg_fscore = 0. + warnings.warn("There isn't any labels in y_true. " + "Weighted-precision, weighted-recall and " + "weighted-fbeta_score are ill defined.", + stacklevel=2) + else: avg_precision = np.average(precision, weights=support) avg_recall = np.average(recall, weights=support) @@ -1698,6 +1752,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'): >>> recall_score(y_true, y_pred, average=None) array([ 1., 0., 0.]) + """ _, r, _, _ = precision_recall_fscore_support(y_true, y_pred, labels=labels, diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 90cdd7f258da4..1e00feddb6512 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -21,6 +21,7 @@ assert_not_equal, assert_array_equal, assert_array_almost_equal, + assert_warns, assert_greater) @@ -55,6 +56,7 @@ from sklearn.externals.six.moves import xrange + REGRESSION_METRICS = { "mean_absolute_error": mean_absolute_error, "mean_squared_error": mean_squared_error, @@ -503,12 +505,17 @@ def test_precision_recall_f1_score_binary(): fs = f1_score(y_true, y_pred) assert_array_almost_equal(fs, 0.76, 2) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2), + (1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2) + def test_precision_recall_f_binary_single_class(): """Test precision, recall and F1 score behave with a single positive or negative class Such a case may occur with non-stratified cross-validation""" + warnings.simplefilter("ignore") + assert_equal(1., precision_score([1, 1], [1, 1])) assert_equal(1., recall_score([1, 1], [1, 1])) assert_equal(1., f1_score([1, 1], [1, 1])) @@ -1218,45 +1225,49 @@ def test_multilabel_representation_invariance(): y2_shuffle_binary_indicator = lb.transform(y2_shuffle) for name, metric in MULTILABELS_METRICS.items(): - measure = metric(y1, y2) - - # Check representation invariance - assert_almost_equal(metric(y1_binary_indicator, y2_binary_indicator), - measure, - err_msg="%s failed representation invariance " - "between list of list of labels format " - "and dense binary indicator format." - % name) - - # Check invariance with redundant labels with list of labels - assert_almost_equal(metric(y1, y2_redundant), measure, - err_msg="%s failed rendundant label invariance" - % name) - - assert_almost_equal(metric(y1_redundant, y2_redundant), measure, - err_msg="%s failed rendundant label invariance" - % name) - - assert_almost_equal(metric(y1_redundant, y2), measure, - err_msg="%s failed rendundant label invariance" - % name) - - # Check shuffling invariance with list of labels - assert_almost_equal(metric(y1_shuffle, y2_shuffle), measure, - err_msg="%s failed shuffling invariance " - "with list of list of labels format." - % name) - - # Check shuffling invariance with dense binary indicator matrix - assert_almost_equal(metric(y1_shuffle_binary_indicator, - y2_shuffle_binary_indicator), measure, - err_msg="%s failed shuffling invariance " - " with dense binary indicator format." - % name) - - # Check raises error with mix input representation - assert_raises(ValueError, metric, y1, y2_binary_indicator) - assert_raises(ValueError, metric, y1_binary_indicator, y2) + with warnings.catch_warnings(True): + warnings.simplefilter("always") + + measure = metric(y1, y2) + + # Check representation invariance + assert_almost_equal(metric(y1_binary_indicator, + y2_binary_indicator), + measure, + err_msg="%s failed representation invariance " + "between list of list of labels " + "format and dense binary indicator " + "format." % name) + + # Check invariance with redundant labels with list of labels + assert_almost_equal(metric(y1, y2_redundant), measure, + err_msg="%s failed rendundant label invariance" + % name) + + assert_almost_equal(metric(y1_redundant, y2_redundant), measure, + err_msg="%s failed rendundant label invariance" + % name) + + assert_almost_equal(metric(y1_redundant, y2), measure, + err_msg="%s failed rendundant label invariance" + % name) + + # Check shuffling invariance with list of labels + assert_almost_equal(metric(y1_shuffle, y2_shuffle), measure, + err_msg="%s failed shuffling invariance " + "with list of list of labels format." + % name) + + # Check shuffling invariance with dense binary indicator matrix + assert_almost_equal(metric(y1_shuffle_binary_indicator, + y2_shuffle_binary_indicator), measure, + err_msg="%s failed shuffling invariance " + " with dense binary indicator format." + % name) + + # Check raises error with mix input representation + assert_raises(ValueError, metric, y1, y2_binary_indicator) + assert_raises(ValueError, metric, y1_binary_indicator, y2) def test_multilabel_zero_one_loss_subset(): @@ -1454,19 +1465,26 @@ def test_precision_recall_f1_score_multilabel_1(): y_true_bi = lb.transform(y_true_ll) y_pred_bi = lb.transform(y_pred_ll) + warnings.simplefilter("ignore") + for y_true, y_pred in [(y_true_ll, y_pred_ll), (y_true_bi, y_pred_bi)]: + p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None) #tp = [0, 1, 1, 0] #fn = [1, 0, 0, 1] #fp = [1, 1, 0, 0] - # Check per class + assert_array_almost_equal(p, [0.0, 0.5, 1.0, 0.0], 2) assert_array_almost_equal(r, [0.0, 1.0, 1.0, 0.0], 2) assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2) assert_array_almost_equal(s, [1, 1, 1, 1], 2) + f2 = fbeta_score(y_true, y_pred, beta=2, average=None) + support = s + assert_array_almost_equal(f2, [0, 0.83, 1, 0], 2) + # Check macro p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="macro") @@ -1474,6 +1492,9 @@ def test_precision_recall_f1_score_multilabel_1(): assert_almost_equal(r, 0.5) assert_almost_equal(f, 2.5 / 1.5 * 0.25) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="macro"), + np.mean(f2)) # Check micro p, r, f, s = precision_recall_fscore_support(y_true, y_pred, @@ -1482,6 +1503,9 @@ def test_precision_recall_f1_score_multilabel_1(): assert_almost_equal(r, 0.5) assert_almost_equal(f, 0.5) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="micro"), + (1 + 4) * p * r / (4 * p + r)) # Check weigted p, r, f, s = precision_recall_fscore_support(y_true, y_pred, @@ -1490,7 +1514,9 @@ def test_precision_recall_f1_score_multilabel_1(): assert_almost_equal(r, 0.5) assert_almost_equal(f, 2.5 / 1.5 * 0.25) assert_equal(s, None) - + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="weighted"), + np.average(f2, weights=support)) # Check weigted # |h(x_i) inter y_i | = [0, 1, 1] # |y_i| = [1, 1, 2] @@ -1501,6 +1527,9 @@ def test_precision_recall_f1_score_multilabel_1(): assert_almost_equal(r, 0.5) assert_almost_equal(f, 0.5) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="samples"), + 0.5) def test_precision_recall_f1_score_multilabel_2(): @@ -1514,6 +1543,8 @@ def test_precision_recall_f1_score_multilabel_2(): y_true_bi = lb.transform(y_true_ll) y_pred_bi = lb.transform(y_pred_ll) + warnings.simplefilter("ignore") + for y_true, y_pred in [(y_true_ll, y_pred_ll), (y_true_bi, y_pred_bi)]: # tp = [ 0. 1. 0. 0.] # fp = [ 1. 0. 0. 2.] @@ -1526,12 +1557,19 @@ def test_precision_recall_f1_score_multilabel_2(): assert_array_almost_equal(f, [0.0, 0.66, 0.0, 0.0], 2) assert_array_almost_equal(s, [1, 2, 1, 0], 2) + f2 = fbeta_score(y_true, y_pred, beta=2, average=None) + support = s + assert_array_almost_equal(f2, [0, 0.55, 0, 0], 2) + p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="micro") assert_almost_equal(p, 0.25) assert_almost_equal(r, 0.25) assert_almost_equal(f, 2 * 0.25 * 0.25 / 0.5) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="micro"), + (1 + 4) * p * r / (4 * p + r)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="macro") @@ -1539,6 +1577,9 @@ def test_precision_recall_f1_score_multilabel_2(): assert_almost_equal(r, 0.125) assert_almost_equal(f, 2 / 12) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="macro"), + np.mean(f2)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="weighted") @@ -1546,6 +1587,9 @@ def test_precision_recall_f1_score_multilabel_2(): assert_almost_equal(r, 1 / 4) assert_almost_equal(f, 2 / 3 * 2 / 4) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="weighted"), + np.average(f2, weights=support)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="samples") @@ -1553,10 +1597,14 @@ def test_precision_recall_f1_score_multilabel_2(): # |h(x_i) inter y_i | = [0, 0, 1] # |y_i| = [1, 1, 2] # |h(x_i)| = [1, 1, 2] + assert_almost_equal(p, 1 / 6) assert_almost_equal(r, 1 / 6) assert_almost_equal(f, 2 / 4 * 1 / 3) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="samples"), + 0.1666, 2) def test_precision_recall_f1_score_with_an_empty_prediction(): @@ -1568,11 +1616,12 @@ def test_precision_recall_f1_score_with_an_empty_prediction(): y_true_bi = lb.transform(y_true_ll) y_pred_bi = lb.transform(y_pred_ll) + warnings.simplefilter("ignore") + for y_true, y_pred in [(y_true_ll, y_pred_ll), (y_true_bi, y_pred_bi)]: # true_pos = [ 0. 1. 1. 0.] # false_pos = [ 0. 0. 0. 1.] # false_neg = [ 1. 1. 0. 0.] - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None) assert_array_almost_equal(p, [0.0, 1.0, 1.0, 0.0], 2) @@ -1580,12 +1629,19 @@ def test_precision_recall_f1_score_with_an_empty_prediction(): assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2) assert_array_almost_equal(s, [1, 2, 1, 0], 2) + f2 = fbeta_score(y_true, y_pred, beta=2, average=None) + support = s + assert_array_almost_equal(f2, [0, 0.55, 1, 0], 2) + p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="macro") assert_almost_equal(p, 0.5) assert_almost_equal(r, 1.5 / 4) assert_almost_equal(f, 2.5 / (4 * 1.5)) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="macro"), + np.mean(f2)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="micro") @@ -1593,6 +1649,9 @@ def test_precision_recall_f1_score_with_an_empty_prediction(): assert_almost_equal(r, 0.5) assert_almost_equal(f, 2 / 3 / (2 / 3 + 0.5)) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="micro"), + (1 + 4) * p * r / (4 * p + r)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="weighted") @@ -1600,6 +1659,9 @@ def test_precision_recall_f1_score_with_an_empty_prediction(): assert_almost_equal(r, 0.5) assert_almost_equal(f, (2 / 1.5 + 1) / 4) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="weighted"), + np.average(f2, weights=support)) p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="samples") @@ -1607,62 +1669,54 @@ def test_precision_recall_f1_score_with_an_empty_prediction(): # |y_i| = [1, 1, 2] # |h(x_i)| = [0, 1, 2] assert_almost_equal(p, 1 / 3) - assert_almost_equal(r, 2 / 3) + assert_almost_equal(r, 1 / 3) assert_almost_equal(f, 1 / 3) assert_equal(s, None) + assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, + average="samples"), + 0.333, 2) def test_precision_recall_f1_no_labels(): y_true = np.zeros((20, 3)) y_pred = np.zeros_like(y_true) - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average=None) - #tp = [0, 0, 0] - #fn = [0, 0, 0] - #fp = [0, 0, 0] - #support = [0, 0, 0] - - # Check per class - assert_array_almost_equal(p, [0, 0, 0], 2) - assert_array_almost_equal(r, [0, 0, 0], 2) - assert_array_almost_equal(f, [0, 0, 0], 2) - assert_array_almost_equal(s, [0, 0, 0], 2) - - # Check macro - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average="macro") - assert_almost_equal(p, 0) - assert_almost_equal(r, 0) - assert_almost_equal(f, 0) - assert_equal(s, None) - - # Check micro - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average="micro") - assert_almost_equal(p, 0) - assert_almost_equal(r, 0) - assert_almost_equal(f, 0) - assert_equal(s, None) - - # Check weighted - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average="weighted") - assert_almost_equal(p, 0) - assert_almost_equal(r, 0) - assert_almost_equal(f, 0) - assert_equal(s, None) - - # # Check example - # |h(x_i) inter y_i | = [0, 0, 0] + # tp = [0, 0, 0] + # fn = [0, 0, 0] + # fp = [0, 0, 0] + # support = [0, 0, 0] + # |y_hat_i inter y_i | = [0, 0, 0] # |y_i| = [0, 0, 0] - # |h(x_i)| = [1, 1, 2] - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, - average="samples") - assert_almost_equal(p, 1) - assert_almost_equal(r, 1) - assert_almost_equal(f, 1) - assert_equal(s, None) + # |y_hat_i| = [0, 0, 0] + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + + for beta in [1]: + p, r, f, s = assert_warns(UserWarning, + precision_recall_fscore_support, + y_true, y_pred, average=None, beta=beta) + assert_array_almost_equal(p, [0, 0, 0], 2) + assert_array_almost_equal(r, [0, 0, 0], 2) + assert_array_almost_equal(f, [0, 0, 0], 2) + assert_array_almost_equal(s, [0, 0, 0], 2) + + fbeta = assert_warns(UserWarning, fbeta_score, y_true, y_pred, + beta=beta, average=None) + assert_array_almost_equal(fbeta, [0, 0, 0], 2) + + for average in ["macro", "micro", "weighted", "samples"]: + p, r, f, s = assert_warns(UserWarning, + precision_recall_fscore_support, + y_true, y_pred, average=average, + beta=beta) + assert_almost_equal(p, 0) + assert_almost_equal(r, 0) + assert_almost_equal(f, 0) + assert_equal(s, None) + + fbeta = assert_warns(UserWarning, fbeta_score, y_true, y_pred, + beta=beta, average=average) + assert_almost_equal(fbeta, 0) def test__check_clf_targets(): diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 37b81c3d03d71..b4bdb0f7d9076 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -10,6 +10,7 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_false +from sklearn.utils.testing import assert_warns from sklearn.utils.sparsefuncs import mean_variance_axis0 from sklearn.preprocessing.data import _transform_selected @@ -306,13 +307,13 @@ def test_warning_scaling_integers(): X = np.array([[1, 2, 0], [0, 0, 0]], dtype=np.uint8) - with warnings.catch_warnings(record=True) as w: - StandardScaler().fit(X) - assert_equal(len(w), 1) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + assert_warns(UserWarning, StandardScaler().fit, X) - with warnings.catch_warnings(record=True) as w: - MinMaxScaler().fit(X) - assert_equal(len(w), 1) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + assert_warns(UserWarning, MinMaxScaler().fit, X) def test_normalizer_l1(): diff --git a/sklearn/svm/tests/test_sparse.py b/sklearn/svm/tests/test_sparse.py index 82ac3308c730c..72fd0605d02bf 100644 --- a/sklearn/svm/tests/test_sparse.py +++ b/sklearn/svm/tests/test_sparse.py @@ -3,7 +3,7 @@ from scipy import sparse from sklearn import datasets, svm, linear_model, base from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal) + assert_equal, assert_warns) from nose.tools import assert_raises, assert_true, assert_false from nose.tools import assert_equal as nose_assert_equal @@ -275,13 +275,7 @@ def test_sparse_svc_clone_with_callable_kernel(): def test_timeout(): sp = svm.SVC(C=1, kernel=lambda x, y: x * y.T, probability=True, max_iter=1) - with warnings.catch_warnings(record=True) as foo: - sp.fit(X_sp, Y) - nose_assert_equal(len(foo), 1, msg=foo) - nose_assert_equal(foo[0].category, ConvergenceWarning, - msg=foo[0].category) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") - -if __name__ == '__main__': - import nose - nose.runmodule() + assert_warns(ConvergenceWarning, sp.fit, X_sp, Y) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 43edb39e52460..eb2f742af14bf 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -9,6 +9,7 @@ # License: BSD 3 clause import inspect import pkgutil +import warnings import scipy as sp from functools import wraps @@ -76,6 +77,41 @@ def _assert_greater(a, b, msg=None): assert a > b, message +# To remove when we support numpy 1.7 +def assert_warns(warning_class, func, *args, **kw): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + # Trigger a warning. + result = func(*args, **kw) + + # Verify some things + if not len(w) > 0: + raise AssertionError("No warning raised when calling %s" + % func.__name__) + + if not w[0].category is warning_class: + raise AssertionError("First warning for %s is not a " + "%s( is %s)" + % (func.__name__, warning_class, w[0])) + + return result + + +# To remove when we support numpy 1.7 +def assert_no_warnings(func, *args, **kw): + # XXX: once we may depend on python >= 2.6, this can be replaced by the + # warnings module context manager. + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + result = func(*args, **kw) + if len(w) > 0: + raise AssertionError("Got warnings when calling %s: %s" + % (func.__name__, w)) + return result + try: from nose.tools import assert_less except ImportError: diff --git a/sklearn/utils/tests/test_testing.py b/sklearn/utils/tests/test_testing.py index 40e33b1a04158..47ff9aa66532a 100644 --- a/sklearn/utils/tests/test_testing.py +++ b/sklearn/utils/tests/test_testing.py @@ -1,8 +1,14 @@ +import warnings +import unittest +import sys + from nose.tools import assert_raises from sklearn.utils.testing import ( _assert_less, _assert_greater, + assert_warns, + assert_no_warnings, assert_equal, set_random_state, assert_raise_message) @@ -62,3 +68,42 @@ def _raise_ValueError(message): assert_raises(ValueError, assert_raise_message, TypeError, "something else", _raise_ValueError, "test") + + + +# This class is taken from numpy 1.7 +class TestWarns(unittest.TestCase): + def test_warn(self): + def f(): + warnings.warn("yo") + return 3 + + before_filters = sys.modules['warnings'].filters[:] + assert_equal(assert_warns(UserWarning, f), 3) + after_filters = sys.modules['warnings'].filters + + assert_raises(AssertionError, assert_no_warnings, f) + assert_equal(assert_no_warnings(lambda x: x, 1), 1) + + # Check that the warnings state is unchanged + assert_equal(before_filters, after_filters, + "assert_warns does not preserver warnings state") + + def test_warn_wrong_warning(self): + def f(): + warnings.warn("yo", DeprecationWarning) + + failed = False + filters = sys.modules['warnings'].filters[:] + try: + try: + # Should raise an AssertionError + assert_warns(UserWarning, f) + failed = True + except AssertionError: + pass + finally: + sys.modules['warnings'].filters = filters + + if failed: + raise AssertionError("wrong warning caught by assert_warn")