diff --git a/vector_quantize_pytorch/finite_scalar_quantization.py b/vector_quantize_pytorch/finite_scalar_quantization.py index 21872f1..f097752 100644 --- a/vector_quantize_pytorch/finite_scalar_quantization.py +++ b/vector_quantize_pytorch/finite_scalar_quantization.py @@ -151,6 +151,9 @@ def quantize(self, z): offset = torch.rand_like(bounded_z) - 0.5 bounded_z = torch.where(offset_mask, bounded_z + offset, bounded_z) + if preserve_symmetry: + return bounded_z + return round_ste(bounded_z) / half_width def _scale_and_shift(self, zhat_normalized):