图池化与图级表示

一句话概述

图池化(Graph Pooling)是将节点级表示聚合为图级表示的关键操作。在节点分类任务中,GNN输出每个节点的嵌入;但在图分类、图回归等任务中,需要将整张图的所有节点表示汇总为一个固定大小的图级向量。图池化方法从简单到复杂分为:①全局池化——对所有节点取平均/求和/最大值,简单但可能丢失结构信息;②层次化池化(如DiffPool)——分层将节点聚类为"超级节点",逐步粗化图结构;③TopK池化(如SAGPool)——基于节点重要性评分,保留Top-K个最重要的节点。图池化是图分类、分子性质预测等任务的最后一环,将复杂的图结构映射为简洁的向量表示。

💡 核心要点:①图池化将节点级表示聚合为整个图的固定大小向量 ②全局池化(平均/求和/最大)最简单但丢失结构信息 ③层次化池化(DiffPool)逐步粗化图结构,保留层次信息 ④TopK池化(SAGPool)基于重要性评分选择关键节点 ⑤图级表示用于图分类、分子性质预测、图相似度等任务

教学与演示

一、全局池化:从节点到图的简单映射

是什么(定义):全局池化(Global Pooling)是最简单的图级读出操作。给定GNN输出的所有节点嵌入H∈R^{n×d},全局池化通过一个对称函数将所有节点聚合为一个图级向量:Readout(H)=AGG({h_i : i∈V})。常用的AGG函数包括:平均(mean)、求和(sum)、最大值(max)。全局池化的输出是一个d维向量,可以直接输入分类器或回归器。

大白话 全局池化就是"全班总结"。每个学生(节点)有各科成绩(特征),校长要看全班总体情况。平均池化就是"算全班平均分"——简单直接,但可能掩盖"偏科"现象。求和池化就是"算全班总分"——保留了班级人数信息。最大池化就是"看各科最高分"——只看最好的,忽略整体。三种方式各有侧重,选哪种取决于具体任务。

为什么(原理):全局池化的选择影响模型的表达能力。求和池化保留了节点数量和度信息,表达力最强(GIN论文证明了这一点);平均池化适合节点数量变化的场景(如不同大小的图);最大池化适合只关注最显著特征的场景。全局池化的缺点是丢失了图的结构信息——两个结构完全不同的图,如果节点特征分布相同,全局池化后会得到相同的图级表示。

import numpy as np

# 全局池化:将节点嵌入聚合为图级表示
# 对比三种全局池化方式的差异

class GlobalPooling:
    def __init__(self):
        pass

    def mean_pooling(self, H):
        """平均池化:所有节点特征取平均"""
        return np.mean(H, axis=0)

    def sum_pooling(self, H):
        """求和池化:所有节点特征求和"""
        return np.sum(H, axis=0)

    def max_pooling(self, H):
        """最大池化:每个维度取最大值"""
        return np.max(H, axis=0)


# 创建两个结构不同但节点特征分布相似的图
# 图A:链状结构,4个节点
H_A = np.array([
    [0.8, 0.1, 0.2],
    [0.2, 0.9, 0.1],
    [0.1, 0.2, 0.8],
    [0.5, 0.5, 0.3],
])

# 图B:星形结构,4个节点(特征分布与A相同)
H_B = np.array([
    [0.8, 0.1, 0.2],
    [0.2, 0.9, 0.1],
    [0.1, 0.2, 0.8],
    [0.5, 0.5, 0.3],
])  # 相同特征,不同结构

pool = GlobalPooling()

print("=== 全局池化:节点到图的映射 ===\n")

print("图A和图B有相同的节点特征(但结构不同)")
print(f"\n节点特征矩阵(4个节点,3维):\n{H_A}")

mean_A = pool.mean_pooling(H_A)
sum_A = pool.sum_pooling(H_A)
max_A = pool.max_pooling(H_A)

print(f"\n平均池化: {np.round(mean_A, 3)}")
print(f"求和池化: {np.round(sum_A, 3)}")
print(f"最大池化: {np.round(max_A, 3)}")

print("\n关键观察:")
print("- 平均池化:特征在[-1,1]范围,不受节点数影响")
print("- 求和池化:保留了节点数量信息(值更大)")
print("- 最大池化:每维取最大值,忽略其他节点")
print("- 全局池化的局限:两个结构不同的图,如果节点特征相同," +
      "全局池化后图表示也相同——丢失了结构信息!")
全局池化公式\(h_G = \text{READOUT}\left(\{h_i^{(L)} : i \in V\}\right) = \text{AGG}\left(\{h_i^{(L)}\}_{i=1}^{n}\right)\)
大白话 全局池化就是"一视同仁"。把所有节点信息压缩成一个向量,不看它们之间的连接关系。就像把全班的成绩单算一个总分——你知道总分是多少,但不知道谁和谁是好朋友(图结构)。这种方法简单快速,但可能会把两个结构完全不同的图看成一样的(如果节点特征相同的话)。

什么用(应用):全局池化是图分类的基本操作。在分子毒性预测中,求和池化保留原子的总体贡献;在社交网络社区检测中,平均池化不受社区大小影响。GIN论文证明了求和池化+MLP可以达到1-WL测试的上限,是最具表达力的全局池化方式。

哪些坑(缺点):全局池化完全丢失了图结构信息——两个结构完全不同的图(如链状和星形),如果节点特征相同,全局池化后图表示完全相同。这促使了层次化池化方法的发展,后者在粗化图的过程中保留了结构信息。

二、层次化池化:DiffPool与图粗化

是什么(定义):层次化池化(Hierarchical Pooling)通过逐步将节点聚类为"超级节点"(super-node)来粗化图结构。DiffPool(Differentiable Pooling)是最具代表性的方法:它学习一个软分配矩阵S∈R^{n_l×n_{l+1}}(n_{l+1}表示粗化后的节点数),将第l层的n_l个节点分配到第l+1层的n_{l+1}个聚类中。新的节点特征为S^T H,新的邻接矩阵为S^T A S。整个过程是可微的,可以端到端训练。

大白话 DiffPool就像"行政区域的合并"。原始图有100个村庄(节点),DiffPool学习把它们合并成10个镇(超级节点)——哪个村庄归哪个镇,不是人为规定的,而是模型自己学的(软分配矩阵S)。合并后,每个镇的人口(新特征)是其下属村庄的人口汇总,镇之间的道路(新邻接矩阵)由其下属村庄之间的道路汇总。这个过程可以重复——镇合并成县,县合并成市……层层嵌套,形成层次化表示。

为什么(原理):DiffPool的核心是软分配矩阵S。S通过另一个GNN(分配GNN)学习:S=softmax(GNN_{pool}(A, H))。每行对应一个原始节点分配到各聚类的概率(行和为1)。这种软分配允许节点同时属于多个聚类(通过概率),提供了更丰富的表示。层次化粗化使得最终图级表示能够捕捉多层次的结构信息——从局部模式到全局拓扑。

import numpy as np

# DiffPool:层次化图池化的简化演示
# 展示如何通过软分配矩阵粗化图

class SimplifiedDiffPool:
    def __init__(self, n_clusters=2):
        np.random.seed(42)
        self.n_clusters = n_clusters

    def compute_assignment(self, H, A):
        """计算软分配矩阵S(简化版)"""
        n = H.shape[0]
        # 模拟基于节点特征的聚类分配
        S = np.zeros((n, self.n_clusters))
        for i in range(n):
            # 基于特征第一个维度的值来分配(实际中由GNN学习)
            if H[i, 0] > 0.3:
                S[i, 0] = 0.7
                S[i, 1] = 0.3
            else:
                S[i, 0] = 0.3
                S[i, 1] = 0.7
        return S

    def pool(self, H, A):
        """DiffPool粗化:X'=S^T H, A'=S^T A S"""
        S = self.compute_assignment(H, A)

        # 新节点特征:聚类内节点的特征加权和
        H_new = S.T @ H  # (n_clusters, d)

        # 新邻接矩阵:聚类之间的连接强度
        A_new = S.T @ A @ S  # (n_clusters, n_clusters)

        return H_new, A_new, S


# 原始图:6个节点
H = np.array([
    [0.5, 0.2],
    [0.4, 0.3],
    [0.1, 0.8],
    [0.2, 0.7],
    [0.6, 0.1],
    [0.1, 0.9],
])

A = np.array([
    [0, 1, 0, 0, 1, 0],
    [1, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 1],
    [0, 0, 1, 0, 0, 0],
    [1, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0],
])

diffpool = SimplifiedDiffPool(n_clusters=2)
H_new, A_new, S = diffpool.pool(H, A)

print("=== DiffPool:层次化图池化 ===\n")
print(f"原始图:{H.shape[0]}个节点")
print(f"粗化后:{H_new.shape[0]}个聚类(超级节点)\n")

print("软分配矩阵 S(行和为1):")
print(np.round(S, 3))
print("→ 节点0,1,4主要分配给聚类0,节点2,3,5主要分配给聚类1")

print(f"\n新节点特征 H' = S^T H({H_new.shape[0]}×{H_new.shape[1]}):")
print(np.round(H_new, 3))

print(f"\n新邻接矩阵 A' = S^T A S(聚类间连接强度):")
print(np.round(A_new, 3))
print("→ 保留了原始图的聚类间连接关系")

print("\n层次化池化的优势:")
print("- 保留了图的结构信息(通过A')")
print("- 可堆叠多层,形成层次化表示")
print("- 端到端可训练")
DiffPool的核心公式\(X^{(l+1)} = S^{(l)^T} H^{(l)}, \quad A^{(l+1)} = S^{(l)^T} A^{(l)} S^{(l)}, \quad S^{(l)} = \text{softmax}\left(\text{GNN}_{\text{pool}}^{(l)}(A^{(l)}, H^{(l)})\right)\)
大白话 DiffPool就是"自动学怎么合并村庄"。你告诉它要从100个村庄合并成10个镇,但哪个村庄归哪个镇——不是你来定,而是模型自己学。模型通过另一个GNN(分配网络)来打分,判断"这个村庄和那个村庄更像"。最终每个村庄被软分配到几个镇(概率分配),形成新的"镇级"图。这个过程可以重复——镇合并成县,县合并成市。

什么用(应用):DiffPool在多个图分类基准上达到当时最优。在ENZYMES(蛋白质分类)和DD(化合物分类)上,DiffPool显著优于全局池化。层次化池化也被用于社交网络分析(识别多层社区结构)、分子性质预测(识别功能团层次)等。后续的MinCutPool、StructPool等改进了DiffPool的计算效率和稀疏性。

哪些坑(缺点):DiffPool的计算复杂度较高——需要额外的分配GNN,且S^TAS的矩阵乘法(O(n²))在稀疏图上效率低。软分配矩阵S通常稠密(每个节点软分配到所有聚类),内存占用大。聚类数n_{l+1}需要预定义(通常为n_l的25%),超参数调优复杂。

三、TopK池化与自注意力池化

是什么(定义):TopK池化(如SAGPool)通过学习每个节点的重要性评分,保留Top-K个最重要的节点及其连接关系。具体为:评分z=GNN(A,H),选择top-K个节点idx=topk(z),输出H'=H[idx,:]⊙z[idx],A'=A[idx,:][:,idx]。自注意力池化(Set2Set、Global Attention)使用注意力机制对节点加权求和:h_G=Σ_i α_i h_i,其中α_i=softmax(W·h_i)。这两种方法在保留重要信息和计算效率之间取得了平衡。

大白话 TopK池化就是"只保留最重要的K个人"。就好像一个班级里,只选前10名代表班级——只保留这些人的特征和关系。SAGPool通过一个打分网络来判断每个节点有多重要,然后"裁员"——只留K个最重要的。自注意力池化则更温和——不裁员,但给重要的人更大发言权(加权求和)。

为什么(原理):TopK池化的优势在于保持了稀疏性——只保留K个节点,后续计算量大幅减少。SAGPool的评分可以基于节点特征(self-attention)或图结构(graph convolution),灵活适应不同任务。自注意力池化使用注意力机制动态计算每个节点对图级表示的贡献,比简单的平均/求和更灵活。

import numpy as np

# TopK池化(SAGPool)与自注意力池化
# 演示两种高级池化方法

class AdvancedPooling:
    def __init__(self, d=3, k=3):
        np.random.seed(42)
        self.d = d
        self.k = k
        # 评分权重
        self.W_score = np.random.randn(d, 1) * 0.3
        # 注意力权重
        self.W_attn = np.random.randn(d, 1) * 0.3

    def topk_pool(self, H, A):
        """TopK池化:保留最重要的K个节点"""
        # 步骤1:计算每个节点的重要性评分
        scores = H @ self.W_score  # (n, 1)
        scores = scores.flatten()

        # 步骤2:选择Top-K个节点
        topk_idx = np.argsort(scores)[-self.k:]  # 分数最高的K个

        # 步骤3:保留Top-K节点和对应边
        H_topk = H[topk_idx]
        # 加权特征:保留的节点特征乘以重要性评分(门控)
        H_topk = H_topk * scores[topk_idx].reshape(-1, 1)
        A_topk = A[topk_idx][:, topk_idx]

        return H_topk, A_topk, topk_idx, scores

    def attention_pool(self, H):
        """自注意力池化:所有节点加权求和"""
        # 计算注意力分数
        attn_scores = H @ self.W_attn  # (n, 1)
        attn_scores = attn_scores.flatten()

        # Softmax归一化
        attn_scores = attn_scores - np.max(attn_scores)
        alpha = np.exp(attn_scores) / np.sum(np.exp(attn_scores))

        # 加权求和得到图级表示
        h_G = np.sum(alpha.reshape(-1, 1) * H, axis=0)

        return h_G, alpha


# 演示
H = np.array([
    [0.5, 0.2, 0.1],
    [0.8, 0.1, 0.3],
    [0.1, 0.9, 0.2],
    [0.3, 0.2, 0.7],
    [0.2, 0.6, 0.1],
])

A = np.ones((5, 5)) - np.eye(5)

pool = AdvancedPooling(d=3, k=3)

# TopK池化
H_topk, A_topk, topk_idx, scores = pool.topk_pool(H, A)
print("=== TopK池化与自注意力池化 ===\n")

print("TopK池化(K=3):")
print(f"节点重要性评分: {np.round(scores, 3)}")
print(f"保留节点: {topk_idx.tolist()}(评分最高的3个)")
print(f"保留后邻接矩阵:\n{A_topk}")
print("→ 只保留前K个重要节点,计算量大幅减少")

# 自注意力池化
h_G, alpha = pool.attention_pool(H)
print(f"\n自注意力池化:")
print(f"注意力权重: {np.round(alpha, 3)}")
print(f"图级表示: {np.round(h_G, 3)}")
print("→ 所有节点参与,重要节点权重大")

print("\n两种方法对比:")
print("- TopK池化:硬选择,计算高效,适合大规模图")
print("- 注意力池化:软加权,保持全部信息,适合中小图")
SAGPool的TopK池化\(\mathbf{z} = \sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H W\right), \quad \text{idx} = \text{topk}(\mathbf{z}, k), \quad H' = H_{\text{idx},:} \odot \mathbf{z}_{\text{idx}}, \quad A' = A_{\text{idx},\text{idx}}\)
大白话 SAGPool就是"选Top-K个代表"。用一个评分网络给每个节点打分,然后只保留分数最高的K个节点,把剩下的节点"裁掉"。保留的节点特征还要乘以评分(门控)——越重要的节点,特征越突出。这种方法很高效——每层减少一半节点,层数越多,计算量越小。

什么用(应用):SAGPool在多个图分类基准上效果优异。TopK池化也用于图匹配和子图检测等任务。自注意力池化(Set2Set)在分子性质预测中特别有效——它能学到一个分子的"注意力指纹",关注关键的原子子结构。

哪些坑(缺点):TopK池化是硬选择,可能丢失重要但评分不高的节点信息。k值的选择需要权衡——太小丢失信息,太大计算量增加。自注意力池化对所有节点加权平均,在大图上计算量较大。

概念关系图谱

概念核心含义与AI的关系关联概念
图池化将节点级表示聚合为图级向量图分类、图回归的最后一环读出操作、聚合
全局池化对所有节点取平均/求和/最大最简单的图池化,丢失结构信息平均、最大池化
DiffPool学习软分配矩阵,逐步粗化图保留层次结构信息的池化方法图粗化、聚类
SAGPool基于重要性评分的TopK节点选择高效保留关键节点,减少计算量门控、TopK选择
自注意力池化注意力加权求和所有节点软加权,保留全部信息注意力、加权和

重点答疑

Q1: 为什么全局池化会丢失结构信息?

全局池化对节点集合操作,不关心节点之间的连接关系。两个结构不同的图(如4个节点连成链 vs 4个节点连成星形),如果它们恰好有相同的节点特征(例如都是[0.5,0.2,0.1]等),全局池化会产生完全相同的图级表示。要区分结构差异,需要层次化池化或多分辨率方法。

Q2: DiffPool中的软分配和硬分配有什么区别?

软分配(soft assignment):每个节点以一定概率分配到多个聚类(行和为1),允许节点同时属于多个聚类,提供更丰富的信息。硬分配(hard assignment):每个节点只能分配到一个聚类,计算更高效但信息损失更多。DiffPool使用软分配,通过softmax产生概率分布。

Q3: GraphSAGE和GCN中是否需要图池化?

节点分类任务(如Cora)不需要图池化——每个节点独立输出分类结果。图分类任务(如分子毒性预测)需要图池化——将所有节点表示聚合为一个图级向量,再输入分类器。GraphSAGE和GCN都可以作为节点编码器,然后接图池化层实现图分类。

章节单词汇总

英文音标术语/释义
Graph Pooling/ɡræf ˈpuːlɪŋ/图池化,将节点表示聚合为图级表示
Readout/ˈriːdaʊt/读出操作,从节点嵌入生成图级表示
Global Pooling/ˈɡloʊbəl ˈpuːlɪŋ/全局池化,对所有节点简单聚合
Hierarchical Pooling/ˌhaɪəˈrɑːrkɪkəl ˈpuːlɪŋ/层次化池化,逐步粗化图结构
DiffPool/dɪf puːl/可微池化,学习软分配矩阵粗化图
SAGPool/sæɡ puːl/自注意力图池化,基于评分选择TopK节点
Assignment Matrix/əˈsaɪnmənt ˈmeɪtrɪks/分配矩阵,S,节点到聚类的概率分配
Super-node/ˈsuːpər noʊd/超级节点,粗化后代表节点聚类的虚拟节点
Coarsening/ˈkɔːrsənɪŋ/粗化,减少节点数量的过程

面试练习

Q1 [单选] 图池化的主要目的是什么?

  • A. 增加节点数量
  • B. 将节点级表示聚合为固定大小的图级向量
  • C. 加快消息传递速度
  • D. 增加模型参数
解答:图池化将变长的节点表示(n个节点→n个向量)聚合为固定大小的图级向量(1个向量),用于图分类、图回归等任务。

Q2 [单选] 以下哪种聚合函数在理论上具有最强的图级表达能力?

  • A. 平均(Mean)
  • B. 求和(Sum)
  • C. 最大值(Max)
  • D. 三者相同
解答:GIN论文证明求和池化具有最强的表达能力(可达1-WL测试上限),因为它保留了节点数量和图的大小信息。

Q3 [单选] DiffPool中,软分配矩阵S的形状是什么?

  • A. (d, d)
  • B. (n_l, n_{l+1})
  • C. (n_l, d)
  • D. (n_{l+1}, n_{l+1})
解答:S的形状为(n_l, n_{l+1}),将第l层的n_l个节点分配到第l+1层的n_{l+1}个聚类。每行(节点)被分配到各聚类,概率和为1。

Q4 [多选] 以下哪些是图池化方法?

  • A. 全局平均池化
  • B. DiffPool
  • C. SAGPool
  • D. 自注意力池化
  • E. Dropout
解答:全局平均池化、DiffPool、SAGPool、自注意力池化都是图池化方法。Dropout是正则化方法,不用于池化。

Q5 [单选] SAGPool中的"K"代表什么?

  • A. 保留的节点数量
  • B. 聚类数量
  • C. 注意力头数量
  • D. 卷积核大小
解答:SAGPool中的TopK指保留K个评分最高的节点,K = ratio × n,如ratio=0.5表示保留一半节点。

Q6 [单选] DiffPool中,新的邻接矩阵A'如何计算?

  • A. A' = A
  • B. A' = S S^T
  • C. A' = S^T A S
  • D. A' = A^T A
解答:DiffPool通过A' = S^T A S计算新邻接矩阵,这等价于通过软分配矩阵S将节点间的边关系聚合为聚类间的连接强度。

Q7 [多选] 关于图池化的应用场景,以下哪些是正确的?

  • A. 分子性质预测(图分类)
  • B. 蛋白质功能分类(图分类)
  • C. 社交网络社区检测(图级任务)
  • D. 论文节点分类(节点级任务)
  • E. 图相似度计算
解答:图池化用于图级任务(图分类、图回归、图相似度)。节点分类是节点级任务,不需要图池化(每个节点直接输出分类)。