Skip to content

Commit 318cf63

Browse files
code optimization
1 parent 97952e4 commit 318cf63

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

GAN/WGAN/training.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ def training(opt):
109109

110110
# ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ #
111111

112+
fake = gen(fixed_noise) # dim of (N,1,W,H)
113+
112114
for _ in range(CRITIC_TRAIN_STEPS):
113-
fake = gen(fixed_noise) # dim of (N,1,W,H)
114115

115116
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
116117

@@ -126,7 +127,7 @@ def training(opt):
126127
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
127128

128129
critic.zero_grad()
129-
C_loss.backward()
130+
C_loss.backward(retain_graph=True)
130131
critic_optim.step()
131132

132133
# ~~~~~~~~~~~ weight cliping as per WGAN paper ~~~~~~~~~~ #
@@ -136,12 +137,9 @@ def training(opt):
136137

137138
# ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ #
138139

139-
fake = gen(fixed_noise)
140-
141140
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
142141

143-
# make it one dimensional array
144-
fake_predict = critic(fake).view(-1)
142+
# re using the fake_predict from cirtic forward
145143

146144
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
147145

0 commit comments

Comments
 (0)