Skip to content

Commit 5662aa4

Browse files
code optimization
1 parent 01fb633 commit 5662aa4

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

GAN/WGAN/training.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def training(opt):
106106

107107
for epoch in range(EPOCHS):
108108

109+
# reset the average loss to zero
110+
C_loss_avg = 0
111+
G_loss_avg = 0
112+
109113
print_memory_utilization()
110114

111115
for batch_idx, (real, _) in enumerate(tqdm(loader)):
@@ -147,7 +151,8 @@ def training(opt):
147151

148152
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
149153

150-
# re using the fake_predict from cirtic forward
154+
# make it one dimensional array
155+
fake_predict = critic(fake).view(-1)
151156

152157
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
153158

@@ -162,16 +167,17 @@ def training(opt):
162167

163168
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
164169

165-
if batch_idx == len(loader)-1: # will execute at the last batch
170+
# will execute at every 50 steps
171+
if (batch_idx+1) % 50 == 0:
166172

167173
# ~~~~~~~~~~~~ calculate average loss ~~~~~~~~~~~~~ #
168174

169-
C_loss_avg = C_loss_avg/(CRITIC_TRAIN_STEPS*BATCH_SIZE)
170-
G_loss_avg = G_loss_avg/(BATCH_SIZE)
175+
C_loss_avg_ = C_loss_avg/(CRITIC_TRAIN_STEPS*batch_idx)
176+
G_loss_avg_ = G_loss_avg/(batch_idx)
171177

172178
print(
173-
f"Epoch [{epoch}/{EPOCHS}] "
174-
+ f"Loss D: {C_loss_avg:.4f}, loss G: {G_loss_avg:.4f}"
179+
f"Epoch [{epoch}/{EPOCHS}] | batch size {batch_idx}"
180+
+ f"Loss C: {C_loss_avg_:.4f}, loss G: {G_loss_avg_:.4f}"
175181
)
176182

177183
# ~~~~~~~~~~~~ send data to tensorboard ~~~~~~~~~~~~~ #
@@ -191,20 +197,18 @@ def training(opt):
191197
img_grid_real = torchvision.utils.make_grid(
192198
data, normalize=True)
193199

200+
step = (epoch+1)*(batch_idx+1)
201+
194202
writer_fake.add_image(
195-
"Mnist Fake Images", img_grid_fake, global_step=epoch
203+
"Mnist Fake Images", img_grid_fake, global_step=step
196204
)
197205
writer_real.add_image(
198-
"Mnist Real Images", img_grid_real, global_step=epoch
206+
"Mnist Real Images", img_grid_real, global_step=step
199207
)
200208
loss_writer.add_scalar(
201-
'Critic', C_loss, global_step=epoch)
209+
'Critic', C_loss, global_step=step)
202210
loss_writer.add_scalar(
203-
'generator', G_loss, global_step=epoch)
204-
205-
# reset the average loss to zero
206-
C_loss_avg = 0
207-
G_loss_avg = 0
211+
'generator', G_loss, global_step=step)
208212

209213
# changing back the model to train mode
210214
critic.train()

0 commit comments

Comments
 (0)