@@ -64,16 +64,20 @@ def denorm(x):
6464 # Build mini-batch dataset
6565 batch_size = images .size (0 )
6666 images = to_var (images .view (batch_size , - 1 ))
67+
68+ # Create the labels which are later used as input for the BCE loss
6769 real_labels = to_var (torch .ones (batch_size ))
6870 fake_labels = to_var (torch .zeros (batch_size ))
6971
7072 #============= Train the discriminator =============#
71- # Compute loss with real images
73+ # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
74+ # Second term of the loss is always zero since real_labels == 1
7275 outputs = D (images )
7376 d_loss_real = criterion (outputs , real_labels )
7477 real_score = outputs
7578
76- # Compute loss with fake images
79+ # Compute BCELoss using fake images
80+ # First term of the loss is always zero since fake_labels == 0
7781 z = to_var (torch .randn (batch_size , 64 ))
7882 fake_images = G (z )
7983 outputs = D (fake_images )
@@ -91,6 +95,9 @@ def denorm(x):
9195 z = to_var (torch .randn (batch_size , 64 ))
9296 fake_images = G (z )
9397 outputs = D (fake_images )
98+
99+ # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
100+ # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
94101 g_loss = criterion (outputs , real_labels )
95102
96103 # Backprop + Optimize
@@ -116,4 +123,4 @@ def denorm(x):
116123
117124# Save the trained parameters
118125torch .save (G .state_dict (), './generator.pkl' )
119- torch .save (D .state_dict (), './discriminator.pkl' )
126+ torch .save (D .state_dict (), './discriminator.pkl' )
0 commit comments