|
| 1 | +import argparse |
1 | 2 | import torch
|
2 | 3 | import torchvision
|
3 | 4 | import torch.optim as optim
|
|
10 | 11 | import math
|
11 | 12 | from tqdm import tqdm
|
12 | 13 |
|
13 |
| -if __name__ == "__main__": |
| 14 | + |
| 15 | +def training(opt): |
14 | 16 |
|
15 | 17 | # ~~~~~~~~~~~~~~~~~~~ hyper parameters ~~~~~~~~~~~~~~~~~~~ #
|
16 |
| - EPOCHS = 20 |
| 18 | + EPOCHS = opt.epochs |
17 | 19 | CHANNELS = 1
|
18 | 20 | H, W = 64, 64
|
19 |
| - IMG_SIZE = CHANNELS * H * W |
20 |
| - lr = 2e-4 |
| 21 | + lr = opt.lr |
21 | 22 | work_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 | 23 | FEATURE_D = 128
|
23 | 24 | Z_DIM = 100
|
24 |
| - GEN_TRAIN_STEPS = 5 |
25 |
| - BATCH_SIZE = 128 |
26 |
| - Weight_dir = Path('/weights').resolve() |
| 25 | + GEN_TRAIN_STEPS = 2 |
| 26 | + BATCH_SIZE = opt.batch_size |
| 27 | + Weight_dir = Path(f'{opt.weights}').resolve() |
27 | 28 | if not Weight_dir.exists():
|
28 | 29 | Weight_dir.mkdir()
|
29 | 30 | # ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
|
30 | 31 |
|
31 | 32 | trans = transforms.Compose(
|
32 |
| - [transforms.Resize((H, W)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) |
| 33 | + [transforms.Resize((H, W)), |
| 34 | + transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) |
33 | 35 |
|
34 | 36 | MNIST_data = MNIST('./data', True, transform=trans, download=True)
|
35 | 37 |
|
|
46 | 48 | feature_d=FEATURE_D).to(work_device)
|
47 | 49 | gen = Faker(Z_DIM, CHANNELS, FEATURE_D).to(work_device)
|
48 | 50 |
|
| 51 | + if opt.resume: |
| 52 | + if Path(Weight_dir/'dirscriminator.pth').exists(): |
| 53 | + |
| 54 | + disc = disc.load_state_dict(torch.load( |
| 55 | + str(Weight_dir/'dirscriminator.pth'), |
| 56 | + map_location=work_device)) |
| 57 | + |
| 58 | + if Path(Weight_dir/'generator.pth').exists(): |
| 59 | + |
| 60 | + gen = gen.load_state_dict(torch.load( |
| 61 | + str(Weight_dir/'generator.pth'), |
| 62 | + map_location=work_device)) |
| 63 | + |
49 | 64 | # ~~~~~~~~~~~~~~~~~~~ create optimizer and loss ~~~~~~~~~~~~~~~~~~~ #
|
50 | 65 |
|
51 | 66 | disc_optim = optim.Adam(disc.parameters(), lr, (0.5, 0.999))
|
|
136 | 151 | G_loss_prev = G_loss
|
137 | 152 | weight_path = str(Weight_dir/'generator.pth')
|
138 | 153 | torch.save(gen.state_dict(), weight_path)
|
| 154 | + |
| 155 | + |
| 156 | +if __name__ == "__main__": |
| 157 | + |
| 158 | + parser = argparse.ArgumentParser() |
| 159 | + # initial pre training weights |
| 160 | + parser.add_argument('--weights', type=str, |
| 161 | + default='', help='save and load location of weights') |
| 162 | + parser.add_argument("epochs", type=int, default=20, |
| 163 | + help='number of epochs to train') |
| 164 | + parser.add_argument('--batch-size', type=int, default=128, |
| 165 | + help='total batch size for all GPUs') |
| 166 | + parser.add_argument('lr', type=float, default=2e-4, |
| 167 | + help='learning rate to use') |
| 168 | + |
| 169 | + parser.add_argument('resume', type=bool, default=True, |
| 170 | + help='should use the last saved weights') |
| 171 | + opt = parser.parse_args() |
| 172 | + training(opt) |
0 commit comments