1.はじめに
今回ご紹介するのは、画像からテキストを生成するモデルとテキストから音楽を生成するモデルを組み合わせた、画像から音楽を生成する img2music です。
実はこれ、Hugging Face (Webサービス)で動かせるんですが、人気のためか処理速度が遅い。なので今回は Google Colab で動かしてみます。
2.img2musicとは?
img2music は画像からテキストを生成するモデルとテキストから音楽を生成するモデル、これら2つのモデルから構成されていますので順番に見て行きましょう。
画像からテキストを生成するモデルは、画像からStable Diffusionのテキストを探索する CLIP Interrogator を使っています。これは、従来からの画像キャプションの生成結果に、CLIPを用いた様々な属性に関する結果を加えて、より適切なテキストを生成します。
テキストから音楽を生成するモデルは、AIで音楽生成するサービスを行なっているMubert社のAPIを使っています。具体的には、テキストからタグを抽出してサーバーへ送り、音楽を受け取る形で音楽を生成します。
3.コード
コードはGoogle Colabで動かす形にしてGithubに上げてありますので、それに沿って説明して行きます。自分で動かしてみたい方は、この「リンク」をクリックし表示されたノートブックの先頭にある「Open in Colab」ボタンをクリックすると動かせます。
まず、img2textをセットアップします。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#@title **setup img2text (CLIP Interrogator)** # install library !pip3 install ftfy regex tqdm transformers==4.15.0 timm==0.4.12 fairscale==0.4.4 !pip3 install git+https://github.com/openai/CLIP.git !git clone -b v1 https://github.com/pharmapsychotic/clip-interrogator.git !git clone https://github.com/salesforce/BLIP %cd /content/BLIP # import library import clip import gc import numpy as np import os import pandas as pd import requests import torch import torchvision.transforms as T import torchvision.transforms.functional as TF from IPython.display import display from PIL import Image from torch import nn from torch.nn import functional as F from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from models.blip import blip_decoder device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') blip_image_eval_size = 384 blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base') blip_model.eval() blip_model = blip_model.to(device) # difine function def generate_caption(pil_image): gpu_image = transforms.Compose([ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ])(image).unsqueeze(0).to(device) with torch.no_grad(): caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) return caption[0] def load_list(filename): with open(filename, 'r', encoding='utf-8', errors='replace') as f: items = [line.strip() for line in f.readlines()] return items def rank(model, image_features, text_array, top_count=1): top_count = min(top_count, len(text_array)) text_tokens = clip.tokenize([text for text in text_array]).cuda() with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) similarity = torch.zeros((1, len(text_array))).to(device) for i in range(image_features.shape[0]): similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity /= image_features.shape[0] top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] def interrogate(image, models): caption = generate_caption(image) if len(models) == 0: print(f"\n\n{caption}") return table = [] bests = [[('',0)]]*5 for model_name in models: print(f"Interrogating with {model_name}...") model, preprocess = clip.load(model_name) model.cuda().eval() images = preprocess(image).unsqueeze(0).cuda() with torch.no_grad(): image_features = model.encode_image(images).float() image_features /= image_features.norm(dim=-1, keepdim=True) ranks = [ rank(model, image_features, mediums), rank(model, image_features, ["by "+artist for artist in artists]), rank(model, image_features, trending_list), rank(model, image_features, movements), rank(model, image_features, flavors, top_count=3) ] for i in range(len(ranks)): confidence_sum = 0 for ci in range(len(ranks[i])): confidence_sum += ranks[i][ci][1] if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))): bests[i] = ranks[i] row = [model_name] for r in ranks: row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r])) table.append(row) del model gc.collect() display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"])) flaves = ', '.join([f"{x[0]}" for x in bests[4]]) medium = bests[0][0][0] if caption.startswith(medium): text = f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}" else: text = f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}" return text # setting data_path = "../clip-interrogator/data/" artists = load_list(os.path.join(data_path, 'artists.txt')) flavors = load_list(os.path.join(data_path, 'flavors.txt')) mediums = load_list(os.path.join(data_path, 'mediums.txt')) movements = load_list(os.path.join(data_path, 'movements.txt')) sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] trending_list = [site for site in sites] trending_list.extend(["trending on "+site for site in sites]) trending_list.extend(["featured on "+site for site in sites]) trending_list.extend([site+" contest winner" for site in sites]) # download sample pics import gdown gdown.download('https://drive.google.com/uc?id=1Mjwnr_m3pgxTPB7kusePjmKDHktGeV2w', 'pics.zip', quiet=False) ! unzip pics.zip |
次に、text2musicをセットアップします。このとき、emailには自分のemailアドレス(事前申請は不要)を記入して実行します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
#@title **setup text2music (Mubert)** # install library ! pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 ! pip install -U sentence-transformers ! pip install httpx # import library import numpy as np from sentence_transformers import SentenceTransformer minilm = SentenceTransformer('all-MiniLM-L6-v2') mubert_tags_string = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic' mubert_tags = np.array(mubert_tags_string.split(',')) mubert_tags_embeddings = minilm.encode(mubert_tags) from IPython.display import Audio, display import httpx import json # difine function def get_track_by_tags(tags, pat, duration, maxit=20, autoplay=False, loop=False): if loop: mode = "loop" else: mode = "track" r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', json={ "method":"RecordTrackTTM", "params": { "pat": pat, "duration": duration, "tags": tags, "mode": mode } }) rdata = json.loads(r.text) assert rdata['status'] == 1, rdata['error']['text'] trackurl = rdata['data']['tasks'][0]['download_link'] print('Generating track ', end='') for i in range(maxit): r = httpx.get(trackurl) if r.status_code == 200: display(Audio(trackurl, autoplay=autoplay)) break time.sleep(1) print('.', end='') def find_similar(em, embeddings, method='cosine'): scores = [] for ref in embeddings: if method == 'cosine': scores.append(1 - np.dot(ref, em)/(np.linalg.norm(ref)*np.linalg.norm(em))) if method == 'norm': scores.append(np.linalg.norm(ref - em)) return np.array(scores), np.argsort(scores) def get_tags_for_prompts(prompts, top_n=3, debug=False): prompts_embeddings = minilm.encode(prompts) ret = [] for i, pe in enumerate(prompts_embeddings): scores, idxs = find_similar(pe, mubert_tags_embeddings) top_tags = mubert_tags[idxs[:top_n]] top_prob = 1 - scores[idxs[:top_n]] if debug: print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n") ret.append((prompts[i], list(top_tags))) return ret # get token email = "" #@param {type:"string"} r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess', json={ "method":"GetServiceAccess", "params": { "email": email, "license":"ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5", "token":"4951f6428e83172a4f39de05d5b3ab10d58560b8", "mode": "loop" } }) rdata = json.loads(r.text) assert rdata['status'] == 1, "probably incorrect e-mail" pat = rdata['data']['pat'] print(f'Got token: {pat}') |
それでは、画像からテキストを生成します。imgで画像ファイルを指定して実行します。ここでは、サンプル画像 01.jpg〜04.jpg の中から 02.jpgを指定します。
自分の用意した画像を使いたい場合は、事前に BLIP/pics フォルダその画像をアップロードしておいて下さい。
1 2 3 4 5 6 7 8 9 10 11 |
#@title **img2text** img = "02.jpg" #@param {type:"string"} image_path ='pics/'+ img image = Image.open(image_path).convert('RGB') thumb = image.copy() thumb.thumbnail([blip_image_eval_size, blip_image_eval_size]) display(thumb) text = interrogate(image, models=['ViT-L/14']) print('text = ', text) |
5つの属性(Medium, Artist, Trending, Movement, Flavors)に関して、類似性トップの内容が表示され、それを踏まえてテキストが生成されています。
そして、テキストから音楽を生成します。duration に曲の長さ(秒)、曲をループしたい場合には loop にチェックを入れて実行します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
#@title **text2music🎵** import time print('text = ', text) prompt = text duration = 30 #@param {type:"number"} loop = False #@param {type:"boolean"} def generate_track_by_prompt(prompt, duration, loop=False): _, tags = get_tags_for_prompts([prompt,])[0] print('tags = ', tags) try: get_track_by_tags(tags, pat, duration, autoplay=True, loop=loop) except Exception as e: print(str(e)) print('\n') generate_track_by_prompt(prompt, duration, loop) |
text からtags( ‘cyberpunk’, ‘art’, ‘artists’ )を抽出してサーバーへ送っていることが分かります。同じtagsでも、乱数処理を行なっているので、毎回ジャンルは同じですが異なる音楽が生成されます。
他のサンプルでもやってみましょう。img = 03.jpgです。
もう1つやってみましょう。今度は、img = 04.jpg です。
いかがだったでしょうか?自分のイメージにピッタリという訳にはいきませんが、画像から音楽を作るというのは刺激的なタスクですよね。
では、また。
(オリジナルgithub1)https://github.com/pharmapsychotic/clip-interrogator
(オリジナルgithub2)https://github.com/MubertAI/Mubert-Text-to-Music
最後にtext2musicを実行した後、作成された音楽ファイルはどこに保存されるのでしょうか?
GoogleColob上で動かしたのですが、そこから先が分からず、、、
佐藤さん
IPython.display.Audioを使って、Munbert社のサーバーで生成したmp3ファイルをテンポラリ(不明)にダウンロードして再生していますね。自分のPCに保存したい場合は、再生表示の右端をクリックすればダウンロード出来ます。
もし、コード内でそのmp3ファイルを使いたい場合は、get_track_tags()関数内にtrackurl変数があり、それがサーバーで生成したmp3ファイルのURLを示しているので、そこから別途ダウンロードすれば良いと思います。