@@ -64,22 +64,23 @@ def denorm(x):
6464 # Build mini-batch dataset
6565 batch_size = images .size (0 )
6666 images = to_var (images .view (batch_size , - 1 ))
67+
6768 # Create the labels which are later used as input for the BCE loss
6869 real_labels = to_var (torch .ones (batch_size ))
6970 fake_labels = to_var (torch .zeros (batch_size ))
7071
7172 #============= Train the discriminator =============#
72- # Compute loss with real images
73+ # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
74+ # Second term of the loss is always zero since real_labels == 1
7375 outputs = D (images )
74- # Apply BCE loss. Second term is always zero since real_labels == 1
7576 d_loss_real = criterion (outputs , real_labels )
7677 real_score = outputs
7778
78- # Compute loss with fake images
79+ # Compute BCELoss using fake images
80+ # First term of the loss is always zero since fake_labels == 0
7981 z = to_var (torch .randn (batch_size , 64 ))
8082 fake_images = G (z )
8183 outputs = D (fake_images )
82- # Apply BCE loss. First term is always zero since fake_labels == 0
8384 d_loss_fake = criterion (outputs , fake_labels )
8485 fake_score = outputs
8586
@@ -94,11 +95,9 @@ def denorm(x):
9495 z = to_var (torch .randn (batch_size , 64 ))
9596 fake_images = G (z )
9697 outputs = D (fake_images )
97- # remember that min log(1-D(G(z))) has the same fix point as max log(D(G(z)))
98- # Here we maximize log(D(G(z))), which is exactly the first term in the BCE loss
99- # with t=1. (see definition of BCE for info on t)
100- # t==1 is valid for real_labels, thus we use them as input for the BCE loss.
101- # Don't get yourself confused by this. It is just convenient to use to the BCE loss.
98+
99+ # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
100+ # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
102101 g_loss = criterion (outputs , real_labels )
103102
104103 # Backprop + Optimize
0 commit comments