11import  torch 
22import  torchvision 
33import  torch .nn  as  nn 
4+ import  torch .nn .functional  as  F 
45import  torchvision .datasets  as  dsets 
56import  torchvision .transforms  as  transforms 
67from  torch .autograd  import  Variable 
78
89# Image Preprocessing 
910transform  =  transforms .Compose ([
10-         transforms .Scale (36 ),
11-         transforms .RandomCrop (32 ),
1211        transforms .ToTensor (),
1312        transforms .Normalize (mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ))])
1413
15- # CIFAR-10  Dataset 
16- train_dataset  =  dsets .CIFAR10 (root = '../data/' ,
17-                                 train = True , 
18-                                 transform = transform ,
19-                                 download = True )
14+ # MNIST  Dataset 
15+ train_dataset  =  dsets .MNIST (root = '../data/' ,
16+                             train = True , 
17+                             transform = transform ,
18+                             download = True )
2019
2120# Data Loader (Input Pipeline) 
2221train_loader  =  torch .utils .data .DataLoader (dataset = train_dataset ,
2322                                           batch_size = 100 , 
2423                                           shuffle = True )
2524
26- # 4x4 Convolution 
27- def  conv4x4 (in_channels , out_channels , stride ):
28-     return  nn .Conv2d (in_channels , out_channels , kernel_size = 4 , 
29-                      stride = stride , padding = 1 , bias = False )
30- 
3125# Discriminator Model 
3226class  Discriminator (nn .Module ):
3327    def  __init__ (self ):
3428        super (Discriminator , self ).__init__ ()
35-         self .model  =  nn .Sequential (
36-             conv4x4 (3 , 16 , 2 ),
37-             nn .LeakyReLU (0.2 , inplace = True ),
38-             conv4x4 (16 , 32 , 2 ),
39-             nn .BatchNorm2d (32 ),
40-             nn .LeakyReLU (0.2 , inplace = True ),
41-             conv4x4 (32 , 64 , 2 ), 
42-             nn .BatchNorm2d (64 ),
43-             nn .LeakyReLU (0.2 , inplace = True ),
44-             nn .Conv2d (64 , 1 , kernel_size = 4 ),
45-             nn .Sigmoid ())
46-     
29+         self .fc1  =  nn .Linear (784 , 256 )
30+         self .fc2  =  nn .Linear (256 , 1 )
31+         
4732    def  forward (self , x ):
48-         out  =  self .model ( x )
49-         out  =  out . view ( out . size ( 0 ),  - 1 )
33+         h  =  F . relu ( self .fc1 ( x ) )
34+         out  =  F . sigmoid ( self . fc2 ( h ) )
5035        return  out 
5136
52- # 4x4 Transpose convolution 
53- def  conv_transpose4x4 (in_channels , out_channels , stride = 1 , padding = 1 , bias = False ):
54-     return  nn .ConvTranspose2d (in_channels , out_channels , kernel_size = 4 , 
55-                               stride = stride , padding = padding , bias = bias )
56- 
5737# Generator Model 
5838class  Generator (nn .Module ):
5939    def  __init__ (self ):
6040        super (Generator , self ).__init__ ()
61-         self .model  =  nn .Sequential (
62-             conv_transpose4x4 (128 , 64 , padding = 0 ),
63-             nn .BatchNorm2d (64 ),
64-             nn .ReLU (inplace = True ),
65-             conv_transpose4x4 (64 , 32 , 2 ),
66-             nn .BatchNorm2d (32 ),
67-             nn .ReLU (inplace = True ),
68-             conv_transpose4x4 (32 , 16 , 2 ),
69-             nn .BatchNorm2d (16 ),
70-             nn .ReLU (inplace = True ),
71-             conv_transpose4x4 (16 , 3 , 2 , bias = True ),
72-             nn .Tanh ())
73-     
41+         self .fc1  =  nn .Linear (128 , 256 )
42+         self .fc2  =  nn .Linear (256 , 512 )
43+         self .fc3  =  nn .Linear (512 , 784 )
44+             
7445    def  forward (self , x ):
75-         x  =  x .view (x .size (0 ), 128 , 1 , 1 )
76-         out  =  self .model (x )
46+         h  =  F .leaky_relu (self .fc1 (x ))
47+         h  =  F .leaky_relu (self .fc2 (h ))
48+         out  =  F .tanh (self .fc3 (h ))
7749        return  out 
7850
7951discriminator  =  Discriminator ()
@@ -83,13 +55,14 @@ def forward(self, x):
8355
8456# Loss and Optimizer 
8557criterion  =  nn .BCELoss ()
86- lr  =  0.0002 
87- d_optimizer  =  torch .optim .Adam (discriminator .parameters (), lr = lr )
88- g_optimizer  =  torch .optim .Adam (generator .parameters (), lr = lr )
58+ d_optimizer  =  torch .optim .Adam (discriminator .parameters (), lr = 0.0005 )
59+ g_optimizer  =  torch .optim .Adam (generator .parameters (), lr = 0.0005 )
8960
9061# Training  
91- for  epoch  in  range (50 ):
62+ for  epoch  in  range (200 ):
9263    for  i , (images , _ ) in  enumerate (train_loader ):
64+         # Build mini-batch dataset 
65+         images  =  images .view (images .size (0 ), - 1 )
9366        images  =  Variable (images )
9467        real_labels  =  Variable (torch .ones (images .size (0 )))
9568        fake_labels  =  Variable (torch .zeros (images .size (0 )))
@@ -119,16 +92,17 @@ def forward(self, x):
11992        g_loss .backward ()
12093        g_optimizer .step ()
12194
122-         if  (i + 1 ) %  100  ==  0 :
95+         if  (i + 1 ) %  300  ==  0 :
12396            print ('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '  
12497                  'D(x): %.2f, D(G(z)): %.2f'  
125-                   % (epoch , 50 , i + 1 , 500 , d_loss .data [0 ], g_loss .data [0 ],
126-                     real_score .data .mean (), fake_score .data .mean ()))
98+                   % (epoch , 200 , i + 1 , 600 , d_loss .data [0 ], g_loss .data [0 ],
99+                     real_score .data .mean (), fake_score .cpu (). data .mean ()))
127100
128-             # Save the sampled images 
129-             torchvision .utils .save_image (fake_images .data , 
130-                 './data/fake_samples_%d_%d.png'  % (epoch + 1 , i + 1 ))
101+     # Save the sampled images 
102+     fake_images  =  fake_images .view (fake_images .size (0 ), 1 , 28 , 28 )
103+     torchvision .utils .save_image (fake_images .data , 
104+         './data2/fake_samples_%d.png'  % epoch + 1 )
131105
132- # Save the Models 
106+ # Save the Models   
133107torch .save (generator .state_dict (), './generator.pkl' )
134108torch .save (discriminator .state_dict (), './discriminator.pkl' )
0 commit comments