@@ -82,8 +82,8 @@ def __init__(self, block, layers, num_classes=10):
8282 self .bn = nn .BatchNorm2d (16 )
8383 self .relu = nn .ReLU (inplace = True )
8484 self .layer1 = self .make_layer (block , 16 , layers [0 ])
85- self .layer2 = self .make_layer (block , 32 , layers [0 ], 2 )
86- self .layer3 = self .make_layer (block , 64 , layers [1 ], 2 )
85+ self .layer2 = self .make_layer (block , 32 , layers [1 ], 2 )
86+ self .layer3 = self .make_layer (block , 64 , layers [2 ], 2 )
8787 self .avg_pool = nn .AvgPool2d (8 )
8888 self .fc = nn .Linear (64 , num_classes )
8989
@@ -112,7 +112,7 @@ def forward(self, x):
112112 out = self .fc (out )
113113 return out
114114
115- model = ResNet (ResidualBlock , [2 , 2 , 2 , 2 ]).to (device )
115+ model = ResNet (ResidualBlock , [2 , 2 , 2 ]).to (device )
116116
117117
118118# Loss and optimizer
@@ -166,4 +166,4 @@ def update_lr(optimizer, lr):
166166 print ('Accuracy of the model on the test images: {} %' .format (100 * correct / total ))
167167
168168# Save the model checkpoint
169- torch .save (model .state_dict (), 'resnet.ckpt' )
169+ torch .save (model .state_dict (), 'resnet.ckpt' )
0 commit comments