Skip to content
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Metrics
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
- :func:`sklearn.metrics.fbeta_score`
- :func:`sklearn.metrics.hamming_loss`
- :func:`sklearn.metrics.max_error`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_absolute_percentage_error`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/30838.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.hamming_loss` now support Array API compatible inputs.
By :user:`Thomas Li <lithomas1>`
21 changes: 10 additions & 11 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from ..utils._unique import attach_unique
from ..utils.extmath import _nanaverage
from ..utils.multiclass import type_of_target, unique_labels
from ..utils.sparsefuncs import count_nonzero
from ..utils.validation import (
_check_pos_label_consistency,
_check_sample_weight,
Expand Down Expand Up @@ -229,12 +228,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
check_consistent_length(y_true, y_pred, sample_weight)

if y_type.startswith("multilabel"):
if _is_numpy_namespace(xp):
differing_labels = count_nonzero(y_true - y_pred, axis=1)
else:
differing_labels = _count_nonzero(
y_true - y_pred, xp=xp, device=device, axis=1
)
differing_labels = _count_nonzero(y_true - y_pred, xp=xp, device=device, axis=1)
score = xp.asarray(differing_labels == 0, device=device)
else:
score = y_true == y_pred
Expand Down Expand Up @@ -2997,15 +2991,20 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)

xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)

if sample_weight is None:
weight_average = 1.0
else:
weight_average = np.mean(sample_weight)
sample_weight = xp.asarray(sample_weight, device=device)
weight_average = _average(sample_weight, xp=xp)

if y_type.startswith("multilabel"):
n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight)
return float(
n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average)
n_differences = _count_nonzero(
y_true - y_pred, xp=xp, device=device, sample_weight=sample_weight
)
return float(n_differences) / (
y_true.shape[0] * y_true.shape[1] * weight_average
)

elif y_type in ["binary", "multiclass"]:
Expand Down
5 changes: 5 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2139,6 +2139,11 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_multiclass_classification_metric,
check_array_api_multilabel_classification_metric,
],
hamming_loss: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
check_array_api_multilabel_classification_metric,
],
mean_tweedie_deviance: [check_array_api_regression_metric],
partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric],
partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric],
Expand Down