Skip to content
28 changes: 28 additions & 0 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,34 @@ if the predicted outputs have been binarized.
for an example of using ROC to
model species distribution.

.. _det_curve:

Detection error tradeoff (DET)
---------------------------------------

The function :func:`detection_error_tradeoff` computes the
`detection error tradeoff curve, or DET curve <https://en.wikipedia.org/wiki/Detection_error_tradeoff>`_.
Quoting Wikipedia :

"A detection error tradeoff (DET) graph is a graphical plot of error rates for binary classification systems, plotting false reject rate vs. false accept rate. The x- and y-axes are scaled non-linearly by their standard normal deviates (or just by logarithmic transformation), yielding tradeoff curves that are more linear than ROC curves, and use most of the image area to highlight the differences of importance in the critical operating region."

This function requires the true binary
value and the target scores, which can either be probability estimates of the
positive class, confidence values, or binary decisions.
Here is a small example of how to use the :func:`detection_error_tradeoff` function::

>>> import numpy as np
>>> from sklearn.metrics import det_error_tradeoff
>>> y = np.array([1, 1, 2, 2])
>>> scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = error_detection_tradeoff(y, scores, pos_label=2)
>>> fpr
array([ 0.5, 0.5, 0. ])
>>> fnr
array([ 0. , 0.5, 0.5])
>>> thresholds
array([ 0.35, 0.4 , 0.8 ])

.. _zero_one_loss:

Zero one loss
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .ranking import auc
from .ranking import average_precision_score
from .ranking import coverage_error
from .ranking import detection_error_tradeoff
from .ranking import label_ranking_average_precision_score
from .ranking import label_ranking_loss
from .ranking import precision_recall_curve
Expand Down Expand Up @@ -74,6 +75,7 @@
'confusion_matrix',
'consensus_score',
'coverage_error',
'detection_error_tradeoff',
'euclidean_distances',
'explained_variance_score',
'f1_score',
Expand Down
79 changes: 79 additions & 0 deletions sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Lars Buitinck
# Joel Nothman <joel.nothman@gmail.com>
# Noel Dawe <noel@dawe.me>
# Jeremy Karnowski <jeremy.karnowski@gmail.com>
# License: BSD 3 clause

from __future__ import division
Expand Down Expand Up @@ -184,6 +185,84 @@ def _binary_average_precision(y_true, y_score, sample_weight=None):
average, sample_weight=sample_weight)


def detection_error_tradeoff(y_true, y_score, pos_label=None,
sample_weight=None):
"""Compute error rates for different probability thresholds

Note: this implementation is restricted to the binary classification task.

Parameters
----------
y_true : array, shape = [n_samples]
True targets of binary classification in range {-1, 1} or {0, 1}.

y_score : array, shape = [n_samples]
Estimated probabilities or decision function.

pos_label : int, optional (default=None)
The label of the positive class

sample_weight : array-like of shape = [n_samples], optional
Sample weights.

Returns
-------
fps : array, shape = [n_thresholds]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken the DET curve should return rates not counts. This looks like a copy/paste legacy which we should update accordingly.

A count of false positives, at index i being the number of negative
samples assigned a score >= thresholds[i]. The total number of
negative samples is equal to fps[-1] (thus true negatives are given by
fps[-1] - fps).

fns : array, shape = [n_thresholds]
A count of false negatives, at index i being the number of positive
samples assigned a score < thresholds[i]. The total number of
positive samples is equal to tps[-1] (thus false negatives are given by
tps[-1] - tps).

thresholds : array, shape = [n_thresholds]
Decreasing score values.

References
----------
.. [1] `Wikipedia entry for Detection error tradeoff
<https://en.wikipedia.org/wiki/Detection_error_tradeoff>`_
.. [2] `The DET Curve in Assessment of Detection Task Performance
<http://www.itl.nist.gov/iad/mig/publications/storage_paper/det.pdf>`_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link doesn't work. It appears to be a problem on https://www.nist.gov/ site though.

.. [3] `2008 NIST Speaker Recognition Evaluation Results
<http://www.itl.nist.gov/iad/mig/tests/sre/2008/official_results/>`_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

.. [4] `DET-Curve Plotting software for use with MATLAB
<http://www.itl.nist.gov/iad/mig/tools/DETware_v2.1.targz.htm>`_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above


Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import detection_error_tradeoff
>>> y_true = np.array([0, 0, 1, 1])
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> fps, fns, thresholds = detection_error_tradeoff(y_true, y_scores)
>>> fps
array([ 0.5, 0.5, 0. ])
>>> fns
array([ 0. , 0.5, 0.5])
>>> thresholds
array([ 0.35, 0.4 , 0.8 ])

"""
fps, tps, thresholds = _binary_clf_curve(y_true, y_score,
pos_label=pos_label,
sample_weight=sample_weight)
fns = tps[-1] - tps
tp_count = tps[-1]
tn_count = (fps[-1] - fps)[0]

# start with false positives is zero and stop with false negatives zero
# and reverse the outputs so list of false positives is decreasing
last_ind = tps.searchsorted(tps[-1]) + 1
first_ind = fps[::-1].searchsorted(fps[0])
sl = range(first_ind, last_ind)[::-1]
return fps[sl] / tp_count, fns[sl] / tn_count, thresholds[sl]


def roc_auc_score(y_true, y_score, average="macro", sample_weight=None):
"""Compute Area Under the Curve (AUC) from prediction scores

Expand Down