Skip to content

Commit fe927e3

Browse files
lucascolleymdhaber
andauthored
ENH: stats.boxcox_llf: vectorize for n-D arrays (scipy#21233)
* ENH: `stats.boxcox_llf`: vectorize for n-D arrays --------- Co-authored-by: Matt Haberland <mhaberla@calpoly.edu>
1 parent 78239d8 commit fe927e3

File tree

3 files changed

+74
-18
lines changed

3 files changed

+74
-18
lines changed

scipy/stats/_morestats.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -854,25 +854,28 @@ def ppcc_plot(x, a, b, dist='tukeylambda', plot=None, N=80):
854854
return svals, ppcc
855855

856856

857-
def _log_mean(logx):
857+
def _log_mean(logx, axis):
858858
# compute log of mean of x from log(x)
859-
res = special.logsumexp(logx, axis=0) - math.log(logx.shape[0])
860-
return res
859+
return (
860+
special.logsumexp(logx, axis=axis, keepdims=True)
861+
- math.log(logx.shape[axis])
862+
)
861863

862864

863-
def _log_var(logx, xp):
865+
def _log_var(logx, xp, axis):
864866
# compute log of variance of x from log(x)
865-
logmean = _log_mean(logx)
867+
logmean = _log_mean(logx, axis=axis)
866868
# get complex dtype with component dtypes same as `logx` dtype;
867869
dtype = xp.result_type(logx.dtype, 1j)
868870
pij = xp.full(logx.shape, pi * 1j, dtype=dtype)
869871
logxmu = special.logsumexp(xp.stack((logx, logmean + pij)), axis=0)
870-
res = (xp.real(xp.asarray(special.logsumexp(2 * logxmu, axis=0)))
871-
- math.log(logx.shape[0]))
872-
return res
872+
return (
873+
xp.real(xp.asarray(special.logsumexp(2 * logxmu, axis=axis)))
874+
- math.log(logx.shape[axis])
875+
)
873876

874877

875-
def boxcox_llf(lmb, data):
878+
def boxcox_llf(lmb, data, *, axis=0, keepdims=False, nan_policy='propagate'):
876879
r"""The boxcox log-likelihood function.
877880
878881
Parameters
@@ -883,6 +886,26 @@ def boxcox_llf(lmb, data):
883886
Data to calculate Box-Cox log-likelihood for. If `data` is
884887
multi-dimensional, the log-likelihood is calculated along the first
885888
axis.
889+
axis : int, default: 0
890+
If an int, the axis of the input along which to compute the statistic.
891+
The statistic of each axis-slice (e.g. row) of the input will appear in a
892+
corresponding element of the output.
893+
If ``None``, the input will be raveled before computing the statistic.
894+
nan_policy : {'propagate', 'omit', 'raise'
895+
Defines how to handle input NaNs.
896+
897+
- ``propagate``: if a NaN is present in the axis slice (e.g. row) along
898+
which the statistic is computed, the corresponding entry of the output
899+
will be NaN.
900+
- ``omit``: NaNs will be omitted when performing the calculation.
901+
If insufficient data remains in the axis slice along which the
902+
statistic is computed, the corresponding entry of the output will be
903+
NaN.
904+
- ``raise``: if a NaN is present, a ``ValueError`` will be raised.
905+
keepdims : bool, default: False
906+
If this is set to True, the axes which are reduced are left
907+
in the result as dimensions with size one. With this option,
908+
the result will broadcast correctly against the input array.
886909
887910
Returns
888911
-------
@@ -955,28 +978,39 @@ def boxcox_llf(lmb, data):
955978
>>> plt.show()
956979
957980
"""
981+
# _axis_nan_policy decorator does not currently support these for non-NumPy arrays
982+
kwargs = {}
983+
if keepdims is not False:
984+
kwargs[keepdims] = keepdims
985+
if nan_policy != 'propagate':
986+
kwargs[nan_policy] = nan_policy
987+
return _boxcox_llf(data, lmb=lmb, axis=axis, **kwargs)
988+
989+
990+
@_axis_nan_policy_factory(lambda x: x, n_outputs=1, default_axis=0,
991+
result_to_tuple=lambda x: (x,))
992+
def _boxcox_llf(data, axis=0, *, lmb):
958993
xp = array_namespace(data)
959994
data = xp_promote(data, force_floating=True, xp=xp)
960-
961-
N = data.shape[0]
995+
N = data.shape[axis]
962996
if N == 0:
963-
return xp.nan
997+
return _get_nan(data, xp=xp)
964998

965999
logdata = xp.log(data)
9661000

9671001
# Compute the variance of the transformed data.
9681002
if lmb == 0:
969-
logvar = xp.log(xp.var(logdata, axis=0))
1003+
logvar = xp.log(xp.var(logdata, axis=axis))
9701004
else:
9711005
# Transform without the constant offset 1/lmb. The offset does
9721006
# not affect the variance, and the subtraction of the offset can
9731007
# lead to loss of precision.
9741008
# Division by lmb can be factored out to enhance numerical stability.
9751009
logx = lmb * logdata
976-
logvar = _log_var(logx, xp) - 2 * math.log(abs(lmb))
1010+
logvar = _log_var(logx, xp, axis) - 2 * math.log(abs(lmb))
9771011

978-
res = (lmb - 1) * xp.sum(logdata, axis=0) - N/2 * logvar
979-
res = xp.astype(res, data.dtype, copy=False)
1012+
res = (lmb - 1) * xp.sum(logdata, axis=axis) - N/2 * logvar
1013+
res = xp.astype(res, data.dtype)
9801014
res = res[()] if res.ndim == 0 else res
9811015
return res
9821016

@@ -1081,7 +1115,7 @@ def boxcox(x, lmbda=None, alpha=None, optimizer=None):
10811115
Notes
10821116
-----
10831117
The Box-Cox transform is given by:
1084-
1118+
10851119
.. math::
10861120
10871121
y =

scipy/stats/tests/test_axis_nan_policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def weightedtau_weighted(x, y, rank, **kwargs):
173173
(gstd, tuple(), dict(), 1, 1, False, lambda x: (x,)),
174174
(stats.power_divergence, tuple(), dict(), 1, 2, False, None),
175175
(stats.chisquare, tuple(), dict(), 1, 2, False, None),
176+
(stats._morestats._boxcox_llf, tuple(), dict(lmb=1.5), 1, 1, False, lambda x: (x,)),
176177
]
177178

178179
# If the message is one of those expected, put nans in

scipy/stats/tests/test_morestats.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import warnings
77
import sys
8+
import contextlib
89
from functools import partial
910

1011
import numpy as np
@@ -2024,7 +2025,11 @@ def test_2d_input(self, xp):
20242025
xp_assert_close(xp.asarray([llf, llf]), xp.asarray(llf2), rtol=1e-12)
20252026

20262027
def test_empty(self, xp):
2027-
assert xp.isnan(xp.asarray(stats.boxcox_llf(1, xp.asarray([]))))
2028+
message = "One or more sample arguments is too small..."
2029+
context = (pytest.warns(SmallSampleWarning, match=message) if is_numpy(xp)
2030+
else contextlib.nullcontext())
2031+
with context:
2032+
assert xp.isnan(xp.asarray(stats.boxcox_llf(1, xp.asarray([]))))
20282033

20292034
def test_gh_6873(self, xp):
20302035
# Regression test for gh-6873.
@@ -2041,6 +2046,22 @@ def test_instability_gh20021(self, xp):
20412046
# expect float64 output for integer input
20422047
xp_assert_close(llf, xp.asarray(-15.32401272869016598, dtype=xp.float64))
20432048

2049+
def test_axis(self, xp):
2050+
data = xp.asarray([[100, 200], [300, 400]])
2051+
llf_axis_0 = stats.boxcox_llf(1, data, axis=0)
2052+
data_axes_swapped = xp.moveaxis(data, 0, -1)
2053+
llf_0 = xp.asarray([
2054+
stats.boxcox_llf(1, data_axes_swapped[0, :]),
2055+
stats.boxcox_llf(1, data_axes_swapped[1, :]),
2056+
])
2057+
xp_assert_close(llf_axis_0, llf_0)
2058+
llf_axis_1 = stats.boxcox_llf(1, data, axis=1)
2059+
llf_1 = xp.asarray([
2060+
stats.boxcox_llf(1, data[0, :]),
2061+
stats.boxcox_llf(1, data[1, :]),
2062+
])
2063+
xp_assert_close(llf_axis_1, llf_1)
2064+
20442065

20452066
# This is the data from GitHub user Qukaiyi, given as an example
20462067
# of a data set that caused boxcox to fail.

0 commit comments

Comments
 (0)