Skip to content

Commit 49eaeda

Browse files
committed
successfully interpret the entire NN architecture
1 parent 5032757 commit 49eaeda

File tree

8 files changed

+195
-135
lines changed

8 files changed

+195
-135
lines changed

neurallogic/hard_and.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def soft_and_include(w: float, x: float) -> float:
2222
def hard_and_include(w: bool, x: bool) -> bool:
2323
return x | ~w
2424

25-
def symbolic_and_include_deprecated(w, x):
25+
def symbolic_and_include(w, x):
2626
expression = f"({x} or not({w}))"
2727
# Check if w is of type bool
2828
if isinstance(w, bool) and isinstance(x, bool):
@@ -39,13 +39,13 @@ def hard_and_neuron(w, x):
3939
x = jax.vmap(hard_and_include, 0, 0)(w, x)
4040
return jax.lax.reduce(x, True, jax.lax.bitwise_and, [0])
4141

42-
def symbolic_and_neuron_deprecated(w, x):
42+
def symbolic_and_neuron(w, x):
4343
# TODO: ensure that this implementation has the same generality over tensors as vmap
4444
if not isinstance(w, list):
4545
raise TypeError(f"Input {x} should be a list")
4646
if not isinstance(x, list):
4747
raise TypeError(f"Input {x} should be a list")
48-
y = [symbolic_and_include_deprecated(wi, xi) for wi, xi in zip(w, x)]
48+
y = [symbolic_and_include(wi, xi) for wi, xi in zip(w, x)]
4949
expression = "(" + str(reduce(lambda a, b: f"{a} and {b}", y)) + ")"
5050
if all(isinstance(yi, bool) for yi in y):
5151
# We know the value of all yis, so we can evaluate the expression
@@ -56,13 +56,13 @@ def symbolic_and_neuron_deprecated(w, x):
5656

5757
hard_and_layer = jax.vmap(hard_and_neuron, (0, None), 0)
5858

59-
def symbolic_and_layer_deprecated(w, x):
59+
def symbolic_and_layer(w, x):
6060
# TODO: ensure that this implementation has the same generality over tensors as vmap
6161
if not isinstance(w, list):
6262
raise TypeError(f"Input {x} should be a list")
6363
if not isinstance(x, list):
6464
raise TypeError(f"Input {x} should be a list")
65-
return [symbolic_and_neuron_deprecated(wi, x) for wi in w]
65+
return [symbolic_and_neuron(wi, x) for wi in w]
6666

6767
# TODO: investigate better initialization
6868
def initialize_near_to_zero():
@@ -111,7 +111,7 @@ def __call__(self, x):
111111
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
112112
return hard_and_layer(weights, x)
113113

114-
class SymbolicAndLayer_deprecated(nn.Module):
114+
class SymbolicAndLayer(nn.Module):
115115
"""A symbolic And layer than transforms its inputs along the last dimension.
116116
Attributes:
117117
layer_size: The number of neurons in the layer.
@@ -125,9 +125,9 @@ def __call__(self, x):
125125
weights = weights.tolist()
126126
if not isinstance(x, list):
127127
raise TypeError(f"Input {x} should be a list")
128-
return symbolic_and_layer_deprecated(weights, x)
128+
return symbolic_and_layer(weights, x)
129129

130130
and_layer = neural_logic_net.select(
131131
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
132132
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: HardAndLayer(layer_size),
133-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SymbolicAndLayer_deprecated(layer_size))
133+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))

neurallogic/primitives.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@
88
symbolic shape transformations
99
"""
1010

11-
def symbolic_ravel_deprecated(x):
11+
def symbolic_ravel(x):
1212
return numpy.array(x).ravel().tolist()
1313

14-
nl_ravel_deprecated = neural_logic_net.select(jnp.ravel, jnp.ravel, symbolic_ravel_deprecated)
14+
nl_ravel = neural_logic_net.select(jnp.ravel, jnp.ravel, symbolic_ravel)
1515

16-
def symbolic_reshape_deprecated(x, newshape):
16+
def symbolic_reshape(x, newshape):
1717
return numpy.array(x).reshape(newshape).tolist()
1818

19-
nl_reshape_deprecated = neural_logic_net.select(lambda newshape: lambda x: jnp.reshape(x, newshape), lambda newshape: lambda x: jnp.reshape(x, newshape), lambda newshape: lambda x: symbolic_reshape_deprecated(x, newshape))
19+
nl_reshape = neural_logic_net.select(lambda newshape: lambda x: jnp.reshape(x, newshape), lambda newshape: lambda x: jnp.reshape(x, newshape), lambda newshape: lambda x: symbolic_reshape(x, newshape))
2020

2121
"""
2222
symbolic computations
2323
"""
24-
def symbolic_reduce_impl_deprecated(op, x, axis):
24+
def symbolic_reduce_impl(op, x, axis):
2525
"""
2626
Cannot support multiple axes due to limitations of numpy.
2727
"""
@@ -36,17 +36,17 @@ def op_xy(x, y):
3636
x = x.tolist()
3737
return x
3838

39-
def symbolic_reduce_deprecated(op, x, axis=None):
39+
def symbolic_reduce(op, x, axis=None):
4040
if axis is None:
4141
# Special case for reducing all elements in a tensor
4242
while isinstance(x, list) and len(x) > 1:
43-
x = symbolic_reduce_impl_deprecated(op, x, 0)
43+
x = symbolic_reduce_impl(op, x, 0)
4444
return x
4545
else:
46-
x = symbolic_reduce_impl_deprecated(op, x, axis)
46+
x = symbolic_reduce_impl(op, x, axis)
4747
return x
4848

49-
def symbolic_sum_deprecated(x, axis=None):
50-
return symbolic_reduce_deprecated((operator.add, "+"), x, axis)
49+
def symbolic_sum(x, axis=None):
50+
return symbolic_reduce((operator.add, "+"), x, axis)
5151

52-
nl_sum_deprecated = neural_logic_net.select(lambda axis=None: lambda x: jnp.sum(x, axis), lambda axis=None: lambda x: jnp.sum(x, axis), lambda axis=None: lambda x: symbolic_sum_deprecated(x, axis))
52+
nl_sum = neural_logic_net.select(lambda axis=None: lambda x: jnp.sum(x, axis), lambda axis=None: lambda x: jnp.sum(x, axis), lambda axis=None: lambda x: symbolic_sum(x, axis))

neurallogic/sym_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def symbolic_bind(prim, *args, **params):
1818
'not': symbolic_primitives.symbolic_not,
1919
'reshape': lax_reference.reshape,
2020
'reduce_or': symbolic_primitives.symbolic_reduce_or,
21+
'reduce_sum': symbolic_primitives.symbolic_reduce_sum,
22+
'convert_element_type': symbolic_primitives.symbolic_convert_element_type
2123
}[prim.name](*args, **params)
2224
return symbolic_outvals
2325

neurallogic/symbolic_primitives.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from plum import dispatch
33
import jax
44
import jax._src.lax_reference as lax_reference
5+
from neurallogic import primitives
56

67

78
def to_boolean_value_string(x):
@@ -110,64 +111,119 @@ def symbolic_eval(x):
110111
return numpy.vectorize(eval)(x)
111112

112113

113-
def all_boolean(data):
114-
if isinstance(data, bool):
115-
return True
114+
def all_concrete_values(data):
115+
if isinstance(data, str):
116+
return False
116117
if isinstance(data, (list, tuple)):
117-
return all(all_boolean(x) for x in data)
118+
return all(all_concrete_values(x) for x in data)
118119
if isinstance(data, dict):
119-
return all(all_boolean(v) for v in data.values())
120+
return all(all_concrete_values(v) for v in data.values())
120121
if isinstance(data, numpy.ndarray):
121-
return all_boolean(data.tolist())
122+
return all_concrete_values(data.tolist())
122123
if isinstance(data, jax.numpy.ndarray):
123-
return all_boolean(data.tolist())
124-
return False
124+
return all_concrete_values(data.tolist())
125+
return True
125126

126127

127128
def symbolic_and(*args, **kwargs):
128-
if all_boolean([*args]):
129+
if all_concrete_values([*args]):
129130
return numpy.logical_and(*args, **kwargs)
130131
else:
131132
return binary_infix_operator("and", *args, **kwargs)
132133

133134

134135
def symbolic_not(*args, **kwargs):
135-
if all_boolean([*args]):
136+
if all_concrete_values([*args]):
136137
return numpy.logical_not(*args, **kwargs)
137138
else:
138139
return unary_operator("not", *args, **kwargs)
139140

140141

141142
def symbolic_xor(*args, **kwargs):
142-
if all_boolean([*args]):
143+
if all_concrete_values([*args]):
143144
return numpy.logical_xor(*args, **kwargs)
144145
else:
145146
return binary_infix_operator("^", *args, **kwargs, bracket=True)
146147

147148

148149
def symbolic_or(*args, **kwargs):
149-
if all_boolean([*args]):
150+
if all_concrete_values([*args]):
150151
return numpy.logical_or(*args, **kwargs)
151152
else:
152153
return binary_infix_operator("or", *args, **kwargs)
153154

154155

156+
def symbolic_sum(*args, **kwargs):
157+
if all_concrete_values([*args]):
158+
return numpy.sum(*args, **kwargs)
159+
else:
160+
return binary_infix_operator("+", *args, **kwargs)
161+
155162
# Uses the lax reference implementation of broadcast_in_dim to
156163
# implement a symbolic version of broadcast_in_dim
157164

158165

159166
def symbolic_broadcast_in_dim(*args, **kwargs):
160167
return lax_reference.broadcast_in_dim(*args, **kwargs)
161168

169+
170+
def is_iterable(obj):
171+
try:
172+
iter(obj)
173+
return True
174+
except TypeError:
175+
return False
176+
177+
# TODO: unify this way of walking a nested iterable with the code above
178+
def apply_func_to_nested_impl(iterable, func):
179+
if isinstance(iterable, (numpy.ndarray, jax.numpy.ndarray)):
180+
iterable = iterable.tolist()
181+
if is_iterable(iterable):
182+
transformed = []
183+
for item in iterable:
184+
if isinstance(item, list):
185+
transformed.append(apply_func_to_nested_impl(item, func))
186+
else:
187+
transformed.append(func(item))
188+
return transformed
189+
else:
190+
return func(iterable)
191+
192+
def apply_func_to_nested(iterable, func):
193+
iterable_type = type(iterable)
194+
r = apply_func_to_nested_impl(iterable, func)
195+
if iterable_type == numpy.ndarray:
196+
r = numpy.array(r, dtype=object)
197+
assert type(r) == iterable_type
198+
return r
199+
200+
def symbolic_convert_element_type_impl(x, dtype):
201+
if dtype == numpy.int32 or dtype == numpy.int64:
202+
dtype = "int"
203+
def convert(x):
204+
return f"{dtype}({x})"
205+
return apply_func_to_nested(x, convert)
206+
207+
208+
# TODO: add a test for this
209+
def symbolic_convert_element_type(*args, **kwargs):
210+
# Check if all the boolean arguments are True or False
211+
if all_concrete_values([*args]):
212+
# If so, we can use the lax reference implementation
213+
return lax_reference.convert_element_type(*args, dtype=kwargs['new_dtype'])
214+
else:
215+
# Otherwise, we use the symbolic implementation
216+
return symbolic_convert_element_type_impl(*args, dtype=kwargs['new_dtype'])
217+
218+
162219
# This function is a hack to get around the fact that JAX doesn't
163220
# support symbolic reduction operations. It takes a symbolic reduction
164221
# operation and a symbolic initial value and returns a function that
165222
# performs the reduction operation on a numpy array.
166223

167224

168225
def make_symbolic_reducer(py_binop, init_val):
169-
def reducer(operand, axis=0):
170-
# axis=0 means we are reducing over the first axis (i.e. the rows) of the operand.
226+
def reducer(operand, axis):
171227
# axis=None means we are reducing over all axes of the operand.
172228
axis = range(numpy.ndim(operand)) if axis is None else axis
173229

@@ -193,10 +249,14 @@ def symbolic_reduce(operand, init_value, computation, dimensions):
193249

194250

195251
def symbolic_reduce_or(*args, **kwargs):
196-
# Check if all the boolean arguments are True or False
197-
if all_boolean([*args]):
198-
# If so, use the numpy function reduce to reduce the logical_or operator
199-
return lax_reference.reduce(*args, init_value=False, dimensions=kwargs['axes'], computation=numpy.logical_or)
252+
if all_concrete_values([*args]):
253+
return lax_reference.reduce(*args, init_value=False, computation=numpy.logical_or, dimensions=kwargs['axes'])
254+
else:
255+
return symbolic_reduce(*args, init_value='False', computation=symbolic_or, dimensions=kwargs['axes'])
256+
257+
258+
def symbolic_reduce_sum(*args, **kwargs):
259+
if all_concrete_values([*args]):
260+
return lax_reference.reduce(*args, init_value=0, computation=numpy.add, dimensions=kwargs['axes'])
200261
else:
201-
# Otherwise, we use the symbolic_reduce function to reduce the symbolic_or operator
202-
return symbolic_reduce(*args, init_value='False', dimensions=kwargs['axes'], computation=symbolic_or)
262+
return symbolic_reduce(*args, init_value='0', computation=symbolic_sum, dimensions=kwargs['axes'])

0 commit comments

Comments
 (0)