Skip to content

Commit 01fb633

Browse files
code optimization
1 parent dfe2df4 commit 01fb633

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

GAN/WGAN/training.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def training(opt):
125125
# make it one dimensional array
126126
real_predict = critic(real).view(-1)
127127
# make it one dimensional array
128-
fake_predict = critic(fake).view(-1)
128+
fake_predict = critic(fake.detach()).view(-1)
129129

130130
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
131131

@@ -135,7 +135,7 @@ def training(opt):
135135
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
136136

137137
critic.zero_grad()
138-
C_loss.backward(retain_graph=True)
138+
C_loss.backward()
139139
critic_optim.step()
140140

141141
# ~~~~~~~~~~~ weight cliping as per WGAN paper ~~~~~~~~~~ #
@@ -170,7 +170,7 @@ def training(opt):
170170
G_loss_avg = G_loss_avg/(BATCH_SIZE)
171171

172172
print(
173-
f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)}"
173+
f"Epoch [{epoch}/{EPOCHS}] "
174174
+ f"Loss D: {C_loss_avg:.4f}, loss G: {G_loss_avg:.4f}"
175175
)
176176

@@ -179,11 +179,13 @@ def training(opt):
179179
with torch.no_grad():
180180
critic.eval()
181181
gen.eval()
182-
fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W)
183-
data = real.reshape(-1, CHANNELS, H, W)
184182
if BATCH_SIZE > 32:
185-
fake = fake[:32]
186-
data = data[:32]
183+
fake = gen(fixed_noise[32]).reshape(-1, CHANNELS, H, W)
184+
data = real[32].reshape(-1, CHANNELS, H, W)
185+
else:
186+
fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W)
187+
data = real.reshape(-1, CHANNELS, H, W)
188+
187189
img_grid_fake = torchvision.utils.make_grid(
188190
fake, normalize=True)
189191
img_grid_real = torchvision.utils.make_grid(

0 commit comments

Comments
 (0)