Stable Diffusion 3及3.5开始,包括flux编辑模型,都不再采用传统DDPM的噪声扩散模式,而是采用流匹配的形式。Flow Matching总结是:简单有效。
- 这篇文章里都采用sd3.5作为backbone进行研究。
原理
Flow Matching 核心机制 —— 从“预测噪声”到“预测速度场”
在 SD 1.5 和 SDXL 时代,扩散模型建立在马尔可夫链和随机微分方程(SDE)之上。模型通过预测“噪声(Noise)”来一步步艰难地将高斯分布还原为真实图像。这种模式不仅数学推导极其繁琐,而且生成轨迹往往是一条极其曲折的曲线,导致采样效率低下。
从 Stable Diffusion 3/3.5 到 Flux,行业全面转向了 Flow Matching(流匹配),更具体的实现形式是 Rectified Flow(校正流)。它的核心思想可以用六个字概括:两点之间,直线最短。
1. 数学直觉:强制拉直的生成轨迹
如果我们有一个纯噪声张量 $x_0 \sim \mathcal{N}(0, I)$,和一个目标真实图像 $x_1 \sim P_{data}$。传统 DDPM 会在两者之间构建一个复杂的加噪/去噪随机漫步过程。
而 Flow Matching 选择了一种极其暴力的美学:直接在噪声 $x_0$ 和图像 $x_1$ 之间连一条直线。
我们可以定义一个最简单的线性插值路径 $x_t$:
$$x_t = (1-t)x_0 + t x_1$$
(注:此处 $t$ 从 0 走向 1,代表从噪声走向图像。在具体的代码实现如 diffusers 中,$t$ 通常对应时间步或噪声比例 $\sigma$ 的衰减)。
2. 模型在学什么?—— 预测速度 (Velocity)
既然轨迹是预先定义好的直线,我们对上面的路径公式求时间 $t$ 的导数,就能得到一个恒定的速度场 (Vector Field):
$$\frac{dx_t}{dt} = x_1 – x_0$$
在 Flow Matching 架构下,以 SD3.5 的 DiT (Diffusion Transformer) 为例,模型不再预测图像上的噪声残差,而是直接预测这个流向目标图像的速度向量 $v_\theta(x_t, t)$。
换句话说,模型在学习:当你站在潜在空间(Latent Space)的任意一个点 $x_t$ 时,风是往哪个方向吹的,风速有多大。
3. 采样过程:极简的常微分方程 (ODE) 求解
由于放弃了复杂的随机微分方程,流匹配的生成过程退化为了求解一个简单的常微分方程(ODE)。这使得我们可以使用极其轻量级的数值求解器(如 Euler 法)。
在代码层面,前向采样的核心逻辑被简化为极其优雅的一行欧拉步进:
$$x_{t+dt} = x_t + v_\theta(x_t, t) \cdot dt$$
因为 Rectified Flow 在训练时强迫轨迹尽可能笔直(曲率极低),这种最基础的一阶 Euler 求解器就能在非常少的步数(如 20-28 步)内完美收敛,直接走到终点生成高清图像。
4. CFG (无分类器引导) 的几何意义
在 Flow Matching 中,引入 Prompt 控制的 Classifier-Free Guidance (CFG) 不再是对噪声的加减,而是速度向量的外推。
模型会同时预测无条件风向 $v_{uncond}$ 和有条件风向 $v_{text}$,最终的引导速度为:
$$v_{guided} = v_{uncond} + \omega \cdot (v_{text} – v_{uncond})$$
其中 $\omega$ 即为 guidance_scale。这相当于在多维空间中,强行将速度向量朝着提示词指引的方向做延长和扭曲,从而让最终的落点(生成的图像)更加贴合文本语义。
Inverse过程:理论上的无损反演
1. ODE 带来的绝对确定性 (Deterministic Trajectory)
在传统 DDPM 的 SDE (随机微分方程) 框架下,正向加噪和逆向去噪都伴随着随机采样。这就好比在暴风雪中寻路,即便你想原路返回,走过的脚印也已经被随机的新雪覆盖,无法精确溯源。但 Flow Matching 采用的是常微分方程 (ODE)。ODE 的最大魅力在于其绝对的确定性——只要给定了初始条件和速度场,粒子的运动轨迹是唯一且可逆的。理论上,正向生成是顺着时间步前进:$x_{t+dt} = x_t + v_\theta(x_t, t) \cdot dt$,那么逆向反演只需要把时间步和速度场反转,就能严丝合缝地退回原始噪声。这种无损反演的特性在精确图像编辑、隐形水印的嵌入与提取等要求极高保真度的底层视觉任务中,是不可或缺的基石。
2. 欧拉求解器的离散化截断误差 (Discretization Error)
理论虽然完美,但实际代码中使用的 Euler(欧拉)一阶步进法打破了无损的神话。在离散的计算中,时间步 $dt$ 是有限大小的。当模型在正向起点 $A$ 计算出切线速度 $v_A$ 并直行走到 $B$ 点时,由于真实的流匹配轨迹带有曲率,$B$ 点的切线速度 $v_B$ 与 $v_A$ 并不相等。因此,当我们试图从终点 $B$ 反推时,模型只能基于当前的 $v_B$ 的切线速度给指路:$A’ = B – v_B \cdot dt$。由于 $v_A \neq v_B$,反推出来的起点 $A’$ 必然偏离真正的 $A$。这种局部截断误差在多次循环中不断累加,导致单纯依靠颠倒循环的朴素反演(Naive Inversion)无法做到真正意义上的 $100\%$ 无损。但是增加步数(step)会有极大的提升。因为step越多,间隔越小,一次step走的路径偏差也就小。
3. CFG (无分类器引导) 带来的轨迹扭曲。
如果说欧拉截断误差是原罪,那么 CFG 就是将误差急剧放大的杠杆。要沿着原路退回,数学上要求逆向反推必须使用与正向生成完全一致的 CFG Scale。然而,高 CFG 会极大地加剧速度场向量的非线性程度,把原本平缓的“微弯直道”强行扭曲成“急转弯”。这种高曲率使得前文提到的截断误差被指数级放大。如果为了降低曲率、平滑轨迹而在反推时关闭 CFG(设为 1.0),则相当于直接换了一套完全不同的速度场,从根本上违背了原路径,导致算出的噪声完全是另一回事。
4. VAE 编解码与 8-bit 量化的“信息黑洞”
抛开纯 Latent 空间的数学博弈,在工程落地时还横亘着无法逾越的物理屏障。第一层是 VAE 自身的非完美对称性作为有损自编码器,$Encode(Decode(z))$ 永远存在微小畸变。第二层则是致命的图像格式截断:将 Float16 精度的高维张量解码,并强制 Clamp 截断后保存为 0-255 的 8-bit RGB 图像(如 PNG)时,小数点后的海量微观特征瞬间湮灭。当这张图像被重新读取并 Encode 回去寻找起点时,坐标早已发生了数十 dB 的灾难性偏移。基于这样一个千疮百孔的起点去逆推 ODE,无异于刻舟求剑。
代码
SD3.5-medium
Forward(不用Diffusers)
from tqdm import tqdm
import torch
import torch.nn.functional as F
from diffusers import (
SD3Transformer2DModel,
FlowMatchEulerDiscreteScheduler,
AutoencoderKL
)
from transformers import (
CLIPTokenizer,
CLIPTextModelWithProjection,
T5TokenizerFast,
T5EncoderModel
)
from utils import calculate_latent_metrics
from PIL import Image
import numpy as np
device = "cuda"
dtype = torch.float16
model_id = "stabilityai/stable-diffusion-3.5-medium"
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
negative_prompt = "" # 引入负向 prompt(空字符串)
guidance_scale = 4.5 # SD3.5 Medium 默认的 CFG scale
steps = 28
def init_models():
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler"
)
transformer = SD3Transformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
torch_dtype=dtype
).to(device)
vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=dtype
).to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
tokenizer_3 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_3")
text_encoder = CLIPTextModelWithProjection.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=dtype
).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=dtype
).to(device)
text_encoder_3 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_3", torch_dtype=dtype
).to(device)
return scheduler, transformer, vae, tokenizer, tokenizer_2, tokenizer_3, text_encoder, text_encoder_2, text_encoder_3
scheduler, transformer, vae, tokenizer, tokenizer_2, tokenizer_3, text_encoder, text_encoder_2, text_encoder_3 = init_models()
# ==========================================
# 1. 文本编码 (正向 + 负向)
# ==========================================
def encode_prompt(text):
# Tokenize
t1 = tokenizer(text, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
t2 = tokenizer_2(text, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
t3 = tokenizer_3(text, padding="max_length", max_length=256, truncation=True, return_tensors="pt")
# Encode
with torch.no_grad():
out1 = text_encoder(t1.input_ids.to(device), output_hidden_states=True)
out2 = text_encoder_2(t2.input_ids.to(device), output_hidden_states=True)
out3 = text_encoder_3(t3.input_ids.to(device))
# CLIP 取倒数第二层,T5 取最后一层
emb1 = out1.hidden_states[-2]
emb2 = out2.hidden_states[-2]
emb3 = out3.last_hidden_state
pooled1 = out1.text_embeds
pooled2 = out2.text_embeds
# 维度对齐与拼接
clip_embeds = torch.cat([emb1, emb2], dim=-1) # (1, 77, 2048)
clip_embeds = F.pad(clip_embeds, (0, 4096 - 2048)) # (1, 77, 4096)
text_embeddings = torch.cat([clip_embeds, emb3], dim=1) # (1, 333, 4096)
pooled_embeds = torch.cat([pooled1, pooled2], dim=-1) # (1, 2048)
return text_embeddings, pooled_embeds
# 分别获取正向和无条件(负向)的 Embeddings
pos_text_emb, pos_pooled_emb = encode_prompt(prompt)
neg_text_emb, neg_pooled_emb = encode_prompt(negative_prompt)
# 在 Batch 维度拼接 (Uncond 在前,Cond 在后,这是 diffusers 的习惯)
# shape: (2, 333, 4096) 和 (2, 2048)
batched_text_embeddings = torch.cat([neg_text_emb, pos_text_emb], dim=0)
batched_pooled_embeds = torch.cat([neg_pooled_emb, pos_pooled_emb], dim=0)
# ==========================================
# 2. Latent 初始化与 Scheduler 设置
# ==========================================
ori_latents = torch.randn((1, 16, 128, 128), device=device, dtype=dtype)
latents = ori_latents.clone()
scheduler.set_timesteps(steps)
timesteps = scheduler.timesteps
sigmas = scheduler.sigmas
print(f"len(scheduler.timesteps)) # 输出: {steps}")
print(f"len(scheduler.sigmas)) # 输出: {steps+1}")
# ==========================================
# 3. Flow Matching 采样循环 (带 CFG)
# ==========================================
print("---Forward Sampling... ---")
for i, t in tqdm(enumerate(timesteps)):
# 为了 CFG,将 latents 复制一份以匹配 batch size = 2
latent_model_input = torch.cat([latents] * 2, dim=0)
timestep = torch.full((latent_model_input.shape[0],), int(t), device=device, dtype=torch.long)
with torch.no_grad():
velocity_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=batched_text_embeddings,
pooled_projections=batched_pooled_embeds,
return_dict=False
)[0]
# 执行 Classifier-Free Guidance
velocity_uncond, velocity_text = velocity_pred.chunk(2)
velocity_pred = velocity_uncond + guidance_scale * (velocity_text - velocity_uncond)
# Euler 更新步进
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
dt = sigma_next - sigma
latents = latents + velocity_pred * dt
# ==========================================
# 4. VAE Decode (shift_factor)
# ==========================================
latents_for_vae = (latents / vae.config.scaling_factor) + vae.config.shift_factor
with torch.no_grad():
image = vae.decode(latents_for_vae).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).astype(np.uint8))
image.save("output/sd3_forward_with_cfg.png")
直接从Latent Inverse
# ==========================================
# 高精度反演测试 (接在正向生成循环结束后,VAE Decode 之前)
# ==========================================
print("\n--- 开始纯 Latent 闭环反演测试 ---")
# 1. 最初的噪声ori_latents
# 2. 直接获取正向生成结束时的 Latent (不经过 VAE)
latents_inv = latents.clone()
# 重置为更高的timestep,以确保足够的反演步骤
scheduler.set_timesteps(steps)
# 3. 翻转时间步和 sigma
timesteps_rev = scheduler.timesteps.flip(0)
sigmas_rev = scheduler.sigmas.flip(0)
# 【关键修改】:必须使用完全相同的 CFG 才能原路返回
inv_guidance_scale = 4.5
for i, t in tqdm(enumerate(timesteps_rev)):
latent_model_input = torch.cat([latents_inv] * 2, dim=0)
timestep = torch.full((latent_model_input.shape[0],), int(t), device=device, dtype=torch.long)
with torch.no_grad():
velocity_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=batched_text_embeddings,
pooled_projections=batched_pooled_embeds,
return_dict=False
)[0]
velocity_uncond, velocity_text = velocity_pred.chunk(2)
# 使用和正向一致的 CFG
velocity_pred = velocity_uncond + inv_guidance_scale * (velocity_text - velocity_uncond)
sigma = sigmas_rev[i]
sigma_next = sigmas_rev[i + 1]
dt = sigma_next - sigma # 正向反推,dt 为正数
latents_inv = latents_inv + velocity_pred * dt
# 计算指标
print("\n--- 闭环反演相似度结果 ---")
# 调用你之前的 calculate_latent_metrics 函数
mse, cos_sim, psnr = calculate_latent_metrics(ori_latents, latents_inv)
print(f"PSNR: {psnr:.2f} dB, MSE: {mse:.6f}, Cosine Similarity: {cos_sim:.6f}")
从图片Inverse
'''
Inverse
'''
from torchvision import transforms
print("\n--- 开始逆向反演 (Inversion) ---")
init_image = Image.open("output/sd3_forward_with_cfg.png").convert("RGB")
image_tensor = transforms.ToTensor()(init_image).unsqueeze(0).to(device, dtype)
image_tensor = image_tensor * 2.0 - 1.0
with torch.no_grad():
latents_inv = vae.encode(image_tensor).latent_dist.mean
latents_inv = (latents_inv - vae.config.shift_factor) * vae.config.scaling_factor
print("\n--- 纯 VAE + PNG 压缩带来的误差 ---")
mse, cos_sim, psnr = calculate_latent_metrics(latents_inv, latents_for_vae)
print(f"PSNR: {psnr:.2f} dB, MSE: {mse:.6f}, Cosine Similarity: {cos_sim:.6f}")
scheduler.set_timesteps(steps)
timesteps_rev = scheduler.timesteps.flip(0)
sigmas_rev = scheduler.sigmas.flip(0)
inv_guidance_scale = 4.5
for i, t in tqdm(enumerate(timesteps_rev)):
latent_model_input = torch.cat([latents_inv] * 2, dim=0)
timestep = torch.full((latent_model_input.shape[0],), int(t), device=device, dtype=torch.long)
with torch.no_grad():
velocity_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=batched_text_embeddings,
pooled_projections=batched_pooled_embeds,
return_dict=False
)[0]
velocity_uncond, velocity_text = velocity_pred.chunk(2)
velocity_pred = velocity_uncond + inv_guidance_scale * (velocity_text - velocity_uncond)
sigma = sigmas_rev[i]
sigma_next = sigmas_rev[i + 1]
dt = sigma_next - sigma
latents_inv = latents_inv + velocity_pred * dt
print("逆向完成!获得的初始噪声 shape:", latents_inv.shape)
mse, cos_sim, psnr = calculate_latent_metrics(ori_latents, latents_inv)
print(f"PSNR: {psnr:.2f} dB, MSE: {mse:.6f}, Cosine Similarity: {cos_sim:.6f}")