Skip to content

Commit 4b95d27

Browse files
added argparse functionality
1 parent e1098f0 commit 4b95d27

File tree

1 file changed

+42
-8
lines changed

1 file changed

+42
-8
lines changed

GAN/DCGAN/training.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import torch
23
import torchvision
34
import torch.optim as optim
@@ -10,26 +11,27 @@
1011
import math
1112
from tqdm import tqdm
1213

13-
if __name__ == "__main__":
14+
15+
def training(opt):
1416

1517
# ~~~~~~~~~~~~~~~~~~~ hyper parameters ~~~~~~~~~~~~~~~~~~~ #
16-
EPOCHS = 20
18+
EPOCHS = opt.epochs
1719
CHANNELS = 1
1820
H, W = 64, 64
19-
IMG_SIZE = CHANNELS * H * W
20-
lr = 2e-4
21+
lr = opt.lr
2122
work_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2223
FEATURE_D = 128
2324
Z_DIM = 100
24-
GEN_TRAIN_STEPS = 5
25-
BATCH_SIZE = 128
26-
Weight_dir = Path('/weights').resolve()
25+
GEN_TRAIN_STEPS = 2
26+
BATCH_SIZE = opt.batch_size
27+
Weight_dir = Path(f'{opt.weights}').resolve()
2728
if not Weight_dir.exists():
2829
Weight_dir.mkdir()
2930
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
3031

3132
trans = transforms.Compose(
32-
[transforms.Resize((H, W)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
33+
[transforms.Resize((H, W)),
34+
transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
3335

3436
MNIST_data = MNIST('./data', True, transform=trans, download=True)
3537

@@ -46,6 +48,19 @@
4648
feature_d=FEATURE_D).to(work_device)
4749
gen = Faker(Z_DIM, CHANNELS, FEATURE_D).to(work_device)
4850

51+
if opt.resume:
52+
if Path(Weight_dir/'dirscriminator.pth').exists():
53+
54+
disc = disc.load_state_dict(torch.load(
55+
str(Weight_dir/'dirscriminator.pth'),
56+
map_location=work_device))
57+
58+
if Path(Weight_dir/'generator.pth').exists():
59+
60+
gen = gen.load_state_dict(torch.load(
61+
str(Weight_dir/'generator.pth'),
62+
map_location=work_device))
63+
4964
# ~~~~~~~~~~~~~~~~~~~ create optimizer and loss ~~~~~~~~~~~~~~~~~~~ #
5065

5166
disc_optim = optim.Adam(disc.parameters(), lr, (0.5, 0.999))
@@ -136,3 +151,22 @@
136151
G_loss_prev = G_loss
137152
weight_path = str(Weight_dir/'generator.pth')
138153
torch.save(gen.state_dict(), weight_path)
154+
155+
156+
if __name__ == "__main__":
157+
158+
parser = argparse.ArgumentParser()
159+
# initial pre training weights
160+
parser.add_argument('--weights', type=str,
161+
default='', help='save and load location of weights')
162+
parser.add_argument("epochs", type=int, default=20,
163+
help='number of epochs to train')
164+
parser.add_argument('--batch-size', type=int, default=128,
165+
help='total batch size for all GPUs')
166+
parser.add_argument('lr', type=float, default=2e-4,
167+
help='learning rate to use')
168+
169+
parser.add_argument('resume', type=bool, default=True,
170+
help='should use the last saved weights')
171+
opt = parser.parse_args()
172+
training(opt)

0 commit comments

Comments
 (0)