返回首页
AI 图像识别

使用RetinaNet算法训练自定义数据集

训练自己的数据集

标注数据

LabelImg

#标注后的目录结构
project
└── labelimg
    ├── 20190128155421222575013.jpg
    ├── 20190128155421222575013.xml
    ├── 20190128155703035712899.jpg
    ├── 20190128155703035712899.xml
    ├── 20190129091126392737624.jpg
    └── 20190129091126392737624.xml

构建镜像

  • 拉取
$ sudo docker pull gouchicao/keras-retinanet:latest
  • 手动构建
FROM gouchicao/tensorflow:2.2.0-gpu-jupyter-opencv4-pillow-wget-curl-git-nano
LABEL maintainer="wang-junjian@qq.com"
WORKDIR /
RUN mkdir -p /root/.keras/models/ && \
    wget -O /root/.keras/models/ResNet-50-model.keras.h5 https://github.com/fizyr/keras-models/releases/download/v0.0.1/ResNet-50-model.keras.h5
RUN git clone --depth 1 --recurse-submodules https://github.com/gouchicao/keras-retinanet.git
WORKDIR /keras-retinanet/keras-retinanet
# 提前安装指定版本 keras==2.3.1 解决错误 TypeError: type object got multiple values for keyword argument 'training'
RUN pip install keras==2.3.1 && \
    pip install . && \
    python setup.py build_ext --inplace
WORKDIR /keras-retinanet

模型训练

  • 运行容器
$ sudo docker run -it --runtime=nvidia --name=keras-retinanet -p 8888:8888 -p 6006:6006 \
                -v /home/wjunjian/ailab/datasets/helmet:/keras-retinanet/project \
                gouchicao/keras-retinanet bash
  • voc转csv格式,分隔数据集
$ python voc2csv.py --data_dir=project/labelimg/ --output_dir=project/dataset
#生成的目录结构
project
├── dataset
│   ├── class.csv
│   ├── train
│   │   ├── 20190128155421222575013.jpg
│   │   ├── 20190128155421222575013.xml
│   │   ├── 20190129091126392737624.jpg
│   │   └── 20190129091126392737624.xml
│   ├── train.csv
│   ├── val
│   │   ├── 20190128155703035712899.jpg
│   │   └── 20190128155703035712899.xml
│   └── val.csv
└── labelimg
    ├── 20190128155421222575013.jpg
    ├── 20190128155421222575013.xml
    ├── 20190128155703035712899.jpg
    ├── 20190128155703035712899.xml
    ├── 20190129091126392737624.jpg
    └── 20190129091126392737624.xml
  • 训练
$ python keras-retinanet/keras_retinanet/bin/train.py --tensorboard-dir=project/logs --snapshot-path project/snapshots \
    csv project/dataset/train.csv project/dataset/class.csv --val-annotations project/dataset/val.csv
$ ll -h project/models/resnet50_csv_01.h5
-rw-r--r-- 1 root     root     417M 7月  27 22:58 resnet50_csv_01.h5

训练过程可视化 TensorBoard

$ tensorboard --logdir=project/logs --bind_all

在本机浏览器中访问网址:http://localhost:6006

模型评估

$ python keras-retinanet/keras_retinanet/bin/evaluate.py csv project/dataset/val.csv project/dataset/class.csv \
    project/snapshots/resnet50_csv_01.h5 --convert-model

模型转换

$ mkdir project/inference
$ python keras-retinanet/keras_retinanet/bin/convert_model.py --no-class-specific-filter \
    project/snapshots/resnet50_csv_01.h5 project/inference/model.h5
$ ll -h project/inference/model.h5
-rw-r--r-- 1 root     root     140M 7月  27 23:14 model.h5

模型预测

$ python predict.py --model project/inference/model.h5 \
    --class_csv project/dataset/class.csv \
    --data_dir project/test \
    --predict_dir project/predict

参考资料