@@ -1444,9 +1444,10 @@ def test_averaging_multilabel_all_ones(name):
1444
1444
check_averaging (name , y_true , y_true_binarize , y_pred , y_pred_binarize , y_score )
1445
1445
1446
1446
1447
- def check_sample_weight_invariance (name , metric , y1 , y2 ):
1447
+ def check_sample_weight_invariance (name , metric , y1 , y2 , sample_weight = None ):
1448
1448
rng = np .random .RandomState (0 )
1449
- sample_weight = rng .randint (1 , 10 , size = len (y1 ))
1449
+ if sample_weight is None :
1450
+ sample_weight = rng .randint (1 , 10 , size = len (y1 ))
1450
1451
1451
1452
# top_k_accuracy_score always lead to a perfect score for k > 1 in the
1452
1453
# binary case
@@ -1550,13 +1551,14 @@ def check_sample_weight_invariance(name, metric, y1, y2):
1550
1551
)
1551
1552
def test_regression_sample_weight_invariance (name ):
1552
1553
n_samples = 51
1553
- random_state = check_random_state (1 )
1554
+ random_state = check_random_state (0 )
1554
1555
# regression
1555
1556
y_true = random_state .random_sample (size = (n_samples ,))
1556
1557
y_pred = random_state .random_sample (size = (n_samples ,))
1558
+ sample_weight = np .arange (len (y_true ))
1557
1559
metric = ALL_METRICS [name ]
1558
1560
1559
- check_sample_weight_invariance (name , metric , y_true , y_pred )
1561
+ check_sample_weight_invariance (name , metric , y_true , y_pred , sample_weight )
1560
1562
1561
1563
1562
1564
@pytest .mark .parametrize (
0 commit comments