1- # Implementation of https://arxiv.org/pdf/1512.03385.pdf/
1+ # Implementation of https://arxiv.org/pdf/1512.03385.pdf
22# See section 4.2 for model architecture on CIFAR-10.
33# Some part of the code was referenced below.
44# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
5- import torch
5+ import torch
66import torch .nn as nn
77import torchvision .datasets as dsets
88import torchvision .transforms as transforms
99from torch .autograd import Variable
1010
11- # Image Preprocessing
11+ # Image Preprocessing
1212transform = transforms .Compose ([
1313 transforms .Scale (40 ),
1414 transforms .RandomHorizontalFlip (),
1717
1818# CIFAR-10 Dataset
1919train_dataset = dsets .CIFAR10 (root = './data/' ,
20- train = True ,
20+ train = True ,
2121 transform = transform ,
2222 download = True )
2323
2424test_dataset = dsets .CIFAR10 (root = './data/' ,
25- train = False ,
25+ train = False ,
2626 transform = transforms .ToTensor ())
2727
2828# Data Loader (Input Pipeline)
2929train_loader = torch .utils .data .DataLoader (dataset = train_dataset ,
30- batch_size = 100 ,
30+ batch_size = 100 ,
3131 shuffle = True )
3232
3333test_loader = torch .utils .data .DataLoader (dataset = test_dataset ,
34- batch_size = 100 ,
34+ batch_size = 100 ,
3535 shuffle = False )
3636
3737# 3x3 Convolution
3838def conv3x3 (in_channels , out_channels , stride = 1 ):
39- return nn .Conv2d (in_channels , out_channels , kernel_size = 3 ,
39+ return nn .Conv2d (in_channels , out_channels , kernel_size = 3 ,
4040 stride = stride , padding = 1 , bias = False )
4141
4242# Residual Block
@@ -49,7 +49,7 @@ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
4949 self .conv2 = conv3x3 (out_channels , out_channels )
5050 self .bn2 = nn .BatchNorm2d (out_channels )
5151 self .downsample = downsample
52-
52+
5353 def forward (self , x ):
5454 residual = x
5555 out = self .conv1 (x )
@@ -76,7 +76,7 @@ def __init__(self, block, layers, num_classes=10):
7676 self .layer3 = self .make_layer (block , 64 , layers [1 ], 2 )
7777 self .avg_pool = nn .AvgPool2d (8 )
7878 self .fc = nn .Linear (64 , num_classes )
79-
79+
8080 def make_layer (self , block , out_channels , blocks , stride = 1 ):
8181 downsample = None
8282 if (stride != 1 ) or (self .in_channels != out_channels ):
@@ -89,7 +89,7 @@ def make_layer(self, block, out_channels, blocks, stride=1):
8989 for i in range (1 , blocks ):
9090 layers .append (block (out_channels , out_channels ))
9191 return nn .Sequential (* layers )
92-
92+
9393 def forward (self , x ):
9494 out = self .conv (x )
9595 out = self .bn (out )
@@ -101,36 +101,36 @@ def forward(self, x):
101101 out = out .view (out .size (0 ), - 1 )
102102 out = self .fc (out )
103103 return out
104-
104+
105105resnet = ResNet (ResidualBlock , [3 , 3 , 3 ])
106106resnet .cuda ()
107107
108108# Loss and Optimizer
109109criterion = nn .CrossEntropyLoss ()
110110lr = 0.001
111111optimizer = torch .optim .Adam (resnet .parameters (), lr = lr )
112-
113- # Training
112+
113+ # Training
114114for epoch in range (80 ):
115115 for i , (images , labels ) in enumerate (train_loader ):
116116 images = Variable (images .cuda ())
117117 labels = Variable (labels .cuda ())
118-
118+
119119 # Forward + Backward + Optimize
120120 optimizer .zero_grad ()
121121 outputs = resnet (images )
122122 loss = criterion (outputs , labels )
123123 loss .backward ()
124124 optimizer .step ()
125-
125+
126126 if (i + 1 ) % 100 == 0 :
127127 print ("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" % (epoch + 1 , 80 , i + 1 , 500 , loss .data [0 ]))
128128
129129 # Decaying Learning Rate
130130 if (epoch + 1 ) % 20 == 0 :
131131 lr /= 3
132- optimizer = torch .optim .Adam (resnet .parameters (), lr = lr )
133-
132+ optimizer = torch .optim .Adam (resnet .parameters (), lr = lr )
133+
134134# Test
135135correct = 0
136136total = 0
@@ -144,4 +144,4 @@ def forward(self, x):
144144print ('Accuracy of the model on the test images: %d %%' % (100 * correct / total ))
145145
146146# Save the Model
147- torch .save (resnet .state_dict (), 'resnet.pkl' )
147+ torch .save (resnet .state_dict (), 'resnet.pkl' )
0 commit comments