cedro-blog

Keras AnoGAN で異常検知をやってみる

今回は、Keras AnoGANでMNISTの異常検知をしてみたいと思います。

先回、VAEによる異常検知をやってみました。最近発表された論文を分かり易く解説したブログがあったので、それをトレースしただけなのですが、私にとっては結構歯ごたえがあり、その分面白かったです。

そうした中、他の異常検知の手法が知りたくなり調べてみると、GANによる異常検知の手法があることが分かりました。

えっ?GANって生成モデルじゃなかったのと思いましたが、AnoGANというGANは異常検知が出来るらしく、GAN好きの私にとってはとても興味深く思えました。

というわけで、今回は、Keras AnoGANでMNISTの異常検知をしてみたいと思います。

 

AnoGANとは?

AnoGANとは、Anomaly Detection with Generative Adversarial Networksの略で、文字通りGANを使って異常検知をするという意味です。

GANは、多量の正常画像を学習すると、潜在空間の中に学習した正常画像を覚え込み、入力(ランダムノイズ)に応じて覚え込んだ様々な正常画像を生成出来るようになります。

AnoGANは、このGANの学習済みモデルを利用して、GANの入力を適切に変化させ、検査する画像に出来るだけ近い画像を生成させます。この時、検査する画像が潜在空間に含まれていないと上手く画像生成が出来ないので、異常と判断するわけです。

模式図を使って、もう少し詳しく説明しましょう。

 

これが、GANの模式図です。Generatorは Noiseを入力としてFakeImage(偽画像)を生成することを学習します。Discriminatorは RealImage(本物画像)とFakeImage(偽画像)を間違えないように識別することを学習します。つまり、GeneratorとDiscriminator が切磋琢磨(Leran weights=重み学習)することによって、最終的には Generator が本物そっくりな偽画像が生成出来るようになります。

 

これが、AnoGANの模式図です。GANの学習済みモデルを利用して、GeneratorとDiscriminatorは新たに学習させず、FC layer を追加してその重みを調整し TargetImage(検査する画像)に出来るだけ近い FakeImage(偽画像)を生成させます。この時、TargetImageが今まで学習したものでないと上手くFakeImageが生成出来ません。従って、FakeImageと TargetImageを比較して、その差が小さければ正常、その差が大きければ異常と判断するわけです。

 

実装します

今回も先回同様、シンプルに mnist で実装します。正常画像は「1」、異常画像は「9」とします。

Generator と Discriminator のコードです。このままだと分かり難いので、実際の入力を入れた場合のサマリーを下記に記載します。

Generator のサマリーです。入力は30次元のランダムノイズ、これを6272個の全結合層で受けて、7×7画像128枚→14×14画像64枚→28×28画像1枚という逆畳み込みを行って偽画像を生成します。

 

Discriminator のサマリーです。こちらは先程の逆で、28×28画像1枚→14×14画像64枚→7×7画像128枚と畳み込みを行って、6227個の全結合層に繋ぎ、画像判定を行います。

 

DCGANのコードです。3−14行目は、Generator とDiscriminator を組み合わせて、DCGANを構成する関数です。

16−48行目は、DCGANを学習させる関数です。25行目で入力のランダムノイズを発生させ、28−38行目で学習画像と偽画像を比較、41−48行目で2つのロス計算をします。

49−53行目は generatror の生成画像を10epoch毎に4×4に連結して result フォルダーに保存する部分で、途中で65−76行目の画像を連結させる関数を呼び出しています。

56−59行目は、 generator と discriminator が学習した重みファイルを50epoch毎に weights フォルダーに格納する部分です。

61−63行目は、generator と discriminator が学習した重みファイルを読み込む関数 で、これは学習後のテスト時に使います。

 

AnoGANのコードです。3−20行目はDCGANの学習済みモデルを利用してAnoGANを構成する部分、22−30行目は検査する画像と生成画像の違いのスコアを計算する部分です。

 

学習を実行するコードです。3−8行目は各種設定部分、10−44行目はデータセット作成部分(学習データは6742個、評価データは100個)、47−52行目はDCGANを学習させ、ロスの推移をファイル(’loss.csv’)に書き込む部分です。

 

評価画像をテストするコードです。15−16行目はGeneratorとDiscriminator の学習した重みを読み込む部分、18−22行目は評価画像と生成画像の違いのスコアを計算する部分、24−27行目は評価画像と生成画像をペアにした画像predict フォルダーに保存する部分です。

29−34行目はスコアを保存する部分です。評価画像が「1」の場合のスコア ’scores_1.txt’ に、「9」の場合のスコア ‘scores_9.txt’ に保存します。

37−61行目はスコア・ヒストグラムを作成する部分です。検査する画像が「1」の場合と「9」の場合のスコア・ヒストグラムを重ね描きし、’histgram.png’ で保存します。

 

コードを実行します

適当なフォルダーに、コード全体(train.pyとか名前を付ける)、result フォルダーweights フォルダーpredict フォルダーを格納し、コードを実行します。なお、コード全体はブログの最後に記載してあります。

DCGANを学習させる部分は割と重いですが、ノートパソコンでもなんとかやれるレベルです。私の MabookAir だと200sec/epoch で、100epoch で5時間半くらいというところ。学習が完了すると、自動的にテストを開始しますが、こちらは軽快に動きます。

 

学習時、result フォルダーに保存されるGeneratorの生成画像です。100epochで、色々な「1」を生成出来る様になります。。

 

テスト時、predict フォルダーに保存される 画像です。各画像の左が評価画像右が生成画像です。学習した「1」に関しては、色々な「1」を見事に再現出来ています。

 

一方、学習をしてない「9」に関しては、全く再現が出来ていません。

 

評価画像100個のスコア・ヒストグラム(’histgram.png’で保存されていますです。「1」は赤「9」は青でプロットしています。「1」と「9」のヒストグラムの重なりはほとんどなく、上手く異常検知出来てますね。今回の場合、スコアの閾値を28〜36の間に設定すると、97%くらいの精度で異常検知が出来ることが分かります。

AnoGANはテスト時に評価画像に出来るだけ近い画像を探索するトライが必要(今回は300回)なので、その分応答性は今一つですが、精度はまずまずという感触です。

また、本ブログ作成に際して、参考にさせて頂いたブログは「旅行好きなソフトエンジニアの備忘録」です。感謝致します!

最後に、コード全体を載せておきます。

では、また。