From 20c229e889e5283cdc18be4530ec2371628b9f2b Mon Sep 17 00:00:00 2001 From: Filya Geikyan Date: Mon, 28 Apr 2025 19:51:36 +0400 Subject: [PATCH] fix: skip round+normalize when preserve_symmetry=True (already in [-1,1]) --- vector_quantize_pytorch/finite_scalar_quantization.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vector_quantize_pytorch/finite_scalar_quantization.py b/vector_quantize_pytorch/finite_scalar_quantization.py index 21872f1..f097752 100644 --- a/vector_quantize_pytorch/finite_scalar_quantization.py +++ b/vector_quantize_pytorch/finite_scalar_quantization.py @@ -151,6 +151,9 @@ def quantize(self, z): offset = torch.rand_like(bounded_z) - 0.5 bounded_z = torch.where(offset_mask, bounded_z + offset, bounded_z) + if preserve_symmetry: + return bounded_z + return round_ste(bounded_z) / half_width def _scale_and_shift(self, zhat_normalized):