使用RetinaNet算法训练自定义数据集
类别: AI 图像识别 标签: 目标检测 TensorFlow Docker RetinaNet Keras TensorBoard Dockerfile目录
训练自己的数据集
标注数据
#标注后的目录结构
project
└── labelimg
├── 20190128155421222575013.jpg
├── 20190128155421222575013.xml
├── 20190128155703035712899.jpg
├── 20190128155703035712899.xml
├── 20190129091126392737624.jpg
└── 20190129091126392737624.xml
构建镜像
- 拉取
$ sudo docker pull gouchicao/keras-retinanet:latest
- 手动构建
FROM gouchicao/tensorflow:2.2.0-gpu-jupyter-opencv4-pillow-wget-curl-git-nano LABEL maintainer="wang-junjian@qq.com" WORKDIR / RUN mkdir -p /root/.keras/models/ && \ wget -O /root/.keras/models/ResNet-50-model.keras.h5 https://github.com/fizyr/keras-models/releases/download/v0.0.1/ResNet-50-model.keras.h5 RUN git clone --depth 1 --recurse-submodules https://github.com/gouchicao/keras-retinanet.git WORKDIR /keras-retinanet/keras-retinanet # 提前安装指定版本 keras==2.3.1 解决错误 TypeError: type object got multiple values for keyword argument 'training' RUN pip install keras==2.3.1 && \ pip install . && \ python setup.py build_ext --inplace WORKDIR /keras-retinanet
模型训练
- 运行容器
$ sudo docker run -it --runtime=nvidia --name=keras-retinanet -p 8888:8888 -p 6006:6006 \ -v /home/wjunjian/ailab/datasets/helmet:/keras-retinanet/project \ gouchicao/keras-retinanet bash
- voc转csv格式,分隔数据集
$ python voc2csv.py --data_dir=project/labelimg/ --output_dir=project/dataset
#生成的目录结构
project
├── dataset
│ ├── class.csv
│ ├── train
│ │ ├── 20190128155421222575013.jpg
│ │ ├── 20190128155421222575013.xml
│ │ ├── 20190129091126392737624.jpg
│ │ └── 20190129091126392737624.xml
│ ├── train.csv
│ ├── val
│ │ ├── 20190128155703035712899.jpg
│ │ └── 20190128155703035712899.xml
│ └── val.csv
└── labelimg
├── 20190128155421222575013.jpg
├── 20190128155421222575013.xml
├── 20190128155703035712899.jpg
├── 20190128155703035712899.xml
├── 20190129091126392737624.jpg
└── 20190129091126392737624.xml
- 训练
$ python keras-retinanet/keras_retinanet/bin/train.py --tensorboard-dir=project/logs --snapshot-path project/snapshots \ csv project/dataset/train.csv project/dataset/class.csv --val-annotations project/dataset/val.csv $ ll -h project/models/resnet50_csv_01.h5 -rw-r--r-- 1 root root 417M 7月 27 22:58 resnet50_csv_01.h5
训练过程可视化 TensorBoard
$ tensorboard --logdir=project/logs --bind_all
在本机浏览器中访问网址:http://localhost:6006
模型评估
$ python keras-retinanet/keras_retinanet/bin/evaluate.py csv project/dataset/val.csv project/dataset/class.csv \
project/snapshots/resnet50_csv_01.h5 --convert-model
模型转换
$ mkdir project/inference
$ python keras-retinanet/keras_retinanet/bin/convert_model.py --no-class-specific-filter \
project/snapshots/resnet50_csv_01.h5 project/inference/model.h5
$ ll -h project/inference/model.h5
-rw-r--r-- 1 root root 140M 7月 27 23:14 model.h5
模型预测
$ python predict.py --model project/inference/model.h5 \
--class_csv project/dataset/class.csv \
--data_dir project/test \
--predict_dir project/predict
参考资料
- Keras RetinaNet 工程实践
- Git 工具 - 子模块
- How to copy multiple files in one layer using a Dockerfile?
- How to git clone only the latest revision
- How to “git clone” including submodules?
- 物件偵測 - RetinaNet 介紹
- Object Detection on Custom Dataset with TensorFlow 2 and Keras using Python
- An Introduction to Implementing Retinanet in Keras for Multi Object Detection on Custom Dataset
- RetinaNet: Custom Object Detection training with 5 lines of code
- RetinaNet和Focal Loss论文笔记