たった100枚の画像でもGANの学習を可能にする技術

1.はじめに

 画像生成するGANの学習には大量のデータが必要で、データが少ないとオーバーフィッティングによるモード崩壊が起こり易く、そうしたことが発生した場合の対策はデータ数の増加しかありませんでした。

 一方、画像分類では、データが少ない場合の対策として、データを変換(拡大、回転、シフトなど)して水増しする Data Augmentation が有効であることが以前から知られていました。

 今回は、2020/6に発表されたGANの学習における Data Augmentation とも言える、Differentiable Augmentation についてご説明します。

2.Differentiable Augmentationとは

 単純に画像分類の場合と同様にリアル画像 x だけをT(x)に変換すると、D (Discriminator) が変換後の画像も実際の画像の分布に含まれていると勘違いしてしまうため、あまり強い変換をかけることが出来ません。

 そこで、上記の図の様に、フェイク画像G(z)もT(G(z))に変換することで、強い変換をかけられるようになったというのがポイントです。式で表すと以下の様です。

 ここで使われる変換T(x)は、カットアウト(画像をランダムな正方形でマスキング), 色・コントラスト・彩度のランダム変化などが使われています。

 さらに、T(x)による変換の影響がより小さくなるConsistency正則化という手法も使われています。

 それでは、本当に効果があるのかをテストしてみましょう。

3.コード

 コードはGoogle Colabで動かす形にしてGithubに上げてありますので、それに沿って説明して行きます。自分で動かしてみたい方は、この「リンク」をクリックし表示されたノートブックの先頭にある「Colab on Web」ボタンをクリックすると動かせます。

 今回、GANはStyleGAN2、データはWebから適当に拾った新垣結衣さんの顔画像100枚を使って Differential Augmentation を試してみます。

 最初に、tensorflow 1.15.0をインストールし、Githubからコードをコピーします。そして、新垣結衣さんの画像100枚(64×64)と学習済みの重みをダウンロードします。詳細はGoogle colabを参照下さい。

 まず、関数を定義します。

 先程ダウンロードした新垣結衣の画像100枚(100-shot-gakki)からデータセットを作成します。

 ご自分のオリジナルデータを使う場合は、同様に 64×64 の画像を100枚集めてフォルダーに入れて使えばOKです。100枚であれば、簡単に集められると思いますので、ぜひトライしてみて下さい。

 さあ、こんな100枚の画像だけで本当にGANの学習ができるのでしょうか?

***************************************

 学習に入る前に、ご自分のGoogle Colabに割り当てられているGPUのタイプを確認します。

 試行回数(kimg)を300とした場合に、P100で7.3時間、V100で4.1時間かかりますので、それを確認後下記のコードを実行し学習を開始して下さい。
 なお、GPUがK80の場合や学習に時間を掛けたくない方は、下記のコードの実行はパスして下さい。

※学習をした場合は、学習完了後 resultsフォルダーの中に( network-snapshot-XXXXXX.pkl)が作成されますので、それをDiffAugment-stylegan2のディレクトリーに移動して下さい。そして、この後の generate() , generate_gif.py の引数をそのファイル名に変更して下さい。

**************************************

 先程ダウンロードした学習済みの重みを使って画像生成します。

 さすがに学習画像100枚では生成画像の質は高くないですが、なんとかガッキーを再現しています。

 先程ダウンロードした学習済みの重みを使ってGIF動画を作成し interp.gif で保存します

 GIF動画を作成してみました。それにしても、以前は事前学習無しで100枚の画像だけでGANをやるなんて考えれなかったわけですが、技術の進歩は凄いですね。

 では、また。

(参考) https://github.com/mit-han-lab/data-efficient-gans

2 件のコメント

  • すいません、別の画像群に変えようと思い、URLをその画像のZIPファイルに変更しました

    ただ、100-shot-gakki という名称がどこで指定された値かがわからずご教示お願いできますでしょうか。

    • horieighturtleさん
      コメントありがとうございます。

      まず、コードにある画像群のダウンロード方法は、Google drive 専用であることにご注意下さい(https://qiita.com/jun40vn/items/66fff06abe48e01e23e3)

      以下の3ステップで進めればOKです。
      1)指定された方法で、画像群のzipファイルを DiffAugment-stylegan2 の下にダウンロードし、解凍する
      2)解凍して出来たフォルダー名でデータセットを作成する(data_dir = dataset_tool.create_dataset(‘***’) の ***にフォルダー名を入れる)
      3)解凍して出来たフォルダー名で学習を行う(!python3 run_few_shot.py –dataset=*** –resolution=△△ –total-kimg=□□□ の ***にフォルダー名を入れる)

  • コメントを残す

    メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

    日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

    ABOUTこの記事をかいた人

    アバター

    ディープラーニング・エンジニアを趣味でやってます。E資格ホルダー。 好きなものは、膨大な凡ショットから生まれる奇跡の1枚、右肩上がりのワクワク感、暑い国の新たな価値観、何もしない南の島、コード通りに動くチップ、完璧なハーモニー、仲間とのバンド演奏、数えきれない流れ星。