Skip to content

Commit 69145a6

Browse files
committed
vanilla gan added
1 parent a438f8e commit 69145a6

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

tutorials/10 - Generative Adversarial Network/main-gpu.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,22 @@ class Discriminator(nn.Module):
2727
def __init__(self):
2828
super(Discriminator, self).__init__()
2929
self.fc1 = nn.Linear(784, 256)
30-
self.fc2 = nn.Linear(256, 1)
30+
self.fc2 = nn.Linear(256, 256)
31+
self.fc3 = nn.Linear(256, 1)
3132

3233
def forward(self, x):
3334
h = F.relu(self.fc1(x))
34-
out = F.sigmoid(self.fc2(h))
35+
h = F.relu(self.fc2(h))
36+
out = F.sigmoid(self.fc3(h))
3537
return out
3638

3739
# Generator Model
3840
class Generator(nn.Module):
3941
def __init__(self):
4042
super(Generator, self).__init__()
4143
self.fc1 = nn.Linear(128, 256)
42-
self.fc2 = nn.Linear(256, 512)
43-
self.fc3 = nn.Linear(512, 784)
44+
self.fc2 = nn.Linear(256, 256)
45+
self.fc3 = nn.Linear(256, 784)
4446

4547
def forward(self, x):
4648
h = F.leaky_relu(self.fc1(x))
@@ -101,7 +103,7 @@ def forward(self, x):
101103
# Save the sampled images
102104
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
103105
torchvision.utils.save_image(fake_images.data,
104-
'./data2/fake_samples_%d.png' %epoch+1)
106+
'./data/fake_samples_%d.png' %(epoch+1))
105107

106108
# Save the Models
107109
torch.save(generator.state_dict(), './generator.pkl')

tutorials/10 - Generative Adversarial Network/main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,22 @@ class Discriminator(nn.Module):
2727
def __init__(self):
2828
super(Discriminator, self).__init__()
2929
self.fc1 = nn.Linear(784, 256)
30-
self.fc2 = nn.Linear(256, 1)
30+
self.fc2 = nn.Linear(256, 256)
31+
self.fc3 = nn.Linear(256, 1)
3132

3233
def forward(self, x):
3334
h = F.relu(self.fc1(x))
34-
out = F.sigmoid(self.fc2(h))
35+
h = F.relu(self.fc2(h))
36+
out = F.sigmoid(self.fc3(h))
3537
return out
3638

3739
# Generator Model
3840
class Generator(nn.Module):
3941
def __init__(self):
4042
super(Generator, self).__init__()
4143
self.fc1 = nn.Linear(128, 256)
42-
self.fc2 = nn.Linear(256, 512)
43-
self.fc3 = nn.Linear(512, 784)
44+
self.fc2 = nn.Linear(256, 256)
45+
self.fc3 = nn.Linear(256, 784)
4446

4547
def forward(self, x):
4648
h = F.leaky_relu(self.fc1(x))
@@ -101,7 +103,7 @@ def forward(self, x):
101103
# Save the sampled images
102104
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
103105
torchvision.utils.save_image(fake_images.data,
104-
'./data2/fake_samples_%d.png' %epoch+1)
106+
'./data/fake_samples_%d.png' %(epoch+1))
105107

106108
# Save the Models
107109
torch.save(generator.state_dict(), './generator.pkl')

0 commit comments

Comments
 (0)