Skip to content

[MRG + 1] Remove np.isclose() from ROC curve calculation #7353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions doc/developers/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ Backports
Used in ``sklearn.cluster.hierarchical``, as well as in tests for
:mod:`sklearn.feature_extraction`.

- :func:`fixes.isclose`
(backported from ``numpy.isclose`` in numpy 1.8.1).
In versions before 1.7, this function was not available in
numpy. Used in ``sklearn.metrics``.


ARPACK
------
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ Bug fixes
(`#6472 <https://github.com/scikit-learn/scikit-learn/pull/6472>`).
By `Andreas Müller`_.

- :func:`metrics.roc_curve` and :func:`metrics.precision_recall_curve` no
longer round ``y_score`` values when creating ROC curves; this was causing
problems for users with very small differences in scores (`#7353
<https://github.com/scikit-learn/scikit-learn/pull/7353>`_).

API changes summary
-------------------

Expand Down
6 changes: 1 addition & 5 deletions sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from ..utils import column_or_1d, check_array
from ..utils.multiclass import type_of_target
from ..utils.extmath import stable_cumsum
from ..utils.fixes import isclose
from ..utils.fixes import bincount
from ..utils.fixes import array_equal
from ..utils.stats import rankdata
Expand Down Expand Up @@ -331,10 +330,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
# y_score typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
# We need to use isclose to avoid spurious repeated thresholds
# stemming from floating point roundoff errors.
distinct_value_indices = np.where(np.logical_not(isclose(
np.diff(y_score), 0)))[0]
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]

# accumulate the true positives with decreasing threshold
Expand Down
39 changes: 11 additions & 28 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from sklearn import datasets
from sklearn import svm
from sklearn import ensemble

from sklearn.datasets import make_multilabel_classification
from sklearn.random_projection import sparse_random_matrix
Expand Down Expand Up @@ -170,29 +169,6 @@ def test_roc_returns_consistency():
assert_equal(fpr.shape, thresholds.shape)


def test_roc_nonrepeating_thresholds():
# Test to ensure that we don't return spurious repeating thresholds.
# Duplicated thresholds can arise due to machine precision issues.
dataset = datasets.load_digits()
X = dataset['data']
y = dataset['target']

# This random forest classifier can only return probabilities
# significant to two decimal places
clf = ensemble.RandomForestClassifier(n_estimators=100, random_state=0)

# How well can the classifier predict whether a digit is less than 5?
# This task contributes floating point roundoff errors to the probabilities
train, test = slice(None, None, 2), slice(1, None, 2)
probas_pred = clf.fit(X[train], y[train]).predict_proba(X[test])
y_score = probas_pred[:, :5].sum(axis=1) # roundoff errors begin here
y_true = [yy < 5 for yy in y[test]]

# Check for repeating values in the thresholds
fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=False)
assert_equal(thresholds.size, np.unique(np.round(thresholds, 2)).size)


def test_roc_curve_multi():
# roc_curve not applicable for multi-class problems
y_true, _, probas_pred = make_prediction(binary=False)
Expand Down Expand Up @@ -621,18 +597,25 @@ def test_precision_recall_curve_toydata():
def test_score_scale_invariance():
# Test that average_precision_score and roc_auc_score are invariant by
# the scaling or shifting of probabilities
# This test was expanded (added scaled_down) in response to github
# issue #3864 (and others), where overly aggressive rounding was causing
# problems for users with very small y_score values
y_true, _, probas_pred = make_prediction(binary=True)

roc_auc = roc_auc_score(y_true, probas_pred)
roc_auc_scaled = roc_auc_score(y_true, 100 * probas_pred)
roc_auc_scaled_up = roc_auc_score(y_true, 100 * probas_pred)
roc_auc_scaled_down = roc_auc_score(y_true, 1e-6 * probas_pred)
roc_auc_shifted = roc_auc_score(y_true, probas_pred - 10)
assert_equal(roc_auc, roc_auc_scaled)
assert_equal(roc_auc, roc_auc_scaled_up)
assert_equal(roc_auc, roc_auc_scaled_down)
assert_equal(roc_auc, roc_auc_shifted)

pr_auc = average_precision_score(y_true, probas_pred)
pr_auc_scaled = average_precision_score(y_true, 100 * probas_pred)
pr_auc_scaled_up = average_precision_score(y_true, 100 * probas_pred)
pr_auc_scaled_down = average_precision_score(y_true, 1e-6 * probas_pred)
pr_auc_shifted = average_precision_score(y_true, probas_pred - 10)
assert_equal(pr_auc, pr_auc_scaled)
assert_equal(pr_auc, pr_auc_scaled_up)
assert_equal(pr_auc, pr_auc_scaled_down)
assert_equal(pr_auc, pr_auc_shifted)


Expand Down
43 changes: 0 additions & 43 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,49 +227,6 @@ def combinations_with_replacement(iterable, r):
yield tuple(pool[i] for i in indices)


try:
from numpy import isclose
except ImportError:
def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
"""
Returns a boolean array where two arrays are element-wise equal within
a tolerance.

This function was added to numpy v1.7.0, and the version you are
running has been backported from numpy v1.8.1. See its documentation
for more details.
"""
def within_tol(x, y, atol, rtol):
with np.errstate(invalid='ignore'):
result = np.less_equal(abs(x - y), atol + rtol * abs(y))
if np.isscalar(a) and np.isscalar(b):
result = bool(result)
return result

x = np.array(a, copy=False, subok=True, ndmin=1)
y = np.array(b, copy=False, subok=True, ndmin=1)
xfin = np.isfinite(x)
yfin = np.isfinite(y)
if all(xfin) and all(yfin):
return within_tol(x, y, atol, rtol)
else:
finite = xfin & yfin
cond = np.zeros_like(finite, subok=True)
# Since we're using boolean indexing, x & y must be the same shape.
# Ideally, we'd just do x, y = broadcast_arrays(x, y). It's in
# lib.stride_tricks, though, so we can't import it here.
x = x * np.ones_like(cond)
y = y * np.ones_like(cond)
# Avoid subtraction with infinite/nan values...
cond[finite] = within_tol(x[finite], y[finite], atol, rtol)
# Check for equality of infinite values...
cond[~finite] = (x[~finite] == y[~finite])
if equal_nan:
# Make NaN == NaN
cond[np.isnan(x) & np.isnan(y)] = True
return cond


if np_version < (1, 7):
# Prior to 1.7.0, np.frombuffer wouldn't work for empty first arg.
def frombuffer_empty(buf, dtype):
Expand Down