Skip to content

Commit 6d0c3c8

Browse files
committed
add denormalization function
1 parent 47d70f9 commit 6d0c3c8

File tree

4 files changed

+29
-17
lines changed

4 files changed

+29
-17
lines changed

tutorials/10 - Generative Adversarial Network/main-gpu.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
transforms.ToTensor(),
1212
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1313

14+
def denorm(x):
15+
return (x + 1) / 2
16+
1417
# MNIST Dataset
15-
train_dataset = dsets.MNIST(root='../data/',
18+
train_dataset = dsets.MNIST(root='./data/',
1619
train=True,
1720
transform=transform,
1821
download=True)
@@ -66,16 +69,16 @@ def forward(self, x):
6669
# Build mini-batch dataset
6770
images = images.view(images.size(0), -1)
6871
images = Variable(images.cuda())
69-
real_labels = Variable(torch.ones(images.size(0))).cuda()
70-
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
72+
real_labels = Variable(torch.ones(images.size(0)).cuda())
73+
fake_labels = Variable(torch.zeros(images.size(0)).cuda())
7174

7275
# Train the discriminator
7376
discriminator.zero_grad()
7477
outputs = discriminator(images)
7578
real_loss = criterion(outputs, real_labels)
7679
real_score = outputs
7780

78-
noise = Variable(torch.randn(images.size(0), 128)).cuda()
81+
noise = Variable(torch.randn(images.size(0), 128).cuda())
7982
fake_images = generator(noise)
8083
outputs = discriminator(fake_images.detach())
8184
fake_loss = criterion(outputs, fake_labels)
@@ -87,7 +90,7 @@ def forward(self, x):
8790

8891
# Train the generator
8992
generator.zero_grad()
90-
noise = Variable(torch.randn(images.size(0), 128)).cuda()
93+
noise = Variable(torch.randn(images.size(0), 128).cuda())
9194
fake_images = generator(noise)
9295
outputs = discriminator(fake_images)
9396
g_loss = criterion(outputs, real_labels)
@@ -98,13 +101,13 @@ def forward(self, x):
98101
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
99102
'D(x): %.2f, D(G(z)): %.2f'
100103
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
101-
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
104+
real_score.data.mean(), fake_score.cpu().data.mean()))
102105

103106
# Save the sampled images
104107
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
105-
torchvision.utils.save_image(fake_images.data,
108+
torchvision.utils.save_image(denorm(fake_images.data),
106109
'./data/fake_samples_%d.png' %(epoch+1))
107110

108111
# Save the Models
109112
torch.save(generator.state_dict(), './generator.pkl')
110-
torch.save(discriminator.state_dict(), './discriminator.pkl')
113+
torch.save(discriminator.state_dict(), './discriminator.pkl')

tutorials/10 - Generative Adversarial Network/main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
transforms.ToTensor(),
1212
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1313

14+
def denorm(x):
15+
return (x + 1) / 2
16+
1417
# MNIST Dataset
15-
train_dataset = dsets.MNIST(root='../data/',
18+
train_dataset = dsets.MNIST(root='./data/',
1619
train=True,
1720
transform=transform,
1821
download=True)
@@ -102,9 +105,9 @@ def forward(self, x):
102105

103106
# Save the sampled images
104107
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
105-
torchvision.utils.save_image(fake_images.data,
108+
torchvision.utils.save_image(denorm(fake_images.data),
106109
'./data/fake_samples_%d.png' %(epoch+1))
107110

108111
# Save the Models
109112
torch.save(generator.state_dict(), './generator.pkl')
110-
torch.save(discriminator.state_dict(), './discriminator.pkl')
113+
torch.save(discriminator.state_dict(), './discriminator.pkl')

tutorials/11 - Deep Convolutional Generative Adversarial Network/main-gpu.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
transforms.ToTensor(),
1313
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1414

15+
def denorm(x):
16+
return (x + 1) / 2
17+
1518
# CIFAR-10 Dataset
16-
train_dataset = dsets.CIFAR10(root='../data/',
19+
train_dataset = dsets.CIFAR10(root='./data/',
1720
train=True,
1821
transform=transform,
1922
download=True)
@@ -126,9 +129,9 @@ def forward(self, x):
126129
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
127130

128131
# Save the sampled images
129-
torchvision.utils.save_image(fake_images.data,
132+
torchvision.utils.save_image(denorm(fake_images.data),
130133
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
131134

132135
# Save the Models
133136
torch.save(generator.state_dict(), './generator.pkl')
134-
torch.save(discriminator.state_dict(), './discriminator.pkl')
137+
torch.save(discriminator.state_dict(), './discriminator.pkl')

tutorials/11 - Deep Convolutional Generative Adversarial Network/main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
transforms.ToTensor(),
1313
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1414

15+
def denorm(x):
16+
return (x + 1) / 2
17+
1518
# CIFAR-10 Dataset
16-
train_dataset = dsets.CIFAR10(root='../data/',
19+
train_dataset = dsets.CIFAR10(root='./data/',
1720
train=True,
1821
transform=transform,
1922
download=True)
@@ -126,9 +129,9 @@ def forward(self, x):
126129
real_score.data.mean(), fake_score.data.mean()))
127130

128131
# Save the sampled images
129-
torchvision.utils.save_image(fake_images.data,
132+
torchvision.utils.save_image(denorm(fake_images.data),
130133
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
131134

132135
# Save the Models
133136
torch.save(generator.state_dict(), './generator.pkl')
134-
torch.save(discriminator.state_dict(), './discriminator.pkl')
137+
torch.save(discriminator.state_dict(), './discriminator.pkl')

0 commit comments

Comments
 (0)