@@ -106,6 +106,10 @@ def training(opt):
106
106
107
107
for epoch in range (EPOCHS ):
108
108
109
+ # reset the average loss to zero
110
+ C_loss_avg = 0
111
+ G_loss_avg = 0
112
+
109
113
print_memory_utilization ()
110
114
111
115
for batch_idx , (real , _ ) in enumerate (tqdm (loader )):
@@ -147,7 +151,8 @@ def training(opt):
147
151
148
152
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
149
153
150
- # re using the fake_predict from cirtic forward
154
+ # make it one dimensional array
155
+ fake_predict = critic (fake ).view (- 1 )
151
156
152
157
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
153
158
@@ -162,16 +167,17 @@ def training(opt):
162
167
163
168
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
164
169
165
- if batch_idx == len (loader )- 1 : # will execute at the last batch
170
+ # will execute at every 50 steps
171
+ if (batch_idx + 1 ) % 50 == 0 :
166
172
167
173
# ~~~~~~~~~~~~ calculate average loss ~~~~~~~~~~~~~ #
168
174
169
- C_loss_avg = C_loss_avg / (CRITIC_TRAIN_STEPS * BATCH_SIZE )
170
- G_loss_avg = G_loss_avg / (BATCH_SIZE )
175
+ C_loss_avg_ = C_loss_avg / (CRITIC_TRAIN_STEPS * batch_idx )
176
+ G_loss_avg_ = G_loss_avg / (batch_idx )
171
177
172
178
print (
173
- f"Epoch [{ epoch } /{ EPOCHS } ] "
174
- + f"Loss D : { C_loss_avg :.4f} , loss G: { G_loss_avg :.4f} "
179
+ f"Epoch [{ epoch } /{ EPOCHS } ] | batch size { batch_idx } "
180
+ + f"Loss C : { C_loss_avg_ :.4f} , loss G: { G_loss_avg_ :.4f} "
175
181
)
176
182
177
183
# ~~~~~~~~~~~~ send data to tensorboard ~~~~~~~~~~~~~ #
@@ -191,20 +197,18 @@ def training(opt):
191
197
img_grid_real = torchvision .utils .make_grid (
192
198
data , normalize = True )
193
199
200
+ step = (epoch + 1 )* (batch_idx + 1 )
201
+
194
202
writer_fake .add_image (
195
- "Mnist Fake Images" , img_grid_fake , global_step = epoch
203
+ "Mnist Fake Images" , img_grid_fake , global_step = step
196
204
)
197
205
writer_real .add_image (
198
- "Mnist Real Images" , img_grid_real , global_step = epoch
206
+ "Mnist Real Images" , img_grid_real , global_step = step
199
207
)
200
208
loss_writer .add_scalar (
201
- 'Critic' , C_loss , global_step = epoch )
209
+ 'Critic' , C_loss , global_step = step )
202
210
loss_writer .add_scalar (
203
- 'generator' , G_loss , global_step = epoch )
204
-
205
- # reset the average loss to zero
206
- C_loss_avg = 0
207
- G_loss_avg = 0
211
+ 'generator' , G_loss , global_step = step )
208
212
209
213
# changing back the model to train mode
210
214
critic .train ()
0 commit comments