@@ -25,14 +25,16 @@ def training(opt):
25
25
FEATURE_D = 128
26
26
Z_DIM = 100
27
27
BATCH_SIZE = opt .batch_size
28
+
28
29
# ~~~~~~~~~~~~~~~~~~~ as per WGAN paper ~~~~~~~~~~~~~~~~~~~ #
29
30
30
31
lr = opt .lr
31
32
CRITIC_TRAIN_STEPS = 5
32
33
WEIGHT_CLIP = 0.01
33
34
34
- print (f"Epochs: { EPOCHS } | lr: { lr } | batch size { BATCH_SIZE } " +
35
- f"device: { work_device } " )
35
+ print (f"Epochs: { EPOCHS } | lr: { lr } | batch size { BATCH_SIZE } |" +
36
+ f" device: { work_device } " )
37
+
36
38
# ~~~~~~~~~~~ creating directories for weights ~~~~~~~~~~~ #
37
39
38
40
if opt .logs :
@@ -88,26 +90,31 @@ def training(opt):
88
90
89
91
# ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ #
90
92
93
+ # loss variables
91
94
C_loss_prev = math .inf
92
95
G_loss_prev = math .inf
93
96
C_loss = 0
94
97
G_loss = 0
98
+ C_loss_avg = 0
99
+ G_loss_avg = 0
95
100
96
101
print_gpu_details ()
102
+
103
+ # setting the models to train mode
104
+ critic .train ()
105
+ gen .train ()
106
+
97
107
for epoch in range (EPOCHS ):
98
- C_loss_avg = 0
99
- G_loss_avg = 0
100
108
101
109
print_memory_utilization ()
102
110
103
111
for batch_idx , (real , _ ) in enumerate (tqdm (loader )):
104
- critic .train ()
105
- gen .train ()
112
+
106
113
real = real .to (work_device )
107
114
fixed_noise = torch .rand (
108
115
real .shape [0 ], Z_DIM , 1 , 1 ).to (work_device )
109
116
110
- # ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ #
117
+ # ~~~~~~~~~~~~~~~~~~~ critic loop ~~~~~~~~~~~~~~~~~~~ #
111
118
112
119
fake = gen (fixed_noise ) # dim of (N,1,W,H)
113
120
@@ -124,6 +131,7 @@ def training(opt):
124
131
125
132
C_loss = - (torch .mean (real_predict ) - torch .mean (fake_predict ))
126
133
C_loss_avg += C_loss
134
+
127
135
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
128
136
129
137
critic .zero_grad ()
@@ -145,6 +153,7 @@ def training(opt):
145
153
146
154
G_loss = - (torch .mean (fake_predict ))
147
155
G_loss_avg += G_loss
156
+
148
157
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
149
158
150
159
gen .zero_grad ()
@@ -153,7 +162,10 @@ def training(opt):
153
162
154
163
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
155
164
156
- if batch_idx == 0 and epoch >= 1 :
165
+ if batch_idx == len (loader )- 1 : # will execute at the last batch
166
+
167
+ # ~~~~~~~~~~~~ calculate average loss ~~~~~~~~~~~~~ #
168
+
157
169
C_loss_avg = C_loss_avg / (CRITIC_TRAIN_STEPS * BATCH_SIZE )
158
170
G_loss_avg = G_loss_avg / (BATCH_SIZE )
159
171
@@ -162,6 +174,8 @@ def training(opt):
162
174
+ f"Loss D: { C_loss_avg :.4f} , loss G: { G_loss_avg :.4f} "
163
175
)
164
176
177
+ # ~~~~~~~~~~~~ send data to tensorboard ~~~~~~~~~~~~~ #
178
+
165
179
with torch .no_grad ():
166
180
critic .eval ()
167
181
gen .eval ()
@@ -186,6 +200,14 @@ def training(opt):
186
200
loss_writer .add_scalar (
187
201
'generator' , G_loss , global_step = epoch )
188
202
203
+ # reset the average loss to zero
204
+ C_loss_avg = 0
205
+ G_loss_avg = 0
206
+
207
+ # changing back the model to train mode
208
+ critic .train ()
209
+ gen .train ()
210
+
189
211
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
190
212
191
213
if opt .weights :
0 commit comments