Skip to content

Commit dfe2df4

Browse files
code optimization
1 parent 318cf63 commit dfe2df4

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

GAN/WGAN/training.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ def training(opt):
2525
FEATURE_D = 128
2626
Z_DIM = 100
2727
BATCH_SIZE = opt.batch_size
28+
2829
# ~~~~~~~~~~~~~~~~~~~ as per WGAN paper ~~~~~~~~~~~~~~~~~~~ #
2930

3031
lr = opt.lr
3132
CRITIC_TRAIN_STEPS = 5
3233
WEIGHT_CLIP = 0.01
3334

34-
print(f"Epochs: {EPOCHS}| lr: {lr}| batch size {BATCH_SIZE}" +
35-
f"device: {work_device}")
35+
print(f"Epochs: {EPOCHS}| lr: {lr}| batch size {BATCH_SIZE}|" +
36+
f" device: {work_device}")
37+
3638
# ~~~~~~~~~~~ creating directories for weights ~~~~~~~~~~~ #
3739

3840
if opt.logs:
@@ -88,26 +90,31 @@ def training(opt):
8890

8991
# ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ #
9092

93+
# loss variables
9194
C_loss_prev = math.inf
9295
G_loss_prev = math.inf
9396
C_loss = 0
9497
G_loss = 0
98+
C_loss_avg = 0
99+
G_loss_avg = 0
95100

96101
print_gpu_details()
102+
103+
# setting the models to train mode
104+
critic.train()
105+
gen.train()
106+
97107
for epoch in range(EPOCHS):
98-
C_loss_avg = 0
99-
G_loss_avg = 0
100108

101109
print_memory_utilization()
102110

103111
for batch_idx, (real, _) in enumerate(tqdm(loader)):
104-
critic.train()
105-
gen.train()
112+
106113
real = real.to(work_device)
107114
fixed_noise = torch.rand(
108115
real.shape[0], Z_DIM, 1, 1).to(work_device)
109116

110-
# ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ #
117+
# ~~~~~~~~~~~~~~~~~~~ critic loop ~~~~~~~~~~~~~~~~~~~ #
111118

112119
fake = gen(fixed_noise) # dim of (N,1,W,H)
113120

@@ -124,6 +131,7 @@ def training(opt):
124131

125132
C_loss = -(torch.mean(real_predict) - torch.mean(fake_predict))
126133
C_loss_avg += C_loss
134+
127135
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
128136

129137
critic.zero_grad()
@@ -145,6 +153,7 @@ def training(opt):
145153

146154
G_loss = -(torch.mean(fake_predict))
147155
G_loss_avg += G_loss
156+
148157
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
149158

150159
gen.zero_grad()
@@ -153,7 +162,10 @@ def training(opt):
153162

154163
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
155164

156-
if batch_idx == 0 and epoch >= 1:
165+
if batch_idx == len(loader)-1: # will execute at the last batch
166+
167+
# ~~~~~~~~~~~~ calculate average loss ~~~~~~~~~~~~~ #
168+
157169
C_loss_avg = C_loss_avg/(CRITIC_TRAIN_STEPS*BATCH_SIZE)
158170
G_loss_avg = G_loss_avg/(BATCH_SIZE)
159171

@@ -162,6 +174,8 @@ def training(opt):
162174
+ f"Loss D: {C_loss_avg:.4f}, loss G: {G_loss_avg:.4f}"
163175
)
164176

177+
# ~~~~~~~~~~~~ send data to tensorboard ~~~~~~~~~~~~~ #
178+
165179
with torch.no_grad():
166180
critic.eval()
167181
gen.eval()
@@ -186,6 +200,14 @@ def training(opt):
186200
loss_writer.add_scalar(
187201
'generator', G_loss, global_step=epoch)
188202

203+
# reset the average loss to zero
204+
C_loss_avg = 0
205+
G_loss_avg = 0
206+
207+
# changing back the model to train mode
208+
critic.train()
209+
gen.train()
210+
189211
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
190212

191213
if opt.weights:

0 commit comments

Comments
 (0)