@@ -125,7 +125,7 @@ def training(opt):
125
125
# make it one dimensional array
126
126
real_predict = critic (real ).view (- 1 )
127
127
# make it one dimensional array
128
- fake_predict = critic (fake ).view (- 1 )
128
+ fake_predict = critic (fake . detach () ).view (- 1 )
129
129
130
130
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
131
131
@@ -135,7 +135,7 @@ def training(opt):
135
135
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
136
136
137
137
critic .zero_grad ()
138
- C_loss .backward (retain_graph = True )
138
+ C_loss .backward ()
139
139
critic_optim .step ()
140
140
141
141
# ~~~~~~~~~~~ weight cliping as per WGAN paper ~~~~~~~~~~ #
@@ -170,7 +170,7 @@ def training(opt):
170
170
G_loss_avg = G_loss_avg / (BATCH_SIZE )
171
171
172
172
print (
173
- f"Epoch [{ epoch } /{ EPOCHS } ] Batch { batch_idx } / { len ( loader ) } "
173
+ f"Epoch [{ epoch } /{ EPOCHS } ] "
174
174
+ f"Loss D: { C_loss_avg :.4f} , loss G: { G_loss_avg :.4f} "
175
175
)
176
176
@@ -179,11 +179,13 @@ def training(opt):
179
179
with torch .no_grad ():
180
180
critic .eval ()
181
181
gen .eval ()
182
- fake = gen (fixed_noise ).reshape (- 1 , CHANNELS , H , W )
183
- data = real .reshape (- 1 , CHANNELS , H , W )
184
182
if BATCH_SIZE > 32 :
185
- fake = fake [:32 ]
186
- data = data [:32 ]
183
+ fake = gen (fixed_noise [32 ]).reshape (- 1 , CHANNELS , H , W )
184
+ data = real [32 ].reshape (- 1 , CHANNELS , H , W )
185
+ else :
186
+ fake = gen (fixed_noise ).reshape (- 1 , CHANNELS , H , W )
187
+ data = real .reshape (- 1 , CHANNELS , H , W )
188
+
187
189
img_grid_fake = torchvision .utils .make_grid (
188
190
fake , normalize = True )
189
191
img_grid_real = torchvision .utils .make_grid (
0 commit comments