Skip to content

Commit 8a82438

Browse files
author
PhysikerErlangen
authored
Update main.py
I got confused by the use of the binary cross entropy. In particular it wasn't clear to me why the variable real_labels are used in the training of the generator. I have added some comments. I am not sure if they are correct, so you might want to double check them.
1 parent 6f255de commit 8a82438

File tree

1 file changed

+9
-1
lines changed
  • tutorials/02-intermediate/generative_adversarial_network

1 file changed

+9
-1
lines changed

tutorials/02-intermediate/generative_adversarial_network/main.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,22 @@ def denorm(x):
6464
# Build mini-batch dataset
6565
batch_size = images.size(0)
6666
images = to_var(images.view(batch_size, -1))
67+
# Create the labels which are later used as input for the BCE loss
6768
real_labels = to_var(torch.ones(batch_size))
6869
fake_labels = to_var(torch.zeros(batch_size))
6970

7071
#============= Train the discriminator =============#
7172
# Compute loss with real images
7273
outputs = D(images)
74+
# Apply BCE loss. Second term is always zero since real_labels == 1
7375
d_loss_real = criterion(outputs, real_labels)
7476
real_score = outputs
7577

7678
# Compute loss with fake images
7779
z = to_var(torch.randn(batch_size, 64))
7880
fake_images = G(z)
7981
outputs = D(fake_images)
82+
# Apply BCE loss. First term is always zero since fake_labels == 0
8083
d_loss_fake = criterion(outputs, fake_labels)
8184
fake_score = outputs
8285

@@ -91,6 +94,11 @@ def denorm(x):
9194
z = to_var(torch.randn(batch_size, 64))
9295
fake_images = G(z)
9396
outputs = D(fake_images)
97+
# remember that min log(1-D(G(z))) has the same fix point as max log(D(G(z)))
98+
# Here we maximize log(D(G(z))), which is exactly the first term in the BCE loss
99+
# with t=1. (see definition of BCE for info on t)
100+
# t==1 is valid for real_labels, thus we use them as input for the BCE loss.
101+
# Don't get yourself confused by this. It is just convenient to use to the BCE loss.
94102
g_loss = criterion(outputs, real_labels)
95103

96104
# Backprop + Optimize
@@ -116,4 +124,4 @@ def denorm(x):
116124

117125
# Save the trained parameters
118126
torch.save(G.state_dict(), './generator.pkl')
119-
torch.save(D.state_dict(), './discriminator.pkl')
127+
torch.save(D.state_dict(), './discriminator.pkl')

0 commit comments

Comments
 (0)