5
5
6
6
from larray .core import LArray , make_numpy_broadcastable
7
7
8
+ __all__ = [
9
+ # Trigonometric functions
10
+ 'sin' , 'cos' , 'tan' , 'arcsin' , 'arccos' , 'arctan' , 'hypot' , 'arctan2' , 'degrees' , 'radians' , 'unwrap' ,
11
+ # 'deg2rad', 'rad2deg',
8
12
9
- def wrapper (func ):
10
- def wrapped (* args , ** kwargs ):
13
+ # Hyperbolic functions
14
+ 'sinh' , 'cosh' , 'tanh' , 'arcsinh' , 'arccosh' , 'arctanh' ,
15
+
16
+ # Rounding
17
+ 'round' , 'around' , 'round_' , 'rint' , 'fix' , 'floor' , 'ceil' , 'trunc' ,
18
+
19
+ # Sums, products, differences
20
+ # 'prod', 'sum', 'nansum', 'cumprod', 'cumsum',
21
+
22
+ # cannot use a simple wrapped ufunc because those ufuncs do not preserve shape or dimensions so labels are wrong
23
+ # 'diff', 'ediff1d', 'gradient', 'cross', 'trapz',
24
+
25
+ # Exponents and logarithms
26
+ 'exp' , 'expm1' , 'exp2' , 'log' , 'log10' , 'log2' , 'log1p' , 'logaddexp' , 'logaddexp2' ,
27
+
28
+ # Other special functions
29
+ 'i0' , 'sinc' ,
30
+
31
+ # Floating point routines
32
+ 'signbit' , 'copysign' , 'frexp' , 'ldexp' ,
33
+
34
+ # Arithmetic operations
35
+ # 'add', 'reciprocal', 'negative', 'multiply', 'divide', 'power', 'subtract', 'true_divide', 'floor_divide',
36
+ # 'fmod', 'mod', 'modf', 'remainder',
37
+
38
+ # Handling complex numbers
39
+ 'angle' , 'real' , 'imag' , 'conj' ,
40
+
41
+ # Miscellaneous
42
+ 'convolve' , 'clip' , 'sqrt' ,
43
+ # 'square',
44
+ 'absolute' , 'fabs' , 'sign' , 'maximum' , 'minimum' , 'fmax' , 'fmin' , 'nan_to_num' , 'real_if_close' ,
45
+ 'interp' , 'where' , 'isnan' , 'isinf' ,
46
+ 'inverse' ,
47
+ ]
48
+
49
+
50
+ def broadcastify (func ):
51
+ # intentionally not using functools.wraps, because it does not work for wrapping a function from another module
52
+ def wrapper (* args , ** kwargs ):
11
53
# TODO: normalize args/kwargs like in LIAM2 so that we can also
12
54
# broadcast if args are given via kwargs (eg out=)
13
55
args , combined_axes = make_numpy_broadcastable (args )
@@ -34,130 +76,133 @@ def wrapped(*args, **kwargs):
34
76
return LArray (res_data , combined_axes )
35
77
else :
36
78
return res_data
37
- # return func(*args, **kwargs)
38
- wrapped .__name__ = func .__name__
39
- wrapped .__doc__ = func .__doc__
40
- return wrapped
79
+ # copy meaningful attributes (numpy ufuncs do not have __annotations__ nor __qualname__)
80
+ wrapper .__name__ = func .__name__
81
+ wrapper .__doc__ = func .__doc__
82
+ # set __qualname__ explicitly (all these functions are supposed to be top-level function in the ufuncs module)
83
+ wrapper .__qualname__ = func .__name__
84
+ # we should not copy __module__
85
+ return wrapper
41
86
42
87
43
88
# Trigonometric functions
44
89
45
- sin = wrapper (np .sin )
46
- cos = wrapper (np .cos )
47
- tan = wrapper (np .tan )
48
- arcsin = wrapper (np .arcsin )
49
- arccos = wrapper (np .arccos )
50
- arctan = wrapper (np .arctan )
51
- hypot = wrapper (np .hypot )
52
- arctan2 = wrapper (np .arctan2 )
53
- degrees = wrapper (np .degrees )
54
- radians = wrapper (np .radians )
55
- unwrap = wrapper (np .unwrap )
56
- # deg2rad = wrapper (np.deg2rad)
57
- # rad2deg = wrapper (np.rad2deg)
90
+ sin = broadcastify (np .sin )
91
+ cos = broadcastify (np .cos )
92
+ tan = broadcastify (np .tan )
93
+ arcsin = broadcastify (np .arcsin )
94
+ arccos = broadcastify (np .arccos )
95
+ arctan = broadcastify (np .arctan )
96
+ hypot = broadcastify (np .hypot )
97
+ arctan2 = broadcastify (np .arctan2 )
98
+ degrees = broadcastify (np .degrees )
99
+ radians = broadcastify (np .radians )
100
+ unwrap = broadcastify (np .unwrap )
101
+ # deg2rad = broadcastify (np.deg2rad)
102
+ # rad2deg = broadcastify (np.rad2deg)
58
103
59
104
# Hyperbolic functions
60
105
61
- sinh = wrapper (np .sinh )
62
- cosh = wrapper (np .cosh )
63
- tanh = wrapper (np .tanh )
64
- arcsinh = wrapper (np .arcsinh )
65
- arccosh = wrapper (np .arccosh )
66
- arctanh = wrapper (np .arctanh )
106
+ sinh = broadcastify (np .sinh )
107
+ cosh = broadcastify (np .cosh )
108
+ tanh = broadcastify (np .tanh )
109
+ arcsinh = broadcastify (np .arcsinh )
110
+ arccosh = broadcastify (np .arccosh )
111
+ arctanh = broadcastify (np .arctanh )
67
112
68
113
# Rounding
69
114
70
115
# all 3 are equivalent, I am unsure I should support around and round_
71
- round = wrapper (np .round )
72
- around = wrapper (np .around )
73
- round_ = wrapper (np .round_ )
74
- rint = wrapper (np .rint )
75
- fix = wrapper (np .fix )
76
- floor = wrapper (np .floor )
77
- ceil = wrapper (np .ceil )
78
- trunc = wrapper (np .trunc )
116
+ round = broadcastify (np .round )
117
+ around = broadcastify (np .around )
118
+ round_ = broadcastify (np .round_ )
119
+ rint = broadcastify (np .rint )
120
+ fix = broadcastify (np .fix )
121
+ floor = broadcastify (np .floor )
122
+ ceil = broadcastify (np .ceil )
123
+ trunc = broadcastify (np .trunc )
79
124
80
125
# Sums, products, differences
81
126
82
- # prod = wrapper (np.prod)
83
- # sum = wrapper (np.sum)
84
- # nansum = wrapper (np.nansum)
85
- # cumprod = wrapper (np.cumprod)
86
- # cumsum = wrapper (np.cumsum)
127
+ # prod = broadcastify (np.prod)
128
+ # sum = broadcastify (np.sum)
129
+ # nansum = broadcastify (np.nansum)
130
+ # cumprod = broadcastify (np.cumprod)
131
+ # cumsum = broadcastify (np.cumsum)
87
132
88
133
# cannot use a simple wrapped ufunc because those ufuncs do not preserve
89
134
# shape or dimensions so labels are wrong
90
- # diff = wrapper (np.diff)
91
- # ediff1d = wrapper (np.ediff1d)
92
- # gradient = wrapper (np.gradient)
93
- # cross = wrapper (np.cross)
94
- # trapz = wrapper (np.trapz)
135
+ # diff = broadcastify (np.diff)
136
+ # ediff1d = broadcastify (np.ediff1d)
137
+ # gradient = broadcastify (np.gradient)
138
+ # cross = broadcastify (np.cross)
139
+ # trapz = broadcastify (np.trapz)
95
140
96
141
# Exponents and logarithms
97
142
98
- exp = wrapper (np .exp )
99
- expm1 = wrapper (np .expm1 )
100
- exp2 = wrapper (np .exp2 )
101
- log = wrapper (np .log )
102
- log10 = wrapper (np .log10 )
103
- log2 = wrapper (np .log2 )
104
- log1p = wrapper (np .log1p )
105
- logaddexp = wrapper (np .logaddexp )
106
- logaddexp2 = wrapper (np .logaddexp2 )
143
+ exp = broadcastify (np .exp )
144
+ expm1 = broadcastify (np .expm1 )
145
+ exp2 = broadcastify (np .exp2 )
146
+ log = broadcastify (np .log )
147
+ log10 = broadcastify (np .log10 )
148
+ log2 = broadcastify (np .log2 )
149
+ log1p = broadcastify (np .log1p )
150
+ logaddexp = broadcastify (np .logaddexp )
151
+ logaddexp2 = broadcastify (np .logaddexp2 )
107
152
108
153
# Other special functions
109
154
110
- i0 = wrapper (np .i0 )
111
- sinc = wrapper (np .sinc )
155
+ i0 = broadcastify (np .i0 )
156
+ sinc = broadcastify (np .sinc )
112
157
113
158
# Floating point routines
114
159
115
- signbit = wrapper (np .signbit )
116
- copysign = wrapper (np .copysign )
117
- frexp = wrapper (np .frexp )
118
- ldexp = wrapper (np .ldexp )
160
+ signbit = broadcastify (np .signbit )
161
+ copysign = broadcastify (np .copysign )
162
+ frexp = broadcastify (np .frexp )
163
+ ldexp = broadcastify (np .ldexp )
119
164
120
165
# Arithmetic operations
121
166
122
- # add = wrapper (np.add)
123
- # reciprocal = wrapper (np.reciprocal)
124
- # negative = wrapper (np.negative)
125
- # multiply = wrapper (np.multiply)
126
- # divide = wrapper (np.divide)
127
- # power = wrapper (np.power)
128
- # subtract = wrapper (np.subtract)
129
- # true_divide = wrapper (np.true_divide)
130
- # floor_divide = wrapper (np.floor_divide)
131
- # fmod = wrapper (np.fmod)
132
- # mod = wrapper (np.mod)
133
- modf = wrapper (np .modf )
134
- # remainder = wrapper (np.remainder)
167
+ # add = broadcastify (np.add)
168
+ # reciprocal = broadcastify (np.reciprocal)
169
+ # negative = broadcastify (np.negative)
170
+ # multiply = broadcastify (np.multiply)
171
+ # divide = broadcastify (np.divide)
172
+ # power = broadcastify (np.power)
173
+ # subtract = broadcastify (np.subtract)
174
+ # true_divide = broadcastify (np.true_divide)
175
+ # floor_divide = broadcastify (np.floor_divide)
176
+ # fmod = broadcastify (np.fmod)
177
+ # mod = broadcastify (np.mod)
178
+ # modf = broadcastify (np.modf)
179
+ # remainder = broadcastify (np.remainder)
135
180
136
181
# Handling complex numbers
137
182
138
- angle = wrapper (np .angle )
139
- real = wrapper (np .real )
140
- imag = wrapper (np .imag )
141
- conj = wrapper (np .conj )
183
+ angle = broadcastify (np .angle )
184
+ real = broadcastify (np .real )
185
+ imag = broadcastify (np .imag )
186
+ conj = broadcastify (np .conj )
142
187
143
188
# Miscellaneous
144
189
145
- convolve = wrapper (np .convolve )
146
- clip = wrapper (np .clip )
147
- sqrt = wrapper (np .sqrt )
148
- # square = wrapper (np.square)
149
- absolute = wrapper (np .absolute )
150
- fabs = wrapper (np .fabs )
151
- sign = wrapper (np .sign )
152
- maximum = wrapper (np .maximum )
153
- minimum = wrapper (np .minimum )
154
- fmax = wrapper (np .fmax )
155
- fmin = wrapper (np .fmin )
156
- nan_to_num = wrapper (np .nan_to_num )
157
- real_if_close = wrapper (np .real_if_close )
158
- interp = wrapper (np .interp )
159
- where = wrapper (np .where )
160
- isnan = wrapper (np .isnan )
161
- isinf = wrapper (np .isinf )
162
-
163
- inverse = wrapper (np .linalg .inv )
190
+ convolve = broadcastify (np .convolve )
191
+ clip = broadcastify (np .clip )
192
+ sqrt = broadcastify (np .sqrt )
193
+ # square = broadcastify (np.square)
194
+ absolute = broadcastify (np .absolute )
195
+ fabs = broadcastify (np .fabs )
196
+ sign = broadcastify (np .sign )
197
+ maximum = broadcastify (np .maximum )
198
+ minimum = broadcastify (np .minimum )
199
+ fmax = broadcastify (np .fmax )
200
+ fmin = broadcastify (np .fmin )
201
+ nan_to_num = broadcastify (np .nan_to_num )
202
+ real_if_close = broadcastify (np .real_if_close )
203
+ interp = broadcastify (np .interp )
204
+ where = broadcastify (np .where )
205
+ isnan = broadcastify (np .isnan )
206
+ isinf = broadcastify (np .isinf )
207
+
208
+ inverse = broadcastify (np .linalg .inv )
0 commit comments