PyTorch 新たなクラスの物体検出をSSDでやってみる

今回は、血液の顕微鏡画像から細胞を検出するSSDモデルを作ってみたいと思います。

こんにちは cedro です。

先回、SSDの学習済みモデルを使った物体検出を行ってみましたが、物体検出できるのはあらかじめ学習した20クラスだけです。

新たなクラスの物体検出をするには、どうしたら良いのでしょうか。大量のアノテーション付きデータを用意してゼロから学習するしかないのでしょうか。

そんなことはありません。SSDのネットワークの前半はVGG16の学習済みモデルの一部をベースネットワークとして使えるので、後半部分のみ新たに検出したいクラスのデータセットを少量学習させることによって、新しいSSDモデルを構築できます。

ということで、今回は、血液の顕微鏡画像から細胞を検出するSSDモデルを作ってみたいと思います。

 

データセットを準備します

今回、使用するデータセットは、BCCD Dataset という血液の顕微鏡写真で、白血球、赤血球、血小板の3つについてバウンディングボックスのアノテーションデータが付いたものです。

データセットの仕様が、PASCAL Visual Object Classes ですので、PyTorch のSSDモデルで簡単に読み込むことが出来ます。

 

BCCD Dataset の検出イメージ

物体検出をした時のイメージは、こんな感じ。wbc が白血球、rbc が赤血球、platelets が血小板です。それにしても、なんともマニアックなデータセットですよね。

BCCD Dataset はこういった画像とアノテーションデータのセットが、全部で364個(trainval:292個、test:72個)しかない非常に小さなデータセットです。

  

  

BCCD_Dataset

Github から BCCD Datset をダウンロードします。実際に使用するのは、赤枠で囲ったBCCDフォルダーの部分だけです。

 

 

コードを書きます

今回も、PyTorchニューラルネットワーク実装ハンドブックのお世話になります。Github からサンプルコードをクローンあるいはダウンロードします。今回使用するのは、Chapter7です。

   

   

chapter7フォルダーの中身です。今回は、この chapter7フォルダーの中で、コードの追加・修正を行います。

まず、VOCdevkitフォルダーを追加し、その中に先程ダウンロードした BCCD フォルダーを格納します。

そして、CNNのベースネットワーク( vgg16_reducedfc.pth )をダウンロードし、weights フォルダーに格納しておきます。

 

これが、SSDのネットワーク構成です。前半は、VGG16ネットワークの学習済みの重みの一部をそのまま利用し、後半のExtra Feature Layers のみを新たなデータセットを使って学習することで、 新たなクラスの物体検出が出来るSSDネットワークが出来ます。早速、コードを書いてみましょう。

  

  

学習するためのコードです。train.py という名前で Chapter7 に保存します。

今回は、Windowsで動かす想定をしています。もし、Macで動かす場合は、24行目を ‘cuda’ : False に変更して下さい。

  

  

dataフォルダーの中にある、voc0712.py のコードを一部修正します。13行目のVOC_CLASSES はBCCDのデータセットに合わせて修正します。

48行目 image_sets の指定を’BCCD’, ‘trainval’ のみにします。59行目の year は dir に変更し、60行目の ‘VOC’+year はdir に変更します。

 

 

dataフォルダー内にある、config.py の SSD300 CONFIGS の一部を修正します。8行目 ‘max_iter’ :12000 → 3000 に変更し、18行目 ‘VOC’ → ‘BCCD’ に変更します。

  

学習・推論を実行します

train.py を実行します。trainvalデータ292個、3000 epoch をGTX1060で学習を行ったところ、約1時間で完了しました。

学習が完了したら、weights フォルダーに保存される重みファイル(BCCD.pth)を使って、テスト画像の推論を実行します。

 

  

推論を実行するコードです。inference.py という名前で、chapter7フォルダーに保存します。

26行目でBCCD_test を読み込んでいますので、27行目のimg_id を指定することで、物体検出に使うテストデータを選択できます。

それでは、inference.py を実行してみます。

 

 

BCCD Dataset :No.42 of test

テストデータの42番目の画像です。これを、物体検出させると、

 

 

BCCD Dataset :No.42 of test (detected)

若干取りこぼしがありそうですが、まあまあ物体検出できている感じです。

 

 

BCCD Dataset :No.55 of test

もう1つ行ってみましょう。テストデータの55番目の画像です。これを、物体検出させると、

 

 

BCCD Dataset :No.55 of test(detected)

これも、まずまずでしょうか。

たった、292個のデータを学習させただけですが、ベースネットワークは1000クラスの分類を学習したVGG16ネットワークなので、結構物体検出ができるものですね。

では、また。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

ABOUTこの記事をかいた人

2017年8月に、SONY Neural Network Console に一目惚れして、ディープラーニングを始めました。初心者の試行錯誤をブログにしています。