Skip to content

Commit be4c08b

Browse files
authored
Update main.py
1 parent 8a82438 commit be4c08b

File tree

1 file changed

+8
-9
lines changed
  • tutorials/02-intermediate/generative_adversarial_network

1 file changed

+8
-9
lines changed

tutorials/02-intermediate/generative_adversarial_network/main.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,23 @@ def denorm(x):
6464
# Build mini-batch dataset
6565
batch_size = images.size(0)
6666
images = to_var(images.view(batch_size, -1))
67+
6768
# Create the labels which are later used as input for the BCE loss
6869
real_labels = to_var(torch.ones(batch_size))
6970
fake_labels = to_var(torch.zeros(batch_size))
7071

7172
#============= Train the discriminator =============#
72-
# Compute loss with real images
73+
# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
74+
# Second term of the loss is always zero since real_labels == 1
7375
outputs = D(images)
74-
# Apply BCE loss. Second term is always zero since real_labels == 1
7576
d_loss_real = criterion(outputs, real_labels)
7677
real_score = outputs
7778

78-
# Compute loss with fake images
79+
# Compute BCELoss using fake images
80+
# First term of the loss is always zero since fake_labels == 0
7981
z = to_var(torch.randn(batch_size, 64))
8082
fake_images = G(z)
8183
outputs = D(fake_images)
82-
# Apply BCE loss. First term is always zero since fake_labels == 0
8384
d_loss_fake = criterion(outputs, fake_labels)
8485
fake_score = outputs
8586

@@ -94,11 +95,9 @@ def denorm(x):
9495
z = to_var(torch.randn(batch_size, 64))
9596
fake_images = G(z)
9697
outputs = D(fake_images)
97-
# remember that min log(1-D(G(z))) has the same fix point as max log(D(G(z)))
98-
# Here we maximize log(D(G(z))), which is exactly the first term in the BCE loss
99-
# with t=1. (see definition of BCE for info on t)
100-
# t==1 is valid for real_labels, thus we use them as input for the BCE loss.
101-
# Don't get yourself confused by this. It is just convenient to use to the BCE loss.
98+
99+
# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
100+
# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
102101
g_loss = criterion(outputs, real_labels)
103102

104103
# Backprop + Optimize

0 commit comments

Comments
 (0)