From 2e9796e97c9f5bd895023874d3a25bca4cfc06be Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 7 Feb 2025 16:31:31 +1100 Subject: [PATCH 1/5] add len check --- sklearn/metrics/_regression.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 65a3073f3691c..dce9f0978ff5b 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -907,9 +907,11 @@ def median_absolute_error( >>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85 """ - y_type, y_true, y_pred, multioutput = _check_reg_targets( + _, y_true, y_pred, multioutput = _check_reg_targets( y_true, y_pred, multioutput ) + check_consistent_length(y_true, y_pred, sample_weight) + if sample_weight is None: output_errors = np.median(np.abs(y_pred - y_true), axis=0) else: From 890aa74ec56105a880614750cc9378704421ebe2 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 7 Feb 2025 16:50:48 +1100 Subject: [PATCH 2/5] fix test list --- sklearn/metrics/tests/test_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 9e8d0ce116394..7a2ec2bbf7a41 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -552,7 +552,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # No Sample weight support METRICS_WITHOUT_SAMPLE_WEIGHT = { - "median_absolute_error", "max_error", "ovo_roc_auc", "weighted_ovo_roc_auc", @@ -1556,6 +1555,8 @@ def test_regression_sample_weight_invariance(name): y_true = random_state.random_sample(size=(n_samples,)) y_pred = random_state.random_sample(size=(n_samples,)) metric = ALL_METRICS[name] + print(f'XX {metric=}') + check_sample_weight_invariance(name, metric, y_true, y_pred) From 5bfc5b6f4d362a9e7774ca1b35b216ae0b7fe96c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Sat, 8 Feb 2025 19:57:10 +1100 Subject: [PATCH 3/5] fix tests --- sklearn/metrics/tests/test_common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 7a2ec2bbf7a41..f344f1f2ca1fe 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1549,13 +1549,12 @@ def check_sample_weight_invariance(name, metric, y1, y2): ), ) def test_regression_sample_weight_invariance(name): - n_samples = 50 - random_state = check_random_state(0) + n_samples = 51 + random_state = check_random_state(1) # regression y_true = random_state.random_sample(size=(n_samples,)) y_pred = random_state.random_sample(size=(n_samples,)) metric = ALL_METRICS[name] - print(f'XX {metric=}') check_sample_weight_invariance(name, metric, y_true, y_pred) From 967212cf3eb6ef1d93782f514ea7c4252a2cf1dd Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Sat, 8 Feb 2025 21:49:06 +1100 Subject: [PATCH 4/5] lint --- sklearn/metrics/_regression.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index dce9f0978ff5b..40b514835a1bf 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -907,9 +907,7 @@ def median_absolute_error( >>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85 """ - _, y_true, y_pred, multioutput = _check_reg_targets( - y_true, y_pred, multioutput - ) + _, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput) check_consistent_length(y_true, y_pred, sample_weight) if sample_weight is None: From 92884bcc0da7d615fcc98a1355e1596d859ae9c6 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 13 Feb 2025 09:42:13 +1100 Subject: [PATCH 5/5] use arange sample weight --- sklearn/metrics/tests/test_common.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index f344f1f2ca1fe..b91cb7c9a11e5 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1444,9 +1444,10 @@ def test_averaging_multilabel_all_ones(name): check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score) -def check_sample_weight_invariance(name, metric, y1, y2): +def check_sample_weight_invariance(name, metric, y1, y2, sample_weight=None): rng = np.random.RandomState(0) - sample_weight = rng.randint(1, 10, size=len(y1)) + if sample_weight is None: + sample_weight = rng.randint(1, 10, size=len(y1)) # top_k_accuracy_score always lead to a perfect score for k > 1 in the # binary case @@ -1550,13 +1551,14 @@ def check_sample_weight_invariance(name, metric, y1, y2): ) def test_regression_sample_weight_invariance(name): n_samples = 51 - random_state = check_random_state(1) + random_state = check_random_state(0) # regression y_true = random_state.random_sample(size=(n_samples,)) y_pred = random_state.random_sample(size=(n_samples,)) + sample_weight = np.arange(len(y_true)) metric = ALL_METRICS[name] - check_sample_weight_invariance(name, metric, y_true, y_pred) + check_sample_weight_invariance(name, metric, y_true, y_pred, sample_weight) @pytest.mark.parametrize(