From 9d9b622b6b49dafc815c6943a5f54906337ce3b5 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:30:10 -0500 Subject: [PATCH 1/9] ENH: Add Array API support to hamming_loss --- doc/modules/array_api.rst | 1 + sklearn/metrics/_classification.py | 13 ++++++++++--- sklearn/metrics/tests/test_common.py | 5 +++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index b50815e1f7fb3..b1d1272e3b173 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -136,6 +136,7 @@ Metrics - :func:`sklearn.metrics.explained_variance_score` - :func:`sklearn.metrics.f1_score` - :func:`sklearn.metrics.fbeta_score` +- :func:`sklearn.metrics.hamming_loss` - :func:`sklearn.metrics.max_error` - :func:`sklearn.metrics.mean_absolute_error` - :func:`sklearn.metrics.mean_absolute_percentage_error` diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 2a08a1893766e..0a14aafbb7e36 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2997,15 +2997,22 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None): y_type, y_true, y_pred = _check_targets(y_true, y_pred) check_consistent_length(y_true, y_pred, sample_weight) + xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight) + if sample_weight is None: weight_average = 1.0 else: - weight_average = np.mean(sample_weight) + weight_average = xp.mean(sample_weight) if y_type.startswith("multilabel"): - n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight) + if _is_numpy_namespace(xp): + n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight) + else: + n_differences = _count_nonzero( + y_true - y_pred, xp=xp, device=device, sample_weight=sample_weight + ) return float( - n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average) + int(n_differences) / (y_true.shape[0] * y_true.shape[1] * weight_average) ) elif y_type in ["binary", "multiclass"]: diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 5f44e7b212105..af1973c095c86 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2139,6 +2139,11 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_multiclass_classification_metric, check_array_api_multilabel_classification_metric, ], + hamming_loss: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], mean_tweedie_deviance: [check_array_api_regression_metric], partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric], partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric], From ccd0adedf3428050113d26704b29ea144d8871f6 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:33:17 -0500 Subject: [PATCH 2/9] add whatsnew --- doc/whats_new/upcoming_changes/array-api/30838.feature.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/array-api/30838.feature.rst diff --git a/doc/whats_new/upcoming_changes/array-api/30838.feature.rst b/doc/whats_new/upcoming_changes/array-api/30838.feature.rst new file mode 100644 index 0000000000000..f733f1c6476a6 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/30838.feature.rst @@ -0,0 +1,2 @@ +- :func:`sklearn.metrics.hamming_loss` now support Array API compatible inputs. + By :user:`Thomas Li ` From 6157383522dbd199b5f2028498848ecea64c413d Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 15 Feb 2025 12:59:51 -0500 Subject: [PATCH 3/9] fixes --- sklearn/metrics/_classification.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 0a14aafbb7e36..4e6add789b219 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -3002,6 +3002,10 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None): if sample_weight is None: weight_average = 1.0 else: + if _is_numpy_namespace(xp): + # calling np.mean(torch.tensor([...])) crashes + # workaround for now + sample_weight = xp.asarray(sample_weight) weight_average = xp.mean(sample_weight) if y_type.startswith("multilabel"): @@ -3011,8 +3015,8 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None): n_differences = _count_nonzero( y_true - y_pred, xp=xp, device=device, sample_weight=sample_weight ) - return float( - int(n_differences) / (y_true.shape[0] * y_true.shape[1] * weight_average) + return float(n_differences) / ( + y_true.shape[0] * y_true.shape[1] * weight_average ) elif y_type in ["binary", "multiclass"]: From 264ef5336f254e955025c9d0d5e4b61360915072 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 22 Feb 2025 08:56:02 -0500 Subject: [PATCH 4/9] simplify from code review --- sklearn/metrics/_classification.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 4e6add789b219..e390646ee2691 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -3002,11 +3002,7 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None): if sample_weight is None: weight_average = 1.0 else: - if _is_numpy_namespace(xp): - # calling np.mean(torch.tensor([...])) crashes - # workaround for now - sample_weight = xp.asarray(sample_weight) - weight_average = xp.mean(sample_weight) + weight_average = _average(sample_weight, xp=xp) if y_type.startswith("multilabel"): if _is_numpy_namespace(xp): From 035d65096b153edf07c3e2791e64755d83d33f2d Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 25 Feb 2025 19:45:13 -0500 Subject: [PATCH 5/9] Update sklearn/metrics/_classification.py Co-authored-by: Virgil Chan --- sklearn/metrics/_classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index e390646ee2691..16f8a9a3d14b3 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -3002,6 +3002,7 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None): if sample_weight is None: weight_average = 1.0 else: + sample_weight = xp.asarray(sample_weight, device=device) weight_average = _average(sample_weight, xp=xp) if y_type.startswith("multilabel"): From 21fbe2b013edffe6c562a3e5a040f833d4df42d6 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 26 Feb 2025 07:22:42 -0500 Subject: [PATCH 6/9] Update sklearn/metrics/_classification.py Co-authored-by: Omar Salman --- sklearn/metrics/_classification.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 16f8a9a3d14b3..5f4f59afa8ac3 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -3006,12 +3006,9 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None): weight_average = _average(sample_weight, xp=xp) if y_type.startswith("multilabel"): - if _is_numpy_namespace(xp): - n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight) - else: - n_differences = _count_nonzero( - y_true - y_pred, xp=xp, device=device, sample_weight=sample_weight - ) + n_differences = _count_nonzero( + y_true - y_pred, xp=xp, device=device, sample_weight=sample_weight + ) return float(n_differences) / ( y_true.shape[0] * y_true.shape[1] * weight_average ) From c9fb17c9cfc217f9500c80fdca69cf61740af8d0 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 1 Mar 2025 21:10:27 -0500 Subject: [PATCH 7/9] cleanup --- sklearn/metrics/_classification.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 5f4f59afa8ac3..13d36acc1bc82 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -229,12 +229,9 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): check_consistent_length(y_true, y_pred, sample_weight) if y_type.startswith("multilabel"): - if _is_numpy_namespace(xp): - differing_labels = count_nonzero(y_true - y_pred, axis=1) - else: - differing_labels = _count_nonzero( - y_true - y_pred, xp=xp, device=device, axis=1 - ) + differing_labels = _count_nonzero( + y_true - y_pred, xp=xp, device=device, axis=1 + ) score = xp.asarray(differing_labels == 0, device=device) else: score = y_true == y_pred From 97fdc1e6de93127d38196c5aaf975f6cbb229158 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 1 Mar 2025 21:24:20 -0500 Subject: [PATCH 8/9] lint --- sklearn/metrics/_classification.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 13d36acc1bc82..2c73782752283 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -229,9 +229,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): check_consistent_length(y_true, y_pred, sample_weight) if y_type.startswith("multilabel"): - differing_labels = _count_nonzero( - y_true - y_pred, xp=xp, device=device, axis=1 - ) + differing_labels = _count_nonzero(y_true - y_pred, xp=xp, device=device, axis=1) score = xp.asarray(differing_labels == 0, device=device) else: score = y_true == y_pred From 5c13dac18d74fcd602611f89c02a54d4085e10b2 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 1 Mar 2025 21:27:04 -0500 Subject: [PATCH 9/9] more lint --- sklearn/metrics/_classification.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 2c73782752283..0fefbd529ee40 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -51,7 +51,6 @@ from ..utils._unique import attach_unique from ..utils.extmath import _nanaverage from ..utils.multiclass import type_of_target, unique_labels -from ..utils.sparsefuncs import count_nonzero from ..utils.validation import ( _check_pos_label_consistency, _check_sample_weight,