Keras MLPを改造して定番パターンを勉強する

今回は、Keras のサンプルプログラム MLPを改造してみることで、新たな定番パターンを勉強したいと思います。

こんにちは cedro です。

最近、Keras をよく触るようになりました。

なぜかというと、KerasWeb に様々な情報が溢れていて、欲しい情報が直ぐ手に入るからです。

例えば、サンプルプログラムは Keras_team の公式版の他にも、色々な方々が作ったものが沢山見つかりますし、個別の機能を紹介するブログにも事欠きません。

また、何かエラーが発生した場合でも、多くの場合その処方箋が簡単に Web で探すことができます。これはありがたいことです。

Keras に触り始めて約2ヶ月経ち、そろそろ色々な 定番の処理パターンを集中してチェックしてみようかなと思っているところです。

ということで、今回は、Keras のサンプルプログラム MLPを改造してみることで、新たな定番処理パターンを勉強したいと思います。

 

Keras MLPとは

今回、改造するサンプルプログラムは、mnist_mlp.py で、0〜9の数字のデータセット MNISTMLP(多層パーセプトロン)で分類する基礎的なものです。

MacbookAir でサンプルプログラムをそのまま動かすと、ターミナルに Epoch毎のデータが次々と軽快に表示され、3分かからずに終了します。

しかし出て来る結論は 、Test_loss : 0.1285  (評価ロスは1.285%) ,  Test_accuracy : 0.9814 (分類精度は98.14%)のたった2行だけ。

えっ!?これだけなの?という感じですよね。

特に、データセット(MNIST)は何処か知らないところから自動でダウンロードされて来るので、益々「これだけ?」感が増します。

ということで、これからこのプログラムを改造して行きます。

 

新たなデータセットを準備します

今回使うデータセットは、NDL Lab の文字画像データセット(平仮名73文字版)で、グレースケール48×48平仮名画像PNG形式で計 80,000枚あるものです。この中から、「あ、い、う、え、お、か、き、く、け、こ」の10種類だけ抜き出します。

各文字の画像数は1,200枚前後で、合計約12,000枚です。MNISTのデータ数は70,000枚なので、その1/6くらいと少ないですが、なんとかなるでしょう。

 

プログラムと同じところに、hiragana フォルダーを作成し、その下に0〜9のフォルダーを作成し、「あ」〜「こ」の文字をそのまま格納します(前処理は一切不要です)。

 

データセットを入れ替えて動かします

新たに必要なライブラリーをインポートします。

 

このブログでは何度も登場しているデータセット読み込みの定番部分です。ラベル数が多いデータセットの読み込みは、この方法が便利ですねー。

なお、ネットワークの入力が全結合のため、x_trainshape[ 学習データ数, 28, 28, 1 ]  から [ 学習データ数, 784 ] にリシェイプしています。x_test も同様です。

 

さて、動かしてみると、データ数が少ない(MNISTの1/6)ので、処理が早いです。約20秒で処理が完了です。

さすがに、自分が準備したデータセットを使ったので、Test_accuracy : 0.9910 (識別精度99.10%)を見ると、「データが少ない割には結構良い精度じゃん」、と実感がわいて来ます(笑)。

 

ロス、精度の時系列グラフが欲しい

ロス、精度については、リアルタイムにターミナルに数字は表示されますが、それをを見るだけではピンと来ません。やはり、時系列推移はグラフで見たいところ。

プログラムの最後に、これを追加します。

Keras は、学習時の様々なデータを history ディレクトリに保存し、そこから必要なデータを読み出すことができる機能を標準で持っていますので、これを活用します。但し、学習部分に、history = model.fit ( x_train, y_train,・・・という様に、history が記述されている必要があります。

インポートしているのは、グラフを表示させるためのライブラリ Matplotlib です。さて、これでプログラムを動かすと、

ロスの推移グラフです。

精度の推移グラフです。

 

Confusion Matrix が欲しい

私がディープラーニングを始めたきっかけになった SONY Neural Network Console には、データセットの識別をする場合、どのデータと間違えたかが一目で分かる Confusion Matrix 機能がありました。Keras にも、これが欲しいということで、追加します。

プログラムの最後に、さらにこれを追加します(「静かなる名辞」さんのブログを参考にさせて頂きました。感謝です。)

まず、confusion_matrix というズバリのライブラリーをインポートします。しかし、これだけでは行列の形で数字を返してくれるだけなので、見栄えが悪いです。

そこで、pandasseaborn のライブラリーをインポートし、直感的に分かり易いヒートマップ形式の画像で保存します。

ヒートマップの仕様は、sn.heatmap(df_cmx, annot=True, fmt=””)の各引数によって、指定しています。df_cmx は表示するデータ、annnot=True はセルに値を表示fmt=”d”整数で表示。

予測したラベルの取得には、model.predict_classes メソッドを使っています。さて、これでプログラムを動かすと、

 

Confusion Matrix です。X軸予測したラベルY軸実際のラベルです。マスの中の数字は、そこに該当したデータ数を表しています。

例えば、実際の0をどう予測したかを見てみると、正解の0と予測したのが235個4と間違えたのが1個8と間違えたのが2個あったことになります。

つまり、左上角から右下角への斜めのマスの数字以外は、全てゼロになるのが理想の状態です。今回の結果は、ほぼ理想に近い状態ではないでしょうか。

 

ネットワークモデルを可視化したい

プログラムの最後に、さらにこれを追加します、

Keras にはネットワークモデルを可視化するための plot_model という ライブラリーがあって、これをインポートしておけば、わずか1行でネットワークモデルを画像ファイルで出力できます。便利ですねー。

但し、あらかじめ pydot graphviz をインストールしておく必要があります。さて、これでプログラムを動かすと、

 

モデルを可視化した結果です。こういう形にすると、モデルが分かりやすいですね。

では、また。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

ABOUTこの記事をかいた人

アバター

ディープラーニング・エンジニアを趣味でやってます。E資格ホルダー。 好きなものは、膨大な凡ショットから生まれる奇跡の1枚、右肩上がりのワクワク感、暑い国の新たな価値観、何もしない南の島、コード通りに動くチップ、完璧なハーモニー、仲間とのバンド演奏、数えきれない流れ星。