图池化与图级表示
一句话概述
图池化(Graph Pooling)是将节点级表示聚合为图级表示的关键操作。在节点分类任务中,GNN输出每个节点的嵌入;但在图分类、图回归等任务中,需要将整张图的所有节点表示汇总为一个固定大小的图级向量。图池化方法从简单到复杂分为:①全局池化——对所有节点取平均/求和/最大值,简单但可能丢失结构信息;②层次化池化(如DiffPool)——分层将节点聚类为"超级节点",逐步粗化图结构;③TopK池化(如SAGPool)——基于节点重要性评分,保留Top-K个最重要的节点。图池化是图分类、分子性质预测等任务的最后一环,将复杂的图结构映射为简洁的向量表示。
💡 核心要点:①图池化将节点级表示聚合为整个图的固定大小向量 ②全局池化(平均/求和/最大)最简单但丢失结构信息 ③层次化池化(DiffPool)逐步粗化图结构,保留层次信息 ④TopK池化(SAGPool)基于重要性评分选择关键节点 ⑤图级表示用于图分类、分子性质预测、图相似度等任务
教学与演示
一、全局池化:从节点到图的简单映射
是什么(定义):全局池化(Global Pooling)是最简单的图级读出操作。给定GNN输出的所有节点嵌入H∈R^{n×d},全局池化通过一个对称函数将所有节点聚合为一个图级向量:Readout(H)=AGG({h_i : i∈V})。常用的AGG函数包括:平均(mean)、求和(sum)、最大值(max)。全局池化的输出是一个d维向量,可以直接输入分类器或回归器。
大白话 全局池化就是"全班总结"。每个学生(节点)有各科成绩(特征),校长要看全班总体情况。平均池化就是"算全班平均分"——简单直接,但可能掩盖"偏科"现象。求和池化就是"算全班总分"——保留了班级人数信息。最大池化就是"看各科最高分"——只看最好的,忽略整体。三种方式各有侧重,选哪种取决于具体任务。
为什么(原理):全局池化的选择影响模型的表达能力。求和池化保留了节点数量和度信息,表达力最强(GIN论文证明了这一点);平均池化适合节点数量变化的场景(如不同大小的图);最大池化适合只关注最显著特征的场景。全局池化的缺点是丢失了图的结构信息——两个结构完全不同的图,如果节点特征分布相同,全局池化后会得到相同的图级表示。
import numpy as np
# 全局池化:将节点嵌入聚合为图级表示
# 对比三种全局池化方式的差异
class GlobalPooling:
def __init__(self):
pass
def mean_pooling(self, H):
"""平均池化:所有节点特征取平均"""
return np.mean(H, axis=0)
def sum_pooling(self, H):
"""求和池化:所有节点特征求和"""
return np.sum(H, axis=0)
def max_pooling(self, H):
"""最大池化:每个维度取最大值"""
return np.max(H, axis=0)
# 创建两个结构不同但节点特征分布相似的图
# 图A:链状结构,4个节点
H_A = np.array([
[0.8, 0.1, 0.2],
[0.2, 0.9, 0.1],
[0.1, 0.2, 0.8],
[0.5, 0.5, 0.3],
])
# 图B:星形结构,4个节点(特征分布与A相同)
H_B = np.array([
[0.8, 0.1, 0.2],
[0.2, 0.9, 0.1],
[0.1, 0.2, 0.8],
[0.5, 0.5, 0.3],
]) # 相同特征,不同结构
pool = GlobalPooling()
print("=== 全局池化:节点到图的映射 ===\n")
print("图A和图B有相同的节点特征(但结构不同)")
print(f"\n节点特征矩阵(4个节点,3维):\n{H_A}")
mean_A = pool.mean_pooling(H_A)
sum_A = pool.sum_pooling(H_A)
max_A = pool.max_pooling(H_A)
print(f"\n平均池化: {np.round(mean_A, 3)}")
print(f"求和池化: {np.round(sum_A, 3)}")
print(f"最大池化: {np.round(max_A, 3)}")
print("\n关键观察:")
print("- 平均池化:特征在[-1,1]范围,不受节点数影响")
print("- 求和池化:保留了节点数量信息(值更大)")
print("- 最大池化:每维取最大值,忽略其他节点")
print("- 全局池化的局限:两个结构不同的图,如果节点特征相同," +
"全局池化后图表示也相同——丢失了结构信息!")
大白话 全局池化就是"一视同仁"。把所有节点信息压缩成一个向量,不看它们之间的连接关系。就像把全班的成绩单算一个总分——你知道总分是多少,但不知道谁和谁是好朋友(图结构)。这种方法简单快速,但可能会把两个结构完全不同的图看成一样的(如果节点特征相同的话)。
什么用(应用):全局池化是图分类的基本操作。在分子毒性预测中,求和池化保留原子的总体贡献;在社交网络社区检测中,平均池化不受社区大小影响。GIN论文证明了求和池化+MLP可以达到1-WL测试的上限,是最具表达力的全局池化方式。
哪些坑(缺点):全局池化完全丢失了图结构信息——两个结构完全不同的图(如链状和星形),如果节点特征相同,全局池化后图表示完全相同。这促使了层次化池化方法的发展,后者在粗化图的过程中保留了结构信息。
二、层次化池化:DiffPool与图粗化
是什么(定义):层次化池化(Hierarchical Pooling)通过逐步将节点聚类为"超级节点"(super-node)来粗化图结构。DiffPool(Differentiable Pooling)是最具代表性的方法:它学习一个软分配矩阵S∈R^{n_l×n_{l+1}}(n_{l+1}表示粗化后的节点数),将第l层的n_l个节点分配到第l+1层的n_{l+1}个聚类中。新的节点特征为S^T H,新的邻接矩阵为S^T A S。整个过程是可微的,可以端到端训练。
大白话 DiffPool就像"行政区域的合并"。原始图有100个村庄(节点),DiffPool学习把它们合并成10个镇(超级节点)——哪个村庄归哪个镇,不是人为规定的,而是模型自己学的(软分配矩阵S)。合并后,每个镇的人口(新特征)是其下属村庄的人口汇总,镇之间的道路(新邻接矩阵)由其下属村庄之间的道路汇总。这个过程可以重复——镇合并成县,县合并成市……层层嵌套,形成层次化表示。
为什么(原理):DiffPool的核心是软分配矩阵S。S通过另一个GNN(分配GNN)学习:S=softmax(GNN_{pool}(A, H))。每行对应一个原始节点分配到各聚类的概率(行和为1)。这种软分配允许节点同时属于多个聚类(通过概率),提供了更丰富的表示。层次化粗化使得最终图级表示能够捕捉多层次的结构信息——从局部模式到全局拓扑。
import numpy as np
# DiffPool:层次化图池化的简化演示
# 展示如何通过软分配矩阵粗化图
class SimplifiedDiffPool:
def __init__(self, n_clusters=2):
np.random.seed(42)
self.n_clusters = n_clusters
def compute_assignment(self, H, A):
"""计算软分配矩阵S(简化版)"""
n = H.shape[0]
# 模拟基于节点特征的聚类分配
S = np.zeros((n, self.n_clusters))
for i in range(n):
# 基于特征第一个维度的值来分配(实际中由GNN学习)
if H[i, 0] > 0.3:
S[i, 0] = 0.7
S[i, 1] = 0.3
else:
S[i, 0] = 0.3
S[i, 1] = 0.7
return S
def pool(self, H, A):
"""DiffPool粗化:X'=S^T H, A'=S^T A S"""
S = self.compute_assignment(H, A)
# 新节点特征:聚类内节点的特征加权和
H_new = S.T @ H # (n_clusters, d)
# 新邻接矩阵:聚类之间的连接强度
A_new = S.T @ A @ S # (n_clusters, n_clusters)
return H_new, A_new, S
# 原始图:6个节点
H = np.array([
[0.5, 0.2],
[0.4, 0.3],
[0.1, 0.8],
[0.2, 0.7],
[0.6, 0.1],
[0.1, 0.9],
])
A = np.array([
[0, 1, 0, 0, 1, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 1],
[0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
])
diffpool = SimplifiedDiffPool(n_clusters=2)
H_new, A_new, S = diffpool.pool(H, A)
print("=== DiffPool:层次化图池化 ===\n")
print(f"原始图:{H.shape[0]}个节点")
print(f"粗化后:{H_new.shape[0]}个聚类(超级节点)\n")
print("软分配矩阵 S(行和为1):")
print(np.round(S, 3))
print("→ 节点0,1,4主要分配给聚类0,节点2,3,5主要分配给聚类1")
print(f"\n新节点特征 H' = S^T H({H_new.shape[0]}×{H_new.shape[1]}):")
print(np.round(H_new, 3))
print(f"\n新邻接矩阵 A' = S^T A S(聚类间连接强度):")
print(np.round(A_new, 3))
print("→ 保留了原始图的聚类间连接关系")
print("\n层次化池化的优势:")
print("- 保留了图的结构信息(通过A')")
print("- 可堆叠多层,形成层次化表示")
print("- 端到端可训练")
大白话 DiffPool就是"自动学怎么合并村庄"。你告诉它要从100个村庄合并成10个镇,但哪个村庄归哪个镇——不是你来定,而是模型自己学。模型通过另一个GNN(分配网络)来打分,判断"这个村庄和那个村庄更像"。最终每个村庄被软分配到几个镇(概率分配),形成新的"镇级"图。这个过程可以重复——镇合并成县,县合并成市。
什么用(应用):DiffPool在多个图分类基准上达到当时最优。在ENZYMES(蛋白质分类)和DD(化合物分类)上,DiffPool显著优于全局池化。层次化池化也被用于社交网络分析(识别多层社区结构)、分子性质预测(识别功能团层次)等。后续的MinCutPool、StructPool等改进了DiffPool的计算效率和稀疏性。
哪些坑(缺点):DiffPool的计算复杂度较高——需要额外的分配GNN,且S^TAS的矩阵乘法(O(n²))在稀疏图上效率低。软分配矩阵S通常稠密(每个节点软分配到所有聚类),内存占用大。聚类数n_{l+1}需要预定义(通常为n_l的25%),超参数调优复杂。
三、TopK池化与自注意力池化
是什么(定义):TopK池化(如SAGPool)通过学习每个节点的重要性评分,保留Top-K个最重要的节点及其连接关系。具体为:评分z=GNN(A,H),选择top-K个节点idx=topk(z),输出H'=H[idx,:]⊙z[idx],A'=A[idx,:][:,idx]。自注意力池化(Set2Set、Global Attention)使用注意力机制对节点加权求和:h_G=Σ_i α_i h_i,其中α_i=softmax(W·h_i)。这两种方法在保留重要信息和计算效率之间取得了平衡。
大白话 TopK池化就是"只保留最重要的K个人"。就好像一个班级里,只选前10名代表班级——只保留这些人的特征和关系。SAGPool通过一个打分网络来判断每个节点有多重要,然后"裁员"——只留K个最重要的。自注意力池化则更温和——不裁员,但给重要的人更大发言权(加权求和)。
为什么(原理):TopK池化的优势在于保持了稀疏性——只保留K个节点,后续计算量大幅减少。SAGPool的评分可以基于节点特征(self-attention)或图结构(graph convolution),灵活适应不同任务。自注意力池化使用注意力机制动态计算每个节点对图级表示的贡献,比简单的平均/求和更灵活。
import numpy as np
# TopK池化(SAGPool)与自注意力池化
# 演示两种高级池化方法
class AdvancedPooling:
def __init__(self, d=3, k=3):
np.random.seed(42)
self.d = d
self.k = k
# 评分权重
self.W_score = np.random.randn(d, 1) * 0.3
# 注意力权重
self.W_attn = np.random.randn(d, 1) * 0.3
def topk_pool(self, H, A):
"""TopK池化:保留最重要的K个节点"""
# 步骤1:计算每个节点的重要性评分
scores = H @ self.W_score # (n, 1)
scores = scores.flatten()
# 步骤2:选择Top-K个节点
topk_idx = np.argsort(scores)[-self.k:] # 分数最高的K个
# 步骤3:保留Top-K节点和对应边
H_topk = H[topk_idx]
# 加权特征:保留的节点特征乘以重要性评分(门控)
H_topk = H_topk * scores[topk_idx].reshape(-1, 1)
A_topk = A[topk_idx][:, topk_idx]
return H_topk, A_topk, topk_idx, scores
def attention_pool(self, H):
"""自注意力池化:所有节点加权求和"""
# 计算注意力分数
attn_scores = H @ self.W_attn # (n, 1)
attn_scores = attn_scores.flatten()
# Softmax归一化
attn_scores = attn_scores - np.max(attn_scores)
alpha = np.exp(attn_scores) / np.sum(np.exp(attn_scores))
# 加权求和得到图级表示
h_G = np.sum(alpha.reshape(-1, 1) * H, axis=0)
return h_G, alpha
# 演示
H = np.array([
[0.5, 0.2, 0.1],
[0.8, 0.1, 0.3],
[0.1, 0.9, 0.2],
[0.3, 0.2, 0.7],
[0.2, 0.6, 0.1],
])
A = np.ones((5, 5)) - np.eye(5)
pool = AdvancedPooling(d=3, k=3)
# TopK池化
H_topk, A_topk, topk_idx, scores = pool.topk_pool(H, A)
print("=== TopK池化与自注意力池化 ===\n")
print("TopK池化(K=3):")
print(f"节点重要性评分: {np.round(scores, 3)}")
print(f"保留节点: {topk_idx.tolist()}(评分最高的3个)")
print(f"保留后邻接矩阵:\n{A_topk}")
print("→ 只保留前K个重要节点,计算量大幅减少")
# 自注意力池化
h_G, alpha = pool.attention_pool(H)
print(f"\n自注意力池化:")
print(f"注意力权重: {np.round(alpha, 3)}")
print(f"图级表示: {np.round(h_G, 3)}")
print("→ 所有节点参与,重要节点权重大")
print("\n两种方法对比:")
print("- TopK池化:硬选择,计算高效,适合大规模图")
print("- 注意力池化:软加权,保持全部信息,适合中小图")
大白话 SAGPool就是"选Top-K个代表"。用一个评分网络给每个节点打分,然后只保留分数最高的K个节点,把剩下的节点"裁掉"。保留的节点特征还要乘以评分(门控)——越重要的节点,特征越突出。这种方法很高效——每层减少一半节点,层数越多,计算量越小。
什么用(应用):SAGPool在多个图分类基准上效果优异。TopK池化也用于图匹配和子图检测等任务。自注意力池化(Set2Set)在分子性质预测中特别有效——它能学到一个分子的"注意力指纹",关注关键的原子子结构。
哪些坑(缺点):TopK池化是硬选择,可能丢失重要但评分不高的节点信息。k值的选择需要权衡——太小丢失信息,太大计算量增加。自注意力池化对所有节点加权平均,在大图上计算量较大。
概念关系图谱
| 概念 | 核心含义 | 与AI的关系 | 关联概念 |
|---|---|---|---|
| 图池化 | 将节点级表示聚合为图级向量 | 图分类、图回归的最后一环 | 读出操作、聚合 |
| 全局池化 | 对所有节点取平均/求和/最大 | 最简单的图池化,丢失结构信息 | 平均、最大池化 |
| DiffPool | 学习软分配矩阵,逐步粗化图 | 保留层次结构信息的池化方法 | 图粗化、聚类 |
| SAGPool | 基于重要性评分的TopK节点选择 | 高效保留关键节点,减少计算量 | 门控、TopK选择 |
| 自注意力池化 | 注意力加权求和所有节点 | 软加权,保留全部信息 | 注意力、加权和 |
重点答疑
Q1: 为什么全局池化会丢失结构信息?
全局池化对节点集合操作,不关心节点之间的连接关系。两个结构不同的图(如4个节点连成链 vs 4个节点连成星形),如果它们恰好有相同的节点特征(例如都是[0.5,0.2,0.1]等),全局池化会产生完全相同的图级表示。要区分结构差异,需要层次化池化或多分辨率方法。
Q2: DiffPool中的软分配和硬分配有什么区别?
软分配(soft assignment):每个节点以一定概率分配到多个聚类(行和为1),允许节点同时属于多个聚类,提供更丰富的信息。硬分配(hard assignment):每个节点只能分配到一个聚类,计算更高效但信息损失更多。DiffPool使用软分配,通过softmax产生概率分布。
Q3: GraphSAGE和GCN中是否需要图池化?
节点分类任务(如Cora)不需要图池化——每个节点独立输出分类结果。图分类任务(如分子毒性预测)需要图池化——将所有节点表示聚合为一个图级向量,再输入分类器。GraphSAGE和GCN都可以作为节点编码器,然后接图池化层实现图分类。
章节单词汇总
| 英文 | 音标 | 术语/释义 |
|---|---|---|
| Graph Pooling | /ɡræf ˈpuːlɪŋ/ | 图池化,将节点表示聚合为图级表示 |
| Readout | /ˈriːdaʊt/ | 读出操作,从节点嵌入生成图级表示 |
| Global Pooling | /ˈɡloʊbəl ˈpuːlɪŋ/ | 全局池化,对所有节点简单聚合 |
| Hierarchical Pooling | /ˌhaɪəˈrɑːrkɪkəl ˈpuːlɪŋ/ | 层次化池化,逐步粗化图结构 |
| DiffPool | /dɪf puːl/ | 可微池化,学习软分配矩阵粗化图 |
| SAGPool | /sæɡ puːl/ | 自注意力图池化,基于评分选择TopK节点 |
| Assignment Matrix | /əˈsaɪnmənt ˈmeɪtrɪks/ | 分配矩阵,S,节点到聚类的概率分配 |
| Super-node | /ˈsuːpər noʊd/ | 超级节点,粗化后代表节点聚类的虚拟节点 |
| Coarsening | /ˈkɔːrsənɪŋ/ | 粗化,减少节点数量的过程 |
面试练习
Q1 [单选] 图池化的主要目的是什么?
- A. 增加节点数量
- B. 将节点级表示聚合为固定大小的图级向量
- C. 加快消息传递速度
- D. 增加模型参数
解答:图池化将变长的节点表示(n个节点→n个向量)聚合为固定大小的图级向量(1个向量),用于图分类、图回归等任务。
Q2 [单选] 以下哪种聚合函数在理论上具有最强的图级表达能力?
- A. 平均(Mean)
- B. 求和(Sum)
- C. 最大值(Max)
- D. 三者相同
解答:GIN论文证明求和池化具有最强的表达能力(可达1-WL测试上限),因为它保留了节点数量和图的大小信息。
Q3 [单选] DiffPool中,软分配矩阵S的形状是什么?
- A. (d, d)
- B. (n_l, n_{l+1})
- C. (n_l, d)
- D. (n_{l+1}, n_{l+1})
解答:S的形状为(n_l, n_{l+1}),将第l层的n_l个节点分配到第l+1层的n_{l+1}个聚类。每行(节点)被分配到各聚类,概率和为1。
Q4 [多选] 以下哪些是图池化方法?
- A. 全局平均池化
- B. DiffPool
- C. SAGPool
- D. 自注意力池化
- E. Dropout
解答:全局平均池化、DiffPool、SAGPool、自注意力池化都是图池化方法。Dropout是正则化方法,不用于池化。
Q5 [单选] SAGPool中的"K"代表什么?
- A. 保留的节点数量
- B. 聚类数量
- C. 注意力头数量
- D. 卷积核大小
解答:SAGPool中的TopK指保留K个评分最高的节点,K = ratio × n,如ratio=0.5表示保留一半节点。
Q6 [单选] DiffPool中,新的邻接矩阵A'如何计算?
- A. A' = A
- B. A' = S S^T
- C. A' = S^T A S
- D. A' = A^T A
解答:DiffPool通过A' = S^T A S计算新邻接矩阵,这等价于通过软分配矩阵S将节点间的边关系聚合为聚类间的连接强度。
Q7 [多选] 关于图池化的应用场景,以下哪些是正确的?
- A. 分子性质预测(图分类)
- B. 蛋白质功能分类(图分类)
- C. 社交网络社区检测(图级任务)
- D. 论文节点分类(节点级任务)
- E. 图相似度计算
解答:图池化用于图级任务(图分类、图回归、图相似度)。节点分类是节点级任务,不需要图池化(每个节点直接输出分类)。