From e74881d6642345c2850c590424291ddddb4a3473 Mon Sep 17 00:00:00 2001 From: Tialo Date: Thu, 30 May 2024 23:30:57 +0300 Subject: [PATCH 1/8] fix --- sklearn/metrics/_regression.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 61bb1caa2d9da..288b7709cb661 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -139,10 +139,10 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None): multioutput = check_array(multioutput, ensure_2d=False) if n_outputs == 1: raise ValueError("Custom weights are useful only in multi-output cases.") - elif n_outputs != len(multioutput): + elif n_outputs != multioutput.shape[0]: raise ValueError( - "There must be equally many custom weights (%d) as outputs (%d)." - % (len(multioutput), n_outputs) + f"There must be equally many custom weights " + f"({multioutput.shape[0]}) as outputs ({n_outputs})." ) y_type = "continuous" if n_outputs == 1 else "continuous-multioutput" From ba4f351c9715ccf6edd9ddfba40bbe80d8c4c84d Mon Sep 17 00:00:00 2001 From: Tialo Date: Thu, 30 May 2024 23:33:27 +0300 Subject: [PATCH 2/8] remove f --- sklearn/metrics/_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 288b7709cb661..ba436466e2f71 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -141,7 +141,7 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None): raise ValueError("Custom weights are useful only in multi-output cases.") elif n_outputs != multioutput.shape[0]: raise ValueError( - f"There must be equally many custom weights " + "There must be equally many custom weights " f"({multioutput.shape[0]}) as outputs ({n_outputs})." ) y_type = "continuous" if n_outputs == 1 else "continuous-multioutput" From d68b63bd08caa5a59587b84e74e4d04c9113e759 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Sat, 15 Jun 2024 11:40:03 +0200 Subject: [PATCH 3/8] Add test --- sklearn/metrics/tests/test_common.py | 35 +++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 42f2a36445642..1845607e7bb87 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1758,6 +1758,12 @@ def check_array_api_metric( metric_kwargs["sample_weight"], device=device ) + multioutput = metric_kwargs.get("multioutput") + if isinstance(multioutput, np.ndarray): + metric_kwargs["multioutput"] = xp.asarray( + metric_kwargs["multioutput"], device=device + ) + with config_context(array_api_dispatch=True): metric_xp = metric(a_xp, b_xp, **metric_kwargs) @@ -1856,8 +1862,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam def check_array_api_regression_metric_multioutput( metric, array_namespace, device, dtype_name ): - y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name) - y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name) + y_true_np = np.array([[1, 3, 2], [1, 2, 2]], dtype=dtype_name) + y_pred_np = np.array([[1, 4, 4], [1, 1, 1]], dtype=dtype_name) check_array_api_metric( metric, @@ -1881,12 +1887,25 @@ def check_array_api_regression_metric_multioutput( sample_weight=sample_weight, ) + check_array_api_metric( + metric, + array_namespace, + device, + dtype_name, + a_np=y_true_np, + b_np=y_pred_np, + multioutput=np.array([0.1, 0.3, 0.7], dtype=dtype_name), + ) -def check_array_api_multioutput_regression_metric( - metric, array_namespace, device, dtype_name -): - metric = partial(metric, multioutput="raw_values") - check_array_api_regression_metric(metric, array_namespace, device, dtype_name) + metric_multioutput = partial(metric, multioutput="raw_values") + check_array_api_metric( + metric_multioutput, + array_namespace, + device, + dtype_name, + a_np=y_true_np, + b_np=y_pred_np, + ) def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name): @@ -1936,7 +1955,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) cosine_similarity: [check_array_api_metric_pairwise], mean_absolute_error: [ check_array_api_regression_metric, - check_array_api_multioutput_regression_metric, + check_array_api_regression_metric_multioutput, ], } From 83d5f7d457dbf449008a11b7ea65cbda419885da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Mon, 17 Jun 2024 10:39:50 +0200 Subject: [PATCH 4/8] Fix after merge --- sklearn/metrics/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index ae6f441b0922d..9fd650263f0be 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1972,7 +1972,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) ], mean_squared_error: [ check_array_api_regression_metric, - check_array_api_multioutput_regression_metric, + check_array_api_regression_metric_multioutput, ], d2_tweedie_score: [ check_array_api_regression_metric, From 54088bdb46c7909da996e200d1b2631121449c37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Mon, 17 Jun 2024 11:47:58 +0200 Subject: [PATCH 5/8] simplify test --- sklearn/metrics/tests/test_common.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 9fd650263f0be..ca72c9834fac9 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1763,10 +1763,10 @@ def check_array_api_metric( ) multioutput = metric_kwargs.get("multioutput") - if isinstance(multioutput, np.ndarray): - metric_kwargs["multioutput"] = xp.asarray( - metric_kwargs["multioutput"], device=device - ) + if multioutput is not None: + if isinstance(multioutput, np.ndarray): + multioutput = xp.asarray(multioutput, device=device) + metric_kwargs["multioutput"] = multioutput with config_context(array_api_dispatch=True): metric_xp = metric(a_xp, b_xp, **metric_kwargs) @@ -1910,14 +1910,14 @@ def check_array_api_regression_metric_multioutput( multioutput=np.array([0.1, 0.3, 0.7], dtype=dtype_name), ) - metric_multioutput = partial(metric, multioutput="raw_values") check_array_api_metric( - metric_multioutput, + metric, array_namespace, device, dtype_name, a_np=y_true_np, b_np=y_pred_np, + multioutput="raw_values", ) From 94aa797916652f1049e487692e64194e77399981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Tue, 18 Jun 2024 07:04:44 +0200 Subject: [PATCH 6/8] Add changelog --- doc/whats_new/v1.6.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index fd43347cf7ac8..fa8b4c7cb9e64 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -35,7 +35,8 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.cluster.entropy` :pr:`29141` by :user:`Yaroslav Korobko `; - :func:`sklearn.metrics.d2_tweedie_score` :pr:`29207` by :user:`Emily Chen `; - :func:`sklearn.metrics.max_error` :pr:`29212` by :user:`Edoardo Abati `; -- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `; +- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati ` + and :pr:`29143` by :user:`Tialo ` and :user:`Loïc Estève `; - :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :usser:`Emily Chen `; - :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko `; - :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li `; From 1d7baf20d8377cad775137306628bf59de36e30c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 20 Jun 2024 09:33:17 +0200 Subject: [PATCH 7/8] Update sklearn/metrics/tests/test_common.py Co-authored-by: Omar Salman --- sklearn/metrics/tests/test_common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index ca72c9834fac9..f12d8ed4c2164 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1765,8 +1765,7 @@ def check_array_api_metric( multioutput = metric_kwargs.get("multioutput") if multioutput is not None: if isinstance(multioutput, np.ndarray): - multioutput = xp.asarray(multioutput, device=device) - metric_kwargs["multioutput"] = multioutput + metric_kwargs["multioutput"] = xp.asarray(multioutput, device=device) with config_context(array_api_dispatch=True): metric_xp = metric(a_xp, b_xp, **metric_kwargs) From 2535a13f2d415a2a2dbdeff560fbf9a4c621cef9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 20 Jun 2024 09:35:45 +0200 Subject: [PATCH 8/8] Further simplification --- sklearn/metrics/tests/test_common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index f12d8ed4c2164..b34d4dcd4a256 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1763,9 +1763,8 @@ def check_array_api_metric( ) multioutput = metric_kwargs.get("multioutput") - if multioutput is not None: - if isinstance(multioutput, np.ndarray): - metric_kwargs["multioutput"] = xp.asarray(multioutput, device=device) + if isinstance(multioutput, np.ndarray): + metric_kwargs["multioutput"] = xp.asarray(multioutput, device=device) with config_context(array_api_dispatch=True): metric_xp = metric(a_xp, b_xp, **metric_kwargs)