PyTorch まずMLPを使ってみる

今回は、フレームワークの「Hello world 」であるMLPを使って、PyTorch の特徴をみてみます。

こんにちは cedro です。

年末に、本屋で「PyTorch ニューラルネットワーク実装ハンドブック」という新刊本を見かけて、何となく気になりました。

後で、Webで調べてみると、PyTorch が中々魅力的なフレームワークなことが分かりました。特徴は3つ、

1)Tensorflow より簡潔にコードが書け、それでいて細かな操作ができるらしい
2)研究者に人気があり、論文発表後に、すぐ実装例がGithubに上がことが多いらしい
3)コミュニティが活発でネット上に参考資料が豊富にあるらしい

特に気に入ったのが、2)の最新の論文の実装が手に入る点で、これはとても刺激的です。今まで、Keras を極めようと思っていた気持ちは何処へやら、もうPyTorch の魔力にかかり、大晦日にこの本を買って帰りました。

ということで、今回は、フレームワークの「Hello world 」であるMLPを使って、PyTorch の特徴をみてみます。

 

PyTorch のインストール

まず、PyTorch をインストールするために、PyTorch のホームページに行きます。

 

自分のインストールしたい組み合わせを赤色で選ぶと、どういうコマンドでインストールすれば良いかが一目で分かります。特に、GPUを使う場合にCUDAのどのバージョンを使うのかを指定してインストールできるのが、安心です。

ちょっと前までは、Ver0.4がStable(安定板)だったのに、もうVer1.0がStableになっていますね。

 

私は、先程の画面で選んだ組み合わせから示されたコマンド、pip3 install torch torchvison でインストールしました。

 

MLP_MNISTのコードを順に見て行きます

まずは、MLPでMNISTの分類をするコードを順番に見て行きましょう。

データセットを読み込む部分です。KerasMNISTを読み込む時は、(x_train, y_train), (x_test, y_test)=  mnist.load_data() と1行で読めるので一見簡単にみえます。しかし、その後、型を変えたり、正規化したり、ラベルをOne_hotにしたりと、バラバラと色々な処理が必要です。

PyTorch の場合は、データセットを読み込む時に、transform = transforms.ToTensor() と引数で指定しておくだけで、その後の処理を自動で行ってくれますし、そもそもラベルをOne_hotにする必要がありません。しかも、データセットを数値とラベルに分けずにまとめて処理できるので、非常にすっきりします。

もちろん、データセットをそのまま分割する、torch.utils.data.random_split というのがあります。ここでは、train_Dataset(60000個) をさらに、train_dataset(48000個) とvalid_dataset(12000個) に分割しています。

 

データローダの部分です。データセットからミニバッチ単位でデータを取り出し、ネットワークへ供給することが出来ます。シンプルで分かりやすいです。

 

ネットワークを構築する部分です。まず、MLPNetクラスでネットワークを定義しています。ネットワークの記述には、色々な方法があるみたいですが、これは nn.Moduleを継承した記法で、前半にブロックを書いて、後半に接続を書いています。

後は、デバイス(GPUかCPU)を選択して、損失関数と最適化関数を設定しています。

 

学習部分です。ここは、細かな記述が必要ですが、こうなっているからこそ、複雑なネットワークを試す時に、細かくデバッグができるわけですね。

train_modevalid_mode に分けているのは、DropoutBatchNomalization などの学習の時には効かせて、評価の時には効かせないブロックがあるためです。

ログの出力や後でグラフを描かせるためのデータ保持をする部分があり、ここはもうちょっと簡略化しても良い気もしますが、まあ良いでしょう。

 

学習後の部分です。テストデータセットを使って、モデルの最終的な精度を計算します。その後、学習した重みファイルを保存し、ロスと精度の推移グラフを描きます。

とりあえず、動かしてみましょう。

50epoch後にテストデータを使って計測した精度は Test_accuracy = 98.36%でした。

 

オリジナルデータを読み込んでみる

あらかじめ用意されたデータセットは簡単に読み込めて当たり前です。問題は、オリジナルデータを読み込む場合がどうかが重要です。ここでは、具体的なオリジナルデータを読み込んでみます。

 

今回用意したデータは、NDL Lab の平仮名73文字から「あ、い、う、え、お、か、き、く、け、こ」の10種類を抜き出したものです。たぶん、オリジナルデータの典型的な形ではないでしょうか。

各文字の画像数は1,200枚前後で、で合計11,754枚。各画像は、カラー48×48ピクセルのPNG形式です。

 

 

root に、hiragana フォルダーを置いて、その下に0〜9のフォルダーを作成し、「あ」〜「こ」の文字画像を格納します。PyTorch では、このデータをどうやって読み込むかと言うと、

 

データの読み込は、実質8–10行目のたった3行だけです。これだけで、0〜9のフォルダー名をラベルとして認識し、データセットとして読み込みます。ふと思い出しましたが、これって、SONY Neural Network Console の画像データを読み込んでデータセットを作成する場合と同じですね。凄く懐かしい。

そして、データに何らかの前処理を加えたい場合は、その内容を3−6行目の様に  transforms_Comose で列挙しておけば、前処理も一気にやってくれます。これは便利!

では、実際にオリジナルデータを読み込んで表示させてみます。

 

オリジナルデータを読み込んで、データセットとデータローダを作り、内容を確認するコードです。

まず、データに必要な前処理を加えて読み込み hiragana_dataset を作り、troch.utils.random_split で  train_dataset, valid_dataset, test_dataset に分割し、それぞれデータローダ(これは先程と同じ)を作成します。そして、最後に train_dataset のimageを8×8のタイル状に表示させ、シェイプ等を確認します。

54行目の torchvision.utils.make_grid が優れもので、ミニバッチの画像入力を受け取って、N*Nのタイル状の画像を自動で作ってくれます。今回の様に入力画像の確認に使っても良いですし、特に生成系の画像の出力には重宝しそうです。

59-60行目は train_loaderのimagesの先頭のシェイプlabels の先頭のデータを確認する部分です。

では、コードを動かしてみます。

 

作成したデータセットが狙い通りになっているのか、train_loader のミニバッチ画像を可視化して確認が出来ます。これ、地味に嬉しくないですか。

そして、imagesの先頭のシェイプtroch.Size( [1, 28, 28] )labelの先頭は tensor(7) で、可視化した画像の左角の「く」のフォルダー名「7」と合っていますね。

それでは、改造したMLP全体のコードを動かしてみます。

 

20epoch後にテストデータを使って計測した精度は Test_accuracy = 99.4% でした。ひらがなのデータセットは、MNISTより簡単なようです。

今回、初めてPyTorchに触ってみたわけですが、メリハリが効いた良いフレームワークだなという感じがします。

データセットを読み込むとか、画像をタイル形状で表示するとか、開発に直接関係ないが良く使う部分は高度に自動化してある一方で、ネットワークを構築するとか、学習・評価するとか、開発に直接関係する部分は細かく記述出来るようになっていて、非常に好印象です。

しばらく、PyTorch を中心に触ってみて、早く慣れたいと思います。最後に mlp_mnist.py の全体のコードを載せておきます。

では、また。

 

 

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

ABOUTこの記事をかいた人

アバター

ディープラーニング・エンジニアを趣味でやってます。E資格ホルダー。 好きなものは、膨大な凡ショットから生まれる奇跡の1枚、右肩上がりのワクワク感、暑い国の新たな価値観、何もしない南の島、コード通りに動くチップ、完璧なハーモニー、仲間とのバンド演奏、数えきれない流れ星。