2
2
from plum import dispatch
3
3
import jax
4
4
import jax ._src .lax_reference as lax_reference
5
+ from neurallogic import primitives
5
6
6
7
7
8
def to_boolean_value_string (x ):
@@ -110,64 +111,119 @@ def symbolic_eval(x):
110
111
return numpy .vectorize (eval )(x )
111
112
112
113
113
- def all_boolean (data ):
114
- if isinstance (data , bool ):
115
- return True
114
+ def all_concrete_values (data ):
115
+ if isinstance (data , str ):
116
+ return False
116
117
if isinstance (data , (list , tuple )):
117
- return all (all_boolean (x ) for x in data )
118
+ return all (all_concrete_values (x ) for x in data )
118
119
if isinstance (data , dict ):
119
- return all (all_boolean (v ) for v in data .values ())
120
+ return all (all_concrete_values (v ) for v in data .values ())
120
121
if isinstance (data , numpy .ndarray ):
121
- return all_boolean (data .tolist ())
122
+ return all_concrete_values (data .tolist ())
122
123
if isinstance (data , jax .numpy .ndarray ):
123
- return all_boolean (data .tolist ())
124
- return False
124
+ return all_concrete_values (data .tolist ())
125
+ return True
125
126
126
127
127
128
def symbolic_and (* args , ** kwargs ):
128
- if all_boolean ([* args ]):
129
+ if all_concrete_values ([* args ]):
129
130
return numpy .logical_and (* args , ** kwargs )
130
131
else :
131
132
return binary_infix_operator ("and" , * args , ** kwargs )
132
133
133
134
134
135
def symbolic_not (* args , ** kwargs ):
135
- if all_boolean ([* args ]):
136
+ if all_concrete_values ([* args ]):
136
137
return numpy .logical_not (* args , ** kwargs )
137
138
else :
138
139
return unary_operator ("not" , * args , ** kwargs )
139
140
140
141
141
142
def symbolic_xor (* args , ** kwargs ):
142
- if all_boolean ([* args ]):
143
+ if all_concrete_values ([* args ]):
143
144
return numpy .logical_xor (* args , ** kwargs )
144
145
else :
145
146
return binary_infix_operator ("^" , * args , ** kwargs , bracket = True )
146
147
147
148
148
149
def symbolic_or (* args , ** kwargs ):
149
- if all_boolean ([* args ]):
150
+ if all_concrete_values ([* args ]):
150
151
return numpy .logical_or (* args , ** kwargs )
151
152
else :
152
153
return binary_infix_operator ("or" , * args , ** kwargs )
153
154
154
155
156
+ def symbolic_sum (* args , ** kwargs ):
157
+ if all_concrete_values ([* args ]):
158
+ return numpy .sum (* args , ** kwargs )
159
+ else :
160
+ return binary_infix_operator ("+" , * args , ** kwargs )
161
+
155
162
# Uses the lax reference implementation of broadcast_in_dim to
156
163
# implement a symbolic version of broadcast_in_dim
157
164
158
165
159
166
def symbolic_broadcast_in_dim (* args , ** kwargs ):
160
167
return lax_reference .broadcast_in_dim (* args , ** kwargs )
161
168
169
+
170
+ def is_iterable (obj ):
171
+ try :
172
+ iter (obj )
173
+ return True
174
+ except TypeError :
175
+ return False
176
+
177
+ # TODO: unify this way of walking a nested iterable with the code above
178
+ def apply_func_to_nested_impl (iterable , func ):
179
+ if isinstance (iterable , (numpy .ndarray , jax .numpy .ndarray )):
180
+ iterable = iterable .tolist ()
181
+ if is_iterable (iterable ):
182
+ transformed = []
183
+ for item in iterable :
184
+ if isinstance (item , list ):
185
+ transformed .append (apply_func_to_nested_impl (item , func ))
186
+ else :
187
+ transformed .append (func (item ))
188
+ return transformed
189
+ else :
190
+ return func (iterable )
191
+
192
+ def apply_func_to_nested (iterable , func ):
193
+ iterable_type = type (iterable )
194
+ r = apply_func_to_nested_impl (iterable , func )
195
+ if iterable_type == numpy .ndarray :
196
+ r = numpy .array (r , dtype = object )
197
+ assert type (r ) == iterable_type
198
+ return r
199
+
200
+ def symbolic_convert_element_type_impl (x , dtype ):
201
+ if dtype == numpy .int32 or dtype == numpy .int64 :
202
+ dtype = "int"
203
+ def convert (x ):
204
+ return f"{ dtype } ({ x } )"
205
+ return apply_func_to_nested (x , convert )
206
+
207
+
208
+ # TODO: add a test for this
209
+ def symbolic_convert_element_type (* args , ** kwargs ):
210
+ # Check if all the boolean arguments are True or False
211
+ if all_concrete_values ([* args ]):
212
+ # If so, we can use the lax reference implementation
213
+ return lax_reference .convert_element_type (* args , dtype = kwargs ['new_dtype' ])
214
+ else :
215
+ # Otherwise, we use the symbolic implementation
216
+ return symbolic_convert_element_type_impl (* args , dtype = kwargs ['new_dtype' ])
217
+
218
+
162
219
# This function is a hack to get around the fact that JAX doesn't
163
220
# support symbolic reduction operations. It takes a symbolic reduction
164
221
# operation and a symbolic initial value and returns a function that
165
222
# performs the reduction operation on a numpy array.
166
223
167
224
168
225
def make_symbolic_reducer (py_binop , init_val ):
169
- def reducer (operand , axis = 0 ):
170
- # axis=0 means we are reducing over the first axis (i.e. the rows) of the operand.
226
+ def reducer (operand , axis ):
171
227
# axis=None means we are reducing over all axes of the operand.
172
228
axis = range (numpy .ndim (operand )) if axis is None else axis
173
229
@@ -193,10 +249,14 @@ def symbolic_reduce(operand, init_value, computation, dimensions):
193
249
194
250
195
251
def symbolic_reduce_or (* args , ** kwargs ):
196
- # Check if all the boolean arguments are True or False
197
- if all_boolean ([* args ]):
198
- # If so, use the numpy function reduce to reduce the logical_or operator
199
- return lax_reference .reduce (* args , init_value = False , dimensions = kwargs ['axes' ], computation = numpy .logical_or )
252
+ if all_concrete_values ([* args ]):
253
+ return lax_reference .reduce (* args , init_value = False , computation = numpy .logical_or , dimensions = kwargs ['axes' ])
254
+ else :
255
+ return symbolic_reduce (* args , init_value = 'False' , computation = symbolic_or , dimensions = kwargs ['axes' ])
256
+
257
+
258
+ def symbolic_reduce_sum (* args , ** kwargs ):
259
+ if all_concrete_values ([* args ]):
260
+ return lax_reference .reduce (* args , init_value = 0 , computation = numpy .add , dimensions = kwargs ['axes' ])
200
261
else :
201
- # Otherwise, we use the symbolic_reduce function to reduce the symbolic_or operator
202
- return symbolic_reduce (* args , init_value = 'False' , dimensions = kwargs ['axes' ], computation = symbolic_or )
262
+ return symbolic_reduce (* args , init_value = '0' , computation = symbolic_sum , dimensions = kwargs ['axes' ])
0 commit comments