PyTorch 新たなクラスの物体検出をSSDでやってみる

今回は、血液の顕微鏡画像から細胞を検出するSSDモデルを作ってみたいと思います。

こんにちは cedro です。

先回、SSDの学習済みモデルを使った物体検出を行ってみましたが、物体検出できるのはあらかじめ学習した20クラスだけです。

新たなクラスの物体検出をするには、どうしたら良いのでしょうか。大量のアノテーション付きデータを用意してゼロから学習するしかないのでしょうか。

そんなことはありません。SSDのネットワークの前半はVGG16の学習済みモデルの一部をベースネットワークとして使えるので、後半部分のみ新たに検出したいクラスのデータセットを少量学習させることによって、新しいSSDモデルを構築できます。

ということで、今回は、血液の顕微鏡画像から細胞を検出するSSDモデルを作ってみたいと思います。

 

データセットを準備します

今回、使用するデータセットは、BCCD Dataset という血液の顕微鏡写真で、白血球、赤血球、血小板の3つについてバウンディングボックスのアノテーションデータが付いたものです。

データセットの仕様が、PASCAL Visual Object Classes ですので、PyTorch のSSDモデルで簡単に読み込むことが出来ます。

 

BCCD Dataset の検出イメージ

物体検出をした時のイメージは、こんな感じ。wbc が白血球、rbc が赤血球、platelets が血小板です。それにしても、なんともマニアックなデータセットですよね。

BCCD Dataset はこういった画像とアノテーションデータのセットが、全部で364個(trainval:292個、test:72個)しかない非常に小さなデータセットです。

  

  

BCCD_Dataset

Github から BCCD Datset をダウンロードします。実際に使用するのは、赤枠で囲ったBCCDフォルダーの部分だけです。

 

 

コードを書きます

今回も、PyTorchニューラルネットワーク実装ハンドブックのお世話になります。Github からサンプルコードをクローンあるいはダウンロードします。今回使用するのは、Chapter7です。

   

   

chapter7フォルダーの中身です。今回は、この chapter7フォルダーの中で、コードの追加・修正を行います。

まず、VOCdevkitフォルダーを追加し、その中に先程ダウンロードした BCCD フォルダーを格納します。

そして、CNNのベースネットワーク( vgg16_reducedfc.pth )をダウンロードし、weights フォルダーに格納しておきます。

 

これが、SSDのネットワーク構成です。前半は、VGG16ネットワークの学習済みの重みの一部をそのまま利用し、後半のExtra Feature Layers のみを新たなデータセットを使って学習することで、 新たなクラスの物体検出が出来るSSDネットワークが出来ます。早速、コードを書いてみましょう。

  

  

学習するためのコードです。train.py という名前で Chapter7 に保存します。

今回は、Windowsで動かす想定をしています。もし、Macで動かす場合は、24行目を ‘cuda’ : False に変更して下さい。

  

  

dataフォルダーの中にある、voc0712.py のコードを一部修正します。13行目のVOC_CLASSES はBCCDのデータセットに合わせて修正します。

48行目 image_sets の指定を’BCCD’, ‘trainval’ のみにします。59行目の year は dir に変更し、60行目の ‘VOC’+year はdir に変更します。

 

 

dataフォルダー内にある、config.py の SSD300 CONFIGS の一部を修正します。8行目 ‘max_iter’ :12000 → 3000 に変更し、18行目 ‘VOC’ → ‘BCCD’ に変更します。

  

学習・推論を実行します

train.py を実行します。trainvalデータ292個、3000 epoch をGTX1060で学習を行ったところ、約1時間で完了しました。

学習が完了したら、weights フォルダーに保存される重みファイル(BCCD.pth)を使って、テスト画像の推論を実行します。

 

  

推論を実行するコードです。inference.py という名前で、chapter7フォルダーに保存します。

26行目でBCCD_test を読み込んでいますので、27行目のimg_id を指定することで、物体検出に使うテストデータを選択できます。

それでは、inference.py を実行してみます。

 

 

BCCD Dataset :No.42 of test

テストデータの42番目の画像です。これを、物体検出させると、

 

 

BCCD Dataset :No.42 of test (detected)

若干取りこぼしがありそうですが、まあまあ物体検出できている感じです。

 

 

BCCD Dataset :No.55 of test

もう1つ行ってみましょう。テストデータの55番目の画像です。これを、物体検出させると、

 

 

BCCD Dataset :No.55 of test(detected)

これも、まずまずでしょうか。

たった、292個のデータを学習させただけですが、ベースネットワークは1000クラスの分類を学習したVGG16ネットワークなので、結構物体検出ができるものですね。

では、また。

google colab バージョンを追加

2021/1

 上記でご説明したコードをGoogle Colabで動かす形にしてGithubに上げました。この「リンク」をクリックし表示されたノートブックの先頭にある「Colab on Web」ボタンをクリックすると動かせます。BCCDデータセット、CNNのベースネットワークの重みもコードに保存済みなので、手軽に試せると思います。

2 件のコメント

  • 突然の質問失礼します。

    train.py

    のコード1行目の
    from data import *
    はどういう意味でしょうか?

    下記の様なエラーが出ており、ご紹介していただいている学習を進めることができずに困っています。
    importの先に何か、importするデータを指定しないといけないと思うのですが、初学者でして、自力で解決できず困っています。
    よろしければご教授願えれば幸いです。

    No module named ‘torch’
    File “/Users/myname/Desktop/pytorch_handbook/chapter7/data/voc0712.py”, line 17, in
    import torch
    File “/Users/myname/Desktop/pytorch_handbook/chapter7/data/__init__.py”, line 7, in
    from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
    File “/Users/myname/Desktop/pytorch_handbook/chapter7/train.py”, line 1, in
    from data import *

    • takumaさん
      コメントありがとうございます。

      from data import* は、dataフォルダーにある全てのファイルをインポートするという意味で、これは問題ありません。

      エラーリストを拝見すると、No module named ‘torch’ とありますので、Pytorchが上手くインストール出来ていないようです。下記の様に、チェックして問題があれば、再インストールしてみて下さい。

      (Pytorchとcudaのチェック)
      import torch
      print(torch.__version__)
      print(torch.cuda.is_available())

      (Torchvisionのチェック)
      import torchvision
      print(torchvision.__version__)

  • コメントを残す

    メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

    日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)