Skip to content

Commit fa65654

Browse files
TialoEdAbati
andauthored
ENH Array API for chi2_kernel (#29267)
Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com>
1 parent e82c14b commit fa65654

File tree

4 files changed

+8
-1
lines changed

4 files changed

+8
-1
lines changed

doc/modules/array_api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Metrics
121121
- :func:`sklearn.metrics.mean_squared_error`
122122
- :func:`sklearn.metrics.mean_tweedie_deviance`
123123
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
124+
- :func:`sklearn.metrics.pairwise.chi2_kernel`
124125
- :func:`sklearn.metrics.pairwise.cosine_similarity`
125126
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
126127
- :func:`sklearn.metrics.r2_score`

doc/whats_new/v1.6.rst

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ See :ref:`array_api` for more details.
4141
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
4242
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
4343
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
44+
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
4445
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
4546
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`.
4647

sklearn/metrics/pairwise.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1828,9 +1828,12 @@ def chi2_kernel(X, Y=None, gamma=1.0):
18281828
array([[0.36..., 0.13...],
18291829
[0.13..., 0.36...]])
18301830
"""
1831+
xp, _ = get_namespace(X, Y)
18311832
K = additive_chi2_kernel(X, Y)
18321833
K *= gamma
1833-
return np.exp(K, K)
1834+
if _is_numpy_namespace(xp):
1835+
return np.exp(K, out=K)
1836+
return xp.exp(K)
18341837

18351838

18361839
# Helper functions - distance

sklearn/metrics/tests/test_common.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from sklearn.metrics._base import _average_binary_score
5454
from sklearn.metrics.pairwise import (
5555
additive_chi2_kernel,
56+
chi2_kernel,
5657
cosine_similarity,
5758
paired_cosine_distances,
5859
)
@@ -1979,6 +1980,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
19791980
additive_chi2_kernel: [check_array_api_metric_pairwise],
19801981
mean_gamma_deviance: [check_array_api_regression_metric],
19811982
max_error: [check_array_api_regression_metric],
1983+
chi2_kernel: [check_array_api_metric_pairwise],
19821984
}
19831985

19841986

0 commit comments

Comments
 (0)