目录

训练自己的数据集

标注数据

LabelImg

#标注后的目录结构
project
└── labelimg
    ├── 20190128155421222575013.jpg
    ├── 20190128155421222575013.xml
    ├── 20190128155703035712899.jpg
    ├── 20190128155703035712899.xml
    ├── 20190129091126392737624.jpg
    └── 20190129091126392737624.xml

构建镜像

  • 拉取
    $ sudo docker pull gouchicao/keras-retinanet:latest
    
  • 手动构建
    FROM gouchicao/tensorflow:2.2.0-gpu-jupyter-opencv4-pillow-wget-curl-git-nano
    LABEL maintainer="wang-junjian@qq.com"
    WORKDIR /
    RUN mkdir -p /root/.keras/models/ && \
      wget -O /root/.keras/models/ResNet-50-model.keras.h5 https://github.com/fizyr/keras-models/releases/download/v0.0.1/ResNet-50-model.keras.h5
    RUN git clone --depth 1 --recurse-submodules https://github.com/gouchicao/keras-retinanet.git
    WORKDIR /keras-retinanet/keras-retinanet
    # 提前安装指定版本 keras==2.3.1 解决错误 TypeError: type object got multiple values for keyword argument 'training'
    RUN pip install keras==2.3.1 && \
      pip install . && \
      python setup.py build_ext --inplace
    WORKDIR /keras-retinanet
    

模型训练

  • 运行容器
    $ sudo docker run -it --runtime=nvidia --name=keras-retinanet -p 8888:8888 -p 6006:6006 \
                  -v /home/wjunjian/ailab/datasets/helmet:/keras-retinanet/project \
                  gouchicao/keras-retinanet bash
    
  • voc转csv格式,分隔数据集
    $ python voc2csv.py --data_dir=project/labelimg/ --output_dir=project/dataset
    
#生成的目录结构
project
├── dataset
│   ├── class.csv
│   ├── train
│   │   ├── 20190128155421222575013.jpg
│   │   ├── 20190128155421222575013.xml
│   │   ├── 20190129091126392737624.jpg
│   │   └── 20190129091126392737624.xml
│   ├── train.csv
│   ├── val
│   │   ├── 20190128155703035712899.jpg
│   │   └── 20190128155703035712899.xml
│   └── val.csv
└── labelimg
    ├── 20190128155421222575013.jpg
    ├── 20190128155421222575013.xml
    ├── 20190128155703035712899.jpg
    ├── 20190128155703035712899.xml
    ├── 20190129091126392737624.jpg
    └── 20190129091126392737624.xml
  • 训练
    $ python keras-retinanet/keras_retinanet/bin/train.py --tensorboard-dir=project/logs --snapshot-path project/snapshots \
      csv project/dataset/train.csv project/dataset/class.csv --val-annotations project/dataset/val.csv
    $ ll -h project/models/resnet50_csv_01.h5
    -rw-r--r-- 1 root     root     417M 7月  27 22:58 resnet50_csv_01.h5
    

训练过程可视化 TensorBoard

$ tensorboard --logdir=project/logs --bind_all

在本机浏览器中访问网址:http://localhost:6006

模型评估

$ python keras-retinanet/keras_retinanet/bin/evaluate.py csv project/dataset/val.csv project/dataset/class.csv \
    project/snapshots/resnet50_csv_01.h5 --convert-model

模型转换

$ mkdir project/inference
$ python keras-retinanet/keras_retinanet/bin/convert_model.py --no-class-specific-filter \
    project/snapshots/resnet50_csv_01.h5 project/inference/model.h5
$ ll -h project/inference/model.h5
-rw-r--r-- 1 root     root     140M 7月  27 23:14 model.h5

模型预测

$ python predict.py --model project/inference/model.h5 \
    --class_csv project/dataset/class.csv \
    --data_dir project/test \
    --predict_dir project/predict

参考资料