Toriskia 's Blog

13 篇文章 · 12 个标签 · 6 个友链

← 返回文章列表

2026.03.26

Generative Adversarial Network

包含 AI 辅助生成内容

一、从显式建模到隐式生成#

1. 生成模型到底想做什么#

生成模型的目标是学习数据分布:

pdata(x)p_{\text{data}}(x)

在 EBM 里,我们显式写出一个概率密度:

pθ(x)=1Z(θ)eEθ(x)p_\theta(x)=\frac{1}{Z(\theta)}e^{-E_\theta(x)}

问题在于:

  • 概率密度往往难算
  • 配分函数 Z(θ)Z(\theta) 难算
  • 采样困难

所以另一种思路是:不直接建模密度,直接学习一个“会生成样本的函数”


2. 隐式生成模型#

令随机噪声:

zp(z),p(z)=N(0,I)z \sim p(z), \quad p(z)=\mathcal N(0,I)

生成器定义为:

x=Gθ(z)x=G_\theta(z)

这里的 GθG_\theta 不直接给出 pθ(x)p_\theta(x) 的解析式,它只负责把简单噪声映射成图像。

这类模型叫做 隐式生成模型(implicit generative model)

目标:

Gθ(z)pdataG_\theta(z) \sim p_{\text{data}}

也就是说,我们只关心生成出来的样本像不像真实数据,而不一定要写出显式似然。


二、GAN#

minθmaxϕExpdata[logDϕ(x)]+Ezp(z)[log(1Dϕ(Gθ(z)))]\min_\theta \max_\phi \mathbb E_{x\sim p_{\text{data}}}\big[\log D_\phi(x)\big] + \mathbb E_{z\sim p(z)}\big[\log(1-D_\phi(G_\theta(z)))\big]

1. 优化判别器 DϕD_\phi#

判别器 DϕD_\phi 的目标是:

maxϕExpdata[logDϕ(x)]+Ezp(z)[log(1Dϕ(Gθ(z)))]\max_\phi \mathbb E_{x\sim p_{\text{data}}}\big[\log D_\phi(x)\big] + \mathbb E_{z\sim p(z)}\big[\log(1-D_\phi(G_\theta(z)))\big]

Dϕ(x)D_\phi(x)(真数据)越接近 1 越好,Dϕ(Gθ(z))D_\phi(G_\theta(z))(生成的假数据)越接近 0 越好。

2. 优化生成器 GθG_\theta#

如果严格按照原始极小极大目标,生成器优化的是:

minθEzp(z)[log(1Dϕ(Gθ(z)))]\min_\theta \mathbb E_{z\sim p(z)}\big[\log(1-D_\phi(G_\theta(z)))\big]

调整 θ\theta 只会影响第二项,目标是让生成的假数据更像真数据,从而骗过判别器。

实践中更常用的 non-saturating 形式是:

maxθEzp(z)[logDϕ(Gθ(z))]\max_\theta \mathbb E_{z\sim p(z)}\big[\log D_\phi(G_\theta(z))\big]

它和原始目标动机一致,但梯度通常更强,后文会再解释。


三、优化了什么#

1. 固定生成器时,最优判别器是什么#

设生成器固定,此时生成分布为 pG(x)p_G(x)

对任意一个固定的 xx,目标里只和 D(x)D(x) 有关的部分是:

f(D)=pdata(x)logD+pG(x)log(1D)f(D)=p_{\text{data}}(x)\log D + p_G(x)\log(1-D)

DD 求导,令导数为 0:

fD=pdata(x)DpG(x)1D=0D(x)=pdata(x)pdata(x)+pG(x)\frac{\partial f}{\partial D} = \frac{p_{\text{data}}(x)}{D} - \frac{p_G(x)}{1-D} = 0 \quad\Rightarrow\quad D^*(x)=\frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_G(x)}

这个结果很重要,它说明:

  • 如果某个位置更像真实数据,则 D(x)D^*(x) 更接近 1
  • 如果某个位置更多来自生成器,则 D(x)D^*(x) 更接近 0
  • pG=pdatap_G=p_{\text{data}} 时,
D(x)=12D^*(x)=\frac12

此时判别器完全分不清真假。


2. 把最优判别器代回去#

D(x)D^*(x) 代回 GAN 目标:

V(D,G)=Expdata[logpdata(x)pdata(x)+pG(x)]+ExpG[logpG(x)pdata(x)+pG(x)]V(D^*,G) = \mathbb E_{x\sim p_{\text{data}}} \left[ \log \frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_G(x)} \right] + \mathbb E_{x\sim p_G} \left[ \log \frac{p_G(x)}{p_{\text{data}}(x)+p_G(x)} \right]

定义中间分布:

m(x)=12[pdata(x)+pG(x)]m(x)=\frac12\left[p_{\text{data}}(x)+p_G(x)\right]

则有:

V(D,G)=Expdata[logpdata(x)m(x)]+ExpG[logpG(x)m(x)]2log2=KL(pdatam)+KL(pGm)log4=2JSD(pdatapG)log4\begin{aligned} V(D^*,G) &= \mathbb E_{x\sim p_{\text{data}}} \left[\log \frac{p_{\text{data}}(x)}{m(x)}\right] + \mathbb E_{x\sim p_G} \left[\log \frac{p_G(x)}{m(x)}\right] -2\log 2 \\ &= KL(p_{\text{data}}\|m)+KL(p_G\|m)-\log 4 \\ &= 2\,JSD(p_{\text{data}}\|p_G)-\log 4 \end{aligned}

因此,在最优判别器下,GAN 本质上是在最小化:

JSD(pdatapG)JSD(p_{\text{data}}\|p_G)

全局最优点是:

pG=pdatap_G = p_{\text{data}}

此时:

V(D,G)=log4V(D^*,G^*)=-\log 4

四、GAN 的训练细节#

1. 交替优化#

实际训练不会真的每一步都把 DD 优化到最优,而是交替进行:

  1. 固定 GG,更新几步 DD
  2. 固定 DD,更新一步 GG
  3. 重复

判别器训练用标准二分类损失。


2. 生成器的常用损失#

严格按照极小极大目标,生成器最小化:

LGsat(θ)=Ezp(z)[log(1Dϕ(Gθ(z)))]\mathcal L_G^{\text{sat}}(\theta) = \mathbb E_{z\sim p(z)} \big[\log(1-D_\phi(G_\theta(z)))\big]

但这在训练早期常常梯度太弱,因为如果判别器很强,D(G(z))0D(G(z))\approx 0,生成器会很难学。

实践中更常用 non-saturating loss

LGns(θ)=Ezp(z)[logDϕ(Gθ(z))]\mathcal L_G^{\text{ns}}(\theta) = -\mathbb E_{z\sim p(z)} \big[\log D_\phi(G_\theta(z))\big]

它不改变“让假样本更像真样本”的目标,但通常更容易优化。


五、GAN 的评估#

1. 为什么 GAN 难评估#

GAN 是隐式生成模型,只会采样,不直接给出显式密度:

pθ(x)p_\theta(x)

所以它不能像显式生成模型那样直接做似然评估。

评估 GAN 时,通常要同时看两件事:

  • 样本质量:单张图像看起来是否真实
  • 样本多样性:是否覆盖了真实数据中的不同模式

2. Inception Score(IS)#

设有一个预训练分类器 f(yx)f(y|x)。如果生成样本质量高,那么对于单张图像,分类器输出应当比较确定,也就是熵较低。

另一方面,如果生成样本足够多样,那么整体类别边缘分布

pf(y)=ExpG[f(yx)]p_f(y)=\mathbb E_{x\sim p_G}[f(y|x)]

又不应坍塌到少数类。

于是定义:

IS=exp(ExpG[KL(f(yx)pf(y))])IS=\exp\left( \mathbb E_{x\sim p_G} \big[ KL(f(y|x)\|p_f(y)) \big] \right)

它希望同时满足:

  • 单个样本“可识别”,即 f(yx)f(y|x) 尽量尖锐
  • 整体样本“够丰富”,即 pf(y)p_f(y) 不要太集中

因此 IS 越高越好

但 IS 的问题也很明显:

  • 它没有直接比较 pGp_Gpdatap_{\text{data}}
  • 它更像在衡量“生成样本是否容易被分类器识别”
  • 如果模型对每个类别只记住一张图,IS 仍可能很高

所以 IS 更适合做粗略参考,而不是最终标准。


3. Fréchet Inception Distance(FID)#

FID 的思路是:先用预训练网络提取特征,再比较真实样本和生成样本在特征空间中的统计量。

设真实特征分布近似为:

N(μr,Σr)\mathcal N(\mu_r,\Sigma_r)

生成特征分布近似为:

N(μg,Σg)\mathcal N(\mu_g,\Sigma_g)

则 FID 定义为:

FID=μrμg2+Tr(Σr+Σg2(ΣrΣg)1/2)FID= \|\mu_r-\mu_g\|^2 + \operatorname{Tr} \left( \Sigma_r+\Sigma_g-2(\Sigma_r\Sigma_g)^{1/2} \right)

因此:

  • 均值差异越大,FID 越大
  • 协方差差异越大,FID 越大
  • FID 越低越好

相比 IS,FID 更常用,因为它同时反映:

  • 图像质量
  • 模式覆盖情况
  • 生成分布与真实分布的接近程度

六、GAN 为什么难训练#

1. JSD 在支撑集不重叠时会出问题#

GAN 的理论目标对应 JSD,但高维数据里,训练初期的 pGp_Gpdatap_{\text{data}} 往往几乎没有重叠。

这时:

  • 判别器很容易把真假完全分开
  • D(x)1,  D(G(z))0D(x)\approx 1,\; D(G(z))\approx 0
  • 生成器拿到的有效梯度很差

这就是 GAN 训练不稳定的根源之一。


2. 模式坍塌#

GAN 还有一个经典问题:模式坍塌(mode collapse)

也就是生成器只学会少数几个“最容易骗过判别器”的样本模式,比如:

  • 不同噪声生成几乎一样的图
  • 在几个样本之间来回跳动

从优化角度看,生成器并不是直接在最大化“覆盖所有模式”,而是在当前判别器下寻找最容易得高分的区域,因此很容易“投机取巧”。


3. 这是一个博弈,不是普通最优化#

GAN 不是单一目标最小化,而是:

minGmaxD\min_G \max_D

这意味着训练动力学像博弈而不是普通梯度下降。

即使存在均衡点,简单交替梯度下降也可能:

  • 持续振荡
  • 不收敛
  • 对学习率、归一化、网络容量非常敏感

所以 GAN 往往需要很多训练技巧。


七、让 GAN 工作起来:工程优化技巧#

1. DCGAN 的基本经验#

很多早期最有效、现在依然值得记的工程经验都来自 DCGAN。最核心的几条是:

  • 尽量使用全卷积结构,少用笨重的全连接层
  • 生成器中常用 ReLU,输出层常用 tanh
  • 判别器中常用 LeakyReLU
  • 使用较小学习率的 Adam,经典设置常见为 lr = 2e-4beta_1 = 0.5
  • 让生成器和判别器的容量大致匹配,不要一边明显过强

这些经验本质上都在做同一件事:让对抗训练保持稳定


2. 归一化层的使用#

GAN 对归一化特别敏感。一个经典经验是:

  • 可以在生成器中使用 BatchNorm
  • 判别器中要谨慎使用 BatchNorm
  • 通常不要在 生成器输出层 使用 BatchNorm
  • 通常不要在 判别器输入层 使用 BatchNorm

原因是判别器本来就在比较真样本和假样本,如果过度依赖 batch 统计量,真假样本之间会互相干扰,训练会更不稳定。

在 WGAN-GP 等方法里,这个问题更明显,因此常见替代方案是:

  • LayerNorm
  • InstanceNorm
  • 谱归一化(Spectral Normalization)

还有一个较老但有代表性的技巧是 Virtual BatchNorm:为归一化固定一批参考样本,减少同一 mini-batch 内样本彼此影响,缓解生成器输出强依赖同 batch 其他样本的问题。


3. 生成器侧的稳定化技巧#

Feature Matching

直接优化 D(G(z))D(G(z)) 有时太激进,生成器容易去钻判别器当前的漏洞。于是可以改成匹配判别器中间层特征的统计量:

LFM(θ)=Expdata[fϕ(x)]Ezp(z)[fϕ(Gθ(z))]2\mathcal L_{\text{FM}}(\theta)= \left\| \mathbb E_{x\sim p_{\text{data}}}[f_\phi(x)] - \mathbb E_{z\sim p(z)}[f_\phi(G_\theta(z))] \right\|^2

其中 fϕ(x)f_\phi(x) 是判别器某一层的特征。

它的直觉是:不要只追最终真假分数,而是让生成样本在更稳定的语义特征上靠近真实样本。

Historical Averaging

还可以给参数加入历史平均正则项:

θ1Tt=1Tθt2\left\| \theta-\frac{1}{T}\sum_{t=1}^{T}\theta_t \right\|^2

它的作用是减小参数震荡,让博弈过程更平滑。


4. 防止模式坍塌的技巧#

Minibatch Discrimination

如果生成器发生模式坍塌,那么同一个 batch 里的假样本通常会异常相似。

于是可以让判别器不只看单个样本,还显式利用 batch 内部样本之间的相似性特征。这样判别器更容易识别“这一整批图都长得太像”的情况,从而反过来逼生成器增加多样性。

One-sided Label Smoothing

把真样本标签从 1 改成 0.9,而假样本仍然保持 0:

  • 真样本:1 -> 0.9
  • 假样本:0 -> 0

这样可以减弱判别器过度自信的问题,避免它太快进入饱和区。

之所以常做单侧而不是双侧,是因为如果连假标签也被抬高,就会削弱生成器把明显错误样本往回拉的动力。


5. 训练流程上的实用经验#

除了损失函数本身,GAN 的很多成败其实取决于训练流程。

常见经验包括:

  • 平衡 GGDD 的更新速度,避免一边压制另一边
  • 必要时多更新几步判别器,再更新一次生成器
  • 使用更小、更稳的学习率,而不是一味追求收敛快
  • 定期看生成样本和 FID,而不只盯着 loss

GAN 训练最重要的不是把某个 loss 压到最低,而是维持生成器和判别器之间“刚好能互相促进”的平衡。


6. 大规模 GAN 中常见的现代技巧#

在 BigGAN 一类更大规模的模型里,常见还会加入一些更“工程化”的技巧:

  • Hinge Loss:比原始 sigmoid 交叉熵更稳定,常写成
LD=Expdata[max(0,1D(x))]+Ezp(z)[max(0,1+D(G(z)))]\mathcal L_D= \mathbb E_{x\sim p_{\text{data}}}[\max(0,1-D(x))] + \mathbb E_{z\sim p(z)}[\max(0,1+D(G(z)))] LG=Ezp(z)[D(G(z))]\mathcal L_G= -\mathbb E_{z\sim p(z)}[D(G(z))]
  • Spectral Normalization:约束判别器每层的谱范数,控制 Lipschitz 常数,往往能明显稳定训练
  • 更大的 batch size:让梯度估计和 batch 统计更稳定
  • Orthogonal Initialization:减少训练初期梯度传播问题
  • Truncation Trick:采样时从截断高斯中取 zz,通常能提高样本质量,但会牺牲一些多样性

这些技巧说明了一点:GAN 的上限不只来自理论目标,也高度依赖实现细节。


八、从 GAN 到 WGAN:为什么要换距离#

1. Wasserstein 距离的动机#

JSD 在分布几乎不重叠(高维空间中真实分布和生成分布的流形不相交)时不够平滑,自然想法就是换掉它。

Wasserstein 距离定义为:

W(P,Q)=infγΠ(P,Q)E(x,y)γ[xy]W(P,Q)=\inf_{\gamma\in\Pi(P,Q)} \mathbb E_{(x,y)\sim\gamma}\big[\|x-y\|\big]

也叫做 Earth Mover’s Distance,是一个非常形象的描述:把分布 PP 看成一堆土,分布 QQ 看成一堆坑,W(P,Q)W(P,Q) 就是把 PP 的土搬到 QQ 的坑里所需的最小运输代价。

相比 JSD,它在分布错位时通常更平滑,梯度信息也更有用。


2. WGAN 的核心形式#

利用 Kantorovich-Rubinstein 对偶,Wasserstein 距离可写为:

W(pdata,pG)=supfL1Expdata[f(x)]ExpG[f(x)]W(p_{\text{data}},p_G) = \sup_{\|f\|_L\le 1} \mathbb E_{x\sim p_{\text{data}}}[f(x)] - \mathbb E_{x\sim p_G}[f(x)]

这里的 ff 不再叫 discriminator,而常叫 critic

于是 WGAN 目标变成:

maxϕExpdata[fϕ(x)]Ezp(z)[fϕ(Gθ(z))]\max_\phi \mathbb E_{x\sim p_{\text{data}}}[f_\phi(x)] - \mathbb E_{z\sim p(z)}[f_\phi(G_\theta(z))]

生成器对应地最小化:

minθEzp(z)[fϕ(Gθ(z))]\min_\theta \mathbb E_{z\sim p(z)}[f_\phi(G_\theta(z))]

关键约束是:

fϕL1\|f_\phi\|_L \le 1

也就是 critic 必须是 1-Lipschitz。为了满足这个约束,WGAN 最初使用了 weight clipping,即把 critic 的权重 ϕi\phi_i 限制在一个小范围内。显然这种方法非常粗糙。


3. 原始 WGAN 怎么训练#

原始 WGAN 为了让 critic 满足 1-Lipschitz 约束,采用了几个非常具体的训练技巧:

  • Weight Clipping:每次更新 critic 后,把参数裁剪到一个小区间内
ϕiclip(ϕi,c,c)\phi_i \leftarrow \operatorname{clip}(\phi_i,-c,c)

常见取值例如 c=0.01c=0.01

  • 不用动量:原论文更推荐 RMSProp,而不是带强动量的 Adam
  • 多训几步 critic:通常每更新一次生成器,先更新 ncriticn_{\text{critic}} 次 critic,常见设定是
ncritic=5n_{\text{critic}}=5

这些技巧的目标都很一致:让 critic 更接近 Wasserstein 对偶中的最优函数,从而给生成器提供更可靠的梯度。

从结果上看,WGAN 相比原始 GAN 的改进点是:

  • 把目标从 JSD 换成了更平滑的 Wasserstein 距离
  • 让生成器在分布尚未重叠时也更可能拿到有效梯度
  • 在实践中显著改善了训练稳定性

但原始 WGAN 仍然不够理想,因为 weight clipping 会带来两个问题:

  • 如果裁剪范围太小,critic 表达能力不足,容易梯度消失
  • 如果裁剪范围太大,又很难真正约束 Lipschitz 常数

所以原始 WGAN 虽然比普通 GAN 更稳定,但还远远不算完美。


4. WGAN-GP:更实用的 WGAN#

WGAN-GP 的核心想法是:与其粗暴裁剪参数,不如直接对函数梯度加惩罚。

理论上,最优 critic 在真实分布和生成分布之间的关键区域应满足:

xf(x)21\|\nabla_x f(x)\|_2 \approx 1

于是可以在真实样本 xx 和生成样本 x~\tilde x 之间做随机插值:

x^=(1ϵ)x+ϵx~,ϵUnif(0,1)\hat x=(1-\epsilon)x+\epsilon \tilde x, \quad \epsilon\sim \operatorname{Unif}(0,1)

然后加入梯度惩罚项:

LGP(ϕ)=Ex^[(x^fϕ(x^)21)2]\mathcal L_{\text{GP}}(\phi)= \mathbb E_{\hat x} \left[ \left(\|\nabla_{\hat x}f_\phi(\hat x)\|_2-1\right)^2 \right]

于是 critic 的目标变成:

maxϕExpdata[fϕ(x)]Ex~pG[fϕ(x~)]λLGP(ϕ)\max_\phi \mathbb E_{x\sim p_{\text{data}}}[f_\phi(x)] - \mathbb E_{\tilde x\sim p_G}[f_\phi(\tilde x)] - \lambda \mathcal L_{\text{GP}}(\phi)

相比原始 WGAN,WGAN-GP 的优点是:

  • Lipschitz 约束更自然
  • 训练通常更稳定
  • 更容易和更深的网络、ResNet、Adam 等搭配使用

一个实用注意点是:由于梯度惩罚本身依赖输入梯度,critic 中通常不要使用 BatchNorm,更常见的是 LayerNorm 或 InstanceNorm。


5. WGAN 之后的几个代表性扩展#

  • BigGAN:把 GAN 推到大规模类别条件生成,代表性技巧包括大 batch、谱归一化、hinge loss、truncation trick
  • GigaGAN:继续做更大规模和更高分辨率的生成,强调大模型架构设计与多尺度训练
  • R3GAN:使用更现代的网络和更简洁的收敛约束,例如成对损失(pairwise loss)和零中心梯度惩罚