-
Notifications
You must be signed in to change notification settings - Fork 280
Open
Description
VQPytorch's FSQ with symmetry on and noise dropping set to 0.5 seems to perform significantly better than the reference implementation in recon loss with the same settings, so I set out to figure out why suspecting one of the two impls may be broken.
First
def symmetry_preserving_bound(self, z):
"""
QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
"""
levels_minus_1 = (self._levels - 1)
scale = 2.0 / levels_minus_1
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
return scale * bracket - 1.0
This simplifies to this.
def symmetry_preserving_bound(self, z):
return torch.tanh(z) + 1.0 / (self._levels - 1)
Second, this version doesn't seem to match what they do in the reference implementation.
You do:
symmetry_preserved_bound = torch.tanh(z) + 1.0 / (self._levels - 1)
rounded = round_ste(symmetry_preserved_bound) / (self._levels // 2)
They do (also simplified):
dfsq_scale_shift = (torch.tanh(z)/self.scale + 1) * (self._levels - 1) / 2
rounded = round_ste(dfsq_scale_shift)
dfsq_inverse_scale_shift = (rounded * self.scale * 2 / (self._levels - 1)) - self.scale
Third, the noise scaling is slightly different, idk how much this matters:
Yours:
offset = (torch.rand_like(z) - 0.5) / (self._levels // 2)
quantized = torch.where(offset_mask, unquantized + offset, quantized)
Theirs
offset = (torch.rand_like(z) - 0.5) * (self.scale * 2 / (self._levels - 1))
quantized = torch.where(mask, quantized, z + offset)
Fourth, you pass through the non-tanh'ed input to the noise dropout quantization portion, which could allow it to scale arbitrarily to fight off the noise.
offset_mask = torch.bernoulli(
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
).bool().expand_as(z)
offset = (torch.rand_like(z) - 0.5) / half_width
quantized = torch.where(offset_mask, unquantized + offset, quantized)
I suspect the last one is the cause of the performance difference and will check later.
Metadata
Metadata
Assignees
Labels
No labels