11import torch
22import torchvision
33import torch .nn as nn
4+ import torch .nn .functional as F
45import torchvision .datasets as dsets
56import torchvision .transforms as transforms
67from torch .autograd import Variable
78
89# Image Preprocessing
910transform = transforms .Compose ([
10- transforms .Scale (36 ),
11- transforms .RandomCrop (32 ),
1211 transforms .ToTensor (),
1312 transforms .Normalize (mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ))])
1413
15- # CIFAR-10 Dataset
16- train_dataset = dsets .CIFAR10 (root = '../data/' ,
17- train = True ,
18- transform = transform ,
19- download = True )
14+ # MNIST Dataset
15+ train_dataset = dsets .MNIST (root = '../data/' ,
16+ train = True ,
17+ transform = transform ,
18+ download = True )
2019
2120# Data Loader (Input Pipeline)
2221train_loader = torch .utils .data .DataLoader (dataset = train_dataset ,
2322 batch_size = 100 ,
2423 shuffle = True )
2524
26- # 4x4 Convolution
27- def conv4x4 (in_channels , out_channels , stride ):
28- return nn .Conv2d (in_channels , out_channels , kernel_size = 4 ,
29- stride = stride , padding = 1 , bias = False )
30-
3125# Discriminator Model
3226class Discriminator (nn .Module ):
3327 def __init__ (self ):
3428 super (Discriminator , self ).__init__ ()
35- self .model = nn .Sequential (
36- conv4x4 (3 , 16 , 2 ),
37- nn .LeakyReLU (0.2 , inplace = True ),
38- conv4x4 (16 , 32 , 2 ),
39- nn .BatchNorm2d (32 ),
40- nn .LeakyReLU (0.2 , inplace = True ),
41- conv4x4 (32 , 64 , 2 ),
42- nn .BatchNorm2d (64 ),
43- nn .LeakyReLU (0.2 , inplace = True ),
44- nn .Conv2d (64 , 1 , kernel_size = 4 ),
45- nn .Sigmoid ())
46-
29+ self .fc1 = nn .Linear (784 , 256 )
30+ self .fc2 = nn .Linear (256 , 1 )
31+
4732 def forward (self , x ):
48- out = self .model ( x )
49- out = out . view ( out . size ( 0 ), - 1 )
33+ h = F . relu ( self .fc1 ( x ) )
34+ out = F . sigmoid ( self . fc2 ( h ) )
5035 return out
5136
52- # 4x4 Transpose convolution
53- def conv_transpose4x4 (in_channels , out_channels , stride = 1 , padding = 1 , bias = False ):
54- return nn .ConvTranspose2d (in_channels , out_channels , kernel_size = 4 ,
55- stride = stride , padding = padding , bias = bias )
56-
5737# Generator Model
5838class Generator (nn .Module ):
5939 def __init__ (self ):
6040 super (Generator , self ).__init__ ()
61- self .model = nn .Sequential (
62- conv_transpose4x4 (128 , 64 , padding = 0 ),
63- nn .BatchNorm2d (64 ),
64- nn .ReLU (inplace = True ),
65- conv_transpose4x4 (64 , 32 , 2 ),
66- nn .BatchNorm2d (32 ),
67- nn .ReLU (inplace = True ),
68- conv_transpose4x4 (32 , 16 , 2 ),
69- nn .BatchNorm2d (16 ),
70- nn .ReLU (inplace = True ),
71- conv_transpose4x4 (16 , 3 , 2 , bias = True ),
72- nn .Tanh ())
73-
41+ self .fc1 = nn .Linear (128 , 256 )
42+ self .fc2 = nn .Linear (256 , 512 )
43+ self .fc3 = nn .Linear (512 , 784 )
44+
7445 def forward (self , x ):
75- x = x .view (x .size (0 ), 128 , 1 , 1 )
76- out = self .model (x )
46+ h = F .leaky_relu (self .fc1 (x ))
47+ h = F .leaky_relu (self .fc2 (h ))
48+ out = F .tanh (self .fc3 (h ))
7749 return out
7850
7951discriminator = Discriminator ()
@@ -83,13 +55,14 @@ def forward(self, x):
8355
8456# Loss and Optimizer
8557criterion = nn .BCELoss ()
86- lr = 0.0002
87- d_optimizer = torch .optim .Adam (discriminator .parameters (), lr = lr )
88- g_optimizer = torch .optim .Adam (generator .parameters (), lr = lr )
58+ d_optimizer = torch .optim .Adam (discriminator .parameters (), lr = 0.0005 )
59+ g_optimizer = torch .optim .Adam (generator .parameters (), lr = 0.0005 )
8960
9061# Training
91- for epoch in range (50 ):
62+ for epoch in range (200 ):
9263 for i , (images , _ ) in enumerate (train_loader ):
64+ # Build mini-batch dataset
65+ images = images .view (images .size (0 ), - 1 )
9366 images = Variable (images )
9467 real_labels = Variable (torch .ones (images .size (0 )))
9568 fake_labels = Variable (torch .zeros (images .size (0 )))
@@ -119,16 +92,17 @@ def forward(self, x):
11992 g_loss .backward ()
12093 g_optimizer .step ()
12194
122- if (i + 1 ) % 100 == 0 :
95+ if (i + 1 ) % 300 == 0 :
12396 print ('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
12497 'D(x): %.2f, D(G(z)): %.2f'
125- % (epoch , 50 , i + 1 , 500 , d_loss .data [0 ], g_loss .data [0 ],
126- real_score .data .mean (), fake_score .data .mean ()))
98+ % (epoch , 200 , i + 1 , 600 , d_loss .data [0 ], g_loss .data [0 ],
99+ real_score .data .mean (), fake_score .cpu (). data .mean ()))
127100
128- # Save the sampled images
129- torchvision .utils .save_image (fake_images .data ,
130- './data/fake_samples_%d_%d.png' % (epoch + 1 , i + 1 ))
101+ # Save the sampled images
102+ fake_images = fake_images .view (fake_images .size (0 ), 1 , 28 , 28 )
103+ torchvision .utils .save_image (fake_images .data ,
104+ './data2/fake_samples_%d.png' % epoch + 1 )
131105
132- # Save the Models
106+ # Save the Models
133107torch .save (generator .state_dict (), './generator.pkl' )
134108torch .save (discriminator .state_dict (), './discriminator.pkl' )
0 commit comments