Skip to content

Commit e6ab2cb

Browse files
committed
make LArray objects (un)picklable (to pave the way for #216)
1 parent f374a8f commit e6ab2cb

File tree

2 files changed

+154
-94
lines changed

2 files changed

+154
-94
lines changed

larray/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2506,6 +2506,14 @@ def __getattr__(self, key):
25062506
except KeyError:
25072507
return self.__getattribute__(key)
25082508

2509+
# needed to make *un*pickling work (because otherwise, __getattr__ is called before _map exists, which leads to
2510+
# an infinite recursion)
2511+
def __getstate__(self):
2512+
return self.__dict__
2513+
2514+
def __setstate__(self, d):
2515+
self.__dict__ = d
2516+
25092517
def __getitem__(self, key):
25102518
if isinstance(key, Axis):
25112519
try:
@@ -11353,6 +11361,13 @@ def evaluate(self, context):
1135311361

1135411362

1135511363
class AxisReferenceFactory(object):
11364+
# needed to make pickle work (because we have a __getattr__ which does not return AttributeError on __getstate__):
11365+
def __getstate__(self):
11366+
return self.__dict__
11367+
11368+
def __setstate__(self, d):
11369+
self.__dict__ = d
11370+
1135611371
def __getattr__(self, key):
1135711372
return AxisReference(key)
1135811373

larray/ufuncs.py

Lines changed: 139 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,51 @@
55

66
from larray.core import LArray, make_numpy_broadcastable
77

8+
__all__ = [
9+
# Trigonometric functions
10+
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'hypot', 'arctan2', 'degrees', 'radians', 'unwrap',
11+
# 'deg2rad', 'rad2deg',
812

9-
def wrapper(func):
10-
def wrapped(*args, **kwargs):
13+
# Hyperbolic functions
14+
'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
15+
16+
# Rounding
17+
'round', 'around', 'round_', 'rint', 'fix', 'floor', 'ceil', 'trunc',
18+
19+
# Sums, products, differences
20+
# 'prod', 'sum', 'nansum', 'cumprod', 'cumsum',
21+
22+
# cannot use a simple wrapped ufunc because those ufuncs do not preserve shape or dimensions so labels are wrong
23+
# 'diff', 'ediff1d', 'gradient', 'cross', 'trapz',
24+
25+
# Exponents and logarithms
26+
'exp', 'expm1', 'exp2', 'log', 'log10', 'log2', 'log1p', 'logaddexp', 'logaddexp2',
27+
28+
# Other special functions
29+
'i0', 'sinc',
30+
31+
# Floating point routines
32+
'signbit', 'copysign', 'frexp', 'ldexp',
33+
34+
# Arithmetic operations
35+
# 'add', 'reciprocal', 'negative', 'multiply', 'divide', 'power', 'subtract', 'true_divide', 'floor_divide',
36+
# 'fmod', 'mod', 'modf', 'remainder',
37+
38+
# Handling complex numbers
39+
'angle', 'real', 'imag', 'conj',
40+
41+
# Miscellaneous
42+
'convolve', 'clip', 'sqrt',
43+
# 'square',
44+
'absolute', 'fabs', 'sign', 'maximum', 'minimum', 'fmax', 'fmin', 'nan_to_num', 'real_if_close',
45+
'interp', 'where', 'isnan', 'isinf',
46+
'inverse',
47+
]
48+
49+
50+
def broadcastify(func):
51+
# intentionally not using functools.wraps, because it does not work for wrapping a function from another module
52+
def wrapper(*args, **kwargs):
1153
# TODO: normalize args/kwargs like in LIAM2 so that we can also
1254
# broadcast if args are given via kwargs (eg out=)
1355
args, combined_axes = make_numpy_broadcastable(args)
@@ -34,130 +76,133 @@ def wrapped(*args, **kwargs):
3476
return LArray(res_data, combined_axes)
3577
else:
3678
return res_data
37-
# return func(*args, **kwargs)
38-
wrapped.__name__ = func.__name__
39-
wrapped.__doc__ = func.__doc__
40-
return wrapped
79+
# copy meaningful attributes (numpy ufuncs do not have __annotations__ nor __qualname__)
80+
wrapper.__name__ = func.__name__
81+
wrapper.__doc__ = func.__doc__
82+
# set __qualname__ explicitly (all these functions are supposed to be top-level function in the ufuncs module)
83+
wrapper.__qualname__ = func.__name__
84+
# we should not copy __module__
85+
return wrapper
4186

4287

4388
# Trigonometric functions
4489

45-
sin = wrapper(np.sin)
46-
cos = wrapper(np.cos)
47-
tan = wrapper(np.tan)
48-
arcsin = wrapper(np.arcsin)
49-
arccos = wrapper(np.arccos)
50-
arctan = wrapper(np.arctan)
51-
hypot = wrapper(np.hypot)
52-
arctan2 = wrapper(np.arctan2)
53-
degrees = wrapper(np.degrees)
54-
radians = wrapper(np.radians)
55-
unwrap = wrapper(np.unwrap)
56-
# deg2rad = wrapper(np.deg2rad)
57-
# rad2deg = wrapper(np.rad2deg)
90+
sin = broadcastify(np.sin)
91+
cos = broadcastify(np.cos)
92+
tan = broadcastify(np.tan)
93+
arcsin = broadcastify(np.arcsin)
94+
arccos = broadcastify(np.arccos)
95+
arctan = broadcastify(np.arctan)
96+
hypot = broadcastify(np.hypot)
97+
arctan2 = broadcastify(np.arctan2)
98+
degrees = broadcastify(np.degrees)
99+
radians = broadcastify(np.radians)
100+
unwrap = broadcastify(np.unwrap)
101+
# deg2rad = broadcastify(np.deg2rad)
102+
# rad2deg = broadcastify(np.rad2deg)
58103

59104
# Hyperbolic functions
60105

61-
sinh = wrapper(np.sinh)
62-
cosh = wrapper(np.cosh)
63-
tanh = wrapper(np.tanh)
64-
arcsinh = wrapper(np.arcsinh)
65-
arccosh = wrapper(np.arccosh)
66-
arctanh = wrapper(np.arctanh)
106+
sinh = broadcastify(np.sinh)
107+
cosh = broadcastify(np.cosh)
108+
tanh = broadcastify(np.tanh)
109+
arcsinh = broadcastify(np.arcsinh)
110+
arccosh = broadcastify(np.arccosh)
111+
arctanh = broadcastify(np.arctanh)
67112

68113
# Rounding
69114

70115
# all 3 are equivalent, I am unsure I should support around and round_
71-
round = wrapper(np.round)
72-
around = wrapper(np.around)
73-
round_ = wrapper(np.round_)
74-
rint = wrapper(np.rint)
75-
fix = wrapper(np.fix)
76-
floor = wrapper(np.floor)
77-
ceil = wrapper(np.ceil)
78-
trunc = wrapper(np.trunc)
116+
round = broadcastify(np.round)
117+
around = broadcastify(np.around)
118+
round_ = broadcastify(np.round_)
119+
rint = broadcastify(np.rint)
120+
fix = broadcastify(np.fix)
121+
floor = broadcastify(np.floor)
122+
ceil = broadcastify(np.ceil)
123+
trunc = broadcastify(np.trunc)
79124

80125
# Sums, products, differences
81126

82-
# prod = wrapper(np.prod)
83-
# sum = wrapper(np.sum)
84-
# nansum = wrapper(np.nansum)
85-
# cumprod = wrapper(np.cumprod)
86-
# cumsum = wrapper(np.cumsum)
127+
# prod = broadcastify(np.prod)
128+
# sum = broadcastify(np.sum)
129+
# nansum = broadcastify(np.nansum)
130+
# cumprod = broadcastify(np.cumprod)
131+
# cumsum = broadcastify(np.cumsum)
87132

88133
# cannot use a simple wrapped ufunc because those ufuncs do not preserve
89134
# shape or dimensions so labels are wrong
90-
# diff = wrapper(np.diff)
91-
# ediff1d = wrapper(np.ediff1d)
92-
# gradient = wrapper(np.gradient)
93-
# cross = wrapper(np.cross)
94-
# trapz = wrapper(np.trapz)
135+
# diff = broadcastify(np.diff)
136+
# ediff1d = broadcastify(np.ediff1d)
137+
# gradient = broadcastify(np.gradient)
138+
# cross = broadcastify(np.cross)
139+
# trapz = broadcastify(np.trapz)
95140

96141
# Exponents and logarithms
97142

98-
exp = wrapper(np.exp)
99-
expm1 = wrapper(np.expm1)
100-
exp2 = wrapper(np.exp2)
101-
log = wrapper(np.log)
102-
log10 = wrapper(np.log10)
103-
log2 = wrapper(np.log2)
104-
log1p = wrapper(np.log1p)
105-
logaddexp = wrapper(np.logaddexp)
106-
logaddexp2 = wrapper(np.logaddexp2)
143+
exp = broadcastify(np.exp)
144+
expm1 = broadcastify(np.expm1)
145+
exp2 = broadcastify(np.exp2)
146+
log = broadcastify(np.log)
147+
log10 = broadcastify(np.log10)
148+
log2 = broadcastify(np.log2)
149+
log1p = broadcastify(np.log1p)
150+
logaddexp = broadcastify(np.logaddexp)
151+
logaddexp2 = broadcastify(np.logaddexp2)
107152

108153
# Other special functions
109154

110-
i0 = wrapper(np.i0)
111-
sinc = wrapper(np.sinc)
155+
i0 = broadcastify(np.i0)
156+
sinc = broadcastify(np.sinc)
112157

113158
# Floating point routines
114159

115-
signbit = wrapper(np.signbit)
116-
copysign = wrapper(np.copysign)
117-
frexp = wrapper(np.frexp)
118-
ldexp = wrapper(np.ldexp)
160+
signbit = broadcastify(np.signbit)
161+
copysign = broadcastify(np.copysign)
162+
frexp = broadcastify(np.frexp)
163+
ldexp = broadcastify(np.ldexp)
119164

120165
# Arithmetic operations
121166

122-
# add = wrapper(np.add)
123-
# reciprocal = wrapper(np.reciprocal)
124-
# negative = wrapper(np.negative)
125-
# multiply = wrapper(np.multiply)
126-
# divide = wrapper(np.divide)
127-
# power = wrapper(np.power)
128-
# subtract = wrapper(np.subtract)
129-
# true_divide = wrapper(np.true_divide)
130-
# floor_divide = wrapper(np.floor_divide)
131-
# fmod = wrapper(np.fmod)
132-
# mod = wrapper(np.mod)
133-
modf = wrapper(np.modf)
134-
# remainder = wrapper(np.remainder)
167+
# add = broadcastify(np.add)
168+
# reciprocal = broadcastify(np.reciprocal)
169+
# negative = broadcastify(np.negative)
170+
# multiply = broadcastify(np.multiply)
171+
# divide = broadcastify(np.divide)
172+
# power = broadcastify(np.power)
173+
# subtract = broadcastify(np.subtract)
174+
# true_divide = broadcastify(np.true_divide)
175+
# floor_divide = broadcastify(np.floor_divide)
176+
# fmod = broadcastify(np.fmod)
177+
# mod = broadcastify(np.mod)
178+
# modf = broadcastify(np.modf)
179+
# remainder = broadcastify(np.remainder)
135180

136181
# Handling complex numbers
137182

138-
angle = wrapper(np.angle)
139-
real = wrapper(np.real)
140-
imag = wrapper(np.imag)
141-
conj = wrapper(np.conj)
183+
angle = broadcastify(np.angle)
184+
real = broadcastify(np.real)
185+
imag = broadcastify(np.imag)
186+
conj = broadcastify(np.conj)
142187

143188
# Miscellaneous
144189

145-
convolve = wrapper(np.convolve)
146-
clip = wrapper(np.clip)
147-
sqrt = wrapper(np.sqrt)
148-
# square = wrapper(np.square)
149-
absolute = wrapper(np.absolute)
150-
fabs = wrapper(np.fabs)
151-
sign = wrapper(np.sign)
152-
maximum = wrapper(np.maximum)
153-
minimum = wrapper(np.minimum)
154-
fmax = wrapper(np.fmax)
155-
fmin = wrapper(np.fmin)
156-
nan_to_num = wrapper(np.nan_to_num)
157-
real_if_close = wrapper(np.real_if_close)
158-
interp = wrapper(np.interp)
159-
where = wrapper(np.where)
160-
isnan = wrapper(np.isnan)
161-
isinf = wrapper(np.isinf)
162-
163-
inverse = wrapper(np.linalg.inv)
190+
convolve = broadcastify(np.convolve)
191+
clip = broadcastify(np.clip)
192+
sqrt = broadcastify(np.sqrt)
193+
# square = broadcastify(np.square)
194+
absolute = broadcastify(np.absolute)
195+
fabs = broadcastify(np.fabs)
196+
sign = broadcastify(np.sign)
197+
maximum = broadcastify(np.maximum)
198+
minimum = broadcastify(np.minimum)
199+
fmax = broadcastify(np.fmax)
200+
fmin = broadcastify(np.fmin)
201+
nan_to_num = broadcastify(np.nan_to_num)
202+
real_if_close = broadcastify(np.real_if_close)
203+
interp = broadcastify(np.interp)
204+
where = broadcastify(np.where)
205+
isnan = broadcastify(np.isnan)
206+
isinf = broadcastify(np.isinf)
207+
208+
inverse = broadcastify(np.linalg.inv)

0 commit comments

Comments
 (0)