1
1
from tqdm import tqdm
2
+ from matplotlib import pyplot as plt
2
3
import tensorflow as tf
3
4
import tensorflow_datasets as tfds
4
5
import jax
20
21
The data is loaded using tensorflow_datasets.
21
22
"""
22
23
23
- def nln (type , x ):
24
- x = hard_or .or_layer (type )(100 , nn .initializers .uniform (1.0 ), dtype = jnp .float32 )(x ) # >=1700 need for >98% accuracy
24
+ def nln (type , x , width ):
25
+ x = hard_or .or_layer (type )(width , nn .initializers .uniform (1.0 ), dtype = jnp .float32 )(x ) # >=1700 need for >98% accuracy
25
26
x = hard_not .not_layer (type )(10 , dtype = jnp .float32 )(x )
26
27
x = primitives .nl_ravel (type )(x ) # flatten the outputs of the not layer
27
28
x = harden_layer .harden_layer (type )(x ) # harden the outputs of the not layer
28
- x = primitives .nl_reshape (type )((10 , 100 ))(x ) # reshape to 10 ports, 100 bits each
29
+ x = primitives .nl_reshape (type )((10 , width ))(x ) # reshape to 10 ports, 100 bits each
29
30
x = primitives .nl_sum (type )(- 1 )(x ) # sum the 100 bits in each port
30
31
return x
31
32
32
- def batch_nln (type , x ):
33
- return jax .vmap (lambda x : nln (type , x ))(x )
33
+ def batch_nln (type , x , width ):
34
+ return jax .vmap (lambda x : nln (type , x , width ))(x )
34
35
35
36
class CNN (nn .Module ):
36
37
"""A simple CNN model."""
@@ -102,6 +103,23 @@ def get_datasets():
102
103
test_ds ['image' ] = jnp .round (test_ds ['image' ])
103
104
return train_ds , test_ds
104
105
106
+ def show_img (img , ax = None , title = None ):
107
+ """Shows a single image."""
108
+ if ax is None :
109
+ ax = plt .gca ()
110
+ ax .imshow (img .reshape (28 , 28 ), cmap = 'gray' )
111
+ ax .set_xticks ([])
112
+ ax .set_yticks ([])
113
+ if title :
114
+ ax .set_title (title )
115
+
116
+ def show_img_grid (imgs , titles ):
117
+ """Shows a grid of images."""
118
+ n = int (np .ceil (len (imgs )** .5 ))
119
+ _ , axs = plt .subplots (n , n , figsize = (3 * n , 3 * n ))
120
+ for i , (img , title ) in enumerate (zip (imgs , titles )):
121
+ show_img (img , axs [i // n ][i % n ], title )
122
+
105
123
def create_train_state (net , rng , config ):
106
124
"""Creates initial `TrainState`."""
107
125
# for CNN
@@ -216,7 +234,8 @@ def test_mnist():
216
234
217
235
# Define the model.
218
236
# soft = CNN()
219
- soft , _ , _ = neural_logic_net .net (batch_nln )
237
+ width = 100
238
+ soft , _ , _ = neural_logic_net .net (lambda type , x : batch_nln (type , x , width ))
220
239
221
240
# Get the MNIST dataset.
222
241
train_ds , test_ds = get_datasets ()
@@ -228,5 +247,5 @@ def test_mnist():
228
247
trained_state = train_and_evaluate (soft , (train_ds , test_ds ), config = config , workdir = "./mnist_metrics" )
229
248
230
249
# Check symbolic net
231
- _ , hard , symbolic = neural_logic_net .net (nln )
250
+ _ , hard , symbolic = neural_logic_net .net (lambda type , x : nln ( type , x , width ) )
232
251
check_symbolic ((soft , hard , symbolic ), (train_ds , test_ds ), trained_state )
0 commit comments