1.はじめに
今回ご紹介するのは、動画のインスタンス・セグメンテーションを何に対して行うかテキストで指定するMTTRという技術です。
*この論文は、2021.11に提出されました。
2.MTTRとは?
下記がMTTR(Multimodal Tracking Transformer)のフローで、大きく分けて3つのブロックで構成されています。
1つ目が、画像と言葉の拡張(Visual & Linguistic Feature Extention)で、フレームとテキストが入力されるとそれぞれエンコーダを通してベクトルに変換し結合します。
2つ目が、マルチモーダル・トランスフォーマ(Multimodal Transformer)で、フレームとテキストの特徴の関係をエンコードし、一連の予測にデコードして対応するマスクを生成します。
3つ目が、一連のインスタンスセグメンテーションと基準予測(Instance Sequence Segmentation & Reference Prediction)で、画像の特徴をデコードして得たセグメンテーションと予測したマスク情報をまとめ、結果を出力します。
3.コード
コードはGoogle Colabで動かす形にしてGithubに上げてありますので、それに沿って説明して行きます。自分で動かしてみたい方は、この「リンク」をクリックし表示されたノートブックの先頭にある「Colab on Web」ボタンをクリックすると動かせます。
まず、セットアップをおこないます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# moviepyインストール # %%capture !pip install av moviepy yt-dlp ruamel.yaml einops timm transformers # ライブラリのインポート import torch import torchvision import torchvision.transforms.functional as F from einops import rearrange import numpy as np from PIL import Image, ImageDraw, ImageOps, ImageFont from yt_dlp import YoutubeDL from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip from IPython.display import HTML from base64 import b64encode from tqdm.notebook import trange, tqdm from transformers import logging # logging.set_verbosity_error() # MTTRモデル初期化 model, postprocessor = torch.hub.load('mttr2021/MTTR:main','mttr_refer_youtube_vos', force_reload=True) model = model.cuda() # 関数定義 class NestedTensor(object): def __init__(self, tensors, mask): self.tensors = tensors self.mask = mask def nested_tensor_from_videos_list(videos_list): def _max_by_axis(the_list): maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes max_size = _max_by_axis([list(img.shape) for img in videos_list]) padded_batch_shape = [len(videos_list)] + max_size b, t, c, h, w = padded_batch_shape dtype = videos_list[0].dtype device = videos_list[0].device padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device) for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks): pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames) vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) def apply_mask(image, mask, color, transparency=0.7): mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) mask = mask * transparency color_matrix = np.ones(image.shape, dtype=np.float) * color out_image = color_matrix * mask + image * (1.0 - mask) return out_image |
このコードはYoutubeからダウンロードした最大10秒の動画と最大2つのテキストを組み合わせて動かすようになっています。
動画とテキストの指定の仕方は、video_url, (start_pt, end_pt), text_queries = f’https://www.youtube.com/watch?v=???????’ ,(start_pnt, end_pnt), [‘text query 1’, ‘text query 2’] で行います。ここで、??????? は Youtubeの共有コード、start_pnt は変換開始ポイント(秒)、end_pnt は変換終了ポイント(秒)、‘text query 1’ と ‘text query 2’ はテキストです。
あらかじめ、12個のYoutube動画とテキストが用意されていますので、選択する行の先頭だけコメントアウト(#)を外せばOKです。とりあえず、8行目がコメントアウトされていますので、このまま実行しましょう。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# Choose (by un-commenting) one of the following: # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=YThX7_8I3m0', (233, 243), ['guy in black performing tricks on a bike', 'a black bike used to perform tricks'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=hwLo7aU1Aas', (1144, 1152), ['a man riding a surfboard', 'a black and white surfboard'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=yvJDHbrumak', (48, 55), ['a red ball thrown in the air', 'a black horse playing with a person'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=L-Wd4A8ESyk', (289, 297), ['a guy performing tricks on a skateboard', 'a black skateboard'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=4iTiRvk4FHY', (24, 34), ['man in red shirt playing tennis', 'white tennis racket held by a man in a red shirt'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=ZHwlmvuW4NY', (115, 125), ['white dog playing', 'brown and black dog playing'] video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=YThX7_8I3m0', (67, 77), ['guy in white shirt performing tricks on a bike', 'a black bike used to perform tricks'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=C7TCH927--g', (3, 13), ['a dog to the right', 'a cat to the left'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=0Z_WAF1GKfk', (143.5, 147.5), ['a dog to the left playing with a toy', 'a dog to the right playing with a toy'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=aEJJmebTLEs', (70, 80), ['a person hugging a dog', 'a white dog sitting'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=8sDF8lflCTs' ,(15, 23), ['person in blue riding a bike'] # video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=dQw4w9WgXcQ', (2.5, 7.5), ['a person dancing'] #OR - try your using own input in the following format: (but keep in mind that performance may be limited!) # video_url, (start_pt, end_pt), text_queries = f'https://www.youtube.com/watch?v=???????' ,(start_pnt, end_pnt), ['text query 1', 'text query 2'] assert 0 < end_pt - start_pt <= 10, 'error - the subclip length must be 0-10 seconds long' assert 1 <= len(text_queries) <= 2, 'error - 1-2 input text queries are expected' |
それでは、今指定した内容でYoutubeから動画をダウンロードし、編集して再生してみましょう。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
download_resolution = 360 full_video_path = 'full_video.mp4' input_clip_path = 'input_clip.mp4' # download parameters: ydl_opts = {'format': f'best[height<={download_resolution}]', 'overwrites': True, 'outtmpl': full_video_path} # download the whole video: with YoutubeDL(ydl_opts) as ydl: ydl.download([video_url]) # extract the relevant subclip: with VideoFileClip(full_video_path) as video: subclip = video.subclip(start_pt, end_pt) subclip.write_videofile(input_clip_path) # visualize the input clip: input_clip = open(input_clip_path,'rb').read() data_url = "data:video/mp4;base64," + b64encode(input_clip).decode() HTML("""<video width=720 controls><source src="%s" type="video/mp4"></video>""" % data_url) |
次に、この動画とテキストを元にインスタンスマスクを生成します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
window_length = 24 # length of window during inference window_overlap = 6 # overlap (in frames) between consecutive windows with torch.inference_mode(): # read and preprocess the video clip: video, audio, meta = torchvision.io.read_video(filename=input_clip_path) video = rearrange(video, 't h w c -> t c h w') input_video = F.resize(video, size=360, max_size=640).cuda() input_video = input_video.to(torch.float).div_(255) input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]} # partition the clip into overlapping windows of frames: windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)] # clean up the text queries: text_queries = [" ".join(q.lower().split()) for q in text_queries] pred_masks_per_query = [] t, _, h, w = video.shape for text_query in tqdm(text_queries, desc='text queries'): pred_masks = torch.zeros(size=(t, 1, h, w)) for i, window in enumerate(tqdm(windows, desc='windows')): window = nested_tensor_from_videos_list([window]) valid_indices = torch.arange(len(window.tensors)).cuda() outputs = model(window, valid_indices, [text_query]) window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks'] win_start_idx = i*(window_length-window_overlap) pred_masks[win_start_idx:win_start_idx + window_length] = window_masks pred_masks_per_query.append(pred_masks) |
そして、動画にインスタンスマスクとテキストを適用させます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# RGB colors for instance masks: light_blue = (41, 171, 226) purple = (237, 30, 121) dark_green = (35, 161, 90) orange = (255, 148, 59) colors = np.array([light_blue, purple, dark_green, orange]) # width (in pixels) of the black strip above the video on which the text queries will be displayed: text_border_height_per_query = 35 video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0 # del video pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy() masked_video = [] for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'): # apply the masks: for inst_mask, color in zip(frame_masks, colors): vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) # visualize the text queries: vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0)) W, H = vid_frame.size draw = ImageDraw.Draw(vid_frame) font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30) for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1): w, h = draw.textsize(text_query, font=font) draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 3), text_query, fill=tuple(color) + (255,), font=font) masked_video.append(np.array(vid_frame)) # generate and save the output clip: output_clip_path = 'output_clip.mp4' clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps']) clip = clip.set_audio(AudioFileClip(input_clip_path)) clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True) del masked_video # visualize the output clip: output_clip = open(output_clip_path,'rb').read() data_url = "data:video/mp4;base64," + b64encode(output_clip).decode() HTML("""<video width=720 controls><source src="%s" type="video/mp4"></video>""" % data_url) |
もう1つやってみましょう。サンプルは全てコメントアウト(#)し、17行目を video_url, (start_pt, end_pt), text_queries = f’https://www.youtube.com/watch?v=btMGKIXtyLo‘ ,(131, 135), [‘person in riding a bike‘, ‘a bike‘] に書き換えてやってみましょう。
いかがだったでしょうか?テキストで指定した物体だけ、インスタンス・セグメンテーションできるのは面白いですよね。
では、また。
(オリジナルgithub : https://github.com/mttr2021/MTTR)
コメントを残す