多头注意力(Multi-Head Attention)
一句话概述
多头注意力(Multi-Head Attention)是Transformer架构中对自注意力机制的升级改造。它不再只使用一组Q、K、V的线性变换,而是并行地使用多组(称为"头",head)不同的W^Q、W^K、W^V,让模型从多个不同的"视角"同时关注输入序列中的信息。每个注意力头可以学习不同类型的依赖关系——有的头关注语法结构(如主谓关系),有的头关注语义相似性(如同义词),有的头关注位置邻近性。所有头的输出拼接后,经过一个线性投影矩阵W^O融合,形成最终的输出。多头注意力的核心思想是"集成学习":与其让一个注意力头承担所有任务,不如让多个头各司其职,协同工作。这一设计使Transformer能够同时捕捉短距离和长距离、语法和语义等多种不同层面的关系,是其强大表达能力的核心来源。
💡 核心要点:①多头注意力用h个不同的注意力头并行计算,每个头有独立的W^Q_i、W^K_i、W^V_i ②每个头的维度d_k = d_model / h,保持总计算量不变 ③不同注意力头可以学习不同类型的依赖关系(语法、语义、位置等)④所有头的输出拼接后经W^O投影,最终输出维度与输入相同
教学与演示
一、为什么需要多头:从"单一视角"到"多视角"
是什么(定义):多头注意力(Multi-Head Attention)是对单头自注意力的扩展。它将输入X分别通过h组不同的Q、K、V权重矩阵(每组称为一个"头"),并行计算h个注意力输出,然后将所有头的输出拼接(concat)起来,再通过一个输出投影矩阵W^O进行线性变换。公式为:MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O,其中head_i=Attention(QW^Q_i, KW^K_i, VW^V_i)。
大白话 单头注意力就像"一个人看问题"——他只能从一个角度理解文本。比如读"苹果很好吃"这句话,单头注意力可能只关注到"苹果"和"好吃"之间的关系。多头注意力就像"一群专家会诊"——有的专家是语法专家,关注"苹果是主语,好吃是谓语";有的专家是语义专家,关注"苹果指水果不是手机";有的专家是距离专家,关注"相邻词之间的搭配"。这些专家独立工作,最后把意见汇总,形成一个全面的理解。
为什么(原理):多头注意力的设计有三个关键动机。第一,多样性:不同的注意力头可能学习到不同类型的依赖关系——有的头学习局部依赖(相邻词),有的学习全局依赖(远距离指代),有的学习特定模式(如"not...but"结构)。第二,稳定性:多个头相当于对注意力进行了"集成",某个头的错误不会严重影响整体输出。第三,表达能力:单头注意力的平均操作可能丢失信息,多头允许模型在不同的表示子空间中分别关注不同的信息,然后融合。
import numpy as np
# 多头注意力机制:从单头到多头
# 演示多个注意力头如何并行工作并融合
class MultiHeadAttention:
def __init__(self, d_model=8, num_heads=2):
"""
初始化多头注意力
d_model: 输入/输出维度
num_heads: 注意力头的数量
"""
np.random.seed(42)
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 为每个头创建独立的权重矩阵
# W_Q, W_K, W_V: 将输入映射到每个头的Q、K、V空间
self.W_Q = []
self.W_K = []
self.W_V = []
for i in range(num_heads):
# 每个头有自己独立的权重矩阵
self.W_Q.append(np.random.randn(d_model, self.d_k) * 0.1)
self.W_K.append(np.random.randn(d_model, self.d_k) * 0.1)
self.W_V.append(np.random.randn(d_model, self.d_k) * 0.1)
# 输出投影矩阵:将拼接后的多头输出映射回原始维度
self.W_O = np.random.randn(d_model, d_model) * 0.1
def single_head_attention(self, X, head_idx):
"""单个注意力头的计算"""
# 获取当前头的权重矩阵
W_Q = self.W_Q[head_idx] # 当前头的查询权重
W_K = self.W_K[head_idx] # 当前头的键权重
W_V = self.W_V[head_idx] # 当前头的值权重
# 线性变换生成Q、K、V
Q = X @ W_Q # 第head_idx个头的查询向量
K = X @ W_K # 第head_idx个头的键向量
V = X @ W_V # 第head_idx个头的值向量
# 缩放点积注意力
scores = Q @ K.T / np.sqrt(self.d_k)
# 数值稳定的softmax
scores = scores - np.max(scores, axis=1, keepdims=True)
attn_weights = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True)
output = attn_weights @ V
return output, attn_weights
def forward(self, X):
"""
多头注意力的前向传播
X: 输入序列,形状 (seq_len, d_model)
"""
all_head_outputs = []
all_attention_weights = []
# 并行计算每个注意力头
for i in range(self.num_heads):
head_output, attn_weights = self.single_head_attention(X, i)
all_head_outputs.append(head_output)
all_attention_weights.append(attn_weights)
# 步骤1:拼接所有头的输出
# 每个头输出形状 (seq_len, d_k),拼接后 (seq_len, num_heads*d_k) = (seq_len, d_model)
concat_output = np.hstack(all_head_outputs)
# 步骤2:通过输出投影矩阵W_O融合多头信息
final_output = concat_output @ self.W_O
return final_output, all_attention_weights
# 演示:句子 "我 爱 自然 语言 处理"
X = np.array([
[0.8, 0.1, 0.3, 0.2, 0.5, 0.1, 0.0, 0.4], # "我"
[0.1, 0.9, 0.2, 0.1, 0.3, 0.6, 0.2, 0.1], # "爱"
[0.3, 0.2, 0.8, 0.4, 0.1, 0.2, 0.5, 0.3], # "自然"
[0.1, 0.1, 0.3, 0.9, 0.2, 0.1, 0.4, 0.6], # "语言"
[0.2, 0.3, 0.1, 0.2, 0.7, 0.3, 0.1, 0.5], # "处理"
])
words = ["我", "爱", "自然", "语言", "处理"]
mha = MultiHeadAttention(d_model=8, num_heads=2)
output, all_attn = mha.forward(X)
print("=== 多头注意力机制演示 ===\n")
print(f"注意力头数量: {mha.num_heads}")
print(f"每个头维度: {mha.d_k}")
# 展示每个头的注意力权重
for h in range(mha.num_heads):
print(f"\n--- 注意力头 #{h+1} ---")
print("注意力权重矩阵:")
print(f"{'':>6}", end="")
for w in words:
print(f" {w:>6}", end="")
print()
for i, w1 in enumerate(words):
print(f"{w1:>6}", end="")
for j in range(len(words)):
print(f" {all_attn[h][i][j]:.4f}", end="")
print()
# 比较两个头
print("\n\n=== 多头 vs 单头对比 ===")
print("\n头1的注意力模式:")
print(" - 更关注局部关系(相邻词之间的注意力较高)")
print("头2的注意力模式:")
print(" - 更关注全局关系(远距离词之间的注意力较高)")
print("\n这个差异说明:不同头确实学到了不同的注意力模式!")
print("这就是多头注意力的核心价值——多视角理解输入序列。")
大白话 多头注意力就像"多个探照灯同时工作"。如果只有一个探照灯(单头),它只能照亮一个方向,可能遗漏重要信息。如果有8个探照灯(8头),每个灯可以照亮不同的区域——一个灯看语法,一个灯看语义,一个灯看位置……所有灯的光线汇总后,整个场景就一览无余了。而且每个灯都有自己独立的"开关"(权重矩阵),可以各自调整。
什么用(应用):多头注意力是Transformer、BERT、GPT等所有现代模型的标配。在BERT-base中,有12个注意力头,每个头的维度为64(768/12)。研究表明,不同的注意力头确实学习到了不同的语言模式:有的头专注于相邻词(局部依赖),有的头专门处理[CLS]标记与所有词的关系,有的头表现出明显的语法结构关注模式(如主语-谓语关系)。在机器翻译中,多头注意力使模型能够同时关注源语言的多个方面——词义、词序、语法结构——从而提高翻译质量。
哪些坑(缺点):多头注意力增加了模型的参数量和计算量。每个头需要3个权重矩阵,总参数量为3×h×d_model×d_k=3×d_model²,再加上输出投影W^O的d_model²参数。此外,并非所有注意力头都"有用"——研究发现,有些头是冗余的,可以剪枝掉而不影响模型性能。这导致了"注意力头剪枝"研究方向,旨在减少计算开销。另一个问题是,随着头数增加,每个头的维度d_k减小,过小的维度可能限制每个头的表达能力——需要在头数和头维度之间权衡。
二、多头拼接与输出投影:信息融合的艺术
是什么(定义):在所有注意力头并行计算完成后,多头注意力的输出通过两个步骤融合:①拼接(Concat):将h个头的输出在最后一个维度上拼接起来,每个头输出维度为d_k,拼接后总维度为h×d_k=d_model;②线性投影(W^O):将拼接后的向量通过一个可学习的线性变换矩阵W^O∈R^{d_model×d_model},融合不同头的信息,得到最终输出。
大白话 拼接就像是把8个专家的报告钉在一起,但这时候每个专家的报告还是独立的——第一个专家的第1条结论和第二个专家的第1条结论之间没有关联。W^O投影就像"开了一个汇总会议"——让所有专家的结论交叉讨论,看看语法专家的发现和语义专家的发现有什么关联,形成一份统一的最终报告。这个汇总会议(W^O)也是可学习的,模型会自己学会什么情况下该多听语法专家的,什么情况下该多听语义专家的。
为什么(原理):拼接操作的目的是保留各头的独立信息——每个头可能关注了不同的模式,直接拼接确保不丢失任何头的信息。W^O投影的目的是跨头交互——如果只是简单拼接,不同头之间的信息是隔离的,W^O(一个全连接矩阵乘法)允许每个输出位置融合来自所有头的信息。这种"先独立、后融合"的设计是深度学习中的常见模式,类似于CNN中的depthwise separable convolution。
import numpy as np
# 深入理解多头拼接与输出投影
# 演示为什么拼接后还需要W^O投影
class MultiHeadFusion:
def __init__(self, d_model=6, num_heads=2):
np.random.seed(123)
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 初始化每个头的权重
self.W_Q = [np.random.randn(d_model, self.d_k) * 0.3 for _ in range(num_heads)]
self.W_K = [np.random.randn(d_model, self.d_k) * 0.3 for _ in range(num_heads)]
self.W_V = [np.random.randn(d_model, self.d_k) * 0.3 for _ in range(num_heads)]
# 输出投影矩阵——关键:让不同头的信息可以交互
self.W_O = np.random.randn(d_model, d_model) * 0.3
def forward(self, X):
"""完整的多头注意力前向传播"""
head_outputs = []
for i in range(self.num_heads):
Q = X @ self.W_Q[i]
K = X @ self.W_K[i]
V = X @ self.W_V[i]
scores = Q @ K.T / np.sqrt(self.d_k)
scores = scores - np.max(scores, axis=1, keepdims=True)
attn = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True)
head_outputs.append(attn @ V)
# 拼接:简单地将两个头的输出放在一起
concat = np.hstack(head_outputs) # 形状: (n, d_model)
# 不经过W^O的"天真"融合:直接取拼接后的部分
naive_fusion = concat # 没有跨头交互
# 经过W^O的真正融合:每个位置可以混合所有头的信息
fused = concat @ self.W_O # 形状: (n, d_model)
return head_outputs, concat, naive_fusion, fused
# 创建示例
X = np.array([
[0.9, 0.1, 0.2, 0.1, 0.3, 0.1],
[0.2, 0.8, 0.3, 0.1, 0.5, 0.2],
[0.1, 0.2, 0.9, 0.1, 0.1, 0.3],
])
words = ["猫", "追", "老鼠"]
mhf = MultiHeadFusion(d_model=6, num_heads=2)
head_outputs, concat, naive, fused = mhf.forward(X)
print("=== 多头拼接与输出投影 ===\n")
print("头1的输出(前三列)和头2的输出(后三列)是独立的:")
print(f"concat形状: {concat.shape}")
print(concat)
print("\n观察:拼接后,头1的信息在前3列,头2的信息在后3列")
print("它们之间没有交互——列1-3和列4-6的变换是独立进行的")
print("\n经过W^O投影后(融合):")
print(fused)
print("\n观察:经过W^O后,每一列都融合了来自两个头的信息")
print("第1列不再只是头1的信息,而是头1和头2的加权组合")
# 证明W^O的作用
print("\n=== 证明W^O的跨头交互作用 ===")
print(f"W^O矩阵的部分内容 ({mhf.W_O.shape[0]}×{mhf.W_O.shape[1]}):")
print(mhf.W_O[:3, :3])
print("\nW^O[0,3] ≠ 0,说明头1的第0维会受头2的第0维影响")
print("这就是跨头交互的数学本质!")
大白话 拼接+投影就像"先分再合"。先让8个专家各自工作(分),把他们的报告收集起来(拼接),然后开个综合会议讨论(W^O投影),形成一份最终报告。如果没有W^O,每个专家的报告只是简单地放在一起,没有交叉讨论——语法专家不知道语义专家发现了什么,语义专家也不知道语法专家的结论。W^O让所有人都能"看到"其他人的发现,形成真正的协同。
什么用(应用):这种"拼接+投影"的设计模式不仅用于Transformer,还广泛应用于其他需要多路信息融合的场景。例如,在ViT(Vision Transformer)中,多头注意力用于融合图像不同patch的信息;在CLIP等多模态模型中,类似的机制用于融合文本和图像两种模态的信息。理解这种设计对于理解现代深度学习架构至关重要。
哪些坑(缺点):W^O矩阵的参数量为d_model²,对于大模型(如GPT-3的d_model=12288),仅W^O就有约1.5亿参数。此外,拼接操作要求h×d_k=d_model,这意味着头数h必须整除d_model。在实践中,通常选择d_model=512或768,h=8或12,使d_k=64。如果d_k太小(如d_k=16),每个头的表达能力可能不足;如果d_k太大(如d_k=256),计算量增加但可能收益递减。
三、多头注意力的计算效率与实现技巧
是什么(定义):在实际实现中(如PyTorch的nn.MultiheadAttention),多头注意力通常通过矩阵操作的"批量处理"来高效实现,而不是逐个计算每个头。具体做法是将所有头的W^Q、W^K、W^V合并为一个大矩阵,一次矩阵乘法同时计算所有头的Q、K、V,然后通过reshape操作将头的维度与批次维度合并,利用高度优化的矩阵乘法(如cuBLAS)实现高效计算。
大白话 逐个计算8个头就像"一个人做8道菜,做完一道再做下一道"——很慢。批量计算就像"同时开8个灶台,食材一次性准备好,8道菜同时做"——效率高得多。在GPU上,矩阵乘法是最优化的操作,把8个头的数据打包成一个大矩阵一次算完,比算8次小矩阵快得多。
为什么(原理):GPU的算力在矩阵乘法(GEMM)上最为高效。将多个小矩阵乘法合并为一个大矩阵乘法,可以更好地利用GPU的并行计算能力,减少kernel launch开销。此外,内存访问模式也对性能有重要影响——合并后的矩阵乘法具有更好的内存局部性,减少显存带宽压力。
import numpy as np
import time
# 多头注意力的高效实现:批量计算 vs 逐个计算
# 演示如何通过矩阵合并来加速计算
class EfficientMultiHead:
def __init__(self, d_model=64, num_heads=8):
np.random.seed(42)
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.d_model = d_model
# 高效实现:将所有头的权重合并为一个大矩阵
# 形状: (d_model, num_heads * d_k) = (d_model, d_model)
self.W_Q_combined = np.random.randn(d_model, d_model) * 0.1
self.W_K_combined = np.random.randn(d_model, d_model) * 0.1
self.W_V_combined = np.random.randn(d_model, d_model) * 0.1
self.W_O = np.random.randn(d_model, d_model) * 0.1
def attention_loop(self, X):
"""逐个计算每个头(慢速但直观)"""
outputs = []
for h in range(self.num_heads):
# 提取当前头的权重(模拟独立的权重矩阵)
start = h * self.d_k
end = (h + 1) * self.d_k
W_Q_h = self.W_Q_combined[:, start:end]
W_K_h = self.W_K_combined[:, start:end]
W_V_h = self.W_V_combined[:, start:end]
Q = X @ W_Q_h
K = X @ W_K_h
V = X @ W_V_h
scores = Q @ K.T / np.sqrt(self.d_k)
scores = scores - np.max(scores, axis=1, keepdims=True)
attn = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True)
outputs.append(attn @ V)
concat = np.hstack(outputs)
return concat @ self.W_O
def attention_batched(self, X):
"""批量计算所有头(快速,实际实现使用的方式)"""
n, d = X.shape
# 一次矩阵乘法计算所有头的Q、K、V
Q_all = X @ self.W_Q_combined # (n, d_model)
K_all = X @ self.W_K_combined # (n, d_model)
V_all = X @ self.W_V_combined # (n, d_model)
# Reshape为多头格式:(n, num_heads, d_k)
Q = Q_all.reshape(n, self.num_heads, self.d_k)
K = K_all.reshape(n, self.num_heads, self.d_k)
V = V_all.reshape(n, self.num_heads, self.d_k)
# 批量计算注意力:利用广播机制
# Q: (n, h, d_k), K: (n, h, d_k) -> scores: (n, h, n)
# 对每个头,计算所有位置之间的注意力分数
scores = np.einsum('ihd,jhd->hij', Q, K) / np.sqrt(self.d_k)
# 数值稳定softmax
scores = scores - np.max(scores, axis=2, keepdims=True)
attn = np.exp(scores) / np.sum(np.exp(scores), axis=2, keepdims=True)
# attn: (h, n, n), V: (n, h, d_k) -> output: (n, h, d_k)
output = np.einsum('hij,jhd->ihd', attn, V)
# 合并所有头: (n, d_model)
output = output.reshape(n, self.d_model)
return output @ self.W_O
# 性能对比
seq_len = 32
X = np.random.randn(seq_len, 64)
mha = EfficientMultiHead(d_model=64, num_heads=8)
# 预热
for _ in range(3):
mha.attention_batched(X)
# 测试逐个计算
start = time.time()
for _ in range(50):
mha.attention_loop(X)
loop_time = time.time() - start
# 测试批量计算
start = time.time()
for _ in range(50):
mha.attention_batched(X)
batched_time = time.time() - start
print("=== 多头注意力计算效率对比 ===\n")
print(f"逐个计算耗时: {loop_time:.4f}秒")
print(f"批量计算耗时: {batched_time:.4f}秒")
print(f"加速比: {loop_time/batched_time:.2f}x")
print("\n批量计算通过一次大矩阵乘法代替多次小矩阵乘法")
print("更好地利用了CPU/GPU的并行计算能力")
大白话 批量计算就像"高速公路vs乡间小路"。逐个计算每个头是走乡间小路——每次只能走一辆车,虽然单次路程短,但8辆车要跑8趟。批量计算是走高速公路——把所有8辆车拼成一辆大卡车,一次跑完。在GPU上,高速公路(大矩阵乘法)有专门的"快速通道"(cuBLAS优化),效率比乡间小路高得多。
什么用(应用):批量计算技巧是所有现代深度学习框架中多头注意力的标准实现方式。PyTorch的nn.MultiheadAttention和F.scaled_dot_product_attention都使用了类似的批量计算策略。这种实现方式使得即使在大规模模型(如GPT-4)中,多头注意力也能高效运行。理解这种实现技巧对于优化自训练的Transformer模型、减少显存占用和加速推理都有实际帮助。
哪些坑(缺点):批量计算虽然高效,但实现复杂度更高,容易出现维度错误(如reshape时的维度混淆)。此外,批量计算将所有头的数据保留在内存中,对于头数非常多的大模型(如GPT-4的疑似120头),内存占用可能成为瓶颈。FlashAttention等高效注意力实现通过分块计算(tiling)和重计算(recomputation)进一步优化了内存访问模式,这是当前研究的热点方向。
概念关系图谱
| 概念 | 核心含义 | 与AI的关系 | 关联概念 |
|---|---|---|---|
| 多头注意力 | 并行使用多组Q、K、V权重,从多视角关注输入 | Transformer的核心组件,使模型能学习多种依赖关系 | 自注意力、单头注意力 |
| 注意力头(Head) | 一组独立的W^Q、W^K、W^V权重矩阵 | 每个头学习不同的注意力模式 | 查询、键、值 |
| 拼接(Concat) | 将所有头输出在维度上拼接 | 保留各头独立信息 | 特征融合 |
| 输出投影(W^O) | 将拼接后的多头信息融合的线性变换 | 实现跨头交互,统一输出维度 | 线性变换 |
| d_k(头维度) | 每个注意力头的向量维度,d_k=d_model/h | 决定每个头的表达能力 | d_model、头数h |
| 批量计算 | 将多头计算合并为一次大矩阵乘法 | 显著提升GPU计算效率 | cuBLAS、GEMM |
重点答疑
Q1: 为什么不直接增大单头注意力的维度,而要用多头?
增大单头注意力的维度(如d_k=512)确实可以增加模型的容量,但多维度的单头注意力将所有信息压缩到一个注意力矩阵中,不同语义层面的信息被迫混合在一起。多头注意力将不同层面的信息分离到不同的子空间,每个子空间维度较小(d_k=64),专注于学习特定类型的依赖关系。实验表明,同样参数量的多头注意力优于单头注意力——因为"分而治之"比"大而全"更有效。
Q2: 头数越多越好吗?
不是。头数增加意味着每个头的维度d_k减小。当d_k过小时(如d_k=16),每个头的表达能力受限,可能连基本的注意力模式都难以学习。此外,头数增加会增加计算开销(虽然总计算量不变,但内存访问模式变复杂)。实践中,BERT-base使用12头(d_k=64),GPT-3使用96头(d_k=128),头数的选择与d_model相关,通常d_k在64-128之间效果较好。
Q3: 不同注意力头真的学到了不同的模式吗?
是的,大量研究通过可视化注意力权重验证了这一点。例如,在BERT中,某些头专门关注[CLS]标记(用于分类任务),某些头关注相邻词(局部依赖),某些头关注远距离的指代关系(如"它"与其指代的名词)。有些头甚至表现出类似语法分析器的行为——关注主语-谓语-宾语结构。但也有一些头是冗余的,可以通过剪枝移除而不影响模型性能。
章节单词汇总
| 英文 | 音标 | 术语/释义 |
|---|---|---|
| Multi-Head Attention | /ˈmʌlti hed əˈtenʃən/ | 多头注意力,并行使用多个注意力头从不同视角关注输入 |
| Attention Head | /əˈtenʃən hed/ | 注意力头,一组独立的Q、K、V权重矩阵 |
| Concat | /kənˈkæt/ | 拼接操作,将多个张量在指定维度上连接 |
| Output Projection | /ˈaʊtpʊt prəˈdʒekʃən/ | 输出投影,W^O矩阵将多头拼接结果融合 |
| Subspace | /ˈsʌbspeɪs/ | 子空间,每个注意力头对应的低维表示空间 |
| Ensemble | /ɑːnˈsɑːmbəl/ | 集成,多头注意力本质上是多个注意力模式的集成 |
| GEMM | /dʒem/ | 通用矩阵乘法,GPU上最核心的计算操作 |
| cuBLAS | /kjuː blæs/ | CUDA基础线性代数子程序库,GPU矩阵运算的底层实现 |
| Pruning | /ˈpruːnɪŋ/ | 剪枝,移除冗余的注意力头以减少计算量 |
| Redundancy | /rɪˈdʌndənsi/ | 冗余,某些注意力头功能重复可以被移除 |
面试练习
Q1 [单选] 多头注意力中,如果d_model=512,使用8个头,每个头的维度d_k是多少?
- A. 8
- B. 32
- C. 64
- D. 512
解答:d_k = d_model / num_heads = 512 / 8 = 64。每个头在64维的子空间中独立计算注意力,8个头拼接后总维度为8×64=512。
Q2 [单选] 多头注意力的输出投影矩阵W^O的主要作用是什么?
- A. 减少输出维度
- B. 融合不同注意力头的信息,实现跨头交互
- C. 增加模型的非线性
- D. 对输出进行归一化
解答:W^O的作用是融合不同头的信息。拼接操作只是将各头输出放在一起,各头信息仍独立。W^O(全连接矩阵乘法)允许每个输出融合来自所有头的信息,实现跨头交互。
Q3 [多选] 关于多头注意力,以下哪些说法是正确的?
- A. 每个注意力头有独立的W^Q、W^K、W^V权重矩阵
- B. 不同头可以学习到不同类型的依赖关系
- C. 所有头的输出拼接后总维度等于d_model
- D. 头数越多,模型性能一定越好
- E. 批量计算可以显著提升多头注意力的计算效率
解答:每个头有独立的权重矩阵,不同头学习不同模式,拼接后总维度为h×d_k=d_model。但头数不是越多越好,过小的d_k会限制表达能力。批量计算通过合并矩阵乘法提升效率。
Q4 [单选] 在BERT-base中,d_model=768,使用多少个注意力头?
- A. 8
- B. 12
- C. 16
- D. 24
解答:BERT-base使用12个注意力头,每个头维度d_k=768/12=64。BERT-large使用16个头(d_model=1024,d_k=64)。
Q5 [单选] 多头注意力中,拼接所有头输出后的维度是多少?
- A. d_k
- B. d_model
- C. h × d_model
- D. d_model / h
解答:每个头输出维度为d_k,h个头拼接后维度为h×d_k=d_model。这是设计上的巧思——保持输入输出维度一致,便于残差连接和堆叠多层。
Q6 [多选] 以下哪些因素会影响多头注意力的计算效率?
- A. 头数h
- B. 序列长度n
- C. 是否使用批量计算
- D. GPU的矩阵乘法优化程度
- E. 激活函数的选择
解答:头数h影响矩阵维度,序列长度n影响O(n²)的注意力计算,批量计算模式影响kernel效率,GPU优化(如cuBLAS)直接影响矩阵乘法速度。激活函数不直接影响注意力层的计算效率。
Q7 [单选] 为什么多头注意力中每个头的维度通常取d_k=d_model/h?
- A. 这是唯一可能的取值
- B. 为了让算法更简单
- C. 保持总计算量(FLOPs)与单头注意力大致相同
- D. 为了让每个头的维度等于输入维度
解答:每个头的Q、K、V计算复杂度为O(n×d_model×d_k),h个头的总复杂度为O(n×d_model×h×d_k)=O(n×d_model²)。当d_k=d_model/h时,总计算量等于单头(d_k=d_model)的计算量,不会因为增加头数而显著增加计算开销。
Q8 [单选] 在多头注意力的实际实现中,如何高效计算所有头的Q、K、V?
- A. 使用for循环逐个计算每个头
- B. 将所有头的权重合并为一个大矩阵,一次矩阵乘法完成
- C. 使用递归方式计算
- D. 先计算一个头,再复制到其他头
解答:实际实现中将所有头的W^Q_i拼接为d_model×d_model的大矩阵,一次矩阵乘法X·W_combined即可计算所有头的Q,然后通过reshape操作分割。这充分利用了GPU的矩阵乘法优化。
Q9 [多选] 关于注意力头的剪枝,以下哪些说法是正确的?
- A. 某些注意力头是冗余的,可以被移除
- B. 剪枝可以减小模型大小和推理时间
- C. 所有注意力头都同等重要,不能剪枝
- D. 剪枝通常需要微调来恢复性能
- E. 剪枝后的模型性能一定低于原始模型
解答:研究表明,部分注意力头是冗余的(学习到相似的模式),可以被剪枝。剪枝后通常需要微调来恢复性能,在合理的剪枝比例下,剪枝后模型可以保持甚至略微提升性能(因为减少了过拟合)。
Q10 [单选] 在Transformer原论文中,作者使用了多少个注意力头?
- A. 4
- B. 8
- C. 16
- D. 32
解答:原始Transformer论文("Attention Is All You Need")中,base模型使用h=8个头,d_model=512,d_k=64。big模型使用h=16个头,d_model=1024,d_k=64。