Connect with us

人工智能

理解扩散模型:生成式AI深度解析

mm
Understanding Diffusion Models: A Deep Dive into Generative AI

扩散模型已成为生成式AI中一种强大的方法,在图像、音频和视频生成方面取得了最先进的结果。在这篇深入的技术文章中,我们将探讨扩散模型的工作原理、其关键创新以及它们为何如此成功。我们将涵盖这一激动人心的新技术的数学基础、训练过程、采样算法和前沿应用。

扩散模型简介

扩散模型是一类生成模型,通过学习逆转扩散过程来逐渐对数据进行去噪。其核心思想是从纯噪声开始,迭代地将其精炼成来自目标分布的高质量样本。

这种方法受到非平衡热力学的启发——特别是逆转扩散以恢复结构的过程。在机器学习的背景下,我们可以将其视为学习逆转向数据逐步添加噪声的过程。

扩散模型的一些关键优势包括:

  • 最先进的图像质量,在许多情况下超越了GANs
  • 稳定的训练,没有对抗性动态
  • 高度可并行化
  • 灵活的架构——任何将输入映射到相同维度输出的模型都可以使用
  • 坚实的理论基础

让我们更深入地了解扩散模型的工作原理。

来源:Song等人

来源:Song等人

随机微分方程支配着扩散模型中的正向和逆向过程。正向SDE向数据添加噪声,逐渐将其转化为噪声分布。逆向SDE在学习的分数函数引导下,逐步去除噪声,从而从随机噪声中生成逼真的图像。这种方法是在连续状态空间中实现高质量生成性能的关键。

正向扩散过程

正向扩散过程从真实数据分布中采样的数据点x₀开始,并在T个时间步长内逐渐添加高斯噪声,产生越来越嘈杂的版本x₁, x₂, …, xT。

在每个时间步t,我们根据以下公式添加少量噪声:

x_t = √(1 - β_t) * x_{t-1} + √(β_t) * ε

其中:

  • β_t 是一个方差调度,控制每个步骤添加多少噪声
  • ε 是随机高斯噪声

这个过程一直持续到xT几乎是纯高斯噪声为止。

从数学上讲,我们可以将其描述为一个马尔可夫链:

q(x_t | x_{t-1}) = N(x_t; √(1 - β_t) * x_{t-1}, β_t * I)

其中N表示高斯分布。

β_t调度通常选择为早期时间步长较小,并随时间增加。常见的选择包括线性、余弦或S形调度。

逆向扩散过程

扩散模型的目标是学习这个过程的逆过程——从纯噪声xT开始,逐步去噪以恢复干净的样本x₀。

我们将这个逆向过程建模为:

p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_θ^2(x_t, t))

其中μ_θ和σ_θ^2是由θ参数化的学习函数(通常是神经网络)。

关键的创新在于,我们不需要显式地建模完整的逆向分布。相反,我们可以根据我们已知的正向过程来参数化它。

具体来说,我们可以证明最优的逆向过程均值μ*是:

μ* = 1/√(1 - β_t) * (x_t - β_t/√(1 - α_t) * ε_θ(x_t, t))

其中:

  • α_t = 1 – β_t
  • ε_θ 是一个学习的噪声预测网络

这给了我们一个简单的目标——训练一个神经网络ε_θ来预测每个步骤添加的噪声。

训练目标

扩散模型的训练目标可以从变分推断中推导出来。经过一些简化后,我们得到一个简单的L2损失:

L = E_t,x₀,ε [ ||ε - ε_θ(x_t, t)||² ]

其中:

  • t 是从1到T均匀采样的
  • x₀ 是从训练数据中采样的
  • ε 是采样的高斯噪声
  • x_t 是通过根据正向过程向x₀添加噪声构建的

换句话说,我们正在训练模型来预测每个时间步添加的噪声。

模型架构

U-Net架构是扩散模型中降噪步骤的核心。它具有编码器-解码器结构,并带有跳跃连接,有助于在重建过程中保留精细细节。编码器逐步下采样输入图像,同时捕获高级特征,解码器对编码后的特征进行上采样以重建图像。这种架构在需要精确定位的任务中特别有效,例如图像分割。

噪声预测网络 ε_θ 可以使用任何将输入映射到相同维度输出的架构。U-Net风格的架构是一个流行的选择,特别是对于图像生成任务。

典型的架构可能如下所示:

 class DiffusionUNet(nn.Module): def __init__(self): super().__init__() # 下采样 self.down1 = UNetBlock(3, 64) self.down2 = UNetBlock(64, 128) self.down3 = UNetBlock(128, 256) # 瓶颈层 self.bottleneck = UNetBlock(256, 512) # 上采样 self.up3 = UNetBlock(512, 256) self.up2 = UNetBlock(256, 128) self.up1 = UNetBlock(128, 64) # 输出 self.out = nn.Conv2d(64, 3, 1) def forward(self, x, t): # 嵌入时间步 t_emb = self.time_embedding(t) # 下采样 d1 = self.down1(x, t_emb) d2 = self.down2(d1, t_emb) d3 = self.down3(d2, t_emb) # 瓶颈层 bottleneck = self.bottleneck(d3, t_emb) # 上采样 u3 = self.up3(torch.cat([bottleneck, d3], dim=1), t_emb) u2 = self.up2(torch.cat([u3, d2], dim=1), t_emb) u1 = self.up1(torch.cat([u2, d1], dim=1), t_emb) # 输出 return self.out(u1) 

关键组件包括:

  • 带有跳跃连接的U-Net风格架构
  • 时间嵌入以根据时间步进行条件化
  • 灵活的深度和宽度

采样算法

一旦我们训练好了噪声预测网络ε_θ,我们就可以用它来生成新的样本。基本的采样算法是:

  1. 从纯高斯噪声xT开始
  2. 对于 t = T 到 1:
    • 预测噪声:ε_θ(x_t, t)
    • 计算均值:μ = 1/√(1-β_t) * (x_t - β_t/√(1-α_t) * ε_θ(x_t, t))
    • 采样: x_{t-1} ~ N(μ, σ_t^2 * I)
  3. 返回 x₀

这个过程在我们学习的噪声预测网络的引导下,逐渐对样本进行去噪。

在实践中,有各种采样技术可以提高质量或速度:

  • DDIM采样:一种确定性变体,允许更少的采样步骤
  • 祖先采样:结合学习到的方差σ_θ^2
  • 截断采样:提前停止以加快生成速度

以下是采样算法的基本实现:

 def sample(model, n_samples, device): # 从纯噪声开始 x = torch.randn(n_samples, 3, 32, 32).to(device) for t in reversed(range(1000)): # 添加噪声以创建 x_t t_batch = torch.full((n_samples,), t, device=device) noise = torch.randn_like(x) x_t = add_noise(x, noise, t) # 预测并去除噪声 pred_noise = model(x_t, t_batch) x = remove_noise(x_t, pred_noise, t) # 为下一步添加噪声(t=0时除外) if t > 0: noise = torch.randn_like(x) x = add_noise(x, noise, t-1) return x 

扩散模型背后的数学原理

要真正理解扩散模型,深入研究支撑它们的数学原理至关重要。让我们更详细地探讨一些关键概念:

马尔可夫链与随机微分方程

扩散模型中的正向扩散过程可以视为一个马尔可夫链,或者在连续极限下,视为一个随机微分方程。SDE公式为分析和扩展扩散模型提供了一个强大的理论框架。

正向SDE可以写成:

dx = f(x,t)dt + g(t)dw

其中:

  • f(x,t) 是漂移项
  • g(t) 是扩散系数
  • dw 是一个维纳过程(布朗运动)

f和g的不同选择会导致不同类型的扩散过程。例如:

  • 方差爆炸(VE) SDE: dx = √(d/dt σ²(t)) dw
  • 方差保持(VP) SDE: dx = -0.5 β(t)xdt + √(β(t)) dw

理解这些SDE使我们能够推导出最优采样策略,并将扩散模型扩展到新的领域。

分数匹配与去噪分数匹配

扩散模型与分数匹配之间的联系提供了另一个有价值的视角。分数函数定义为对数概率密度的梯度:

s(x) = ∇x log p(x)

去噪分数匹配旨在通过训练一个模型去噪轻微扰动的数据点来估计这个分数函数。这个目标在连续极限下被证明等同于扩散模型的训练目标。

这种联系使我们能够利用基于分数的生成建模技术,例如用于采样的退火

I have spent the past five years immersing myself in the fascinating world of Machine Learning and Deep Learning. My passion and expertise have led me to contribute to over 50 diverse software engineering projects, with a particular focus on AI/ML. My ongoing curiosity has also drawn me toward Natural Language Processing, a field I am eager to explore further.