Skip to content

Commit c13b4cc

Browse files
authored
Merge pull request #51 from github/mnist-demo
Mnist demo
2 parents 82e5a29 + a03c722 commit c13b4cc

File tree

4 files changed

+1330
-15
lines changed

4 files changed

+1330
-15
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Questions? Ask @Z80coder
1818

1919
- [Neural logic nets](https://drive.google.com/file/d/1P25OxM7Af8ppUGOUhKd6psGHI0OVXIzw/view?usp=sharing) (15m)
2020
- [Neural logic nets for differentiable QL](https://drive.google.com/file/d/195r9Y08Q61V80f2Hqw62YuHpYsCzJCmZ/view?usp=sharing) (30m)
21+
- [Boolean logic nets and MNIST](https://drive.google.com/file/d/1dWAQfFWcOm1ORqfh62H66nigGy2wQ17G/view?usp=share_link) (18m)
2122

2223
## Development videos
2324

@@ -35,5 +36,6 @@ Questions? Ask @Z80coder
3536
- [If-Then-Else neuron](https://drive.google.com/file/d/1qelfWX6s2XhlHxFwUSV76tAS2tyDK3Q0/view?usp=sharing) (23m)
3637
- [Neural conditions and actions](https://drive.google.com/file/d/1nrn_4TlNCmdC1ZAlN9pKIOF2hEjtykuo/view?usp=sharing) (24m)
3738
- [Neural decision lists](https://drive.google.com/file/d/16F_2kpBaZO-qPQLX38Sar9pJfuunsVyO/view?usp=sharing) (15m)
39+
- [Boolean logic nets and MNIST](https://drive.google.com/file/d/1dWAQfFWcOm1ORqfh62H66nigGy2wQ17G/view?usp=share_link) (18m)
3840

3941
More to come!

demos/MNIST-JAX.ipynb

Lines changed: 947 additions & 0 deletions
Large diffs are not rendered by default.

scratchpad.ipynb

Lines changed: 355 additions & 8 deletions
Large diffs are not rendered by default.

tests/test_mnist.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from tqdm import tqdm
2+
from matplotlib import pyplot as plt
23
import tensorflow as tf
34
import tensorflow_datasets as tfds
45
import jax
@@ -20,17 +21,17 @@
2021
The data is loaded using tensorflow_datasets.
2122
"""
2223

23-
def nln(type, x):
24-
x = hard_or.or_layer(type)(100, nn.initializers.uniform(1.0), dtype=jnp.float32)(x) # >=1700 need for >98% accuracy
24+
def nln(type, x, width):
25+
x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0), dtype=jnp.float32)(x) # >=1700 need for >98% accuracy
2526
x = hard_not.not_layer(type)(10, dtype=jnp.float32)(x)
2627
x = primitives.nl_ravel(type)(x) # flatten the outputs of the not layer
2728
x = harden_layer.harden_layer(type)(x) # harden the outputs of the not layer
28-
x = primitives.nl_reshape(type)((10, 100))(x) # reshape to 10 ports, 100 bits each
29+
x = primitives.nl_reshape(type)((10, width))(x) # reshape to 10 ports, 100 bits each
2930
x = primitives.nl_sum(type)(-1)(x) # sum the 100 bits in each port
3031
return x
3132

32-
def batch_nln(type, x):
33-
return jax.vmap(lambda x: nln(type, x))(x)
33+
def batch_nln(type, x, width):
34+
return jax.vmap(lambda x: nln(type, x, width))(x)
3435

3536
class CNN(nn.Module):
3637
"""A simple CNN model."""
@@ -102,6 +103,23 @@ def get_datasets():
102103
test_ds['image'] = jnp.round(test_ds['image'])
103104
return train_ds, test_ds
104105

106+
def show_img(img, ax=None, title=None):
107+
"""Shows a single image."""
108+
if ax is None:
109+
ax = plt.gca()
110+
ax.imshow(img.reshape(28, 28), cmap='gray')
111+
ax.set_xticks([])
112+
ax.set_yticks([])
113+
if title:
114+
ax.set_title(title)
115+
116+
def show_img_grid(imgs, titles):
117+
"""Shows a grid of images."""
118+
n = int(np.ceil(len(imgs)**.5))
119+
_, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
120+
for i, (img, title) in enumerate(zip(imgs, titles)):
121+
show_img(img, axs[i // n][i % n], title)
122+
105123
def create_train_state(net, rng, config):
106124
"""Creates initial `TrainState`."""
107125
# for CNN
@@ -216,7 +234,8 @@ def test_mnist():
216234

217235
# Define the model.
218236
# soft = CNN()
219-
soft, _, _ = neural_logic_net.net(batch_nln)
237+
width = 100
238+
soft, _, _ = neural_logic_net.net(lambda type, x: batch_nln(type, x, width))
220239

221240
# Get the MNIST dataset.
222241
train_ds, test_ds = get_datasets()
@@ -228,5 +247,5 @@ def test_mnist():
228247
trained_state = train_and_evaluate(soft, (train_ds, test_ds), config=config, workdir="./mnist_metrics")
229248

230249
# Check symbolic net
231-
_, hard, symbolic = neural_logic_net.net(nln)
250+
_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type,x, width))
232251
check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)

0 commit comments

Comments
 (0)