Skip to content

Commit bb27534

Browse files
authored
MAINT: stats.boxcox_llf: refactor for simplicity (scipy#22835)
* MAINT: stats.boxcox_llf: refactor for simplicity * MAINT: stats.boxcox_llf: compensate for NumPy <2.0
1 parent a9df0fd commit bb27534

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

scipy/stats/_morestats.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -864,15 +864,11 @@ def _log_mean(logx, axis):
864864

865865
def _log_var(logx, xp, axis):
866866
# compute log of variance of x from log(x)
867-
logmean = _log_mean(logx, axis=axis)
868-
# get complex dtype with component dtypes same as `logx` dtype;
869-
dtype = xp.result_type(logx.dtype, 1j)
870-
pij = xp.full(logx.shape, pi * 1j, dtype=dtype)
871-
logxmu = special.logsumexp(xp.stack((logx, logmean + pij)), axis=0)
872-
return (
873-
xp.real(xp.asarray(special.logsumexp(2 * logxmu, axis=axis)))
874-
- math.log(logx.shape[axis])
875-
)
867+
logmean = xp.broadcast_to(_log_mean(logx, axis=axis), logx.shape)
868+
ones = xp.ones_like(logx)
869+
logxmu, _ = special.logsumexp(xp.stack((logx, logmean), axis=0), axis=0,
870+
b=xp.stack((ones, -ones), axis=0), return_sign=True)
871+
return special.logsumexp(2 * logxmu, axis=axis) - math.log(logx.shape[axis])
876872

877873

878874
def boxcox_llf(lmb, data, *, axis=0, keepdims=False, nan_policy='propagate'):
@@ -991,7 +987,7 @@ def boxcox_llf(lmb, data, *, axis=0, keepdims=False, nan_policy='propagate'):
991987
result_to_tuple=lambda x: (x,))
992988
def _boxcox_llf(data, axis=0, *, lmb):
993989
xp = array_namespace(data)
994-
data = xp_promote(data, force_floating=True, xp=xp)
990+
lmb, data = xp_promote(lmb, data, force_floating=True, xp=xp)
995991
N = data.shape[axis]
996992
if N == 0:
997993
return _get_nan(data, xp=xp)
@@ -1010,7 +1006,7 @@ def _boxcox_llf(data, axis=0, *, lmb):
10101006
logvar = _log_var(logx, xp, axis) - 2 * math.log(abs(lmb))
10111007

10121008
res = (lmb - 1) * xp.sum(logdata, axis=axis) - N/2 * logvar
1013-
res = xp.astype(res, data.dtype)
1009+
res = xp.astype(res, data.dtype, copy=False) # compensate for NumPy <2.0
10141010
res = res[()] if res.ndim == 0 else res
10151011
return res
10161012

scipy/stats/tests/test_morestats.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,10 +2049,9 @@ def test_instability_gh20021(self, xp):
20492049
def test_axis(self, xp):
20502050
data = xp.asarray([[100, 200], [300, 400]])
20512051
llf_axis_0 = stats.boxcox_llf(1, data, axis=0)
2052-
data_axes_swapped = xp.moveaxis(data, 0, -1)
20532052
llf_0 = xp.asarray([
2054-
stats.boxcox_llf(1, data_axes_swapped[0, :]),
2055-
stats.boxcox_llf(1, data_axes_swapped[1, :]),
2053+
stats.boxcox_llf(1, data[:, 0]),
2054+
stats.boxcox_llf(1, data[:, 1]),
20562055
])
20572056
xp_assert_close(llf_axis_0, llf_0)
20582057
llf_axis_1 = stats.boxcox_llf(1, data, axis=1)

0 commit comments

Comments
 (0)