您现在的位置是:首页 > 技术教程 正文

【MMDetection】——训练个人数据集

admin 阅读: 2024-03-19
后台-插件-广告管理-内容页头部广告(手机)

文章目录

  • 1、数据集格式及存放
  • 2、修改两处
  • 3、用训练命令生成配置文件
  • 4、正式训练开始
  • 5、报错记录
  • 6、模型评价测试(VOC指标mAP、COCO指标AP)
  • 7、绘制每个类别bbox 的结果曲线图并保存
  • 8、统计模型参数量和FLOPs
  • 9 计算混淆矩阵
  • 10 画PR曲线
  • 11 查看完整config配置文件
  • 12 核查数据增强的结果是否正确
  • 8、参考链接

1、数据集格式及存放

mmdet支持COCO格式和VOC格式,能用COCO格式,还是建议COCO的。网上有YOLO转COCO,VOC转COCO,可以自己转换。

在mmdetection代码的根目录下,创建 data/coco 文件夹,按照coco的格式排放好数据集。annotations下面是标签文件,train2017、val2017、test2017是图片。
在这里插入图片描述
在这里插入图片描述

2、修改两处

第一处: mmdet/core/evalution/class_names.py 代码下的 def coco_classes() 的 return 内容改为自己数据集的类别;

第二处:mmdet/datasets/coco.py 代码下的 class CocoDataset(CustomDataset) 的 CLASSES 改为自己数据集的类别;

注意:修改两处后,一定要在根目录下,输入命令:
python setup.py install build
重新编译代码,要不然类别会没有载入,还是原coco类别,训练异常。

3、用训练命令生成配置文件

python tools/train.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py --work-dir work_dirs
  • 1

其中,work_dirs是自己在根目录新建的工作目录,训练文件存储在这里。

注意,此时运行命令之后,并不是直接训练就可以不管了!我们还有参数设置没改!这里输入训练命令,只是需要它生成一个配置文件,便于我们改参数!在这里插入图片描述

打开配置文件 cascade_rcnn_r50_fpn_1x_coco.py :
(1)修改 num_classes ,将其改为自己数据类别(直接全局搜索,有3处,都要改);

(2)修改 data_root 路径和训练集、验证集、测试集的图片和标签路径,如下图:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

(3)修改训练图片大小和学习率

修改下处代码,可以更改图片大小

img_scale = (1333, 800),
  • 1

batch_size, mmdet默认的方式是由 GPU 数量与 samples_per_gpu 参数决定:
samples_per_gpu: 每个gpu读取的图像数量(意思不就是batch_size=2),该参数和训练时的gpu数量决定了训练时的batch_size。(为什么这么说呢?因为mmdet是8个GPU训练的,那么总的batch就是 8 *samples_per_gpu=16,即训练时是batch_size为16) 。
但我们通常是只有一个gpu, 该参数设置为 2, 意思就是我们训练的 batch_size为2;

workers_per_gpu: 读取数据时每个gpu分配的线程数 ,一般设置为 2即可;(我感觉既然用单个GPU,设置到8也无妨吧?我还没试)

在这里插入图片描述

学习率设置:
mmdet 默认的学习率是基于8个gpu,而且默认是1个GPU处理2个图像(就上面说的samples_per_gpu为2),可以这样理解:
8个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括16张图片,学习率为0.02;
4个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括8张图片,学习率为0.01;
1个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括2张图片,学习率为0.0025;
1个GPU,每个GPU处理1张图片,那么真实训练总的一个batch就包括1张图片,学习率为0.00125;
在这里插入图片描述
(4)使用预训练模型
提前从github上下载预训练模型,新建一个checkpoints文件夹下,放到里面。(模型下载链接:https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md)
然后修改以下代码:

# 原本是 load_from = None ,修改为 load_from = 'checkpoints/fcascade_rcnn_r50_fpn_1x_coco_20200316-3dc56deb.pth’
  • 1
  • 2

(5)训练轮数,保存模型间隔,日志保存参数
在这里插入图片描述

4、正式训练开始

!!!看清楚路径!使用的是更改过的配置文件训练!!!

python tools/train.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py
  • 1

5、报错记录

在第三步生成配置文件时,遇到以下报错:

AssertionError: The num_classes (10) in Shared2FCBBoxHead of
MMDataParallel does not matches the length of CLASSES 80) in
CocoDataset

即使在修改 coco.py 和 class_names.py 后运行 python setup.py install仍然无法解决;

解决方法:
根据报错信息,找到自己虚拟环境的/mmdet/datasets/coco.py和mmdet/core/evaluation/class_names.py,再次修改
CocoDataset()和 coco_classes()l两处(跟第二步一样,其实打开,就能看到虚拟环境下的并没有修改成功)

参考链接:AssertionError: The num_classes (3) in Shared2FCBBoxHead of
MMDataParallel does not
matches

6、模型评价测试(VOC指标mAP、COCO指标AP)

(1)生成中间件

python tools/test.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth --out results.pkl
  • 1
  • work_dirs/cascade_rcnn_r50_fpn_1x_coco.py 模型配置文件(跟训练时的一样)
  • work_dirs/epoch_20.pth: 训练好的模型(我是训练了20epoch)
  • --out 指定 results.pkl 输出目录,可以自己指定输出目录

(2)使用COCO标准评估指标

python tools/analysis_tools/eval_metric.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl --eval=bbox
  • 1
  • 2
  • --eval,COCO数据集可选参数有:bbox 、segm、proposal ;对VOC数据集可选参数有:mAP

(3)使用VOC标准评估指标

# results.pkl 的顺序别放错,在中间。 python tools/voc_eval.py results.pkl work_dirs/cascade_rcnn_r50_fpn_1x_coco.py
  • 1
  • 2
  • voc_eval.py 文件 mmdetection 2.X 版本删除了,可以去老版本1.X 找找

7、绘制每个类别bbox 的结果曲线图并保存

(1)使用 test.py 生成 results.bbox.json 文件(在根目录下,路径可自己指定)

python tools/test.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth --format-only --options "jsonfile_prefix=./results"
  • 1

(2)获得COCO bbox错误结果每个类别,保存分析结果图像到目录results/

python tools/analysis_tools/coco_error_analysis.py results.bbox.json results --ann=data/coco/annotations/instances_val2017.json
  • 1
  • results.bbox.json:上一步生成的文件
  • results: 结果曲线图的生成目录, 此处将生成到results/ 目录下
  • –ann=data/coco/annotations/instances_val2017.json: 数据集标注文件存放路径

8、统计模型参数量和FLOPs

python tools/analysis_tools/get_flops.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py --shape 640 640
  • 1
  • --shape 参数指定输入图片尺寸

9 计算混淆矩阵

python tools/analysis_tools/confusion_matrix.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl coco_confusion_matrix/
  • 1
  • 需要三个参数,配置文件、pkl文件、输出目录

10 画PR曲线

plot_pr_curve.py 代码来自:https://blog.csdn.net/weixin_44966641/article/details/124558532

import os import sys import mmcv import numpy as np import argparse import matplotlib.pyplot as plt from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from mmcv import Config from mmdet.datasets import build_dataset def plot_pr_curve(config_file, result_file, out_pic, metric="bbox"): """plot precison-recall curve based on testing results of pkl file. Args: config_file (list[list | tuple]): config file path. result_file (str): pkl file of testing results path. metric (str): Metrics to be evaluated. Options are 'bbox', 'segm'. """ cfg = Config.fromfile(config_file) # turn on test mode of dataset if isinstance(cfg.data.test, dict): cfg.data.test.test_mode = True elif isinstance(cfg.data.test, list): for ds_cfg in cfg.data.test: ds_cfg.test_mode = True # build dataset dataset = build_dataset(cfg.data.test) # load result file in pkl format pkl_results = mmcv.load(result_file) # convert pkl file (list[list | tuple | ndarray]) to json json_results, _ = dataset.format_results(pkl_results) # initialize COCO instance coco = COCO(annotation_file=cfg.data.test.ann_file) coco_gt = coco coco_dt = coco_gt.loadRes(json_results[metric]) # initialize COCOeval instance coco_eval = COCOeval(coco_gt, coco_dt, metric) coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() # extract eval data precisions = coco_eval.eval["precision"] ''' precisions[T, R, K, A, M] T: iou thresholds [0.5 : 0.05 : 0.95], idx from 0 to 9 R: recall thresholds [0 : 0.01 : 1], idx from 0 to 100 K: category, idx from 0 to ... A: area range, (all, small, medium, large), idx from 0 to 3 M: max dets, (1, 10, 100), idx from 0 to 2 ''' pr_array1 = precisions[0, :, 0, 0, 2] pr_array2 = precisions[1, :, 0, 0, 2] pr_array3 = precisions[2, :, 0, 0, 2] pr_array4 = precisions[3, :, 0, 0, 2] pr_array5 = precisions[4, :, 0, 0, 2] pr_array6 = precisions[5, :, 0, 0, 2] pr_array7 = precisions[6, :, 0, 0, 2] pr_array8 = precisions[7, :, 0, 0, 2] pr_array9 = precisions[8, :, 0, 0, 2] pr_array10 = precisions[9, :, 0, 0, 2] x = np.arange(0.0, 1.01, 0.01) # plot PR curve plt.plot(x, pr_array1, label="iou=0.5") plt.plot(x, pr_array2, label="iou=0.55") plt.plot(x, pr_array3, label="iou=0.6") plt.plot(x, pr_array4, label="iou=0.65") plt.plot(x, pr_array5, label="iou=0.7") plt.plot(x, pr_array6, label="iou=0.75") plt.plot(x, pr_array7, label="iou=0.8") plt.plot(x, pr_array8, label="iou=0.85") plt.plot(x, pr_array9, label="iou=0.9") plt.plot(x, pr_array10, label="iou=0.95") plt.xlabel("recall") plt.ylabel("precison") plt.xlim(0, 1.0) plt.ylim(0, 1.01) plt.grid(True) plt.legend(loc="lower left") plt.savefig(out_pic) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('config', help='config file path') parser.add_argument('pkl_result_file', help='pkl result file path') parser.add_argument('--out', default='pr_curve.png') parser.add_argument('--eval', default='bbox') cfg = parser.parse_args() plot_pr_curve(config_file=cfg.config, result_file=cfg.pkl_result_file, out_pic=cfg.out, metric=cfg.eval)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100

输入命令:

python plot_pr_curve.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl
  • 1

11 查看完整config配置文件

python tools/misc/print_config.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py
  • 1

12 核查数据增强的结果是否正确

python tools/misc/browse_dataset.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py --output-dir work_dirs/
  • 1

8、参考链接

https://blog.csdn.net/qq_35077107/article/details/124768460?spm=1001.2014.3001.5502

https://blog.csdn.net/weixin_44966641/article/details/124558532

标签:
声明

1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。

在线投稿:投稿 站长QQ:1888636

后台-插件-广告管理-内容页尾部广告(手机)
关注我们

扫一扫关注我们,了解最新精彩内容

搜索