1.はじめに
GAN(Generative Adversarial Network)は、あるドメインの画像データを大量に学習すると、学習しなかった画像も生成できる面白いモデルなのですが、1つ大きな欠点があります。それは、学習に非常に手間が掛かることです。
特に、画素数の大きいものをゼロから作る場合、多量の画像データ・並列処理用の多くのGPU・時間という多大な学習コストが掛かり、個人で手を出すのは無理です。なので、通常は自分が生成したい画像に似たものを生成する学習済みモデルを見つけて転移学習などをして利用します。
今回は、この様な学習コストの高いGANを改良し、少量の画像データ・1つのGPU・短時間でゼロからの学習が行える、Lightweight GANをご紹介します。
2.Lightweight GANとは?
Generatorは以下のような構造になっています。
基本的に青線で順次アップサンプリングを行って行くわけですが、ポイントは赤線の Skip-layer excitation によって直近の特徴マップと4段階下の特徴マップを融合させることです。右側にあるその詳細図で分かる様に、直近の特徴マップの方がスキップ構造になっているので、計算量をあまり増やさず安定性を向上させることが出来ます。
次に、Discriminatorです。
基本的に青線で順次ダウンサンプリングを行い、最終的に5×5でリアル画像とGenerater画像との差を計算しこれを Adversarial loss とします。ここでのポイントは、16×16の特徴マップから8×8をランダムクロップしたものと8×8の特徴マップの差を Reconstruction loss とし、より画像の特徴を捉えやすくしていることです。
そして、Loss は以下の様にします。基本は Adversarial loss で、リアル画像を入力した時には Discriminator の側に Reconstruction loss を加えます。
この他、以前のブログでもご紹介した様に、学習の際に Differentiable-Augmentationを使用することで、少量のデータでの学習を可能しています。
3.コード
コードはGoogle Colabで動かす形にしてGithubに上げてありますので、それに沿って説明して行きます。自分で動かしてみたい方は、この「リンク」をクリックし表示されたノートブックの先頭にある「Colab on Web」ボタンをクリックすると動かせます。
今回のコンセプトは、色々なラーメンの画像をGANに学習させて、様々なラーメンを生成するモデルを作るです。ラーメン画像は300枚(カラー256×256ピクセル)だけ用意しました。
まず、google driveと接続します。理由は、作成したファイルを簡単にダウンロードするためと、もし途中で接続が切れても過去の学習データを保存しておくためです。
1 2 3 |
# google drive に接続 from google.colab import drive drive.mount('/content/drive') |
1 2 |
# MyDrive に移動 %cd ./drive/MyDrive |
次に、セットアップを行います。なんと、lightweight_gan は pip でインストールでき、それ以後コマンドとして使えるようになります。超便利!そして、ラーメン画像のデータをダウンロードします。
1 2 3 4 5 6 |
# lightweight_ganをインストール !pip install lightweight-gan # githubから学習用データを取得 !git clone https://github.com/cedro3/lightweight_gan.git % cd lightweight_gan/ramen |
それでは、モデルの学習を開始します。学習時間は、GPUの種類によって変わり、V100で3時間、P100で5時間くらいです。
なお、学習途中の生成画像は、./results/ramen/ の中に順次保存されますので随時見ることが出来ます。
1 2 3 4 5 6 7 |
# 学習の実行 !lightweight_gan \ --data ./data \ --name 'ramen' \ --batch-size 16 \ --gradient-accumulate-every 4 \ --num-train-steps 15000 |
学習が完了したら、学習済みモデルを使って、様々なラーメンの画像を生成してみます。下記のコードを実行すると、生成したJPG画像が ./results/ramen-generated-1/ に保存されます。
–load-from ‘*’ (*は数字)オプションを付けると、./models/ramen/ に保存されているどの重みを使うかを指定出来ます。省略すると、最新のものが使われます。
1 2 |
# 学習済みモデルでJPG画像を生成 !lightweight_gan --name 'ramen' --generate |
それでは、複数の生成画像の間を補完する動画を作成してみましょう。下記のコードを実行すると、GIF動画が、./results/ramen/ に保存されます。
JPG画像生成の場合と同様に、–load-from ‘*’ (*は数字)オプションが使えます。
1 2 |
# 学習済みモデルでGIF動画を生成 !lightweight_gan --name 'ramen' --generate-interpolation |
様々なラーメン画像が連続的に生成されます。この中には、今まで無かったようなラーメンが含まれているはずです。
それにしても、ラーメンが無性に食べたくなって来ました(笑)。これから、ラーメン屋に直行します!
では、また。
(オリジナルgithub)https://github.com/lucidrains/lightweight-gan
コメントを残す