今回ご紹介するのは、顔画像の特徴だけを別の画像のものに置き換えることができる Retrieve in Style(RIS)という技術です。
2.Retrieve in Style(RIS)とは?
ソース画像に参照画像のある特徴 k を転送するときに、今までは大規模なデータからチャンネル c の寄与スコア Mk,c を平均的に求めて、補完ベクトルを取得していました。しかし、この方法はある特徴(髪の毛・ポーズ)では上手く行かないことが分かっていました。
これは、チャンネルの寄与スコア Mk,c は画像によって異なるため、平均すると固有の特徴が破壊されることを示しています。そこで、サブメンバーシップというの概念を導入して、ソース画像と参照画像のみ(N=2)で寄与スコア Mk,c を求めて、補完ベクトルを取得するのが本論文の提案です。
寄与スコア Mk,c は下記の様に表されます。ここで、Aは活性化テンソル、Uはクラスターメンバーシップ、sはソース、rは参照、kは特徴、cはチャンネル数、hは高さ、wは幅です。

補完ベクトル q は下記の様に表されます。ここで、M は全ての特徴 k を積み重ねた寄与スコア、τ は温度です。

コードはGoogle Colabで動かす形にしてGithubに上げてありますので、それに沿って説明して行きます。自分で動かしてみたい方は、この「リンク」をクリックし表示されたノートブックの先頭にある「Colab on Web」ボタンをクリックすると動かせます。
# githubからコードをコピー !git clone https://github.com/mchong6/RetrieveInStyle.git %cd RetrieveInStyle # ライブラリーのインストール !pip install tqdm gdown scikit-learn scipy lpips dlib opencv-python # ライブラリーのインポート import torch from torch import nn import numpy as np import torch.backends.cudnn as cudnn cudnn.benchmark = True import matplotlib.pyplot as plt import torch.nn.functional as F from model import * from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans from tqdm import tqdm as tqdm import pickle import warnings warnings.filterwarnings("ignore", category=UserWarning) # get rid of interpolation warning from util import * from google.colab import files from util import align_face import os from e4e_projection import projection %matplotlib inline # 学習済みモデルのロード device = 'cuda' # if GPU memory is low, use cpu instead generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() ensure_checkpoint_exists('stylegan2-ffhq-config-f.pt') ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) generator.load_state_dict(ckpt["g_ema"], strict=False) with torch.no_grad(): mean_latent = generator.mean_latent(50000) # カタログのロード truncation = 0.5 stop_idx = 11 # choose 32x32 layer to do kmeans clustering n_clusters = 18 # Number of Kmeans cluster clusterer = pickle.load(open("catalog.pkl", "rb")) |
plt.rcParams['figure.dpi'] = 150 with torch.no_grad(): sample_z = torch.randn([1, 512]).to(device) sample_w = generator.get_latent(sample_z, truncation=truncation, mean_latent=mean_latent) sample, outputs = generator(sample_w, is_cluster=1) # [b, c, h, w] # obtain 32x32 activations and classify using kmeans act = flatten_act(outputs[stop_idx][0]) b,c,h,w = outputs[stop_idx][0].size() alpha = 0.5 seg_mask = clusterer.predict(act) seg_mask = torch.from_numpy(seg_mask).view(1,h,w) seg_out = decode_segmap(seg_mask) sample_d = F.interpolate(sample, size=(256,256), mode='bilinear').cpu() seg_out_d = F.interpolate(seg_out, size=(256,256), mode='nearest') out = alpha*seg_out_d + (1-alpha)*sample_d display_image(out) |

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
#Gives an index to each feature we care about labels2idx = { 'nose': 0, 'eyes': 1, 'mouth':2, 'hair': 3, 'background': 4, 'cheeks': 5, 'neck': 6, 'clothes': 7, } # Assign to each feature the cluster index from segmentation labels_map = { 0: torch.tensor([7]), 1: torch.tensor([1,6]), 2: torch.tensor([4]), 3: torch.tensor([0,3,5,8,10,15,16]), 4: torch.tensor([11,13,14]), 5: torch.tensor([9]), 6: torch.tensor([17]), 7: torch.tensor([2,12]), } idx2labels = dict((v,k) for k,v in labels2idx.items()) n_class = len(labels2idx) # compute M given a style code. @torch.no_grad() def compute_M(w, device='cuda'): M = [] # get segmentation _, outputs = generator(w, is_cluster=1) cluster_layer = outputs[stop_idx][0] activation = flatten_act(cluster_layer) seg_mask = clusterer.predict(activation) b,c,h,w = cluster_layer.size() # create masks for each feature all_seg_mask = [] seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device) for key in range(n_class): # combine masks for all indices for a particular segmentation class indices = labels_map[key].view(1,1,1,1,-1) key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w] all_seg_mask.append(key_mask) all_seg_mask = torch.stack(all_seg_mask, 1) # go through each activation layer and compute M for layer_idx in range(len(outputs)): layer = outputs[layer_idx][1].to(device) b,c,h,w = layer.size() layer = F.instance_norm(layer) layer = layer.pow(2) # resize the segmentation masks to current activations' resolution layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, size=(h,w), mode='bilinear').view(b,-1,1,h,w) masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w] masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c] M.append(masked_layer.to(device)) M = torch.cat(M, -1) #[b, k, c] # softmax to assign each channel to a particular segmentation class M = F.softmax(M/.1, 1) # simple thresholding M = (M>.8).float() # zero out torgb transfers, from https://arxiv.org/abs/2011.12799 for i in range(n_class): part_M = style2list(M[:, i]) for j in range(len(part_M)): if j in rgb_layer_idx: part_M[j].zero_() part_M = list2style(part_M) M[:, i] = part_M return M |
自分で用意した画像を使用する場合は、PCからその画像をドラッグ&ドロップで RetrieveInStyle/images
1 2 3 4 5 6 |
import glob files = glob.glob('images/*.jpg') for file in files: filename = file[7:-4] cropped_face = align_face(file) projection(cropped_face, filename, generator, device) |
plt.rcParams['figure.dpi'] = 75 # load codes from inverted real images using our projection code with torch.no_grad(): ''' if you gan inverted in the previous cell, you can call it here with variable filename otherwise, you can randomly generate or call a pre-inverted image ''' # source = load_source([filename], generator, device) source = load_source(['brad_pitt'], generator, device) source_im, _ = generator(source) display_image(source_im, size=256) ref = load_source(['emma_watson', 'emma_stone', 'jennie'], generator, device) ref_im, _ = generator(ref) ref_im = downsample(ref_im) show(normalize_im(ref_im).permute(0,2,3,1).cpu(), title='References') |

ソース画像は 'brad_pitt'
、参照画像は 'emma_watson'
、 'nose'
、 'mouth'
、 'hair'
# Compute M for both source and reference images use cpu here to save memory source_M = compute_M(source, device='cpu') ref_M = compute_M(ref, device='cpu') # Find relevant channels for source and reference by taking max over their individual M max_M = torch.max(source_M.expand_as(ref_M), ref_M) max_M = add_pose(max_M, labels2idx) all_im = {} with torch.no_grad(): # features we are interest in transferring parts = ('eyes', 'nose', 'mouth', 'hair','pose') for label in parts: if label == 'pose': idx = -1 else: idx = labels2idx[label] part_M = max_M[:,idx].to(device) blend = style2list(add_direction(source, ref, part_M, 1.3)) blend_im, _ = generator(blend) blend_im = downsample(blend_im).cpu() all_im[label] = normalize_im(blend_im) part_grid(normalize_im(source_im.detach()), normalize_im(ref_im.detach()), all_im); |

ソース画像 'brad_pitt'
plt.rcParams['figure.dpi'] = 75 torch.manual_seed(3913) with torch.no_grad(): source = load_source(['emma_stone'], generator, device) source_im, _ = generator(source) display_image(source_im, size=256) ref = load_source(['brad_pitt'], generator, device) ref_im, _ = generator(ref) ref_im = downsample(ref_im) display_image(ref_im, title='reference') |

ソース画像が 'emma_stone'
、参照画像が 'brad_pitt'
plt.rcParams['figure.dpi'] = 200 source_M = compute_M(source, device='cpu') ref_M = compute_M(ref, device='cpu') max_M = torch.max(source_M.expand_as(ref_M), ref_M) max_M = add_pose(max_M, labels2idx) labels = ('eyes', 'hair') # choose what feature to interpolate {eyes/nose/mouth/hair/pose} max_alpha = 1.5 # max range to interpolate all_im = [] with torch.no_grad(): for label in labels: row = [] if label == 'pose': idx = -1 else: idx = labels2idx[label] for alpha in np.linspace(-max_alpha, max_alpha, 5): part_M = max_M[:,idx].to(device) blend = style2list(add_direction(source, ref, part_M, alpha)) blend_im, _ = generator(blend) blend_im = downsample(blend_im).cpu() row.append(blend_im) row.append(ref_im.cpu()) row = torch.cat(row, -1) all_im.append(row) all_im = torch.cat(all_im, 2) display_image(all_im, size=None) |

今度は、大量の顔画像データベースの中から、ソース画像の指定した特徴に近い画像を検索してみましょう。まず、generator にランダムベクトルを入力し、5000個の顔画像データベースを作成します。
# 顔データベースの作成 torch.manual_seed(12390) num_data = 5000 dataset = torch.randn([num_data, 512]).to(device) with torch.no_grad(): dataset_w = generator.get_latent(dataset, truncation=truncation, mean_latent=mean_latent) dataset_M = [] for i in tqdm(range(num_data)): # have to use cuda for this or it will be very slow dataset_M.append(compute_M(index_layers(dataset_w, i), device='cuda')) dataset_M = remove_2048(torch.cat(dataset_M, 0), labels2idx).to(device) #[N, K, C] |
# 検索対象の表示 plt.rcParams['figure.dpi'] = 75 with torch.no_grad(): query_w = load_source(['tom_hiddleston'], generator, device) query_im, _ = generator(query_w) display_image(query_im) |

# 検索実行と結果表示 plt.rcParams['figure.dpi'] = 300 num_nn = 6 all_im = [] query_M = remove_2048(compute_M(query_w, device=device), labels2idx).to(device) r_query_w = list2style(query_w) r_dataset_w = list2style(dataset_w) # normalize each style dimension largest = r_dataset_w.abs().max(0, keepdim=True)[0] + 1e-8 norm_query_w = r_query_w/largest norm_target_w = r_dataset_w/largest # choose what features to perform retrieval on # parts = ('eyes', 'nose', 'mouth', 'hair') parts = ('eyes', 'mouth', 'hair',) # perform cosine similarity w.r.t a given feature with torch.no_grad(): for part in parts: idx = labels2idx[part] source_part = norm_query_w * query_M[:,idx].to(device) target_part = norm_target_w * dataset_M[:,idx].to(device) distance = cos_dist(target_part, source_part) nearest_neighbors = torch.sort(distance)[1][:num_nn] row = [query_im.cpu()] for idx in nearest_neighbors: nn_w = index_layers(dataset_w, int(idx)) nn_image, _ = generator(nn_w) row.append(nn_image.cpu()) row = [downsample(a) for a in row] row = torch.cat(row, -1) all_im.append(row) all_im = torch.cat(all_im,-2) display_image(all_im, size=None) |
