cedro-blog

Keras VAEの画像異常検出を理解する

今回は、VAEを使った最新の画像異常検知について理解してみます。

こんにちは cedro です。

以前 、オートエンコーダーを使って、ノイズ除去メガネ女子のメガネ除去をやってみました。いずれも、入力画像に今まで学習したことがない物が含まれていると、その物は出力画像に再現出来ないという性質を利用したものです。

この性質を発展させると、正常品の画像のみオートエンコーダーに学習させておくと、入力に異常品の画像が入った時に異常検知が出来ることになります。しかも、学習時の正常品の画像は、ラベル付けの必要がなく、とても使い勝手が良いと思っていました。

そうした中、最近、VAEを使った画像の異常検知について最新の論文の内容を実装している興味深いブログを見つけました。まだ、自分にとってはとても難しいレベルなのですが、自分なりに読み解いてなんとか理解したいと思います。

ということで、今回は、VAEを使った最新の画像異常検知について理解してみます。

 

論文のポイント

今回題材にする論文は、深層生成モデルによる非正則化異常度を用いた工業製品の異常検知 です。内容は、複雑な工業製品の画像を対象にした高精度な異常検知方法の提案です。

 

これは論文からの引用です。VAE は、損失関数 Lvae(x)  =  Dave(x) + Avae(x) + Mvae(x) と3つの項から構成されています。1つ目 Dvae(x) は、与えられたデータの中から頻度の高い特徴を学習し、頻度の少ない特徴を無視する性質を持つ項。2つ目 Avae(x) は、与えられたデータの中の特徴の複雑さに応じて調整する性質を持つ項。そして3つ目 Mvae(x) 直接的に再現誤差に関係している項です。

ご存知の様に、VAE は高次元の空間にあるデータを低次元の潜在空間に、特徴量を連続的に滑らかに写像させることを狙いとしています。こういう場合は、Dave(x) Avae(x) が有効に機能するわけです。

しかし、複雑な工業製品の画像の異常検知を行う場合は、検出したい異常は頻度が低く、様々な種類の要素(平らな表面、曲がった部分、ネジの穴など)から構成されているため複雑さも高いですなので、Dave(x) Avae(x) の項目は異常かどうかを判定するための閾値を変化させてしまいます。従って、こういう場合は、Dave(x) Avae(x) の項目は削除し、損失関数 Lvae(x) = Mvae(x) だけで良いと言うのが、論文の主張です。

 

これも論文からの引用です。これは、Dave(x) Avae(x) があると、特徴量の出現頻度によって異常検知の閾値が変化してしまうということを示す概念図です。出現頻度が高い場合の閾値は相対的に高く、出現頻度が低い場合の閾値は相対的に低くなります。

 

実装します

論文では、モノクロ640×480のネジ穴画像を使って実験しています。学習時は、640×640の正常画像からランダム96×96のサイズを切り出して学習を行います。テスト時は、640×640の正常・異常画像から96×96サイズ16ピクセル間隔で切り出して異常度を算出し、少なくとも1枚が閾値を越えていたら異常と判断しています。

今回は、これをmnist で実装してみます。正常画像は「1」異常画像を「9」とします。学習時は、28×28の正常画像「1」の画像からランダム8×8サイズを100,000枚切り出して学習を行います。テスト時は、28×28の正常画像「1」1枚と異常画像「9」1枚から8×8サイズ2ピクセル間隔で切り出して異常度を算出します。まだ閾値は設定せず、上手く検出できるかどうか見るために可視化のみ行います。

 

28×28画像から8×8サイズを切り出す関数 cut_img() です。引数は、x = 元画像データ、number = 切り出す枚数です。

8-10行目で、shape_0 、shape_1、 shape_2 に適切な範囲のランダムな整数を入れます。11行目で、temp = x [ shape_0, shape_1:shape_1+height, shape_2:shape_2+width, 0 ] によって、temp にx画像から8×8サイズに切り出したデータが入ります。12行目で、x_out temp の結果をアペンドします。これをnumber 回繰り返すわけです。

この理屈は、単純化したサンプルコードを見ると分かりやすいので、下記に補足します。

 

5行5列のデータから2行3列のデータを切り出すサンプルコードです。これを、実行すると、

こんな結果になります。つまり、x = x [ 0, 0:2, 0:3, 0 ] で、左上角から2行×3列分データを切り出すことになるわけです。

 

データセットを作成する部分です。3−8行目で、minist のデータセットを読み込みます。

11行目から学習データの作成です。学習には「1」のデータしか使わないので、データセットから「1」のみ抽出し、19行目で先程の関数を呼んで、x_train_1 「1」からランダムに切り出した8×8サイズの画像100,000枚格納します。

24行目から評価データの作成です。評価には「1」と「9」のデータを使うので、データセットから「1」と「9」を抽出します。最終的な評価には、test_normal にランダムに選んだ「1」を1枚格納し、test_anomaly にランダムに選んだ「9」を1枚格納します。

 

学習後に、テスト画像の評価結果を可視化する関数 evaluate_img() です。引数は、model = モデル名、x_noramal = 正常画像データ(1枚)、x_anomaly = 異常画像データ(1枚)、name =  従来手法(old)か提案手法(new)の選択です。

4−5行目で、8×8サイズの判定結果を順次上書きするための28×28サイズを「0」で埋めた配列 img_normalimg_anomaly を作ります。

7−12行目で、x_nomal (正常画像)から8×8サイズを切り出したものx_sub_nomal に格納し、x_anomaly(異常画像) から8×8サイズを切り出したものx_sub_anomaly に格納します。そして、これを行方向、列方向の両方で、2ピクセル間隔で順次行います。

15-22行目は、従来手法による計算です。17行目で x_sub_normal のスコア(nomal_score)を計算し、18行目で先程作成した img_normal に上書きします(8行8列には同じスコアが加算されます)。x_sub_anomalyの場合も同様です。

27−40行目は、提案手法による計算です。29−31行目で、Mvae(x) を計算するために、x_sub_normal について変数 loss  = 0.5*(x_sub_normal – mu) **2 / sigma1ピクセル単位で計算し、loss に累計します。32行目で、先程作成した img_normal にこれを上書きします(8行8列には同じloss値が加算されます)。x_sub_anomalyの場合も同様です。

 

コードを動かしてみます。

学習データが100,000個ありますが、モノクロ28×28なので、ノートパソコンでも軽快に動きます。私のMacbookAirでは、26 sec/epoch で、10epoch を4分ちょっとで完了しました。

 

コードを実行すると、まずモデルサマリーが、Encoder、Decoder、VAE と表示され、学習を開始します。

 

従来手法(old)による異常検出結果です。normal(正常画像)とanomaly(異常画像)による差があまりありません。これでは、適切な閾値の設定は難しく、検出精度はあまり期待出来ません。

 

提案手法(new)による異常検出結果です。normal(正常画像)とanomaly(異常画像)による差が明確です。これなら、適切な閾値を設定すれば、結構高い検出精度が期待出来そうです。

コードのデータセットの読み込み部分を少し変更すると、Fashion _mnist でも直ぐ試せますので、やってみましょう。

 

従来手法(old)による異常検出結果(7.スニーカーと9.ブーツ)です。やはり、normal(正常画像)とanomaly(異常画像)による差があまり無いです。

 

提案手法(new)による異常検出結果(7.スニーカーと9.ブーツ)です。normal(正常画像)とanomaly(異常画像)による差が明確です。

 

従来手法(old)による異常検出結果(2.セーターと3.ドレス)です。やはりイマイチです。

 

提案手法(new)による異常検出結果(2.セーターと3.ドレス)です。素晴らしい! 本当にこれ結構実用になりそうな気がします。

最後に、コード全体を載せておきます。

 

 

さて、今回大変お世話になったブログは、「Variational Autoencoder を使った画像の異常検知 前編」です。深く感謝致します! いやー、それにしても、私の実力ではトレースするだけで精一杯でした(笑)。

なお、コードは自分で理解しやすい様に、オリジナルに対して順番を入れ替えたり表現を一部修正したりしています。

いつかは、自ら論文を読んで実装できる様になりたいものですが、それは見果ての夢かなー。

では、また。