生成对抗网络(GAN):生成器与判别器

一句话概述

生成对抗网络(Generative Adversarial Network, GAN)是Ian Goodfellow于2014年提出的革命性生成模型。GAN由两个神经网络组成:生成器(Generator, G)和判别器(Discriminator, D),它们进行一场"对抗游戏"。生成器的目标是从随机噪声生成"以假乱真"的数据,骗过判别器;判别器的目标是正确区分真实数据和生成数据。这个博弈过程可以用一个最小最大化目标函数描述:min_G max_D E[log D(x)] + E[log(1-D(G(z)))]。理想情况下,两者达到纳什均衡——判别器无法区分真假(输出恒为0.5),生成器完美再现了真实数据分布。GAN的对抗训练思想开创了生成模型的新范式,产生的图像质量远超当时的VAE等方法。

💡 核心要点:①GAN由生成器和判别器组成,进行对抗博弈 ②生成器从随机噪声生成假数据,判别器区分真伪 ③训练是交替进行的:先训练判别器识别真伪,再训练生成器骗过判别器 ④理想均衡时判别器输出0.5,生成器完美拟合数据分布 ⑤GAN生成的图像清晰锐利,但训练不稳定是主要挑战

教学与演示

一、GAN的对抗博弈框架

是什么(定义):GAN的框架由两个玩家组成。生成器G:输入随机噪声z~p_z(通常为标准高斯或均匀分布),输出合成数据G(z)。判别器D:输入数据x,输出D(x)∈[0,1],表示x为真实数据的概率。目标是max_D E[log D(x)] + E[log(1-D(G(z)))]。判别器希望D(x)→1(真实数据),D(G(z))→0(生成数据);生成器希望D(G(z))→1(骗过判别器)。

大白话 GAN就像"造假者vs鉴定师"的对决。造假者(生成器)刚开始技术很差,做的假货一眼就能被鉴定师(判别器)识破。但每被识破一次,造假者就改进技术(参数更新)——"这个细节做得更像真的"。鉴定师也在不断进步——"这种纹理是假的典型特征"。经过多轮较量,造假者技术登峰造极,鉴定师再也分不出真假——这时造假者就学到了"真品的本质"。

为什么(原理):GAN的博弈理论基于二人零和博弈(two-player zero-sum game)。当判别器D达到最优时(D*(x)=p_data(x)/(p_data(x)+p_g(x))),生成器的优化目标等价于最小化JS散度JSD(p_data||p_g)——真实数据分布和生成数据分布之间的距离。因此,GAN在本质上是通过对抗训练来隐式地匹配数据分布,不需要显式定义似然函数。

import numpy as np

# GAN的生成器和判别器(简化版)
# 演示对抗博弈的核心机制

class SimpleGAN:
    def __init__(self, z_dim=5, data_dim=2):
        np.random.seed(42)
        # 生成器:z → 生成数据
        self.G_W1 = np.random.randn(z_dim, 8) * 0.3
        self.G_b1 = np.zeros(8)
        self.G_W2 = np.random.randn(8, data_dim) * 0.3
        self.G_b2 = np.zeros(data_dim)

        # 判别器:数据 → 真伪概率
        self.D_W1 = np.random.randn(data_dim, 8) * 0.3
        self.D_b1 = np.zeros(8)
        self.D_W2 = np.random.randn(8, 1) * 0.3
        self.D_b2 = np.zeros(1)

    def generator(self, z):
        """生成器:噪声 → 假数据"""
        h = np.maximum(0, z @ self.G_W1 + self.G_b1)
        x_fake = np.tanh(h @ self.G_W2 + self.G_b2)
        return x_fake

    def discriminator(self, x):
        """判别器:数据 → 真伪概率(sigmoid)"""
        h = np.maximum(0, x @ self.D_W1 + self.D_b1)
        logit = h @ self.D_W2 + self.D_b2
        prob = 1.0 / (1.0 + np.exp(-logit))  # sigmoid
        return prob

    def generate_samples(self, n):
        """生成n个假样本"""
        z = np.random.randn(n, self.G_W1.shape[0])
        return self.generator(z)


# 演示对抗博弈
gan = SimpleGAN(z_dim=5, data_dim=2)

print("=== GAN的对抗博弈框架 ===\n")

# 真实数据(模拟)
np.random.seed(1)
X_real = np.random.randn(100, 2) * 0.5

# 生成假数据
z_sample = np.random.randn(5, 5)
X_fake = gan.generator(z_sample)

print("生成器:噪声z → 假数据")
print(f"  噪声z形状: {z_sample.shape}")
print(f"  生成数据形状: {X_fake.shape}")

# 判别器评估
d_real = gan.discriminator(X_real[:5])
d_fake = gan.discriminator(X_fake)

print(f"\n判别器对真数据的评分: {np.round(d_real.flatten(), 3)}")
print(f"判别器对假数据的评分: {np.round(d_fake.flatten(), 3)}")

print("\n对抗博弈的核心:")
print("- 生成器目标:使判别器对假数据输出>0.5(骗过)")
print("- 判别器目标:对真数据输出→1,对假数据→0")
print("- 两者交替训练,互相促进")
print("- 理想均衡:判别器输出恒为0.5,无法区分")
GAN的极小极大目标函数\(\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\)
大白话 GAN的公式就是"两个人在拉锯"。判别器(D)想最大化得分——给真货打高分(log D(x)接近0,因为D(x)→1),给假货打低分(log(1-D(G(z)))接近0,因为D(G(z))→0)。生成器(G)想最小化得分——让假货得高分。两个人此消彼长,最终达到平衡。

什么用(应用):GAN在图像生成领域取得了惊人的效果。StyleGAN生成的高清人脸可以达到"以假乱真";CycleGAN实现图像风格迁移(照片→梵高风格);Pix2Pix实现图像到图像的翻译(轮廓→实物图)。GAN还用于超分辨率(SRGAN)、图像修复、数据增强、文本到图像生成等。

哪些坑(缺点):GAN训练极不稳定——模式坍缩(生成器只生成几种样本)、梯度消失(判别器太强生成器无法学习)、难以收敛(非凸博弈没有保证收敛到纳什均衡)。这些训练困难是GAN的主要挑战,也是后续WGAN等改进的动机。

二、交替训练策略

是什么(定义):GAN的训练是交替进行的,而非同时优化。标准流程:①固定生成器G,训练判别器D k步(通常k=1或k=5),使D更好地区分真伪;②固定判别器D,训练生成器G 1步,使G生成更逼真的假数据。重复此过程直到收敛。每次训练D时使用真实数据和生成数据的混合批次,标签分别为1和0;训练G时使用噪声生成假数据,标签设为1(假装是真数据)。

大白话 GAN的训练就像"轮流学习"。先让鉴定师(判别器)学习识别假货——给他看一批真货和一批假货,告诉他哪个是哪个。然后让造假者(生成器)学习改进——做出新的假货,告诉鉴定师"这些都是真货",让鉴定师打分,根据鉴定师的意见改进工艺。两人轮流学习,鉴定师越来越精明,造假者越来越高明。

为什么(原理):交替训练的必要性在于GAN的非凸博弈特性。如果只训练生成器而不更新判别器,生成器会"钻空子"——找到一个能让当前判别器误判的模式就停止学习。交替训练确保判别器始终"跟得上"生成器的进步,推动生成器持续改善。在实践中,训练判别器k步(k通常=1或根据任务调整)可以使判别器保持在最优状态附近。

import numpy as np

# GAN交替训练策略演示
# 展示判别器和生成器的交替更新

class GANTraining:
    def __init__(self, z_dim=5, data_dim=2):
        np.random.seed(42)
        self.z_dim = z_dim
        self.data_dim = data_dim
        # 生成器参数
        self.G_W = np.random.randn(z_dim, data_dim) * 0.5
        self.G_b = np.zeros(data_dim)
        # 判别器参数
        self.D_W = np.random.randn(data_dim, 1) * 0.5
        self.D_b = np.zeros(1)

    def G(self, z):
        """生成器"""
        return z @ self.G_W + self.G_b

    def D(self, x):
        """判别器:sigmoid输出"""
        logits = x @ self.D_W + self.D_b
        return 1.0 / (1.0 + np.exp(-logits))

    def train_discriminator(self, X_real, z, lr=0.01):
        """训练判别器:提升辨别真伪的能力"""
        X_fake = self.G(z)
        # 简化梯度更新
        d_real = self.D(X_real)
        d_fake = self.D(X_fake)
        loss_D = -np.mean(np.log(d_real + 1e-8) + np.log(1 - d_fake + 1e-8))
        return loss_D

    def train_generator(self, z, lr=0.01):
        """训练生成器:提升欺骗判别器的能力"""
        X_fake = self.G(z)
        d_fake = self.D(X_fake)
        loss_G = -np.mean(np.log(d_fake + 1e-8))
        return loss_G


print("=== GAN交替训练策略 ===\n")
print("标准流程:")
print("  步骤1:固定G,训练D k步(提升辨别能力)")
print("  步骤2:固定D,训练G 1步(提升生成质量)")
print("  重复直到D(G(z)) ≈ 0.5(无法区分)\n")

gan = GANTraining()
z = np.random.randn(30, 5)
X_real = np.random.randn(30, 2) * 0.5

# 模拟训练过程
print("训练模拟:")
for step in range(5):
    # 训练判别器
    loss_D = gan.train_discriminator(X_real, z)
    # 训练生成器
    loss_G = gan.train_generator(z)
    print(f"  步{step+1}: D损失={loss_D:.3f}, G损失={loss_G:.3f}")

print("\n交替训练的核心要点:")
print("- D学太快 → G学不到(梯度消失)")
print("- G学太快 → D跟不上(模式坍缩)")
print("- 需要平衡两者的训练速度")
GAN交替训练的损失函数\(L_D = -\frac{1}{m}\sum_{i=1}^{m}\left[\log D(x_i) + \log(1 - D(G(z_i)))\right], \quad L_G = -\frac{1}{m}\sum_{i=1}^{m}\log D(G(z_i))\)
大白话 判别器的损失就是"真货该打高分,假货该打低分";生成器的损失就是"假的也要打高分"。实际中生成器用-log D(G(z))而不是log(1-D(G(z))),因为后者在D(G(z))很小时梯度太小(生成器初始太差,学不动),前者梯度更强,帮助生成器快速起步。

什么用(应用):交替训练是GAN的标准训练方式。在实际中,k值的选择依赖任务:对于简单任务k=1即可;对于复杂任务(如高分辨率图像),可能需要k=3-5。此外,一些技巧如两时间尺度更新规则(TTUR)——给D和G不同的学习率——也用于平衡训练。

哪些坑(缺点):判别器和生成器的训练速度需要精细平衡。如果D太强(判别准确率接近100%),G的梯度消失(log(1-D(G(z)))梯度过小);如果G太强,D无法学到有意义的决策边界。模式坍缩是另一个常见问题——G学会生成少数几种"安全"样本(D无法区分的),而不是覆盖整个数据分布。

三、GAN的训练挑战与评估

是什么(定义):GAN的训练面临三个核心挑战:①模式坍缩(Mode Collapse)——生成器只生成数据分布中的少数模式,缺乏多样性;②梯度消失——判别器太强导致生成器梯度为零;③训练不收敛——GAN是非凸博弈,理论上不保证收敛到纳什均衡。评估GAN生成质量的主要指标包括Inception Score(IS)和Fréchet Inception Distance(FID)。

大白话 GAN训练就像一个"猫鼠游戏"中猫永远比老鼠快——你需要小心控制猫的速度(判别器学习率),不能太快(老鼠学不到东西)也不能太慢(猫自己跟不上)。模式坍缩就是"老鼠只学会了一种逃跑路线"——生成器发现"画圆"就能骗过判别器,于是只会画圆,从来不画方形。这就是为什么GAN的训练被称为"艺术而非科学"。

为什么(原理):模式坍缩的数学原因是GAN优化的是生成分布和真实分布之间的"覆盖"而非"匹配"。JS散度在不重叠区域恒为log2,梯度为零——如果生成分布只覆盖真实分布的一个子集,JS散度可能已经"最小化"(因为在该子集上两者重叠很好),导致优化停滞。WGAN通过使用Wasserstein距离替代JS散度,从根本上缓解了这一问题。

import numpy as np

# 模式坍缩(Mode Collapse)演示
# 展示生成器只生成少数模式的问题

def demo_mode_collapse():
    print("=== 模式坍缩(Mode Collapse)演示 ===\n")

    # 真实数据:8个高斯模式(如8种手写数字)
    print("真实数据分布:8个模式(如数字0-7)")
    print("  0  1  2  3  4  5  6  7")

    # 正常生成器:覆盖所有模式
    print("\n【正常生成器(覆盖所有模式)】")
    all_modes = [0, 1, 2, 3, 4, 5, 6, 7]
    print(f"  生成的类别: {sorted(all_modes)}")
    print("  → 多样性好,每种模式都有")

    # 模式坍缩的生成器:只生成少数模式
    print("\n【模式坍缩的生成器】")
    collapsed_modes = [1, 1, 1, 3, 3, 1, 3, 1, 1, 1]
    unique = np.unique(collapsed_modes)
    print(f"  生成的类别: {sorted(unique)}(只有2种!)")
    print("  → 虽然每个样本质量高,但多样性差")

    print("\n模式坍缩的原因:")
    print("  1. 生成器发现'模式1和3'容易骗过判别器")
    print("  2. 只生成这两种,其他模式被遗忘")
    print("  3. 判别器针对性地学会辨别1和3")

    print("\n解决方案:")
    print("  - WGAN/WGAN-GP:使用Wasserstein距离")
    print("  - Minibatch Discrimination:让判别器看到批次多样性")
    print("  - Unrolled GAN:考虑判别器的未来反应")


demo_mode_collapse()
FID评估指标\(\text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}\left(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{\frac{1}{2}}\right)\)
大白话 FID就是"两个数据集的相似度"。把真实图片和生成图片都输入Inception网络(一个预训练的图像分类器),取中间层的特征,比较两组特征的"均值差异"和"分布差异"。FID越小说明生成的越像真的。Inception Score(IS)只看"每张图是否清晰"(类别置信度高)和"全局多样性",FID还能检测"是否所有类别都生成了"。

什么用(应用):FID和IS是评估GAN生成质量的标准指标。对于人脸生成(StyleGAN),FID越低表示生成的人脸越真实多样。在GAN训练中,监控FID可以判断模型是否在改进——FID持续下降表示训练正常,FID突然上升可能是模式坍缩的信号。

哪些坑(缺点):FID依赖Inception网络的预训练特征,对于非自然图像(如医学图像),Inception特征可能不适用。IS对批次数敏感。两个指标都不能完美衡量"图像质量"——低FID不保证图像在人类眼中好看。

概念关系图谱

概念核心含义与AI的关系关联概念
生成器(Generator)从噪声生成假数据的网络GAN的"造假者",学习数据分布解码器、转置卷积
判别器(Discriminator)区分真伪的二分类网络GAN的"鉴定师",提供训练信号二分类器、特征提取
对抗训练两个网络交替博弈的训练方式GAN的核心训练范式博弈论、纳什均衡
模式坍缩生成器只生成少数模式GAN训练的主要失败模式多样性、WGAN
FIDFrenchet Inception DistanceGAN生成质量的评估指标Inception Score
JS散度Jensen-Shannon距离原始GAN隐式优化的分布距离KL散度、Wasserstein距离

重点答疑

Q1: GAN中为什么使用sigmoid作为判别器的输出?

Sigmoid将判别器的输出映射到(0,1),可以解释为"输入是真实数据的概率"。这对于GAN的对抗损失(log D(x) + log(1-D(G(z))))是自然的——log D(x)在D(x)→1时最大化,log(1-D(G(z)))在D(G(z))→0时最大化。WGAN去掉了sigmoid,直接优化Wasserstein距离,判别器变为"评分器"而非"概率输出器"。

Q2: 模式坍缩和梯度消失如何区分?

模式坍缩:生成器生成的样本看起来质量不错,但缺乏多样性(只生成几种样板)。梯度消失:生成器生成的样本质量很差(几乎随机噪声),训练过程中生成器的损失不再下降。前者是生成器"走捷径",后者是判别器太强导致生成器学不动。

Q3: 为什么训练GAN时生成器的损失要用非饱和版本?

原始GAN中生成器最小化log(1-D(G(z)))。当生成器很差(D(G(z))≈0)时,log(1-D(G(z)))≈0,梯度≈0——生成器学不到东西。非饱和版本改为最小化-log D(G(z)),当D(G(z))≈0时梯度过大——生成器有强力推动。实践中非饱和版本帮助生成器在早期快速提升。

章节单词汇总

英文音标术语/释义
Generative Adversarial Network/ˈdʒenərətɪv ædvərˈseriəl ˈnetwɜːrk/生成对抗网络,由生成器和判别器组成
Generator (G)/ˈdʒenəreɪtər/生成器,从噪声生成假数据
Discriminator (D)/dɪˈskrɪmɪneɪtər/判别器,区分真伪数据
Adversarial Training/ædvərˈseriəl ˈtreɪnɪŋ/对抗训练,两个网络交替博弈
Mode Collapse/moʊd kəˈlæps/模式坍缩,生成器只生成少数模式
Nash Equilibrium/næʃ ˌiːkwɪˈlɪbriəm/纳什均衡,博弈双方都无法单方面改进
Wasserstein Distance/ˈvɑːsərʃtaɪn ˈdɪstəns/Wasserstein距离,衡量分布差异的指标
Inception Score (IS)/ɪnˈsepʃən skɔːr/Inception Score,GAN生成质量评估指标

面试练习

Q1 [单选] GAN的两个核心组件是什么?

  • A. 编码器和解码器
  • B. 生成器和判别器
  • C. 卷积层和全连接层
  • D. 注意力机制和前馈网络
解答:GAN由生成器(Generator)和判别器(Discriminator)组成,两者进行对抗博弈。

Q2 [单选] GAN的目标函数中,判别器的目标是什么?

  • A. 让D(x)接近0
  • B. 对真数据输出1,对假数据输出0
  • C. 让D(G(z))接近1
  • D. 让所有输出恒为0.5
解答:判别器最大化E[log D(x)] + E[log(1-D(G(z)))],即真数据D(x)→1,假数据D(G(z))→0。

Q3 [单选] GAN的模式坍缩(Mode Collapse)是指什么?

  • A. 判别器无法区分真伪
  • B. 生成器只生成数据分布中的少数模式
  • C. 训练损失降为零
  • D. 网络参数全部变为零
解答:模式坍缩指生成器只学习到数据分布的少数模式(如只生成数字"1"),缺乏多样性。

Q4 [多选] 关于GAN训练,以下哪些是正确的?

  • A. 训练是交替进行的(先D后G)
  • B. 判别器训练过多会导致梯度消失
  • C. GAN的非凸博弈不保证收敛
  • D. 生成器和判别器可以同时更新
  • E. 模式坍缩是GAN训练的常见问题
解答:GAN交替训练,D过强导致梯度消失,不保证收敛,模式坍缩常见。同时更新在理论上有问题(G得不到有意义的梯度信号)。

Q5 [单选] FID(Fréchet Inception Distance)衡量的是什么?

  • A. 生成速度
  • B. 真实数据和生成数据分布的差异
  • C. 判别器的准确率
  • D. 网络参数量
解答:FID衡量真实数据和生成数据在Inception特征空间中的分布差异(均值和协方差),值越小表示生成质量越好。