使用MMDetection训练自己的数据集
前言
本文主要阐述如何使用mmdetection训练自己的数据,包括配置文件的修改,训练时的数据增强,加载预训练权重以及绘制损失函数图等。这里承接上一篇文章,默认已经准备好了COCO格式数据集且已安装mmdetection,环境也已经配置完成。
这里说明一下,因为mmdetection更新至2.x版本之后有些用法不一样了,所以对本文重新更新一下,这里使用的mmdetection的版本是2.27.0,只要是2.x版本本文都适用的。
1、配置文件修改
配置文件获取方式一:
首先就是根据任务选定一个训练模型,这里我选用yolox-s作为我的训练模型,进入mmdetection/configs/yolox文件夹,可以看到有以下文件:

这里可以看到有yolox-s的配置文件yolox_s_8x8_300e_coco.py,这里请注意:不要在默认的配置文件中修改内容,最好将要修改的配置文件复制一份,在副本文件中修改内容!!! 复制一份配置文件之后,就可以根据需要进行修改了!
配置文件获取方式二(推荐):
方式二就是使用openmmlab的包管理工具mim来获取配置文件和预训练权重,这里的mim在安装mmdetection时就安装上了,就不进行安装说明了,如果有问题请参考mmdetection文档。
这里以jupyternotebook为例进行讲述,在终端使用时请去掉感叹号!
# 以yolox为例获取配置文件 --model后面就写你想获取的模型配置文件
!mim search mmdet --model 'yolox'
当上述运行后,会出现下面的内容:

这里我要使用yolox-s,那么我就选中对应的config id:yolox_s_8x8_300e_coco,然后执行下面的代码:
# --dest后面有个空格,然后再加一个点,这个得是英文的点
!mim download mmdet --config yolox_s_8x8_300e_coco --dest .
这样就将配置文件和预训练权重下载到你的当前文件执行的目录下了:

之后在这个配置文件里面改你所需的东西就行了
1.1 model部分

这里大部分参数可以沿用默认的,或许有修改的是bbox_head中的num_classes=80,这个是类别数,COCO数据集是80类,你可以看自己数据集是多少类别,然后改成对应的,比如我的数据集有2类,那么就改成num_classes=2。
另外就是test_cfg下的nms=dict(type='nms', iou_threshold=0.65),iou_threshold=0.65可以修改,你可以把iou阈值改成你想要的,比如iou_threshold=0.40。
如果想使用预训练权重,那么可以这样设置,就是在model字典开头,加上init_cfg=dict(type='Pretrained', checkpoint='这里输入你的预训练权重文件路径')

一些分类头和FPN的修改和BackBone的替换并不在本文之内。
1.2 dataset部分
2.x之后的mmdetection在dataset部分有一些不同,这里重新说明一下自定义数据集的设置
在mmdetection文件夹中创建data文件夹,然后创建子文件夹,把子文件夹的名称设为coco,将你的训练、验证、测试数据导入其中。具体样式如下:


返回配置文件,然后在下列填入你的数据集路径:

你数据集的类别可能不是coco的80类,那么就需要把类别给改了,具体操作如下:
- 进入
mmdetection/mmdet/datasets,打开coco.py,我们要修改其内容(这里我们默认数据集格式是COCO)

- 按照下图样式,把原来的CLASSES注掉,新起一个CLASSES,里面填你的类别,这里需要注意:如果你的数据集只有一个类别,那么记得在类别后面加一个逗号,不然会报错!!! 请注意:这里类别的名字得和你的图片目标标签名字一样,别你的标签是 Cat和Dog,然后在这里变成了 cat和dog!!! 其他的地方都不需要动!!!

这里对coco.py修改之后,还需修改一个地方,请把目光转到mmdetection-2.27.0/mmdet/core/evaluation/class_names.py这个文件下面,将你的类别数量也进行修改,找到def coco_classes():,改成你自己的类别:

当我们把上述两个文件修改之后记得重新编译一下代码:
!python setup.py install
这样dataset初步构建完成,下面针对train dataset进行修改
1.2.1 train dataset部分
训练部分数据增强
说起train dataset肯定离不开数据增强,我这里没有使用mmdetection内置的数据增强,如果你想看其内置哪些增强,可以在mmdetection/mmdet/datasets/pipelines/transforms.py中查看。我这里使用albumentations库进行数据增强(主要是功能真的很强大,太香了),如果你也想使用这个开源库,那么请先安装它:
pip install -U albumentations
然后在train_pipelines添加或修改你的增强策略。具体可以参考我的代码:
- 首先在配置文件开头添加如下代码:
### Albumentations Start ###
img_norm_cfg = dict(
mean=[95.4554, 107.3959, 69.8863], std=[56.0811, 55.2941, 55.2364], to_rgb=True)
albu_train_transforms = [
dict(
type='RandomBrightnessContrast',
brightness_limit=[-0.2, 0.3],
contrast_limit=[-0.2, 0.3],
p=0.5),
dict(type='RandomRotate90', p=0.5),
dict(type='GaussianBlur', blur_limit=(3, 7), sigma_limit=(0, 0.99), p=0.5),
dict(type='MotionBlur', blur_limit=(3, 7), p=0.5)
]
### Albumentations End ###

type='RandomBrightnessContrast',type='RandomRotate90' 都是增强策略,这个可以查看albumentations官方文档,根据自己需求添加,添加格式和我上面的代码一样。
然后 mean=[95.4554, 107.3959, 69.8863], std=[56.0811, 55.2941, 55.2364] ,这个是你数据集的均值和标准差,可以自己编写一个Python程序自动计算一下,如果你懒得编写,那么可以参考我的这个,就是计算起来稍微有点慢。
import torch
from torch.utils.data import DataLoader, Dataset
import os
from pathlib import Path
import numpy as np
from PIL import Image
def cal_mean_std(path: str):
channels_sum, channels_squared_sum, nums = 0, 0, 0
path_list = os.listdir(path)
for img_path in path_list:
image_path = os.path.join(path, img_path)
# image = torch.from_numpy(np.array(Image.open(image_path)) / 255).permute([2, 0, 1]).float()
image = torch.from_numpy(np.array(Image.open(image_path))).permute([2, 0, 1]).float()
channels_sum += torch.mean(image, dim=[1, 2])
channels_squared_sum += torch.mean(image**2, dim=[1, 2])
nums += 1
mean = channels_sum / nums
std = (channels_squared_sum / nums - mean**2)**0.5
return mean, std
if __name__ == '__main__':
path = os.path.abspath("F:/VOC2012/VOCdevkit/VOC2012/JPEGImages")
mean, std = cal_mean_std(path=path)
print(f'mean : {
mean}, std : {
std}')
到这里train_pipelines添加完成
train dataset后续
作完数据增强后就应该将其添加到train_dataset中了,照我这样添加就好了:
train_pipeline = [
dict(
type='Albu',
transforms=albu_train_transforms,
bbox_params=dict(
type='BboxParams',
format='pascal_voc',
label_fields=['gt_labels'],
min_visibility=0.1,
filter_lost_elements=True),
keymap={
'img': 'image',
'gt_bboxes': 'bboxes'
},
update_pad_shape=False,
skip_img_without_anno=True),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1,

8472

被折叠的 条评论
为什么被折叠?



