Skip to content

FSQ Oddness #194

@zaptrem

Description

@zaptrem

VQPytorch's FSQ with symmetry on and noise dropping set to 0.5 seems to perform significantly better than the reference implementation in recon loss with the same settings, so I set out to figure out why suspecting one of the two impls may be broken.

First

    def symmetry_preserving_bound(self, z):
        """
        QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
        """
        levels_minus_1 = (self._levels - 1)
        scale = 2.0 / levels_minus_1
        bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
        return scale * bracket - 1.0

This simplifies to this.

    def symmetry_preserving_bound(self, z):
        return torch.tanh(z) + 1.0 / (self._levels - 1)

Second, this version doesn't seem to match what they do in the reference implementation.

You do:

symmetry_preserved_bound = torch.tanh(z) + 1.0 / (self._levels - 1)
rounded = round_ste(symmetry_preserved_bound) / (self._levels // 2)

They do (also simplified):

dfsq_scale_shift = (torch.tanh(z)/self.scale + 1) * (self._levels - 1) / 2
rounded = round_ste(dfsq_scale_shift)
dfsq_inverse_scale_shift = (rounded * self.scale * 2 / (self._levels - 1)) - self.scale

Third, the noise scaling is slightly different, idk how much this matters:

Yours:

offset = (torch.rand_like(z) - 0.5) / (self._levels // 2)
quantized = torch.where(offset_mask, unquantized + offset, quantized)

Theirs

offset = (torch.rand_like(z) - 0.5) * (self.scale * 2 / (self._levels - 1))
quantized = torch.where(mask, quantized, z + offset)

Fourth, you pass through the non-tanh'ed input to the noise dropout quantization portion, which could allow it to scale arbitrarily to fight off the noise.

            offset_mask = torch.bernoulli(
                torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
            ).bool().expand_as(z)
            
            offset = (torch.rand_like(z) - 0.5) / half_width
            quantized = torch.where(offset_mask, unquantized + offset, quantized)

I suspect the last one is the cause of the performance difference and will check later.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions