|
| 1 | +from typing import Callable |
| 2 | + |
| 3 | +import jax |
| 4 | +from flax import linen as nn |
| 5 | + |
| 6 | +from neurallogic import neural_logic_net, symbolic_generation |
| 7 | + |
| 8 | + |
| 9 | +def soft_real_encoder(t: float, x: float) -> float: |
| 10 | + eps = 0.0000001 |
| 11 | + # x should be in [0, 1] |
| 12 | + t = jax.numpy.clip(t, 0.0, 1.0) |
| 13 | + return jax.numpy.where( |
| 14 | + jax.numpy.isclose(t, x), |
| 15 | + 0.5, |
| 16 | + # t != x |
| 17 | + jax.numpy.where( |
| 18 | + x < t, |
| 19 | + (1.0 / (2.0 * t + eps)) * x, |
| 20 | + # x > t |
| 21 | + (1.0 / (2.0 * (1.0 - t) + eps)) * (x + 1.0 - 2.0 * t) |
| 22 | + ) |
| 23 | + ) |
| 24 | + |
| 25 | + |
| 26 | +def hard_real_encoder(t: float, x: float) -> bool: |
| 27 | + # t and x must be floats |
| 28 | + return jax.numpy.where(soft_real_encoder(t, x) > 0.5, True, False) |
| 29 | + |
| 30 | + |
| 31 | +soft_real_encoder_neuron = jax.vmap(soft_real_encoder, in_axes=(0, None)) |
| 32 | + |
| 33 | +hard_real_encoder_neuron = jax.vmap(hard_real_encoder, in_axes=(0, None)) |
| 34 | + |
| 35 | +soft_real_encoder_layer = jax.vmap(soft_real_encoder_neuron, (0, 0), 0) |
| 36 | + |
| 37 | +hard_real_encoder_layer = jax.vmap(hard_real_encoder_neuron, (0, 0), 0) |
| 38 | + |
| 39 | + |
| 40 | +class SoftRealEncoderLayer(nn.Module): |
| 41 | + bits_per_real: int |
| 42 | + thresholds_init: Callable = nn.initializers.uniform(1.0) |
| 43 | + dtype: jax.numpy.dtype = jax.numpy.float32 |
| 44 | + |
| 45 | + @nn.compact |
| 46 | + def __call__(self, x): |
| 47 | + thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real) |
| 48 | + thresholds = self.param( |
| 49 | + "thresholds", self.thresholds_init, thresholds_shape, self.dtype) |
| 50 | + x = jax.numpy.asarray(x, self.dtype) |
| 51 | + return soft_real_encoder_layer(thresholds, x) |
| 52 | + |
| 53 | + |
| 54 | +class HardRealEncoderLayer(nn.Module): |
| 55 | + bits_per_real: int |
| 56 | + |
| 57 | + @nn.compact |
| 58 | + def __call__(self, x): |
| 59 | + thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real) |
| 60 | + thresholds = self.param( |
| 61 | + "thresholds", nn.initializers.constant(0.0), thresholds_shape) |
| 62 | + return hard_real_encoder_layer(thresholds, x) |
| 63 | + |
| 64 | + |
| 65 | +class SymbolicRealEncoderLayer: |
| 66 | + def __init__(self, bits_per_real): |
| 67 | + self.bits_per_real = bits_per_real |
| 68 | + self.hard_real_encoder_layer = HardRealEncoderLayer(self.bits_per_real) |
| 69 | + |
| 70 | + def __call__(self, x): |
| 71 | + jaxpr = symbolic_generation.make_symbolic_flax_jaxpr( |
| 72 | + self.hard_real_encoder_layer, x |
| 73 | + ) |
| 74 | + return symbolic_generation.symbolic_expression(jaxpr, x) |
| 75 | + |
| 76 | + |
| 77 | +real_encoder_layer = neural_logic_net.select( |
| 78 | + lambda bits_per_real, weights_init=nn.initializers.uniform( |
| 79 | + 1.0 |
| 80 | + ), dtype=jax.numpy.float32: SoftRealEncoderLayer( |
| 81 | + bits_per_real, weights_init, dtype |
| 82 | + ), |
| 83 | + lambda bits_per_real, weights_init=nn.initializers.uniform( |
| 84 | + 1.0 |
| 85 | + ), dtype=jax.numpy.float32: HardRealEncoderLayer(bits_per_real), |
| 86 | + lambda bits_per_real, weights_init=nn.initializers.uniform( |
| 87 | + 1.0 |
| 88 | + ), dtype=jax.numpy.float32: SymbolicRealEncoderLayer(bits_per_real), |
| 89 | +) |
0 commit comments