人工智能
理解扩散模型:生成式AI深度解析
扩散模型已成为生成式AI中一种强大的方法,在图像、音频和视频生成方面取得了最先进的结果。在这篇深入的技术文章中,我们将探讨扩散模型的工作原理、其关键创新以及它们为何如此成功。我们将涵盖这一激动人心的新技术的数学基础、训练过程、采样算法和前沿应用。
扩散模型简介
扩散模型是一类生成模型,通过学习逆转扩散过程来逐渐对数据进行去噪。其核心思想是从纯噪声开始,迭代地将其精炼成来自目标分布的高质量样本。
这种方法受到非平衡热力学的启发——特别是逆转扩散以恢复结构的过程。在机器学习的背景下,我们可以将其视为学习逆转向数据逐步添加噪声的过程。
扩散模型的一些关键优势包括:
- 最先进的图像质量,在许多情况下超越了GANs
- 稳定的训练,没有对抗性动态
- 高度可并行化
- 灵活的架构——任何将输入映射到相同维度输出的模型都可以使用
- 坚实的理论基础
让我们更深入地了解扩散模型的工作原理。

来源:Song等人
随机微分方程支配着扩散模型中的正向和逆向过程。正向SDE向数据添加噪声,逐渐将其转化为噪声分布。逆向SDE在学习的分数函数引导下,逐步去除噪声,从而从随机噪声中生成逼真的图像。这种方法是在连续状态空间中实现高质量生成性能的关键。
正向扩散过程
正向扩散过程从真实数据分布中采样的数据点x₀开始,并在T个时间步长内逐渐添加高斯噪声,产生越来越嘈杂的版本x₁, x₂, …, xT。
在每个时间步t,我们根据以下公式添加少量噪声:
x_t = √(1 - β_t) * x_{t-1} + √(β_t) * ε
其中:
- β_t 是一个方差调度,控制每个步骤添加多少噪声
- ε 是随机高斯噪声
这个过程一直持续到xT几乎是纯高斯噪声为止。
从数学上讲,我们可以将其描述为一个马尔可夫链:
q(x_t | x_{t-1}) = N(x_t; √(1 - β_t) * x_{t-1}, β_t * I)
其中N表示高斯分布。
β_t调度通常选择为早期时间步长较小,并随时间增加。常见的选择包括线性、余弦或S形调度。
逆向扩散过程
扩散模型的目标是学习这个过程的逆过程——从纯噪声xT开始,逐步去噪以恢复干净的样本x₀。
我们将这个逆向过程建模为:
p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_θ^2(x_t, t))
其中μ_θ和σ_θ^2是由θ参数化的学习函数(通常是神经网络)。
关键的创新在于,我们不需要显式地建模完整的逆向分布。相反,我们可以根据我们已知的正向过程来参数化它。
具体来说,我们可以证明最优的逆向过程均值μ*是:
μ* = 1/√(1 - β_t) * (x_t - β_t/√(1 - α_t) * ε_θ(x_t, t))
其中:
- α_t = 1 – β_t
- ε_θ 是一个学习的噪声预测网络
这给了我们一个简单的目标——训练一个神经网络ε_θ来预测每个步骤添加的噪声。
训练目标
扩散模型的训练目标可以从变分推断中推导出来。经过一些简化后,我们得到一个简单的L2损失:
L = E_t,x₀,ε [ ||ε - ε_θ(x_t, t)||² ]
其中:
- t 是从1到T均匀采样的
- x₀ 是从训练数据中采样的
- ε 是采样的高斯噪声
- x_t 是通过根据正向过程向x₀添加噪声构建的
换句话说,我们正在训练模型来预测每个时间步添加的噪声。
模型架构
U-Net架构是扩散模型中降噪步骤的核心。它具有编码器-解码器结构,并带有跳跃连接,有助于在重建过程中保留精细细节。编码器逐步下采样输入图像,同时捕获高级特征,解码器对编码后的特征进行上采样以重建图像。这种架构在需要精确定位的任务中特别有效,例如图像分割。
噪声预测网络 ε_θ 可以使用任何将输入映射到相同维度输出的架构。U-Net风格的架构是一个流行的选择,特别是对于图像生成任务。
典型的架构可能如下所示:
class DiffusionUNet(nn.Module): def __init__(self): super().__init__() # 下采样 self.down1 = UNetBlock(3, 64) self.down2 = UNetBlock(64, 128) self.down3 = UNetBlock(128, 256) # 瓶颈层 self.bottleneck = UNetBlock(256, 512) # 上采样 self.up3 = UNetBlock(512, 256) self.up2 = UNetBlock(256, 128) self.up1 = UNetBlock(128, 64) # 输出 self.out = nn.Conv2d(64, 3, 1) def forward(self, x, t): # 嵌入时间步 t_emb = self.time_embedding(t) # 下采样 d1 = self.down1(x, t_emb) d2 = self.down2(d1, t_emb) d3 = self.down3(d2, t_emb) # 瓶颈层 bottleneck = self.bottleneck(d3, t_emb) # 上采样 u3 = self.up3(torch.cat([bottleneck, d3], dim=1), t_emb) u2 = self.up2(torch.cat([u3, d2], dim=1), t_emb) u1 = self.up1(torch.cat([u2, d1], dim=1), t_emb) # 输出 return self.out(u1)














