Stable Diffusion は、文章のプロンプトから画像を生成します。
ここで2つのプロンプトを考えます。
p1 = バナナの写真(”High-resolution photo of bananas”)
p2 = オレンジの写真(”High-resolution photo of oranges”)
このp1とp2の中間のプロンプトで画像を生成させる方法があります。p1とp2の混合比率を変えることも可能です。KerasCVでのstable Diffusion を使った実装がこちらで紹介されています。
具体的には、p1とp2から変換された行列c1 とc2に対して、その中間表現c_tを生成し(c_t = c1 + t(c2 – c1), t = 0, 0.1, 0.2, …, 1.0 というイメージです) 、c_tから画像を作るという方法です。
これと同様なことをローカルなStable Diffusionで試してみたのが以下の画像です。ここでは、c1(バナナ)からc2(オレンジ)までの7点を作成し、画像にしました。
生成された画像は、バナナから徐々にオレンジに変わっていっているように見えます。特に3枚目でバナナともオレンジとも似ているような中間の物体が出てきているのが面白いです。
コードは以下です。
import torch import numpy as np from omegaconf import OmegaConf from PIL import Image, ImageOps from einops import rearrange from pytorch_lightning import seed_everything from ldm.util import instantiate_from_config from ldm.models.diffusion.plms import PLMSSampler # from ldm.models.diffusion.ddim import DDIMSampler def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.cuda() model.eval() return model def concat_images_with_border(imgs, border_width=2): """ 複数の画像を縦に連結して一つの画像を生成し、各画像に3ピクセルの黒色縁を追加する関数。 """ width, height = imgs[0].size concatenated_height = sum(img.size[1] for img in imgs) + (len(imgs) - 1) * border_width concatenated_image = Image.new("RGB", (width, concatenated_height), (0, 0, 0)) y_offset = 0 for i, img in enumerate(imgs): # 画像に縁を追加 img_with_border = ImageOps.expand(img, border=border_width, fill=(0, 0, 0)) concatenated_image.paste(img_with_border, (0, y_offset)) y_offset += img_with_border.size[1] # 最後の画像以外には縦の空白スペースを追加 if i < len(imgs) - 1: y_offset += border_width return concatenated_image class opt: seed=1 prompt='' H=256 # 出力画像の高さ W=512 # 出力画像の幅 ddim_steps=50 # ddimサンプル数 n_iterとの違い不明 ddim_eta=0.0 # 0.0でサンプリングが決定的になる f=8 # ダウンサンプリングファクター n_iter=2 # サンプル数、画像の数 scale=7.5 # unconditional guidance scale C=4 # 潜在変数のチャンネル数 n_samples=1 # 一つのプロンプトから生成する画像の数 ckpt='sd-v1-4.ckpt' config='configs/stable-diffusion/v1-inference.yaml' dpm_solver=False fixed_code=False from_file=None laion400m=False n_rows=0 outdir='outputs/txt2img-samples' precision='autocast' skip_grid=False skip_save=False print("モデルのロード----------") config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, 'sd-v1-4.ckpt') device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) ## --------ここから p1 = "High-resolution photo of bananas" p2 = "High-resolution photo of oranges" n_divisions = 7 opt.seed = 1 # 13 print("プロンプトから潜在変数を作成----------") seed_everything(opt.seed) # seed値 c1 = model.get_learned_conditioning(p1) # promptからc[1,77,768]を生成 c2 = model.get_learned_conditioning(p2) # promptからc[1,77,768]を生成 steps = np.linspace(0, 1, n_divisions) cs = [] for t in steps: cs.append(torch.lerp(c1, c2, t)) print("潜在変数を条件にしてサンプリング----------") # 時間がかかる # sampler = DDIMSampler(model) sampler = PLMSSampler(model) shape = [opt.C, opt.H // opt.f, opt.W // opt.f] # # 出力のサイズ [4, 32, 64] batch_size = opt.n_samples uc = model.get_learned_conditioning(batch_size * [""]) start_code = None imgs = [] for c in cs: samples_ddim, _ = sampler.sample(S=opt.ddim_steps, conditioning=c, batch_size=opt.n_samples, shape=shape, verbose=False, unconditional_guidance_scale=opt.scale, unconditional_conditioning=uc, eta=opt.ddim_eta, x_T=start_code) # samples_ddimの次元は [1, 4 ,32, 64] print("潜在変数から、画像を作成----------") # 一瞬 x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() # x_samples_ddimの次元は (1, 256, 512, 3) print("画像表示----------") # 一瞬 x_checked_image = x_samples_ddim x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) x_sample = x_checked_image_torch[0] x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') img = Image.fromarray(x_sample.astype(np.uint8)) # 結果の表示 display(img) imgs.append(img) concat_img = concat_images_with_border(imgs) display(concat_img) # 解放 del model, sampler, device
コメントを残す