cedro-blog

PyTorch で Conditional GAN をやってみる

今回は、生成するクラスのコントロールが可能な Conditional GAN をやってみたいと思います。

こんにちは cedro です。

先回、LSGANをやってみて、その学習安定性に驚きました。ただ、LSGANは自分の生成したいクラスを指定することは出来ません。

CelebAのデータで言えば、 ランダムベクトルの入力に対して、 ブロンドの女性が出て来るか、黒髪の男性が出てくるか、メガネを掛けているのかいないのか、は成り行きです。

つまり生成するクラスはコントロールできません。そこで、生成するクラスをコントロールする機能を加えたGANの登場です。

ということで、今回は、生成するクラスのコントロールが可能な Conditional GAN をやってみたいと思います。

 

Conditional GANの仕組み

n_class = 3 の場合の概念図

これが、Conditional GAN が学習する時の概念図です。n_class = 3 で、ラベル番号1の画像を学習する時を想定しています。

ランダムベクトル入力には、ラベル番号1をOne-Hot形式にした [ 0 , 1, 0 ] を加算します。 その結果、ランダムベクトル入力は 100次元+3 = 103 次元となります。

画像の方は、ラベル番号1をOne-Hot形式にした [ 0, 1, 0 ] を、さらに画像のサイズに拡大します。0だけで埋まった64×64の画像(真っ黒な画像)と1だけで埋まった64×64の画像(真っ白な画像)を作って、チャンネルに追加します。従って、 3ch +3 = 6 ch となります。

このような形で、画像を学習する時に、ラベル情報も合わせて学習することで、学習後は、Generator のランダムベクトル入力に、指定したラベル番号を付加することで、生成画像のクラスのコントロールが可能になります。

 

コードで書いてみる

学習ループのところです。基本的にLSGANと同じで、ランダムベクトル、本物画像、偽物画像、それぞれにラベルの情報を加算する部分だけを追加します。なお、One-Hot形式への変換 、画像とラベルの連結、ランダムベクトルとラベルの連結 は、関数を呼び出して使っています。

 

データセットを準備する

データセットは、CelebAから 属性ファイルを使って 抽出します。ブログ「CelebA データセットから好みのデータセットを抽出する」を参照下さい。

 

今回の抽出条件がこれです。フォルダー0~6まで7つのクラスを設定します。やっぱり見たいのは女性の顔なので、女性限定で、髪の色、笑っているかいないかで分けて、最後にメガネを掛けている人(これは性別・髪の色・笑顔かどうか不問)を加えています。

ルートに、celeba フォルダーを作成し、その下に0~6のフォルダーを置きます。先程の条件で画像を抽出しながら、センターから160×160でクロップし、128×128にリサイズしたものを各フォルダーへ順次格納します。

各フォルダーで枚数がバラツキますので、10,000枚で揃えます(画像枚数が不揃いだと画像生成が上手く行かないようです)。

動かしてみます

ロス推移グラフです。LSGAN同様、学習安定性は非常に高いです。 70,000枚の画像、batch_size49、50 epochをGTX1060 で学習させるのに、2時間強かかりました。

 

1 epoch 毎に生成した fake_image をGIF動画にしたものです。fake_imageを生成するランダムベクトルには、適切なラベル情報(0~6の繰り返し)を加えています。

その結果、左から縦2列が黒髪の女性(笑っている、いない)、次の縦2列がブロンドの女性(笑っている、いない)、その次の縦2列がブラウンの女性(笑っている、いない)、そして一番右1列がメガネを掛けた人、という様に、狙ったクラスの画像生成が出来ていることが分かると思います。

狙ったクラスだけの画像を生成できるのは、中々興味深いですね。最後に、コード全体を載せておきます。

では、また。