1+ import  torch 
2+ import  torchvision 
3+ import  torch .nn  as  nn 
4+ import  torchvision .datasets  as  dsets 
5+ import  torchvision .transforms  as  transforms 
6+ from  torch .autograd  import  Variable 
7+ 
8+ # Image Preprocessing 
9+ transform  =  transforms .Compose ([
10+         transforms .Scale (36 ),
11+         transforms .RandomCrop (32 ),
12+         transforms .ToTensor (),
13+         transforms .Normalize (mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ))])
14+ 
15+ # CIFAR-10 Dataset 
16+ train_dataset  =  dsets .CIFAR10 (root = '../data/' ,
17+                                train = True , 
18+                                transform = transform ,
19+                                download = True )
20+ 
21+ # Data Loader (Input Pipeline) 
22+ train_loader  =  torch .utils .data .DataLoader (dataset = train_dataset ,
23+                                            batch_size = 100 , 
24+                                            shuffle = True )
25+ 
26+ # 4x4 Convolution 
27+ def  conv4x4 (in_channels , out_channels , stride ):
28+     return  nn .Conv2d (in_channels , out_channels , kernel_size = 4 , 
29+                      stride = stride , padding = 1 , bias = False )
30+ 
31+ # Discriminator Model 
32+ class  Discriminator (nn .Module ):
33+     def  __init__ (self ):
34+         super (Discriminator , self ).__init__ ()
35+         self .model  =  nn .Sequential (
36+             conv4x4 (3 , 16 , 2 ),
37+             nn .LeakyReLU (0.2 , inplace = True ),
38+             conv4x4 (16 , 32 , 2 ),
39+             nn .BatchNorm2d (32 ),
40+             nn .LeakyReLU (0.2 , inplace = True ),
41+             conv4x4 (32 , 64 , 2 ), 
42+             nn .BatchNorm2d (64 ),
43+             nn .LeakyReLU (0.2 , inplace = True ),
44+             nn .Conv2d (64 , 1 , kernel_size = 4 ),
45+             nn .Sigmoid ())
46+     
47+     def  forward (self , x ):
48+         out  =  self .model (x )
49+         out  =  out .view (out .size (0 ), - 1 )
50+         return  out 
51+ 
52+ # 4x4 Transpose convolution 
53+ def  conv_transpose4x4 (in_channels , out_channels , stride = 1 , padding = 1 , bias = False ):
54+     return  nn .ConvTranspose2d (in_channels , out_channels , kernel_size = 4 , 
55+                               stride = stride , padding = padding , bias = bias )
56+ 
57+ # Generator Model 
58+ class  Generator (nn .Module ):
59+     def  __init__ (self ):
60+         super (Generator , self ).__init__ ()
61+         self .model  =  nn .Sequential (
62+             conv_transpose4x4 (128 , 64 , padding = 0 ),
63+             nn .BatchNorm2d (64 ),
64+             nn .ReLU (inplace = True ),
65+             conv_transpose4x4 (64 , 32 , 2 ),
66+             nn .BatchNorm2d (32 ),
67+             nn .ReLU (inplace = True ),
68+             conv_transpose4x4 (32 , 16 , 2 ),
69+             nn .BatchNorm2d (16 ),
70+             nn .ReLU (inplace = True ),
71+             conv_transpose4x4 (16 , 3 , 2 , bias = True ),
72+             nn .Tanh ())
73+     
74+     def  forward (self , x ):
75+         x  =  x .view (x .size (0 ), 128 , 1 , 1 )
76+         out  =  self .model (x )
77+         return  out 
78+ 
79+ discriminator  =  Discriminator ()
80+ generator  =  Generator ()
81+ discriminator .cuda ()
82+ generator .cuda ()
83+ 
84+ # Loss and Optimizer 
85+ criterion  =  nn .BCELoss ()
86+ lr  =  0.002 
87+ d_optimizer  =  torch .optim .Adam (discriminator .parameters (), lr = lr )
88+ g_optimizer  =  torch .optim .Adam (generator .parameters (), lr = lr )
89+ 
90+ # Training  
91+ for  epoch  in  range (50 ):
92+     for  i , (images , _ ) in  enumerate (train_loader ):
93+         images  =  Variable (images .cuda ())
94+         real_labels  =  Variable (torch .ones (images .size (0 ))).cuda ()
95+         fake_labels  =  Variable (torch .zeros (images .size (0 ))).cuda ()
96+         
97+         # Train the discriminator 
98+         discriminator .zero_grad ()
99+         outputs  =  discriminator (images )
100+         real_loss  =  criterion (outputs , real_labels )
101+         real_score  =  outputs 
102+         
103+         noise  =  Variable (torch .randn (images .size (0 ), 128 )).cuda ()
104+         fake_images  =  generator (noise )
105+         outputs  =  discriminator (fake_images ) 
106+         fake_loss  =  criterion (outputs , fake_labels )
107+         fake_score  =  outputs 
108+         
109+         d_loss  =  real_loss  +  fake_loss 
110+         d_loss .backward ()
111+         d_optimizer .step ()
112+         
113+         # Train the generator  
114+         generator .zero_grad ()
115+         noise  =  Variable (torch .randn (images .size (0 ), 128 )).cuda ()
116+         fake_images  =  generator (noise )
117+         outputs  =  discriminator (fake_images )
118+         g_loss  =  criterion (outputs , real_labels )
119+         g_loss .backward ()
120+         g_optimizer .step ()
121+         
122+         if  (i + 1 ) %  100  ==  0 :
123+             print ('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '  
124+                   'D(x): %.2f, D(G(z)): %.2f'  
125+                   % (epoch , 50 , i + 1 , 500 , d_loss .data [0 ], g_loss .data [0 ],
126+                     real_score .cpu ().data .mean (), fake_score .cpu ().data .mean ()))
127+             
128+             # Save the sampled images 
129+             torchvision .utils .save_image (fake_images .data , 
130+                 './data/fake_samples_%d_%d.png'  % (epoch + 1 , i + 1 ))
131+ 
132+ # Save the Models  
133+ torch .save (generator .state_dict (), './generator.pkl' )
134+ torch .save (discriminator .state_dict (), './discriminator.pkl' )
0 commit comments