Skip to content

Commit e29a7a8

Browse files
authored
Merge pull request scipy#22843 from crusaderky/special_xp_capabilities
DOC: `special`: add `xp_capabilities` to logsumexp
2 parents 4e30da0 + 7577507 commit e29a7a8

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

scipy/special/_logsumexp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from scipy._lib._array_api import (
33
array_namespace,
4+
xp_capabilities,
45
xp_device,
56
xp_size,
67
xp_promote,
@@ -11,6 +12,7 @@
1112
__all__ = ["logsumexp", "softmax", "log_softmax"]
1213

1314

15+
@xp_capabilities()
1416
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
1517
"""Compute the log of the sum of exponentials of input elements.
1618
@@ -58,7 +60,8 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
5860
5961
See Also
6062
--------
61-
numpy.logaddexp, numpy.logaddexp2
63+
:data:`numpy.logaddexp`
64+
:data:`numpy.logaddexp2`
6265
6366
Notes
6467
-----
@@ -246,6 +249,7 @@ def _logsumexp(a, b, *, axis, return_sign, xp):
246249
return out, sgn
247250

248251

252+
@xp_capabilities()
249253
def softmax(x, axis=None):
250254
r"""Compute the softmax function.
251255
@@ -344,6 +348,7 @@ def softmax(x, axis=None):
344348
return exp_x_shifted / xp.sum(exp_x_shifted, axis=axis, keepdims=True)
345349

346350

351+
@xp_capabilities()
347352
def log_softmax(x, axis=None):
348353
r"""Compute the logarithm of the softmax function.
349354

scipy/special/tests/test_logsumexp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import numpy as np
66

7-
from scipy._lib._array_api import is_array_api_strict, xp_default_dtype, xp_device
7+
from scipy._lib._array_api import (is_array_api_strict, make_skip_xp_backends,
8+
xp_default_dtype, xp_device)
89
from scipy._lib._array_api_no_0d import (xp_assert_equal, xp_assert_close,
910
xp_assert_less)
1011

@@ -38,6 +39,7 @@ def test_wrap_radians(xp):
3839
@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
3940
@pytest.mark.filterwarnings("ignore:divide by zero encountered:RuntimeWarning")
4041
@pytest.mark.filterwarnings("ignore:overflow encountered:RuntimeWarning")
42+
@make_skip_xp_backends(logsumexp)
4143
class TestLogSumExp:
4244
def test_logsumexp(self, xp):
4345
# Test with zero-size array
@@ -310,6 +312,7 @@ def test_device(self, x_raw, xp, devices):
310312
assert xp_device(logsumexp(x, b=x)) == xp_device(x)
311313

312314

315+
@make_skip_xp_backends(softmax)
313316
class TestSoftmax:
314317
def test_softmax_fixtures(self, xp):
315318
xp_assert_close(softmax(xp.asarray([1000., 0., 0., 0.])),
@@ -378,6 +381,7 @@ def test_softmax_array_like(self):
378381
np.asarray([1., 0., 0., 0.]), rtol=1e-13)
379382

380383

384+
@make_skip_xp_backends(log_softmax)
381385
class TestLogSoftmax:
382386
def test_log_softmax_basic(self, xp):
383387
xp_assert_close(log_softmax(xp.asarray([1000., 1.])),

0 commit comments

Comments
 (0)