Skip to content

Commit 8f3136b

Browse files
EdAbatiogrisel
andauthored
TST Test Array API-compatible metrics with sample_weight (#27335)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 1f1329f commit 8f3136b

File tree

1 file changed

+33
-17
lines changed

1 file changed

+33
-17
lines changed

sklearn/metrics/tests/test_common.py

+33-17
Original file line numberDiff line numberDiff line change
@@ -1733,16 +1733,18 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
17331733

17341734

17351735
def check_array_api_metric(
1736-
metric, array_namespace, device, dtype, y_true_np, y_pred_np
1736+
metric, array_namespace, device, dtype, y_true_np, y_pred_np, sample_weight=None
17371737
):
17381738
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
17391739
y_true_xp = xp.asarray(y_true_np, device=device)
17401740
y_pred_xp = xp.asarray(y_pred_np, device=device)
17411741

1742-
metric_np = metric(y_true_np, y_pred_np)
1742+
metric_np = metric(y_true_np, y_pred_np, sample_weight=sample_weight)
17431743

17441744
with config_context(array_api_dispatch=True):
1745-
metric_xp = metric(y_true_xp, y_pred_xp)
1745+
if sample_weight is not None:
1746+
sample_weight = xp.asarray(sample_weight, device=device)
1747+
metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
17461748

17471749
assert_allclose(
17481750
metric_xp,
@@ -1754,27 +1756,41 @@ def check_array_api_metric(
17541756
def check_array_api_binary_classification_metric(
17551757
metric, array_namespace, device, dtype
17561758
):
1757-
return check_array_api_metric(
1758-
metric,
1759-
array_namespace,
1760-
device,
1761-
dtype,
1762-
y_true_np=np.array([0, 0, 1, 1]),
1763-
y_pred_np=np.array([0, 1, 0, 1]),
1759+
y_true_np = np.array([0, 0, 1, 1])
1760+
y_pred_np = np.array([0, 1, 0, 1])
1761+
check_array_api_metric(
1762+
metric, array_namespace, device, dtype, y_true_np=y_true_np, y_pred_np=y_pred_np
17641763
)
1764+
if "sample_weight" in signature(metric).parameters:
1765+
check_array_api_metric(
1766+
metric,
1767+
array_namespace,
1768+
device,
1769+
dtype,
1770+
y_true_np=y_true_np,
1771+
y_pred_np=y_pred_np,
1772+
sample_weight=np.array([0.0, 0.1, 2.0, 1.0]),
1773+
)
17651774

17661775

17671776
def check_array_api_multiclass_classification_metric(
17681777
metric, array_namespace, device, dtype
17691778
):
1770-
return check_array_api_metric(
1771-
metric,
1772-
array_namespace,
1773-
device,
1774-
dtype,
1775-
y_true_np=np.array([0, 1, 2, 3]),
1776-
y_pred_np=np.array([0, 1, 0, 2]),
1779+
y_true_np = np.array([0, 1, 2, 3])
1780+
y_pred_np = np.array([0, 1, 0, 2])
1781+
check_array_api_metric(
1782+
metric, array_namespace, device, dtype, y_true_np=y_true_np, y_pred_np=y_pred_np
17771783
)
1784+
if "sample_weight" in signature(metric).parameters:
1785+
check_array_api_metric(
1786+
metric,
1787+
array_namespace,
1788+
device,
1789+
dtype,
1790+
y_true_np=y_true_np,
1791+
y_pred_np=y_pred_np,
1792+
sample_weight=np.array([0.0, 0.1, 2.0, 1.0]),
1793+
)
17781794

17791795

17801796
metric_checkers = {

0 commit comments

Comments
 (0)