-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Array API support for f1_score and multilabel_confusion_matrix #27369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
9bcbc4c
ENH Array API support for f1_score
OmarManzoor 87b65b9
Merge branch 'main' into f1_array_api
OmarManzoor 17457b9
Merge branch 'main' into f1_array_api
OmarManzoor c80617f
Add array api support for f1_score
OmarManzoor 649ce17
Add changelog
OmarManzoor 8f0db56
Merge branch 'main' into f1_array_api
OmarManzoor 6a02fcd
Fix sample weights in _bincount
OmarManzoor c150c9c
Add some fixes
OmarManzoor a01f2d7
Correct and add tests for nanmean
OmarManzoor aa2f521
Add options for testing with various average values
OmarManzoor 75c7d5a
Use reshape when creating arrays in micro average
OmarManzoor 8b21b51
Add LabelEncoder and f1_score in array_api.rst
OmarManzoor bc8c2df
Merge branch 'main' into f1_array_api
OmarManzoor ef33cf6
Merge branch 'main' into f1_array_api
ogrisel 91ab0d5
Merge branch 'main' into f1_array_api
OmarManzoor 696e65b
Update: PR suggestions
OmarManzoor d0b647b
Use xp.reshape with (1,)
OmarManzoor 842e269
Simplify count in _nanmean
OmarManzoor 6e9596e
Merge branch 'main' into f1_array_api
OmarManzoor 5cd9a11
Merge branch 'main' into f1_array_api
OmarManzoor 2c3cc32
Merge branch 'main' into f1_array_api
OmarManzoor 5c73766
Merge branch 'main' into f1_array_api
OmarManzoor 0df1e0f
Add multilabel confusion metrics as it seems to work
OmarManzoor 3d42289
Merge branch 'main' into f1_array_api
OmarManzoor 78e4f31
Handle multi-label case
OmarManzoor 74ccf6a
Fix commented tests
OmarManzoor 6428b8e
Merge branch 'main' into f1_array_api
ogrisel 03a2432
Merge branch 'main' into f1_array_api
OmarManzoor 3b3555b
Fix errors because of update in numpy
OmarManzoor fd79874
Merge branch 'main' into f1_array_api
OmarManzoor ecc0fb1
Minor updates
OmarManzoor c57398f
Update the conversion to float
OmarManzoor e6970cb
Fix the doctests
OmarManzoor 1461ad7
Fix f1 score doctests
OmarManzoor def2e42
Merge branch 'main' into f1_array_api
OmarManzoor c828e4d
Merge branch 'main' into f1_array_api
OmarManzoor a21a875
Add new changelog
OmarManzoor 0c80235
Add docstring
OmarManzoor 7b8fb15
Add xp after array param in _bincount
OmarManzoor 82cf2bc
Merge branch 'main' into f1_array_api
OmarManzoor 24d91a4
Refactor based on PR suggestions
OmarManzoor c71c974
Handle the _tolist method and add doc for device param
OmarManzoor 48b24ce
Use _convert_to_numpy is _tolist
OmarManzoor 6679ae3
Further PR suggestions
OmarManzoor fad41ad
Add sparse check specifically when multiplying
OmarManzoor 8c6c4a0
Merge branch 'main' into f1_array_api
OmarManzoor 0912004
Merge branch 'main' into f1_array_api
OmarManzoor 88a48e0
Update 'and' to 'or' condition
OmarManzoor File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
- :func:`sklearn.metrics.f1_score` now supports Array API compatible | ||
inputs. | ||
By :user:`Omar Salman <OmarManzoor>` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
from numbers import Integral, Real | ||
|
||
import numpy as np | ||
from scipy.sparse import coo_matrix, csr_matrix | ||
from scipy.sparse import coo_matrix, csr_matrix, issparse | ||
from scipy.special import xlogy | ||
|
||
from ..exceptions import UndefinedMetricWarning | ||
|
@@ -28,9 +28,15 @@ | |
) | ||
from ..utils._array_api import ( | ||
_average, | ||
_bincount, | ||
_count_nonzero, | ||
_find_matching_floating_dtype, | ||
_is_numpy_namespace, | ||
_searchsorted, | ||
_setdiff1d, | ||
_tolist, | ||
_union1d, | ||
device, | ||
get_namespace, | ||
get_namespace_and_device, | ||
) | ||
|
@@ -521,9 +527,11 @@ def multilabel_confusion_matrix( | |
[1, 2]]]) | ||
""" | ||
y_true, y_pred = attach_unique(y_true, y_pred) | ||
xp, _ = get_namespace(y_true, y_pred) | ||
device_ = device(y_true, y_pred) | ||
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||
if sample_weight is not None: | ||
sample_weight = column_or_1d(sample_weight) | ||
sample_weight = column_or_1d(sample_weight, device=device_) | ||
check_consistent_length(y_true, y_pred, sample_weight) | ||
|
||
if y_type not in ("binary", "multiclass", "multilabel-indicator"): | ||
|
@@ -534,9 +542,11 @@ def multilabel_confusion_matrix( | |
labels = present_labels | ||
n_labels = None | ||
else: | ||
n_labels = len(labels) | ||
labels = np.hstack( | ||
[labels, np.setdiff1d(present_labels, labels, assume_unique=True)] | ||
labels = xp.asarray(labels, device=device_) | ||
n_labels = labels.shape[0] | ||
labels = xp.concat( | ||
[labels, _setdiff1d(present_labels, labels, assume_unique=True, xp=xp)], | ||
axis=-1, | ||
) | ||
|
||
if y_true.ndim == 1: | ||
|
@@ -556,77 +566,102 @@ def multilabel_confusion_matrix( | |
tp = y_true == y_pred | ||
tp_bins = y_true[tp] | ||
if sample_weight is not None: | ||
tp_bins_weights = np.asarray(sample_weight)[tp] | ||
tp_bins_weights = sample_weight[tp] | ||
else: | ||
tp_bins_weights = None | ||
|
||
if len(tp_bins): | ||
tp_sum = np.bincount( | ||
tp_bins, weights=tp_bins_weights, minlength=len(labels) | ||
if tp_bins.shape[0]: | ||
tp_sum = _bincount( | ||
tp_bins, weights=tp_bins_weights, minlength=labels.shape[0], xp=xp | ||
) | ||
else: | ||
# Pathological case | ||
true_sum = pred_sum = tp_sum = np.zeros(len(labels)) | ||
if len(y_pred): | ||
pred_sum = np.bincount(y_pred, weights=sample_weight, minlength=len(labels)) | ||
if len(y_true): | ||
true_sum = np.bincount(y_true, weights=sample_weight, minlength=len(labels)) | ||
true_sum = pred_sum = tp_sum = xp.zeros(labels.shape[0]) | ||
if y_pred.shape[0]: | ||
pred_sum = _bincount( | ||
y_pred, weights=sample_weight, minlength=labels.shape[0], xp=xp | ||
) | ||
if y_true.shape[0]: | ||
true_sum = _bincount( | ||
y_true, weights=sample_weight, minlength=labels.shape[0], xp=xp | ||
) | ||
|
||
# Retain only selected labels | ||
indices = np.searchsorted(sorted_labels, labels[:n_labels]) | ||
tp_sum = tp_sum[indices] | ||
true_sum = true_sum[indices] | ||
pred_sum = pred_sum[indices] | ||
indices = _searchsorted(sorted_labels, labels[:n_labels], xp=xp) | ||
tp_sum = xp.take(tp_sum, indices, axis=0) | ||
true_sum = xp.take(true_sum, indices, axis=0) | ||
pred_sum = xp.take(pred_sum, indices, axis=0) | ||
|
||
else: | ||
OmarManzoor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sum_axis = 1 if samplewise else 0 | ||
|
||
# All labels are index integers for multilabel. | ||
# Select labels: | ||
if not np.array_equal(labels, present_labels): | ||
if np.max(labels) > np.max(present_labels): | ||
if labels.shape != present_labels.shape or xp.any( | ||
xp.not_equal(labels, present_labels) | ||
): | ||
OmarManzoor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if xp.max(labels) > xp.max(present_labels): | ||
OmarManzoor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
"All labels must be in [0, n labels) for " | ||
"multilabel targets. " | ||
"Got %d > %d" % (np.max(labels), np.max(present_labels)) | ||
"Got %d > %d" % (xp.max(labels), xp.max(present_labels)) | ||
) | ||
if np.min(labels) < 0: | ||
if xp.min(labels) < 0: | ||
raise ValueError( | ||
"All labels must be in [0, n labels) for " | ||
"multilabel targets. " | ||
"Got %d < 0" % np.min(labels) | ||
"Got %d < 0" % xp.min(labels) | ||
) | ||
|
||
if n_labels is not None: | ||
y_true = y_true[:, labels[:n_labels]] | ||
y_pred = y_pred[:, labels[:n_labels]] | ||
|
||
if issparse(y_true) or issparse(y_pred): | ||
true_and_pred = y_true.multiply(y_pred) | ||
else: | ||
true_and_pred = xp.multiply(y_true, y_pred) | ||
|
||
# calculate weighted counts | ||
true_and_pred = y_true.multiply(y_pred) | ||
tp_sum = count_nonzero( | ||
true_and_pred, axis=sum_axis, sample_weight=sample_weight | ||
tp_sum = _count_nonzero( | ||
true_and_pred, | ||
axis=sum_axis, | ||
sample_weight=sample_weight, | ||
xp=xp, | ||
device=device_, | ||
) | ||
pred_sum = _count_nonzero( | ||
y_pred, | ||
axis=sum_axis, | ||
sample_weight=sample_weight, | ||
xp=xp, | ||
device=device_, | ||
) | ||
true_sum = _count_nonzero( | ||
y_true, | ||
axis=sum_axis, | ||
sample_weight=sample_weight, | ||
xp=xp, | ||
device=device_, | ||
) | ||
pred_sum = count_nonzero(y_pred, axis=sum_axis, sample_weight=sample_weight) | ||
true_sum = count_nonzero(y_true, axis=sum_axis, sample_weight=sample_weight) | ||
|
||
fp = pred_sum - tp_sum | ||
fn = true_sum - tp_sum | ||
tp = tp_sum | ||
|
||
if sample_weight is not None and samplewise: | ||
sample_weight = np.array(sample_weight) | ||
tp = np.array(tp) | ||
fp = np.array(fp) | ||
fn = np.array(fn) | ||
tp = xp.asarray(tp) | ||
fp = xp.asarray(fp) | ||
fn = xp.asarray(fn) | ||
tn = sample_weight * y_true.shape[1] - tp - fp - fn | ||
elif sample_weight is not None: | ||
tn = sum(sample_weight) - tp - fp - fn | ||
tn = xp.sum(sample_weight) - tp - fp - fn | ||
elif samplewise: | ||
tn = y_true.shape[1] - tp - fp - fn | ||
else: | ||
tn = y_true.shape[0] - tp - fp - fn | ||
|
||
return np.array([tn, fp, fn, tp]).T.reshape(-1, 2, 2) | ||
return xp.reshape(xp.stack([tn, fp, fn, tp]).T, (-1, 2, 2)) | ||
|
||
|
||
@validate_params( | ||
|
@@ -1262,21 +1297,21 @@ def f1_score( | |
>>> y_true = [0, 1, 2, 0, 1, 2] | ||
>>> y_pred = [0, 2, 1, 0, 0, 1] | ||
>>> f1_score(y_true, y_pred, average='macro') | ||
np.float64(0.26...) | ||
0.26... | ||
OmarManzoor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> f1_score(y_true, y_pred, average='micro') | ||
np.float64(0.33...) | ||
0.33... | ||
>>> f1_score(y_true, y_pred, average='weighted') | ||
np.float64(0.26...) | ||
0.26... | ||
>>> f1_score(y_true, y_pred, average=None) | ||
array([0.8, 0. , 0. ]) | ||
|
||
>>> # binary classification | ||
>>> y_true_empty = [0, 0, 0, 0, 0, 0] | ||
>>> y_pred_empty = [0, 0, 0, 0, 0, 0] | ||
>>> f1_score(y_true_empty, y_pred_empty) | ||
np.float64(0.0...) | ||
0.0... | ||
>>> f1_score(y_true_empty, y_pred_empty, zero_division=1.0) | ||
np.float64(1.0...) | ||
1.0... | ||
>>> f1_score(y_true_empty, y_pred_empty, zero_division=np.nan) | ||
nan... | ||
|
||
|
@@ -1466,17 +1501,17 @@ def fbeta_score( | |
>>> y_true = [0, 1, 2, 0, 1, 2] | ||
>>> y_pred = [0, 2, 1, 0, 0, 1] | ||
>>> fbeta_score(y_true, y_pred, average='macro', beta=0.5) | ||
np.float64(0.23...) | ||
0.23... | ||
>>> fbeta_score(y_true, y_pred, average='micro', beta=0.5) | ||
np.float64(0.33...) | ||
0.33... | ||
>>> fbeta_score(y_true, y_pred, average='weighted', beta=0.5) | ||
np.float64(0.23...) | ||
0.23... | ||
>>> fbeta_score(y_true, y_pred, average=None, beta=0.5) | ||
array([0.71..., 0. , 0. ]) | ||
>>> y_pred_empty = [0, 0, 0, 0, 0, 0] | ||
>>> fbeta_score(y_true, y_pred_empty, | ||
... average="macro", zero_division=np.nan, beta=0.5) | ||
np.float64(0.12...) | ||
0.12... | ||
""" | ||
|
||
_, _, f, _ = precision_recall_fscore_support( | ||
|
@@ -1505,12 +1540,14 @@ def _prf_divide( | |
The metric, modifier and average arguments are used only for determining | ||
an appropriate warning. | ||
""" | ||
mask = denominator == 0.0 | ||
denominator = denominator.copy() | ||
xp, _ = get_namespace(numerator, denominator) | ||
dtype_float = _find_matching_floating_dtype(numerator, denominator, xp=xp) | ||
mask = denominator == 0 | ||
denominator = xp.asarray(denominator, copy=True, dtype=dtype_float) | ||
denominator[mask] = 1 # avoid infs/nans | ||
result = numerator / denominator | ||
result = xp.asarray(numerator, dtype=dtype_float) / denominator | ||
|
||
if not np.any(mask): | ||
if not xp.any(mask): | ||
return result | ||
|
||
# set those with 0 denominator to `zero_division`, and 0 when "warn" | ||
|
@@ -1559,7 +1596,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): | |
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||
# Convert to Python primitive type to avoid NumPy type / Python str | ||
# comparison. See https://github.com/numpy/numpy/issues/6784 | ||
present_labels = unique_labels(y_true, y_pred).tolist() | ||
present_labels = _tolist(unique_labels(y_true, y_pred)) | ||
if average == "binary": | ||
if y_type == "binary": | ||
if pos_label not in present_labels: | ||
|
@@ -1774,11 +1811,11 @@ def precision_recall_fscore_support( | |
>>> y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig']) | ||
>>> y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog']) | ||
>>> precision_recall_fscore_support(y_true, y_pred, average='macro') | ||
(np.float64(0.22...), np.float64(0.33...), np.float64(0.26...), None) | ||
(0.22..., 0.33..., 0.26..., None) | ||
>>> precision_recall_fscore_support(y_true, y_pred, average='micro') | ||
(np.float64(0.33...), np.float64(0.33...), np.float64(0.33...), None) | ||
(0.33..., 0.33..., 0.33..., None) | ||
>>> precision_recall_fscore_support(y_true, y_pred, average='weighted') | ||
(np.float64(0.22...), np.float64(0.33...), np.float64(0.26...), None) | ||
(0.22..., 0.33..., 0.26..., None) | ||
|
||
It is possible to compute per-label precisions, recalls, F1-scores and | ||
supports instead of averaging: | ||
|
@@ -1805,10 +1842,11 @@ def precision_recall_fscore_support( | |
pred_sum = tp_sum + MCM[:, 0, 1] | ||
true_sum = tp_sum + MCM[:, 1, 0] | ||
|
||
xp, _ = get_namespace(y_true, y_pred) | ||
if average == "micro": | ||
tp_sum = np.array([tp_sum.sum()]) | ||
pred_sum = np.array([pred_sum.sum()]) | ||
true_sum = np.array([true_sum.sum()]) | ||
tp_sum = xp.reshape(xp.sum(tp_sum), (1,)) | ||
pred_sum = xp.reshape(xp.sum(pred_sum), (1,)) | ||
true_sum = xp.reshape(xp.sum(true_sum), (1,)) | ||
|
||
# Finally, we have all our sufficient statistics. Divide! # | ||
beta2 = beta**2 | ||
|
@@ -1851,10 +1889,10 @@ def precision_recall_fscore_support( | |
weights = None | ||
|
||
if average is not None: | ||
assert average != "binary" or len(precision) == 1 | ||
precision = _nanaverage(precision, weights=weights) | ||
recall = _nanaverage(recall, weights=weights) | ||
f_score = _nanaverage(f_score, weights=weights) | ||
assert average != "binary" or precision.shape[0] == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm okay to leave this as is in this PR with this since it's existing code, but we really shouldn't be |
||
precision = float(_nanaverage(precision, weights=weights)) | ||
recall = float(_nanaverage(recall, weights=weights)) | ||
f_score = float(_nanaverage(f_score, weights=weights)) | ||
OmarManzoor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
true_sum = None # return no support | ||
|
||
return precision, recall, f_score, true_sum | ||
|
@@ -2185,11 +2223,11 @@ def precision_score( | |
>>> y_true = [0, 1, 2, 0, 1, 2] | ||
>>> y_pred = [0, 2, 1, 0, 0, 1] | ||
>>> precision_score(y_true, y_pred, average='macro') | ||
np.float64(0.22...) | ||
0.22... | ||
>>> precision_score(y_true, y_pred, average='micro') | ||
np.float64(0.33...) | ||
0.33... | ||
>>> precision_score(y_true, y_pred, average='weighted') | ||
np.float64(0.22...) | ||
0.22... | ||
>>> precision_score(y_true, y_pred, average=None) | ||
array([0.66..., 0. , 0. ]) | ||
>>> y_pred = [0, 0, 0, 0, 0, 0] | ||
|
@@ -2367,11 +2405,11 @@ def recall_score( | |
>>> y_true = [0, 1, 2, 0, 1, 2] | ||
>>> y_pred = [0, 2, 1, 0, 0, 1] | ||
>>> recall_score(y_true, y_pred, average='macro') | ||
np.float64(0.33...) | ||
0.33... | ||
>>> recall_score(y_true, y_pred, average='micro') | ||
np.float64(0.33...) | ||
0.33... | ||
>>> recall_score(y_true, y_pred, average='weighted') | ||
np.float64(0.33...) | ||
0.33... | ||
>>> recall_score(y_true, y_pred, average=None) | ||
array([1., 0., 0.]) | ||
>>> y_true = [0, 0, 0, 0, 0, 0] | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this was introduced in #27381 but was missing from the array API doc.