You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Do you have any suggestions of changes to make to train BYOL on the CIFAR10 dataset?
The way I am doing this (in main.py) (I am also training my own custom models, but I do not think that is too relevant)
DATASET='CIFAR10' # Can change to STL10
if DATASET=='STL10':
train_dataset = datasets.STL10('/workspace/STLDataset', split='train+unlabeled', download=True,
transform=MultiViewDataInjector([data_transform, data_transform]))
elif DATASET=='CIFAR10':
train_dataset = datasets.CIFAR10('/workspace/CIFAR10Dataset', train=True, download=True,
transform=MultiViewDataInjector([data_transform, data_transform]))
else:
print("Error, dataset not supported, choose CIFAR10 or STL10")
exit(0)
I also change the config to have: input_shape: (32,32,3).
Further, I may not have taken a very deep look into this code-base, but how do we produce the 'STL10 Top 1' accuracies(75.2%) after training the model on the self-supervised task? Do we take the trained model and fine-tune on the STL10 supervised dataset? I assume that code is not included in this library?
Thank you!
The text was updated successfully, but these errors were encountered:
Hi Akhauriyash, you can just modify the input shape and name of the dataset.
I am testing with the model but it doesn't work well with CIFAR10, ~ 54% top1 accuracy and I wonder the config is the same or different on learning rate?
Thank you!
Hello,
Thank you for this excellent repository!
Do you have any suggestions of changes to make to train BYOL on the CIFAR10 dataset?
The way I am doing this (in main.py) (I am also training my own custom models, but I do not think that is too relevant)
I also change the config to have: input_shape: (32,32,3).
Further, I may not have taken a very deep look into this code-base, but how do we produce the 'STL10 Top 1' accuracies(75.2%) after training the model on the self-supervised task? Do we take the trained model and fine-tune on the STL10 supervised dataset? I assume that code is not included in this library?
Thank you!
The text was updated successfully, but these errors were encountered: