PyTorch そして LSGANをやってみる

今回は、安定した学習を可能にしたLSGANを試してみます

こんにちは cedro です。

私が、ディープラーニングの中で一番好きなのは生成系。中でも、単なるノイズから本物の様な画像を生成するGANは大好物です。

特に、最初何も存在しない画像から、少しづつ本物に近い画像に変化して行くプロセスを見るのが好き。

但し、DCGANは色々やってみましたが、大きな画像になると直ぐ mode collapse (生成が途中で失敗し砂嵐に戻る)になってしまい、パラメータ設定が微妙で学習が安定しないのが難点でした。

今回、例のハンドブックの第6章には、安定した学習を可能にしたLSGANというのが載っていたので、早速試してみたくなりました。

ということで、今回は、安定した学習を可能にしたLSGANを試してみます。

 

 

GANのお勉強

これは GAN の模式図です。Generator は Noise (乱数)を入力として、Discriminator に本物と間違わせるような偽物を作成することを学習します。

一方、Discriminator は本物と偽物を間違えないように学習します。この2つのネットワークが切磋琢磨することで高度な画像生成ができる様になります。

 

LSGANの論文に出て来る目的関数です。a,b,cは定数で、論文ではa,b,c=−1,1,0 または a,b,c=0,1,1 が推奨されていて、ハンドブックではa,b,c=0,1,1 の方が使われています。

式(1)も(2)も最小化することが目的です。それにはどうすれば良いかと言うと、式(1)のDiscriminatorの方は D(x) が1、D(G(z)) が0になれば良い。式(2)のGeneratorの方はD(G(z))が1になれば良い。

言い換えると、Discriminator は自らが本物を本物と判断し偽物を偽物と判断することが目標となり、Generator は Discriminator が偽物を本物と間違えることが目標になります。

 

これは、DCGANの論文に載っている Generator のネットワークです。100次元ベクトルを入力に、 1024×4×4 → 512×8×8 → 256×16×16 → 128×32×32 → 3×64×64と転置畳み込みを行います。その結果、3×64×64のフェイク画像を生成します。

Discriminator は基本的にこの逆で、3×64×64のフェイク画像や実画像を入力に、128×32×32 → 256×16×16 → 512×8×8 → 1026×4×4 → 1×1×1と畳み込みを行います。その結果、1×1×1の判定結果を出力します。

 

コードを書きます

Generator のコードです。論文のままのスペックだと重量級になり、GPUを使っても学習にかなり時間が掛かるので、入出力以外のチャンネル数は半減しています。

GANの安定化にはなくてはならないのが Batch_Normalization で、これは Generator と Discriminator の両方に入っています。

 

Discriminator のコードです。基本的に、Generator の逆です。活性化関数は学習が安定するように、Discriminator の方だけ LeakyReULを使うのがお約束です。

 

メイン部分です。データは、CelebAの画像約20万枚をセンターから160×160でクロップしてから128×128にリサイズしたものを celeba フォルダーの下の「0」フォルダーにまとめて格納しました。

データを読み込み必要な処理を行ったら、データローダに渡します。Generator と Discriminator の2つのネットワークをインスタンス化し、重みの初期化を行います。損失関数は MSELoss(平均二乗誤差)、最適化関数は これもお約束 Adam を使います。

 

学習ループです。先程の論文にある様に、Discriminator は本物画像の判定結果と目標値(本物)の二乗誤差と偽物画像の判定結果と目標値(偽物)の二乗誤差の最小化を図ります。そして、Generator は、Discriminator の偽物画像の判定結果と目標値(本物)の二乗誤差の最小化を図ります。

それでは、コードを動かしてみます。batch_size 64、20 epochを GTX1060で動かすと、所用時間は2時間強くらいでした。

ロスDとロスGの推移グラフです。横軸の単位は iter (×100)で、バッチ単位で学習を行った回数です。いやー、ビックリしました。もう何の不安も感じさせないほど安定してます。LSGAN凄いです。

 

1500 iter 毎に生成した fake_image をGIF動画にしたものです。毎回入力するベクトルを固定しているので、同じ画像の質がだんだん上がって行く状態が見えて興味深いです。

それにしても、以前やってみたDCGANとは比べ物にならないくらい安定しています。技術革新の早さを感じますね。最後に、コード全体を載せておきます。

では、また。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

ABOUTこの記事をかいた人

アバター

ディープラーニング・エンジニアを趣味でやってます。E資格ホルダー。 好きなものは、膨大な凡ショットから生まれる奇跡の1枚、右肩上がりのワクワク感、暑い国の新たな価値観、何もしない南の島、コード通りに動くチップ、完璧なハーモニー、仲間とのバンド演奏、数えきれない流れ星。