@@ -1733,16 +1733,18 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
1733
1733
1734
1734
1735
1735
def check_array_api_metric (
1736
- metric , array_namespace , device , dtype , y_true_np , y_pred_np
1736
+ metric , array_namespace , device , dtype , y_true_np , y_pred_np , sample_weight = None
1737
1737
):
1738
1738
xp , device , dtype = _array_api_for_tests (array_namespace , device , dtype )
1739
1739
y_true_xp = xp .asarray (y_true_np , device = device )
1740
1740
y_pred_xp = xp .asarray (y_pred_np , device = device )
1741
1741
1742
- metric_np = metric (y_true_np , y_pred_np )
1742
+ metric_np = metric (y_true_np , y_pred_np , sample_weight = sample_weight )
1743
1743
1744
1744
with config_context (array_api_dispatch = True ):
1745
- metric_xp = metric (y_true_xp , y_pred_xp )
1745
+ if sample_weight is not None :
1746
+ sample_weight = xp .asarray (sample_weight , device = device )
1747
+ metric_xp = metric (y_true_xp , y_pred_xp , sample_weight = sample_weight )
1746
1748
1747
1749
assert_allclose (
1748
1750
metric_xp ,
@@ -1754,27 +1756,41 @@ def check_array_api_metric(
1754
1756
def check_array_api_binary_classification_metric (
1755
1757
metric , array_namespace , device , dtype
1756
1758
):
1757
- return check_array_api_metric (
1758
- metric ,
1759
- array_namespace ,
1760
- device ,
1761
- dtype ,
1762
- y_true_np = np .array ([0 , 0 , 1 , 1 ]),
1763
- y_pred_np = np .array ([0 , 1 , 0 , 1 ]),
1759
+ y_true_np = np .array ([0 , 0 , 1 , 1 ])
1760
+ y_pred_np = np .array ([0 , 1 , 0 , 1 ])
1761
+ check_array_api_metric (
1762
+ metric , array_namespace , device , dtype , y_true_np = y_true_np , y_pred_np = y_pred_np
1764
1763
)
1764
+ if "sample_weight" in signature (metric ).parameters :
1765
+ check_array_api_metric (
1766
+ metric ,
1767
+ array_namespace ,
1768
+ device ,
1769
+ dtype ,
1770
+ y_true_np = y_true_np ,
1771
+ y_pred_np = y_pred_np ,
1772
+ sample_weight = np .array ([0.0 , 0.1 , 2.0 , 1.0 ]),
1773
+ )
1765
1774
1766
1775
1767
1776
def check_array_api_multiclass_classification_metric (
1768
1777
metric , array_namespace , device , dtype
1769
1778
):
1770
- return check_array_api_metric (
1771
- metric ,
1772
- array_namespace ,
1773
- device ,
1774
- dtype ,
1775
- y_true_np = np .array ([0 , 1 , 2 , 3 ]),
1776
- y_pred_np = np .array ([0 , 1 , 0 , 2 ]),
1779
+ y_true_np = np .array ([0 , 1 , 2 , 3 ])
1780
+ y_pred_np = np .array ([0 , 1 , 0 , 2 ])
1781
+ check_array_api_metric (
1782
+ metric , array_namespace , device , dtype , y_true_np = y_true_np , y_pred_np = y_pred_np
1777
1783
)
1784
+ if "sample_weight" in signature (metric ).parameters :
1785
+ check_array_api_metric (
1786
+ metric ,
1787
+ array_namespace ,
1788
+ device ,
1789
+ dtype ,
1790
+ y_true_np = y_true_np ,
1791
+ y_pred_np = y_pred_np ,
1792
+ sample_weight = np .array ([0.0 , 0.1 , 2.0 , 1.0 ]),
1793
+ )
1778
1794
1779
1795
1780
1796
metric_checkers = {
0 commit comments