1.はじめに
今まで、ある画像から別の画像へ色やテクスチャを転送する「スタイル転送」という技術がありましたが、今回ご紹介するのはもっと自然なスタイル転送を可能にする「TargetCLIP」と言う技術です。
*この論文は、2021.10に提出されました。
2.TargetCLIPとは?
TargetCLIPは、ある画像の概念的な「スタイル」と別の画像の客観的な「コンテンツ」を全く新しい画像に組み合わせる技術です。この技術は、画像を生成するStyleGAN2とCLIPが持っている画像を特徴ベクトルに変換するモジュールを組み合わせることによって実現されています。
2つの損失関数を使用します。1つ目は、ターゲット画像とソース画像をそれぞれCLIPでエンコードしたベクトルのCOS類似度をできるだけ上げるような関数。2つ目は、ソースの変換前後の画像をそれぞれCLIPでエンコードしたベクトルのCOS類似度をできるだけ上げるような関数です。
この2つの損失関数を最適化することによって、下記のような結果が得られます。
では、早速コードを動かしてみましょう。
3.コード
コードはGoogle Colabで動かす形にしてGithubに上げてありますので、それに沿って説明して行きます。自分で動かしてみたい方は、この「リンク」をクリックし表示されたノートブックの先頭にある「Colab on Web」ボタンをクリックすると動かせます.
最初に、セットアップをおこないます。まず、画像から逆算してStyleGAN2のベクトルを求めるためのフレームワークe4eをセットアップします。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
#@title e4e Setup (may take a few minutes) import os os.chdir('/content') CODE_DIR = 'encoder4editing' %tensorflow_version 1.x ! pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html !git clone https://github.com/omertov/encoder4editing.git $CODE_DIR !wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip !sudo unzip ninja-linux.zip -d /usr/local/bin/ !sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force os.chdir(f'./{CODE_DIR}') from argparse import Namespace import time import os import sys import numpy as np from PIL import Image import torch import torchvision.transforms as transforms sys.path.append(".") sys.path.append("..") from utils.common import tensor2im from models.psp import pSp # we use the pSp framework to load the e4e encoder. %load_ext autoreload %autoreload 2 # --- Download e4e model --- experiment_type = 'ffhq_encode' def get_download_model_command(file_id, file_name): """ Get wget download command for downloading the desired model and save to directory pretrained_models. """ current_directory = os.getcwd() save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models") if not os.path.exists(save_path): os.makedirs(save_path) url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path) return url MODEL_PATHS = { "ffhq_encode": {"id": "1cUv_reLE6k3604or78EranS7XzuVMWeO", "name": "e4e_ffhq_encode.pt"}, "cars_encode": {"id": "17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV", "name": "e4e_cars_encode.pt"}, "horse_encode": {"id": "1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX", "name": "e4e_horse_encode.pt"}, "church_encode": {"id": "1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa", "name": "e4e_church_encode.pt"} } path = MODEL_PATHS[experiment_type] download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) !wget {download_command} # --- e4e setup --- import gdown gdown.download('https://drive.google.com/u/0/uc?id=1jZnwdPOXhte2gseETvTwRtx8r8TT196J', '/content/encoder4editing/e4e_ffhq_encode.pt', quiet=False) experiment_type = 'ffhq_encode' os.chdir('/content/encoder4editing') EXPERIMENT_ARGS = { "model_path": "e4e_ffhq_encode.pt" } EXPERIMENT_ARGS['transform'] = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) resize_dims = (256, 256) model_path = EXPERIMENT_ARGS['model_path'] ckpt = torch.load(model_path, map_location='cpu') opts = ckpt['opts'] # pprint.pprint(opts) # Display full options used # update the training options opts['checkpoint_path'] = model_path opts= Namespace(**opts) net = pSp(opts) net.eval() net.cuda() print('Model successfully loaded!') |
続いて、TargetCLIP本体をセットアップします。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
#@title TargetCLIP Setup import os !git clone https://github.com/hila-chefer/TargetCLIP os.chdir(f'./TargetCLIP') !pip install ftfy regex tqdm !pip install git+https://github.com/openai/CLIP.git from pydrive.auth import GoogleAuth from pydrive.drive import GoogleDrive from google.colab import auth from oauth2client.client import GoogleCredentials # Authenticate and create the PyDrive client. auth.authenticate_user() gauth = GoogleAuth() gauth.credentials = GoogleCredentials.get_application_default() drive = GoogleDrive(gauth) # downloads StyleGAN's weights ids = ['1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT'] for file_id in ids: downloaded = drive.CreateFile({'id':file_id}) downloaded.FetchMetadata(fetch_all=True) downloaded.GetContentFile(downloaded.metadata['title']) import argparse import os os.chdir(f'../TargetCLIP') import numpy as np import torch import torchvision from torch import optim from tqdm import tqdm import clip # from criteria.clip_loss import CLIPLoss from models.stylegan2.model import Generator import math import copy # aux function def get_latent(args, g_ema): mean_latent = g_ema.mean_latent(4096) latent_code_init_not_trunc = torch.randn(1, 512).cuda() with torch.no_grad(): # _, latent_code_init = g_ema([latent_code_init_not_trunc], return_latents=True, # truncation=args.truncation, truncation_latent=mean_latent) _, latent_code_init,_ = g_ema([latent_code_init_not_trunc], return_latents=True, truncation=args.truncation, truncation_latent=mean_latent) direction = latent_code_init.detach().clone() direction.requires_grad = True return direction def load_model(args): g_ema = Generator(args.stylegan_size, 512, 8) g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) g_ema.eval() g_ema = g_ema.cuda() return g_ema def get_lr(t, initial_lr, rampdown=0.75, rampup=0.005): lr_ramp = min(1, (1 - t) / rampdown) lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) lr_ramp = lr_ramp * min(1, t / rampup) return initial_lr * lr_ramp args = { "ckpt": "stylegan2-ffhq-config-f.pt", "stylegan_size": 1024, "lr": 0.1, "truncation": 0.7, "save_intermediate_image_every": 1, "results_dir": "results", "dir_name": "results", "num_batches": 1, "real_images": True, "data_path": "train_faces.pt", } from argparse import Namespace a=Namespace(**args) g_ema = load_model(a) |
サンプル画像を読み込みます。自分の用意した画像を読み込む場合は、TargetCLIP/picフォルダーに画像を保存し、image_nameをそのファイル名に変更してください。ここでは、’./TargetCLIP/pic/01.png’ を指定しています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
#@markdown Upload an image to the encoder4editing folder and set the image_name into the image name image_name = './TargetCLIP/pic/01.png' #@param {type:"string"} os.chdir('/content/encoder4editing') EXPERIMENT_DATA_ARGS = { "ffhq_encode": { "model_path": "pretrained_models/e4e_ffhq_encode.pt", "image_path": image_name } } # Setup required image transformations EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type] if experiment_type == 'cars_encode': EXPERIMENT_ARGS['transform'] = transforms.Compose([ transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) resize_dims = (256, 192) else: EXPERIMENT_ARGS['transform'] = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) resize_dims = (256, 256) |
顔を所定の位置で切り取り水平にする処理(align処理)を行い、その画像を表示します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
#@title Show aligned original image image_path = EXPERIMENT_DATA_ARGS[experiment_type]["image_path"] original_image = Image.open(image_path) original_image = original_image.convert("RGB") if experiment_type == "ffhq_encode" and 'shape_predictor_68_face_landmarks.dat' not in os.listdir(): !wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 !bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2 def run_alignment(image_path): import dlib from utils.alignment import align_face predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") aligned_image = align_face(filepath=image_path, predictor=predictor) print("Aligned image has shape: {}".format(aligned_image.size)) return aligned_image if experiment_type == "ffhq_encode": input_image = run_alignment(image_path) else: input_image = original_image input_image.resize(resize_dims) |
align処理した画像からStyleGAN2のベクトルを逆算で求めます。そして、そのベクトルで生成した画像とオリジナル画像を並べて表示します。右が逆算、左がオリジナルの画像です。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
#@title Invert img_transforms = EXPERIMENT_ARGS['transform'] transformed_image = img_transforms(input_image) def display_alongside_source_image(result_image, source_image): res = np.concatenate([np.array(source_image.resize(resize_dims)), np.array(result_image.resize(resize_dims))], axis=1) return Image.fromarray(res) def run_on_batch(inputs, net): images, latents = net(inputs.to("cuda").float(), randomize_noise=False, return_latents=True) if experiment_type == 'cars_encode': images = images[:, :, 32:224, :] return images, latents with torch.no_grad(): tic = time.time() images, latents = run_on_batch(transformed_image.unsqueeze(0), net) result_image, latent = images[0], latents[0] toc = time.time() print('Inference took {:.4f} seconds.'.format(toc - tic)) # Display inversion: display_alongside_source_image(tensor2im(result_image), input_image) |
スタイルを転送するためのターゲットの画像を選択します。’Elsa’, ‘Joker’など9種類のターゲットが用意されています。ここでは、’Elsa’を選択します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
#@title Chose target os.chdir('/content/encoder4editing/TargetCLIP') dirs = { 'Elsa': 'dirs/elsa.npy', 'Pocahontas': 'dirs/pocahontas.npy', 'Keanu Reeves': 'dirs/keanu.npy', 'Trump': 'dirs/trump.npy', 'Joker': 'dirs/joker.npy', 'Ariel': 'dirs/ariel.npy', 'Doc Brown': 'dirs/doc.npy', 'Beyonce': 'dirs/beyonce.npy', 'Morgan Freeman': 'dirs/morgan.npy', } targets = { 'Elsa': 'dirs/targets/elsa.jpg', 'Pocahontas': 'dirs/targets/pocahontas.jpg', 'Keanu Reeves': 'dirs/targets/keanu.jpg', 'Trump': 'dirs/targets/trump.jpg', 'Joker': 'dirs/targets/joker.jpg', 'Ariel': 'dirs/targets/ariel.jpeg', 'Doc Brown': 'dirs/targets/doc_brown.jpg', 'Beyonce': 'dirs/targets/beyonce.jpg', 'Morgan Freeman': 'dirs/targets/morgan_freeman.jpg', } sources_ids = { 'Taylor Swift': 67, 'Elon Musk': 4, 'Hillary Clinton': 9, 'Alfie Allen': 34, 'Obama': 61 } target = 'Ariel' #@param ['Trump','Keanu Reeves', 'Elsa', 'Pocahontas', 'Joker', 'Ariel', 'Doc Brown', 'Beyonce', 'Morgan Freeman'] source = latents dir = torch.from_numpy(np.load(dirs[target])) #dir = torch.load('dirs/10me.pt') target_path = targets[target] # title Show target image assert(target_path is not None) import matplotlib.pyplot as plt import matplotlib.image as mpimg img = mpimg.imread(target_path) imgplot = plt.imshow(img) plt.xticks([]) plt.yticks([]) plt.show() |
変換係数を設定してスタイルを転送します。そしてその結果を表示します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
#@title Show manipulation on source from PIL import Image import matplotlib.pyplot as plt import matplotlib.image as mpimg #alpha=1 alpha = 1.2 #@param {type:"slider", min:0, max:2, step:0.1} dir = dir.cuda() source = source.cuda() source_img, _ = g_ema([source], input_is_latent=True, randomize_noise=False) source_amp, _ = g_ema([source + dir * alpha], input_is_latent=True, randomize_noise=False) torchvision.utils.save_image(source_img, f"results_orig.png", normalize=True, range=(-1, 1)) torchvision.utils.save_image(source_amp, f"results_manipulated.png", normalize=True, range=(-1, 1)) plt.figure(figsize=(14,7), dpi= 100) plt.subplot(1,2,1) plt.imshow(mpimg.imread('results_orig.png')) plt.title('original') plt.axis('off') plt.subplot(1,2,2) plt.imshow(mpimg.imread('results_manipulated.png')) plt.title('manipulated') plt.axis('off') plt.tight_layout() |
変換係数を徐々に変化させたときの静止画から動画を作成します。作成した動画は、./TargetCLIP/output.mp4 に保存されます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
#@title Make manipulated movie import os import shutil from tqdm import trange # reset images folder if os.path.isdir('images'): shutil.rmtree('images') os.makedirs('images', exist_ok=True) # delete output.mp4 if os.path.exists('./output.mp4'): os.remove('./output.mp4') alpha = 1.2 #@param {type:"slider", min:0, max:2, step:0.1} dir = dir.cuda() source = source.cuda() end = int(alpha * 100) for i in trange(0, end): source_amp, _ = g_ema([source + dir * i/100], input_is_latent=True, randomize_noise=False) torchvision.utils.save_image(source_amp, 'images/'+str(i).zfill(4)+'.png', normalize=True, range=(-1, 1)) # make movie from png in images folder ! ffmpeg -r 30 -i images/%4d.png\ -vcodec libx264 -pix_fmt yuv420p output.mp4 |
作成した動画を再生します。
1 2 3 4 5 6 7 8 9 10 |
#@title play movie from IPython.display import HTML from base64 import b64encode mp4 = open('./output.mp4', 'rb').read() data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode() HTML(f""" <video width="70%" height="70%" controls> <source src="{data_url}" type="video/mp4"> </video>""") |
どうでしょうか。実写にディズニープリンセスのスタイルが上手く転送されているのが分かると思います。
では、また。
(オリジナルgithub) https://github.com/hila-chefer/TargetCLIP