@@ -864,15 +864,11 @@ def _log_mean(logx, axis):
864
864
865
865
def _log_var (logx , xp , axis ):
866
866
# compute log of variance of x from log(x)
867
- logmean = _log_mean (logx , axis = axis )
868
- # get complex dtype with component dtypes same as `logx` dtype;
869
- dtype = xp .result_type (logx .dtype , 1j )
870
- pij = xp .full (logx .shape , pi * 1j , dtype = dtype )
871
- logxmu = special .logsumexp (xp .stack ((logx , logmean + pij )), axis = 0 )
872
- return (
873
- xp .real (xp .asarray (special .logsumexp (2 * logxmu , axis = axis )))
874
- - math .log (logx .shape [axis ])
875
- )
867
+ logmean = xp .broadcast_to (_log_mean (logx , axis = axis ), logx .shape )
868
+ ones = xp .ones_like (logx )
869
+ logxmu , _ = special .logsumexp (xp .stack ((logx , logmean ), axis = 0 ), axis = 0 ,
870
+ b = xp .stack ((ones , - ones ), axis = 0 ), return_sign = True )
871
+ return special .logsumexp (2 * logxmu , axis = axis ) - math .log (logx .shape [axis ])
876
872
877
873
878
874
def boxcox_llf (lmb , data , * , axis = 0 , keepdims = False , nan_policy = 'propagate' ):
@@ -991,7 +987,7 @@ def boxcox_llf(lmb, data, *, axis=0, keepdims=False, nan_policy='propagate'):
991
987
result_to_tuple = lambda x : (x ,))
992
988
def _boxcox_llf (data , axis = 0 , * , lmb ):
993
989
xp = array_namespace (data )
994
- data = xp_promote (data , force_floating = True , xp = xp )
990
+ lmb , data = xp_promote (lmb , data , force_floating = True , xp = xp )
995
991
N = data .shape [axis ]
996
992
if N == 0 :
997
993
return _get_nan (data , xp = xp )
@@ -1010,7 +1006,7 @@ def _boxcox_llf(data, axis=0, *, lmb):
1010
1006
logvar = _log_var (logx , xp , axis ) - 2 * math .log (abs (lmb ))
1011
1007
1012
1008
res = (lmb - 1 ) * xp .sum (logdata , axis = axis ) - N / 2 * logvar
1013
- res = xp .astype (res , data .dtype )
1009
+ res = xp .astype (res , data .dtype , copy = False ) # compensate for NumPy <2.0
1014
1010
res = res [()] if res .ndim == 0 else res
1015
1011
return res
1016
1012
0 commit comments