@@ -64,16 +64,20 @@ def denorm(x):
64
64
# Build mini-batch dataset
65
65
batch_size = images .size (0 )
66
66
images = to_var (images .view (batch_size , - 1 ))
67
+
68
+ # Create the labels which are later used as input for the BCE loss
67
69
real_labels = to_var (torch .ones (batch_size ))
68
70
fake_labels = to_var (torch .zeros (batch_size ))
69
71
70
72
#============= Train the discriminator =============#
71
- # Compute loss with real images
73
+ # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
74
+ # Second term of the loss is always zero since real_labels == 1
72
75
outputs = D (images )
73
76
d_loss_real = criterion (outputs , real_labels )
74
77
real_score = outputs
75
78
76
- # Compute loss with fake images
79
+ # Compute BCELoss using fake images
80
+ # First term of the loss is always zero since fake_labels == 0
77
81
z = to_var (torch .randn (batch_size , 64 ))
78
82
fake_images = G (z )
79
83
outputs = D (fake_images )
@@ -91,6 +95,9 @@ def denorm(x):
91
95
z = to_var (torch .randn (batch_size , 64 ))
92
96
fake_images = G (z )
93
97
outputs = D (fake_images )
98
+
99
+ # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
100
+ # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
94
101
g_loss = criterion (outputs , real_labels )
95
102
96
103
# Backprop + Optimize
@@ -116,4 +123,4 @@ def denorm(x):
116
123
117
124
# Save the trained parameters
118
125
torch .save (G .state_dict (), './generator.pkl' )
119
- torch .save (D .state_dict (), './discriminator.pkl' )
126
+ torch .save (D .state_dict (), './discriminator.pkl' )
0 commit comments