文章目录
前言
在学pytorch,读读调用resnet训练模型的源码。
代码随便找了个,来自 b站up主Zomin的这个视频,项目代码在评论区置顶。
整体框架

单独看其中一次训练:
核心代码:
# 遍历data loader里的数据,对每个数据进行训练 # 梯度归零-正向传递(计算输出)-计算损失-反向传递- for data, target in train_loader: data = data.to(device) target = target.to(device) # clear the gradients of all optimized variables(清除梯度) optimizer.zero_grad() # forward pass: compute predicted outputs by passing inputs to the model # (正向传递:通过向模型传递输入来计算预测输出) output = model(data).to(device) #(等价于output = model.forward(data).to(device) ) # calculate the batch loss(计算损失值) loss = criterion(output, target) # backward pass: compute gradient of the loss with respect to model parameters # (反向传递:计算损失相对于模型参数的梯度) loss.backward() # perform a single optimization step (parameter update) # 执行单个优化步骤(参数更新) optimizer.step() # update training loss(更新损失) train_loss += loss.item()*data.size(0)
Moudule类的_call_impl()函数
通过该方法调用模型中的对应函数。
def _call_impl(self, *input, **kwargs): forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) # If we don't have any hooks, we want to skip the rest of the logic in # this function, and just call forward. if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks or _global_forward_hooks or _global_forward_pre_hooks): return forward_call(*input, **kwargs)
前向传播代码
在ResNet类中的_forward_impl(self,x)函数中,完成了残差网络的一次完整运算,主要分为三个部分。
def _forward_impl(self, x): # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) #x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x
一、降采样卷积
原版resnet采用7X7降采样卷积,但是因为cifar-10图片很小因此代码中改为3X3卷积并取消了最大池化层,减小数据损失。
x = self.conv1(x) # 卷积层 x = self

3万+

被折叠的 条评论
为什么被折叠?



