Stable Diffusion は、文章のプロンプトから画像を生成します。

ここで2つのプロンプトを考えます。

p1 = バナナの写真(”High-resolution photo of bananas”)
p2 = オレンジの写真(”High-resolution photo of oranges”)

このp1とp2の中間のプロンプトで画像を生成させる方法があります。p1とp2の混合比率を変えることも可能です。KerasCVでのstable Diffusion を使った実装がこちらで紹介されています。

具体的には、p1とp2から変換された行列c1 とc2に対して、その中間表現c_tを生成し(c_t = c1 + t(c2 – c1), t = 0, 0.1, 0.2, …, 1.0 というイメージです) 、c_tから画像を作るという方法です。

これと同様なことをローカルなStable Diffusionで試してみたのが以下の画像です。ここでは、c1(バナナ)からc2(オレンジ)までの7点を作成し、画像にしました。

生成された画像は、バナナから徐々にオレンジに変わっていっているように見えます。特に3枚目でバナナともオレンジとも似ているような中間の物体が出てきているのが面白いです。


コードは以下です。

import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from einops import rearrange
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.plms import PLMSSampler
# from ldm.models.diffusion.ddim import DDIMSampler


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)
    model.cuda()
    model.eval()
    return model

def concat_images_with_border(imgs, border_width=2):
    """
    複数の画像を縦に連結して一つの画像を生成し、各画像に3ピクセルの黒色縁を追加する関数。
    """
    width, height = imgs[0].size
    concatenated_height = sum(img.size[1] for img in imgs) + (len(imgs) - 1) * border_width
    concatenated_image = Image.new("RGB", (width, concatenated_height), (0, 0, 0))
    y_offset = 0
    for i, img in enumerate(imgs):
        # 画像に縁を追加
        img_with_border = ImageOps.expand(img, border=border_width, fill=(0, 0, 0))
        concatenated_image.paste(img_with_border, (0, y_offset))
        y_offset += img_with_border.size[1]
        # 最後の画像以外には縦の空白スペースを追加
        if i < len(imgs) - 1:
            y_offset += border_width
    return concatenated_image    

class opt:
    seed=1
    prompt=''
    H=256   # 出力画像の高さ
    W=512   # 出力画像の幅
    ddim_steps=50  # ddimサンプル数 n_iterとの違い不明
    ddim_eta=0.0  # 0.0でサンプリングが決定的になる
    f=8              # ダウンサンプリングファクター
    n_iter=2  # サンプル数、画像の数
    scale=7.5  # unconditional guidance scale
    C=4     # 潜在変数のチャンネル数
    n_samples=1  # 一つのプロンプトから生成する画像の数
    ckpt='sd-v1-4.ckpt'
    config='configs/stable-diffusion/v1-inference.yaml'
    dpm_solver=False
    fixed_code=False
    from_file=None
    laion400m=False
    n_rows=0
    outdir='outputs/txt2img-samples'
    precision='autocast'
    skip_grid=False
    skip_save=False

print("モデルのロード----------")
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, 'sd-v1-4.ckpt') 
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

## --------ここから
p1 = "High-resolution photo of bananas"
p2 = "High-resolution photo of oranges"
n_divisions = 7
opt.seed = 1 # 13

print("プロンプトから潜在変数を作成----------") 
seed_everything(opt.seed) # seed値
c1 = model.get_learned_conditioning(p1) # promptからc[1,77,768]を生成
c2 = model.get_learned_conditioning(p2) # promptからc[1,77,768]を生成
steps = np.linspace(0, 1, n_divisions)
cs = []
for t in steps:
    cs.append(torch.lerp(c1, c2, t))

print("潜在変数を条件にしてサンプリング----------")  # 時間がかかる
# sampler = DDIMSampler(model)
sampler = PLMSSampler(model)

shape = [opt.C, opt.H // opt.f, opt.W // opt.f]  # # 出力のサイズ [4, 32, 64]
batch_size = opt.n_samples
uc = model.get_learned_conditioning(batch_size * [""])
start_code = None

imgs = []
for c in cs:
    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                        conditioning=c,
                                        batch_size=opt.n_samples,
                                        shape=shape,
                                        verbose=False,
                                        unconditional_guidance_scale=opt.scale,
                                        unconditional_conditioning=uc,
                                        eta=opt.ddim_eta,
                                        x_T=start_code)
    # samples_ddimの次元は [1, 4 ,32, 64]

    print("潜在変数から、画像を作成----------")  # 一瞬
    x_samples_ddim = model.decode_first_stage(samples_ddim)
    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
    x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
    # x_samples_ddimの次元は (1, 256, 512, 3)

    print("画像表示----------")  # 一瞬
    x_checked_image = x_samples_ddim
    x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

    x_sample = x_checked_image_torch[0]
    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
    img = Image.fromarray(x_sample.astype(np.uint8))

    # 結果の表示
    display(img)
    imgs.append(img)

concat_img = concat_images_with_border(imgs)
display(concat_img)

# 解放
del model, sampler, device