Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input preprocessing for VGG #5

Closed
borisdayma opened this issue Sep 24, 2021 · 4 comments
Closed

Input preprocessing for VGG #5

borisdayma opened this issue Sep 24, 2021 · 4 comments

Comments

@borisdayma
Copy link

Hi,

In the README, it is mentioned that input should be between 0 and 1.

In the training code, they seem to be between -1 and 1.

In the torchvision doc, they seem to be loaded between 0 and 1 and then normalized with

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

Should they be preprocessed as per the torchvision docs?

@borisdayma
Copy link
Author

Interestingly, in lpips module, they are just normalized to [-1, 1].

@matthias-wright
Copy link
Owner

Hi @borisdayma!
When you are using the pretrained weights (pretrained='imagenet'), the input should be between 0 and 1.
The torchvision normalization is applied to the input in the __call__ method so you don't have to do that, see here.
I am considering to add an additional argument normalization so that people can use the imagenet weights without using the torchvision normalization, what do you think about that?

For the training code I used the range [-1, 1] because that range has worked better for me in the past.

Yeah, for lpips the input is [-1, 1] and then the input is normalized using mean=[-.030,-.088,-.188] and std=[.458,.448,.450].

@borisdayma
Copy link
Author

Thanks, this is much clearer.
I think it's a good idea to add the argument normalize as a parameter.

@matthias-wright
Copy link
Owner

Thanks for the feedback! I added the argument:

import flaxmodels as fm

vgg16 = fm.VGG16(output='logits', pretrained='imagenet', normalize=False)

This way the imagenet weights are used but the images are not normalized internally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants