Skip to content

ENH Array API support for f1_score and multilabel_confusion_matrix #27369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9bcbc4c
ENH Array API support for f1_score
OmarManzoor Sep 14, 2023
87b65b9
Merge branch 'main' into f1_array_api
OmarManzoor May 17, 2024
17457b9
Merge branch 'main' into f1_array_api
OmarManzoor May 17, 2024
c80617f
Add array api support for f1_score
OmarManzoor May 20, 2024
649ce17
Add changelog
OmarManzoor May 20, 2024
8f0db56
Merge branch 'main' into f1_array_api
OmarManzoor May 20, 2024
6a02fcd
Fix sample weights in _bincount
OmarManzoor May 20, 2024
c150c9c
Add some fixes
OmarManzoor May 20, 2024
a01f2d7
Correct and add tests for nanmean
OmarManzoor May 20, 2024
aa2f521
Add options for testing with various average values
OmarManzoor May 20, 2024
75c7d5a
Use reshape when creating arrays in micro average
OmarManzoor May 20, 2024
8b21b51
Add LabelEncoder and f1_score in array_api.rst
OmarManzoor May 27, 2024
bc8c2df
Merge branch 'main' into f1_array_api
OmarManzoor May 27, 2024
ef33cf6
Merge branch 'main' into f1_array_api
ogrisel Jun 5, 2024
91ab0d5
Merge branch 'main' into f1_array_api
OmarManzoor Jun 6, 2024
696e65b
Update: PR suggestions
OmarManzoor Jun 6, 2024
d0b647b
Use xp.reshape with (1,)
OmarManzoor Jun 6, 2024
842e269
Simplify count in _nanmean
OmarManzoor Jun 6, 2024
6e9596e
Merge branch 'main' into f1_array_api
OmarManzoor Jun 6, 2024
5cd9a11
Merge branch 'main' into f1_array_api
OmarManzoor Jun 7, 2024
2c3cc32
Merge branch 'main' into f1_array_api
OmarManzoor Jun 14, 2024
5c73766
Merge branch 'main' into f1_array_api
OmarManzoor Jun 25, 2024
0df1e0f
Add multilabel confusion metrics as it seems to work
OmarManzoor Jun 25, 2024
3d42289
Merge branch 'main' into f1_array_api
OmarManzoor Jul 2, 2024
78e4f31
Handle multi-label case
OmarManzoor Jul 2, 2024
74ccf6a
Fix commented tests
OmarManzoor Jul 2, 2024
6428b8e
Merge branch 'main' into f1_array_api
ogrisel Jul 10, 2024
03a2432
Merge branch 'main' into f1_array_api
OmarManzoor Jul 11, 2024
3b3555b
Fix errors because of update in numpy
OmarManzoor Jul 11, 2024
fd79874
Merge branch 'main' into f1_array_api
OmarManzoor Sep 16, 2024
ecc0fb1
Minor updates
OmarManzoor Sep 16, 2024
c57398f
Update the conversion to float
OmarManzoor Sep 16, 2024
e6970cb
Fix the doctests
OmarManzoor Sep 18, 2024
1461ad7
Fix f1 score doctests
OmarManzoor Sep 18, 2024
def2e42
Merge branch 'main' into f1_array_api
OmarManzoor Sep 30, 2024
c828e4d
Merge branch 'main' into f1_array_api
OmarManzoor Oct 24, 2024
a21a875
Add new changelog
OmarManzoor Oct 24, 2024
0c80235
Add docstring
OmarManzoor Oct 24, 2024
7b8fb15
Add xp after array param in _bincount
OmarManzoor Oct 24, 2024
82cf2bc
Merge branch 'main' into f1_array_api
OmarManzoor Nov 6, 2024
24d91a4
Refactor based on PR suggestions
OmarManzoor Nov 6, 2024
c71c974
Handle the _tolist method and add doc for device param
OmarManzoor Nov 6, 2024
48b24ce
Use _convert_to_numpy is _tolist
OmarManzoor Nov 7, 2024
6679ae3
Further PR suggestions
OmarManzoor Nov 7, 2024
fad41ad
Add sparse check specifically when multiplying
OmarManzoor Nov 7, 2024
8c6c4a0
Merge branch 'main' into f1_array_api
OmarManzoor Nov 7, 2024
0912004
Merge branch 'main' into f1_array_api
OmarManzoor Nov 15, 2024
88a48e0
Update 'and' to 'or' condition
OmarManzoor Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Estimators
- :class:`linear_model.Ridge` (with `solver="svd"`)
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
- :class:`preprocessing.KernelCenterer`
- :class:`preprocessing.LabelEncoder`
Copy link
Member

@ogrisel ogrisel Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this was introduced in #27381 but was missing from the array API doc.

- :class:`preprocessing.MaxAbsScaler`
- :class:`preprocessing.MinMaxScaler`
- :class:`preprocessing.Normalizer`
Expand All @@ -115,6 +116,7 @@ Metrics
- :func:`sklearn.metrics.cluster.entropy`
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.f1_score`
- :func:`sklearn.metrics.max_error`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_absolute_percentage_error`
Expand All @@ -123,6 +125,7 @@ Metrics
- :func:`sklearn.metrics.mean_squared_error`
- :func:`sklearn.metrics.mean_squared_log_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.multilabel_confusion_matrix`
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
- :func:`sklearn.metrics.pairwise.chi2_kernel`
- :func:`sklearn.metrics.pairwise.cosine_similarity`
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/27369.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :func:`sklearn.metrics.f1_score` now supports Array API compatible
inputs.
By :user:`Omar Salman <OmarManzoor>`
166 changes: 102 additions & 64 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from numbers import Integral, Real

import numpy as np
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse import coo_matrix, csr_matrix, issparse
from scipy.special import xlogy

from ..exceptions import UndefinedMetricWarning
Expand All @@ -28,9 +28,15 @@
)
from ..utils._array_api import (
_average,
_bincount,
_count_nonzero,
_find_matching_floating_dtype,
_is_numpy_namespace,
_searchsorted,
_setdiff1d,
_tolist,
_union1d,
device,
get_namespace,
get_namespace_and_device,
)
Expand Down Expand Up @@ -521,9 +527,11 @@ def multilabel_confusion_matrix(
[1, 2]]])
"""
y_true, y_pred = attach_unique(y_true, y_pred)
xp, _ = get_namespace(y_true, y_pred)
device_ = device(y_true, y_pred)
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
sample_weight = column_or_1d(sample_weight, device=device_)
check_consistent_length(y_true, y_pred, sample_weight)

if y_type not in ("binary", "multiclass", "multilabel-indicator"):
Expand All @@ -534,9 +542,11 @@ def multilabel_confusion_matrix(
labels = present_labels
n_labels = None
else:
n_labels = len(labels)
labels = np.hstack(
[labels, np.setdiff1d(present_labels, labels, assume_unique=True)]
labels = xp.asarray(labels, device=device_)
n_labels = labels.shape[0]
labels = xp.concat(
[labels, _setdiff1d(present_labels, labels, assume_unique=True, xp=xp)],
axis=-1,
)

if y_true.ndim == 1:
Expand All @@ -556,77 +566,102 @@ def multilabel_confusion_matrix(
tp = y_true == y_pred
tp_bins = y_true[tp]
if sample_weight is not None:
tp_bins_weights = np.asarray(sample_weight)[tp]
tp_bins_weights = sample_weight[tp]
else:
tp_bins_weights = None

if len(tp_bins):
tp_sum = np.bincount(
tp_bins, weights=tp_bins_weights, minlength=len(labels)
if tp_bins.shape[0]:
tp_sum = _bincount(
tp_bins, weights=tp_bins_weights, minlength=labels.shape[0], xp=xp
)
else:
# Pathological case
true_sum = pred_sum = tp_sum = np.zeros(len(labels))
if len(y_pred):
pred_sum = np.bincount(y_pred, weights=sample_weight, minlength=len(labels))
if len(y_true):
true_sum = np.bincount(y_true, weights=sample_weight, minlength=len(labels))
true_sum = pred_sum = tp_sum = xp.zeros(labels.shape[0])
if y_pred.shape[0]:
pred_sum = _bincount(
y_pred, weights=sample_weight, minlength=labels.shape[0], xp=xp
)
if y_true.shape[0]:
true_sum = _bincount(
y_true, weights=sample_weight, minlength=labels.shape[0], xp=xp
)

# Retain only selected labels
indices = np.searchsorted(sorted_labels, labels[:n_labels])
tp_sum = tp_sum[indices]
true_sum = true_sum[indices]
pred_sum = pred_sum[indices]
indices = _searchsorted(sorted_labels, labels[:n_labels], xp=xp)
tp_sum = xp.take(tp_sum, indices, axis=0)
true_sum = xp.take(true_sum, indices, axis=0)
pred_sum = xp.take(pred_sum, indices, axis=0)

else:
sum_axis = 1 if samplewise else 0

# All labels are index integers for multilabel.
# Select labels:
if not np.array_equal(labels, present_labels):
if np.max(labels) > np.max(present_labels):
if labels.shape != present_labels.shape or xp.any(
xp.not_equal(labels, present_labels)
):
if xp.max(labels) > xp.max(present_labels):
raise ValueError(
"All labels must be in [0, n labels) for "
"multilabel targets. "
"Got %d > %d" % (np.max(labels), np.max(present_labels))
"Got %d > %d" % (xp.max(labels), xp.max(present_labels))
)
if np.min(labels) < 0:
if xp.min(labels) < 0:
raise ValueError(
"All labels must be in [0, n labels) for "
"multilabel targets. "
"Got %d < 0" % np.min(labels)
"Got %d < 0" % xp.min(labels)
)

if n_labels is not None:
y_true = y_true[:, labels[:n_labels]]
y_pred = y_pred[:, labels[:n_labels]]

if issparse(y_true) or issparse(y_pred):
true_and_pred = y_true.multiply(y_pred)
else:
true_and_pred = xp.multiply(y_true, y_pred)

# calculate weighted counts
true_and_pred = y_true.multiply(y_pred)
tp_sum = count_nonzero(
true_and_pred, axis=sum_axis, sample_weight=sample_weight
tp_sum = _count_nonzero(
true_and_pred,
axis=sum_axis,
sample_weight=sample_weight,
xp=xp,
device=device_,
)
pred_sum = _count_nonzero(
y_pred,
axis=sum_axis,
sample_weight=sample_weight,
xp=xp,
device=device_,
)
true_sum = _count_nonzero(
y_true,
axis=sum_axis,
sample_weight=sample_weight,
xp=xp,
device=device_,
)
pred_sum = count_nonzero(y_pred, axis=sum_axis, sample_weight=sample_weight)
true_sum = count_nonzero(y_true, axis=sum_axis, sample_weight=sample_weight)

fp = pred_sum - tp_sum
fn = true_sum - tp_sum
tp = tp_sum

if sample_weight is not None and samplewise:
sample_weight = np.array(sample_weight)
tp = np.array(tp)
fp = np.array(fp)
fn = np.array(fn)
tp = xp.asarray(tp)
fp = xp.asarray(fp)
fn = xp.asarray(fn)
tn = sample_weight * y_true.shape[1] - tp - fp - fn
elif sample_weight is not None:
tn = sum(sample_weight) - tp - fp - fn
tn = xp.sum(sample_weight) - tp - fp - fn
elif samplewise:
tn = y_true.shape[1] - tp - fp - fn
else:
tn = y_true.shape[0] - tp - fp - fn

return np.array([tn, fp, fn, tp]).T.reshape(-1, 2, 2)
return xp.reshape(xp.stack([tn, fp, fn, tp]).T, (-1, 2, 2))


@validate_params(
Expand Down Expand Up @@ -1262,21 +1297,21 @@ def f1_score(
>>> y_true = [0, 1, 2, 0, 1, 2]
>>> y_pred = [0, 2, 1, 0, 0, 1]
>>> f1_score(y_true, y_pred, average='macro')
np.float64(0.26...)
0.26...
>>> f1_score(y_true, y_pred, average='micro')
np.float64(0.33...)
0.33...
>>> f1_score(y_true, y_pred, average='weighted')
np.float64(0.26...)
0.26...
>>> f1_score(y_true, y_pred, average=None)
array([0.8, 0. , 0. ])

>>> # binary classification
>>> y_true_empty = [0, 0, 0, 0, 0, 0]
>>> y_pred_empty = [0, 0, 0, 0, 0, 0]
>>> f1_score(y_true_empty, y_pred_empty)
np.float64(0.0...)
0.0...
>>> f1_score(y_true_empty, y_pred_empty, zero_division=1.0)
np.float64(1.0...)
1.0...
>>> f1_score(y_true_empty, y_pred_empty, zero_division=np.nan)
nan...

Expand Down Expand Up @@ -1466,17 +1501,17 @@ def fbeta_score(
>>> y_true = [0, 1, 2, 0, 1, 2]
>>> y_pred = [0, 2, 1, 0, 0, 1]
>>> fbeta_score(y_true, y_pred, average='macro', beta=0.5)
np.float64(0.23...)
0.23...
>>> fbeta_score(y_true, y_pred, average='micro', beta=0.5)
np.float64(0.33...)
0.33...
>>> fbeta_score(y_true, y_pred, average='weighted', beta=0.5)
np.float64(0.23...)
0.23...
>>> fbeta_score(y_true, y_pred, average=None, beta=0.5)
array([0.71..., 0. , 0. ])
>>> y_pred_empty = [0, 0, 0, 0, 0, 0]
>>> fbeta_score(y_true, y_pred_empty,
... average="macro", zero_division=np.nan, beta=0.5)
np.float64(0.12...)
0.12...
"""

_, _, f, _ = precision_recall_fscore_support(
Expand Down Expand Up @@ -1505,12 +1540,14 @@ def _prf_divide(
The metric, modifier and average arguments are used only for determining
an appropriate warning.
"""
mask = denominator == 0.0
denominator = denominator.copy()
xp, _ = get_namespace(numerator, denominator)
dtype_float = _find_matching_floating_dtype(numerator, denominator, xp=xp)
mask = denominator == 0
denominator = xp.asarray(denominator, copy=True, dtype=dtype_float)
denominator[mask] = 1 # avoid infs/nans
result = numerator / denominator
result = xp.asarray(numerator, dtype=dtype_float) / denominator

if not np.any(mask):
if not xp.any(mask):
return result

# set those with 0 denominator to `zero_division`, and 0 when "warn"
Expand Down Expand Up @@ -1559,7 +1596,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
# Convert to Python primitive type to avoid NumPy type / Python str
# comparison. See https://github.com/numpy/numpy/issues/6784
present_labels = unique_labels(y_true, y_pred).tolist()
present_labels = _tolist(unique_labels(y_true, y_pred))
if average == "binary":
if y_type == "binary":
if pos_label not in present_labels:
Expand Down Expand Up @@ -1774,11 +1811,11 @@ def precision_recall_fscore_support(
>>> y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig'])
>>> y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog'])
>>> precision_recall_fscore_support(y_true, y_pred, average='macro')
(np.float64(0.22...), np.float64(0.33...), np.float64(0.26...), None)
(0.22..., 0.33..., 0.26..., None)
>>> precision_recall_fscore_support(y_true, y_pred, average='micro')
(np.float64(0.33...), np.float64(0.33...), np.float64(0.33...), None)
(0.33..., 0.33..., 0.33..., None)
>>> precision_recall_fscore_support(y_true, y_pred, average='weighted')
(np.float64(0.22...), np.float64(0.33...), np.float64(0.26...), None)
(0.22..., 0.33..., 0.26..., None)

It is possible to compute per-label precisions, recalls, F1-scores and
supports instead of averaging:
Expand All @@ -1805,10 +1842,11 @@ def precision_recall_fscore_support(
pred_sum = tp_sum + MCM[:, 0, 1]
true_sum = tp_sum + MCM[:, 1, 0]

xp, _ = get_namespace(y_true, y_pred)
if average == "micro":
tp_sum = np.array([tp_sum.sum()])
pred_sum = np.array([pred_sum.sum()])
true_sum = np.array([true_sum.sum()])
tp_sum = xp.reshape(xp.sum(tp_sum), (1,))
pred_sum = xp.reshape(xp.sum(pred_sum), (1,))
true_sum = xp.reshape(xp.sum(true_sum), (1,))

# Finally, we have all our sufficient statistics. Divide! #
beta2 = beta**2
Expand Down Expand Up @@ -1851,10 +1889,10 @@ def precision_recall_fscore_support(
weights = None

if average is not None:
assert average != "binary" or len(precision) == 1
precision = _nanaverage(precision, weights=weights)
recall = _nanaverage(recall, weights=weights)
f_score = _nanaverage(f_score, weights=weights)
assert average != "binary" or precision.shape[0] == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay to leave this as is in this PR with this since it's existing code, but we really shouldn't be assert ing here. If this never happens, then the line shouldn't be here, if it can happen, we should raise a meaningful error.

precision = float(_nanaverage(precision, weights=weights))
recall = float(_nanaverage(recall, weights=weights))
f_score = float(_nanaverage(f_score, weights=weights))
true_sum = None # return no support

return precision, recall, f_score, true_sum
Expand Down Expand Up @@ -2185,11 +2223,11 @@ def precision_score(
>>> y_true = [0, 1, 2, 0, 1, 2]
>>> y_pred = [0, 2, 1, 0, 0, 1]
>>> precision_score(y_true, y_pred, average='macro')
np.float64(0.22...)
0.22...
>>> precision_score(y_true, y_pred, average='micro')
np.float64(0.33...)
0.33...
>>> precision_score(y_true, y_pred, average='weighted')
np.float64(0.22...)
0.22...
>>> precision_score(y_true, y_pred, average=None)
array([0.66..., 0. , 0. ])
>>> y_pred = [0, 0, 0, 0, 0, 0]
Expand Down Expand Up @@ -2367,11 +2405,11 @@ def recall_score(
>>> y_true = [0, 1, 2, 0, 1, 2]
>>> y_pred = [0, 2, 1, 0, 0, 1]
>>> recall_score(y_true, y_pred, average='macro')
np.float64(0.33...)
0.33...
>>> recall_score(y_true, y_pred, average='micro')
np.float64(0.33...)
0.33...
>>> recall_score(y_true, y_pred, average='weighted')
np.float64(0.33...)
0.33...
>>> recall_score(y_true, y_pred, average=None)
array([1., 0., 0.])
>>> y_true = [0, 0, 0, 0, 0, 0]
Expand Down
Loading