下記の図が、DualStyleGANのGenerator Networkで、左側がコンテンツ(Intrinsic Style Code)のネットワーク、右側がスタイル(Extrinsic Style Code)のネットワークです。そして、スタイルの内、基本的な構造に影響する低解像度レイヤ(〜32×32)への入力を構造、詳細部分に影響する高解像度レイヤ(64×64〜)への入力をカラーと呼んでいます。

コードはGoogle Colabで動かす形にしてGithubに上げてありますので、それに沿って説明して行きます。自分で動かしてみたい方は、この「リンク」をクリックし表示されたノートブックの先頭にある「Colab on Web」ボタンをクリックすると動かせます。
#@title セットアップ !git clone https://github.com/cedro3/DualStyleGAN.git %cd DualStyleGAN !wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip !sudo unzip ninja-linux.zip -d /usr/local/bin/ !sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force !pip install faiss-cpu !pip install wget %load_ext autoreload %autoreload 2 import numpy as np import torch from util import save_image, load_image, visualize import argparse from argparse import Namespace from torchvision import transforms from torch.nn import functional as F import torchvision import matplotlib.pyplot as plt from model.dualstylegan import DualStyleGAN from model.sampler.icp import ICPTrainer from model.encoder.psp import pSp |
次に、学習済みパラメータのダウンロードを style_type を指定して実行します。gdownを使ってgoogle driveからダウンロードしますが、最近時々ミスることがあるので最大10回までトライさせる仕様にしています。ここでは、 style_type は cartoon を選択しています。
#@title 学習済みパラメータのダウンロード style_type = 'cartoon'#@param ['cartoon', 'caricature', 'anime', 'arcane', 'comic', 'pixar', 'slamdunk'] MODEL_PATHS = { "encoder": {"id": "1NgI4mPkboYvYw3MWcdUaQhkr0OWgs9ej", "name": "encoder.pt"}, "cartoon-G": {"id": "1exS9cSFkg8J4keKPmq2zYQYfJYC5FkwL", "name": "generator.pt"}, "cartoon-N": {"id": "1JSCdO0hx8Z5mi5Q5hI9HMFhLQKykFX5N", "name": "sampler.pt"}, "cartoon-S": {"id": "1ce9v69JyW_Dtf7NhbOkfpH77bS_RK0vB", "name": "refined_exstyle_code.npy"}, "caricature-G": {"id": "1BXfTiMlvow7LR7w8w0cNfqIl-q2z0Hgc", "name": "generator.pt"}, "caricature-N": {"id": "1eJSoaGD7X0VbHS47YLehZayhWDSZ4L2Q", "name": "sampler.pt"}, "caricature-S": {"id": "1-p1FMRzP_msqkjndRK_0JasTdwQKDsov", "name": "refined_exstyle_code.npy"}, "anime-G": {"id": "1BToWH-9kEZIx2r5yFkbjoMw0642usI6y", "name": "generator.pt"}, "anime-N": {"id": "19rLqx_s_SUdiROGnF_C6_uOiINiNZ7g2", "name": "sampler.pt"}, "anime-S": {"id": "17-f7KtrgaQcnZysAftPogeBwz5nOWYuM", "name": "refined_exstyle_code.npy"}, "arcane-G": {"id": "15l2O7NOUAKXikZ96XpD-4khtbRtEAg-Q", "name": "generator.pt"}, "arcane-N": {"id": "1fa7p9ZtzV8wcasPqCYWMVFpb4BatwQHg", "name": "sampler.pt"}, "arcane-S": {"id": "1z3Nfbir5rN4CrzatfcgQ8u-x4V44QCn1", "name": "exstyle_code.npy"}, "comic-G": {"id": "1_t8lf9lTJLnLXrzhm7kPTSuNDdiZnyqE", "name": "generator.pt"}, "comic-N": {"id": "1RXrJPodIn7lCzdb5BFc03kKqHEazaJ-S", "name": "sampler.pt"}, "comic-S": {"id": "1ZfQ5quFqijvK3hO6f-YDYJMqd-UuQtU-", "name": "exstyle_code.npy"}, "pixar-G": {"id": "1TgH7WojxiJXQfnCroSRYc7BgxvYH9i81", "name": "generator.pt"}, "pixar-N": {"id": "18e5AoQ8js4iuck7VgI3hM_caCX5lXlH_", "name": "sampler.pt"}, "pixar-S": {"id": "1I9mRTX2QnadSDDJIYM_ntyLrXjZoN7L-", "name": "exstyle_code.npy"}, "slamdunk-G": {"id": "1MGGxSCtyf9399squ3l8bl0hXkf5YWYNz", "name": "generator.pt"}, "slamdunk-N": {"id": "1-_L7YVb48sLr_kPpOcn4dUq7Cv08WQuG", "name": "sampler.pt"}, "slamdunk-S": {"id": "1Dgh11ZeXS2XIV2eJZAExWMjogxi_m_C8", "name": "exstyle_code.npy"}, } import os os.makedirs('checkpoint/'+style_type, exist_ok=True) ! pip install --upgrade gdown import gdown # download pSp encoder for i in range(10): path = MODEL_PATHS["encoder"] if os.path.isfile('checkpoint/encoder.pt'): break else: path['name'] == 'encoder.pt' gdown.download('https://drive.google.com/uc?id='+path['id'], 'checkpoint/'+path['name'], quiet=False) # download dualstylegan for i in range(10): path = MODEL_PATHS[style_type+'-G'] if os.path.isfile('checkpoint/'+style_type+'/'+path['name']): break else: gdown.download('https://drive.google.com/uc?id='+path['id'], 'checkpoint/'+style_type+'/'+path['name'], quiet=False) # download sampler for i in range(10): path = MODEL_PATHS[style_type+'-N'] if os.path.isfile('checkpoint/'+style_type+'/'+path['name']): break else: gdown.download('https://drive.google.com/uc?id='+path['id'], 'checkpoint/'+style_type+'/'+path['name'], quiet=False) # download extrinsic style code for i in range(10): path = MODEL_PATHS[style_type+'-S'] if os.path.isfile('checkpoint/'+style_type+'/'+path['name']): break else: gdown.download('https://drive.google.com/uc?id='+path['id'], 'checkpoint/'+style_type+'/'+path['name'], quiet=False) # --- モデルのロード --- MODEL_DIR = 'checkpoint' DATA_DIR = 'data' device = 'cuda' transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) # load DualStyleGAN generator = DualStyleGAN(1024, 512, 8, 2, res_index=6) generator.eval() ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'generator.pt'), map_location=lambda storage, loc: storage) generator.load_state_dict(ckpt["g_ema"]) generator = generator.to(device) # load encoder model_path = os.path.join(MODEL_DIR, 'encoder.pt') ckpt = torch.load(model_path, map_location='cpu') opts = ckpt['opts'] opts['checkpoint_path'] = model_path opts = Namespace(**opts) opts.device = device encoder = pSp(opts) encoder.eval() encoder = encoder.to(device) # load extrinsic style code exstyles = np.load(os.path.join(MODEL_DIR, style_type, MODEL_PATHS[style_type+'-S']["name"]), allow_pickle='TRUE').item() # load sampler network icptc = ICPTrainer(np.empty([0,512*11]), 128) icpts = ICPTrainer(np.empty([0,512*7]), 128) ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'sampler.pt'), map_location=lambda storage, loc: storage) icptc.icp.netT.load_state_dict(ckpt['color']) icpts.icp.netT.load_state_dict(ckpt['structure']) icptc.icp.netT = icptc.icp.netT.to(device) icpts.icp.netT = icpts.icp.netT.to(device) print('Model successfully loaded!') |
#@title 画像入力 %matplotlib inline image_path = './data/content/02.jpg' #@param {type:"string"} original_image = load_image(image_path) plt.figure(figsize=(10,10),dpi=30) visualize(original_image[0]) plt.show() |

#@title align処理 if_align_face = True def run_alignment(image_path): import dlib from model.encoder.align_all_parallel import align_face modelname = os.path.join(MODEL_DIR, 'shape_predictor_68_face_landmarks.dat') if not os.path.exists(modelname): import wget, bz2 wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2') zipfile = bz2.BZ2File(modelname+'.bz2') data = zipfile.read() open(modelname, 'wb').write(data) predictor = dlib.shape_predictor(modelname) aligned_image = align_face(filepath=image_path, predictor=predictor) return aligned_image if if_align_face: I = transform(run_alignment(image_path)).unsqueeze(dim=0).to(device) else: I = F.adaptive_avg_pool2d(load_image(image_path).to(device), 256) plt.figure(figsize=(10,10),dpi=30) visualize(I[0].cpu()) plt.show() |

画像ファイルは、data/cartoon/images/train にあるものを使っています。cartoon 以外の style_type を使う場合は、該当するフォルダへこの情報を参考に画像をアップロードする必要がありますのでご注意下さい。
#@title スタイルの選択 style_id = 26 #@param {type:"slider", min:0, max:316, step:1} # try to load the style image stylename = list(exstyles.keys())[style_id] stylepath = os.path.join(DATA_DIR, style_type, 'images/train', stylename) print('loading %s'%stylepath) if os.path.exists(stylepath): S = load_image(stylepath) plt.figure(figsize=(10,10),dpi=30) visualize(S[0]) plt.show() else: print('%s is not found'%stylename) |

#@title スタイル転送 #@markdown From left to right: #@markdown 1. **pSp recontructed content image** #@markdown 2. **style transfer result**: both color and strcture styles are transferred #@markdown 3. **structure transfer result**: preserve the color of the content image by replacing the extrinsic color codes with intrinsic color codes #@markdown 4. **structure transfer result**: preserve the color of the content image by deactivating color-related layers with torch.no_grad(): img_rec, instyle = encoder(I, randomize_noise=False, return_latents=True, z_plus_latent=True, return_z_plus_latent=True, resize=False) img_rec = torch.clamp(img_rec.detach(), -1, 1) latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).to(device) # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer latent[1,7:18] = instyle[0,7:18] exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape) img_gen, _ = generator([instyle.repeat(2,1,1)], exstyle, z_plus_latent=True, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[1]*11) img_gen = torch.clamp(img_gen.detach(), -1, 1) # deactivate color-related layers by setting w_c = 0 img_gen2, _ = generator([instyle], exstyle[0:1], z_plus_latent=True, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[0]*11) img_gen2 = torch.clamp(img_gen2.detach(), -1, 1) vis = torchvision.utils.make_grid(F.adaptive_avg_pool2d(torch.cat([img_rec, img_gen, img_gen2], dim=0), 256), 4, 1) plt.figure(figsize=(10,10),dpi=120) visualize(vis.cpu()) plt.show() |

#@title 色と構造のスタイル転送マトリックス results = [] for i in range(6): # change weights of structure codes for j in range(6): # change weights of color codes w = [i/5.0]*7+[j/5.0]*11 img_gen, _ = generator([instyle], exstyle[0:1], z_plus_latent=True, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=w) img_gen = torch.clamp(F.adaptive_avg_pool2d(img_gen.detach(), 128), -1, 1) results += [img_gen] vis = torchvision.utils.make_grid(torch.cat(results, dim=0), 6, 1) plt.figure(figsize=(10,10),dpi=120) visualize(vis.cpu()) plt.show() |

#@title 新たなスタイルの選択 style_id2 = 82 #@param {type:"slider", min:0, max:317, step:1} # try to load the style image stylename2 = list(exstyles.keys())[style_id2] stylepath = os.path.join(DATA_DIR, style_type, 'images/train', stylename2) print('loading %s'%stylepath) if os.path.exists(stylepath): S = load_image(stylepath) plt.figure(figsize=(10,10),dpi=30) visualize(S[0]) plt.show() else: print('%s is not found'%stylename2) |

#@title トランジション with torch.no_grad(): latent = torch.tensor(exstyles[stylename]).repeat(6,1,1).to(device) latent2 = torch.tensor(exstyles[stylename2]).repeat(6,1,1).to(device) fuse_weight = torch.arange(6).reshape(6,1,1).to(device) / 5.0 fuse_latent = latent * fuse_weight + latent2 * (1-fuse_weight) exstyle = generator.generator.style(fuse_latent.reshape(fuse_latent.shape[0]*fuse_latent.shape[1], fuse_latent.shape[2])).reshape(fuse_latent.shape) img_gen, _ = generator([instyle.repeat(6,1,1)], exstyle, z_plus_latent=True, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[1]*11) img_gen = F.adaptive_avg_pool2d(torch.clamp(img_gen.detach(), -1, 1), 128) vis = torchvision.utils.make_grid(img_gen, 6, 1) plt.figure(figsize=(10,10),dpi=120) visualize(vis.cpu()) plt.show() |

#@title コンテンツとスタイルを乱数で指定 seed = 123 torch.manual_seed(seed) batch = 6 # sample 6 style codes with torch.no_grad(): instyle = torch.randn(6, 512).to(device) # sample structure codes res_in = icpts.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,7,512) # sample color codes ada_in = icptc.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,11,512) # concatenate two codes to form the complete extrinsic style code latent = torch.cat((res_in, ada_in), dim=1) # map into W+ space exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape) with torch.no_grad(): img_gen, _ = generator([instyle], exstyle, input_is_latent=False, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[1]*11) img_gen = F.adaptive_avg_pool2d(torch.clamp(img_gen.detach(), -1, 1), 128) vis = torchvision.utils.make_grid(img_gen, batch, 1) plt.figure(figsize=(10,10),dpi=120) visualize(vis.cpu()) plt.show() |

#@title スタイル一定 with torch.no_grad(): img_gen, _ = generator([instyle], exstyle[4:5].repeat(batch, 1, 1), input_is_latent=False, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[1]*11) img_gen = F.adaptive_avg_pool2d(torch.clamp(img_gen.detach(), -1, 1), 128) vis = torchvision.utils.make_grid(img_gen, batch, 1) plt.figure(figsize=(10,10),dpi=120) visualize(vis.cpu()) plt.show() |

#@title コンテンツ一定 with torch.no_grad(): img_gen, _ = generator([instyle[4:5].repeat(batch,1)], exstyle, input_is_latent=False, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[1]*11) img_gen = F.adaptive_avg_pool2d(torch.clamp(img_gen.detach(), -1, 1), 128) vis = torchvision.utils.make_grid(img_gen, batch, 1) plt.figure(figsize=(10,10),dpi=120) visualize(vis.cpu()) plt.show() |

新しい、論文を次々と試されていてすごいと思います。この論文は私が1月ごろに実装した機能と原理はほぼ同じなのかなと。以下のURLから操作可能です。ページ中央のStyle Layerの横にあるスラーイダーで、色、顔、ポーズの3つに分類したレイヤーの範囲や位置を自由に動かして同時にMIXできます。スライダーを有効にるるには左横にある色、顔、ポーズの文字をクリック(色が変わる)する必要があり、その後、スライダーを移動させると選択されているスライダーで指定したレイヤーすべてがmixされた画像が8枚生成されます。
Deep Learningの世界はジャーナル誌や査読なんて言うかったるいプロセスはないので、思いっきりやっちゃって下さい!