GraphSAGE
论文地址:1706.02216
网站:GraphSAGE
代码仓库:williamleif/graphsage-simple: Simple reference implementation of GraphSAGE.
介绍
GraphSAGE(Graph Sample and Aggregate)是一种图神经网络模型,旨在有效地处理大规模图数据。它通过采样和聚合节点的邻居信息,生成节点的嵌入向量,能够进行图结构的任务,如节点分类、节点嵌入、链路预测等。
1. 背景:
传统的图神经网络(GNN)模型,如GCN(Graph Convolutional Network),通常需要处理整个图的邻居信息来生成节点嵌入。这种全图卷积的方式在小型图上表现良好,但随着图规模增大,计算和存储成本急剧增加,尤其是在处理社交网络、生物网络等大规模图时,直接使用所有邻居节点的信息变得不可行。因此,GraphSAGE模型应运而生,旨在解决大规模图数据中计算和存储效率的问题。
2. 动机:
GraphSAGE的动机是解决以下问题:
- 大规模图上的计算效率:随着图规模的增大,传统GNN需要处理整个图的数据,导致内存和计算资源的消耗难以接受。
- 高效地生成节点嵌入:许多实际任务需要学习节点的低维表示,而这些任务的需求往往是图不断增长的,GraphSAGE希望在无需重新训练模型的情况下,能够灵活地生成新节点的嵌入。
- 采样与泛化能力:GraphSAGE通过局部采样机制,避免处理整个图的邻居节点,从而提升了模型的泛化能力,特别是能够处理动态图或部分未知结构的图。
3. 创新点:
GraphSAGE的创新点主要体现在以下几个方面:
- 采样邻居节点:与GCN不同,GraphSAGE不是使用全局图卷积操作,而是为每个节点采样固定数量的邻居节点。通过局部采样的方式,可以有效地减少计算量,特别是当图的节点数和边数非常大时,这种方法可以显著降低内存和计算资源的消耗。
- 聚合函数设计:GraphSAGE通过不同的聚合函数(如均值、池化、LSTM等)对采样的邻居节点进行信息汇总,灵活地捕捉图的局部结构信息。不同的聚合函数能够根据任务需求调整节点嵌入的表达能力。
- 节点归纳学习:相比于传统的GNN模型需要全图信息,GraphSAGE采用的是归纳学习方式,不需要提前知道整个图的结构,也能够在训练好的模型基础上为新加入的节点生成嵌入,这使得模型在处理动态或部分未知图结构时具有很强的泛化能力。
4. 效果:
GraphSAGE在多个大规模图任务上取得了较好的效果,特别是在以下方面:
- 计算效率:通过局部采样,GraphSAGE极大地减少了计算开销,使得在大型图数据上进行节点嵌入学习成为可能。相比于传统的GCN,它能够处理上亿节点规模的图。
- 嵌入生成能力:GraphSAGE能够在不重新训练模型的情况下,为新加入的节点生成嵌入,这对动态图和实时应用非常有利。
- 实验表现:在标准数据集(如Reddit、PPI等)上,GraphSAGE在节点分类和链路预测任务中展现了优异的表现,与其他模型相比,它能够在保持较高准确率的同时显著提高计算效率。
方法实现
算法
GraphSAGE 提出了一种归纳式方法,用于在图中生成节点嵌入,具体通过从节点的局部邻域聚合特征来实现。这一过程依赖于学习得到的聚合函数,记为 ,其中 代表搜索深度。每次聚合都会结合邻居节点的信息,生成特定深度的节点表示。
算法概览
输入:
- 图 ,其中 是节点集, 是边集。
- 每个节点的输入特征 。
- 搜索深度 。
- 权重矩阵 ,。
- 激活函数 。
- 可微分的聚合函数 ,。
- 邻居选择函数 ,返回节点 的邻居节点集合。
输出:
- 每个节点 的向量表示 。
算法步骤
- 初始化每个节点的特征表示:
- 对于每一个深度 :
- 对于每个节点 ,从其邻居中聚合特征:
- 更新节点 的表示,结合节点自身和邻居的信息:
- 对表示进行归一化:
- 输出每个节点的最终嵌入表示:
算法流程
在每一次迭代(即搜索深度)中,节点通过聚合其局部邻域中的信息逐步构建其嵌入表示。随着算法的迭代,节点逐渐获得来自更远邻域的信息,这样最终生成的嵌入不仅包含了节点自身的特征,还包含了邻居节点的特征信息。
GraphSAGE(Graph Sample and Aggregate)的具体实现主要包括以下几个关键步骤:邻居采样、聚合函数、节点更新、损失函数与训练过程。下面详细介绍其实现流程。
1. 邻居采样(Neighbor Sampling)
GraphSAGE通过随机采样的方式,选择每个节点的固定数量的邻居节点进行消息传递和聚合。这样可以有效减少大规模图上计算的复杂度和内存需求。对于每一层卷积,GraphSAGE只采样每个节点的部分邻居,从而避免了全图卷积带来的效率瓶颈。
- 具体操作:对于每个节点 ,在每一层中,GraphSAGE会随机采样固定数量的邻居节点(设为 个),并使用这些邻居节点来参与聚合运算。这种局部采样方法可以有效减少计算量,并使得模型能够扩展到大规模图。
2. 聚合函数(Aggregation Function)
采样到邻居节点后,GraphSAGE的核心是如何聚合这些邻居节点的信息。GraphSAGE提供了几种不同的聚合函数选择,这些聚合方式可以捕捉邻居节点的局部特征并传递给目标节点。常用的聚合函数包括以下几种:
Mean Aggregator:这是最简单的一种聚合方式,直接取邻居节点的特征的均值。公式如下:
其中, 是节点 在第 层的表示, 是节点 的邻居节点集合, 是可学习的权重矩阵, 是非线性激活函数(如ReLU)。
Pooling Aggregator:对于每个邻居节点,首先通过一个多层感知机(MLP)将其映射到高维特征空间,然后对所有邻居节点的表示进行池化操作(如最大池化或平均池化)。公式如下:
LSTM Aggregator:LSTM聚合器使用循环神经网络(RNN)中的LSTM单元来处理邻居节点信息。它将邻居节点序列输入LSTM,最后将LSTM的输出作为聚合结果。公式如下:
Mean Aggregator:适合处理大规模数据,简单且高效,适合对节点之间关系相对简单的图。
Pooling Aggregator:通过引入非线性变换增加了模型的表达能力,适合在需要较复杂的特征表达时使用。
LSTM Aggregator:适合在节点间存在序列关系或依赖时使用,但由于计算复杂度较高,可能不适合非常大规模的图数据。
3. 节点更新(Node Update)
在每一层中,目标节点 的新表示 是通过将它自己的特征 和聚合后的邻居特征结合起来生成的。具体更新规则通常为:
其中, 和 是可学习的权重矩阵, 是邻居节点特征的聚合函数。这个过程会迭代多层,每一层都会进一步整合节点的局部和邻居信息。
4. 归纳学习(Inductive Learning)
GraphSAGE的一个显著优点是其归纳学习能力。传统的图神经网络(如GCN)是基于整个图的全局信息来学习节点表示的,而GraphSAGE通过局部采样和聚合邻居信息,实现了在不使用全图信息的情况下进行学习和推断。因此,GraphSAGE能够处理动态图或新增节点,即使这些节点在训练阶段没有出现,模型也能生成合理的嵌入。
5. 损失函数(Loss Function)
GraphSAGE通常用于节点分类、节点嵌入等任务。其常用的损失函数包括:
监督学习:在节点分类任务中,使用标准的交叉熵损失函数。假设任务是对节点进行分类,目标是最小化以下损失:
其中, 是节点 的真实标签, 是模型预测的概率分布。
无监督学习:GraphSAGE也可以在无监督的设置下用于学习节点嵌入。通常使用负采样(negative sampling)的方法,通过最大化正负样本之间的对比损失来学习节点表示。
6. 训练过程(Training Procedure)
GraphSAGE的训练流程与传统的神经网络类似,主要步骤如下:
- 前向传播:对每个节点,采样邻居节点,并通过聚合函数生成新的节点表示。
- 反向传播:通过标准的梯度下降方法,计算损失并更新模型参数。
- 优化:使用优化器(如Adam或SGD)更新模型的可学习参数。
训练过程中,GraphSAGE通过小批量采样(mini-batch)的方式进行迭代,以避免处理整个图结构,从而提高了效率。每一轮迭代更新节点表示后,模型会逐渐捕捉到更深层次的图结构信息。
代码模板
模型代码
1 | class GraphSage(nn.Module): |
下面的代码定义了 GraphSAGE 模型中的 Encoder
模块,它负责从节点及其邻居的特征中生成嵌入表示。Encoder
通过使用一个 聚合器(Aggregator) 来聚合节点的邻居特征,并根据是否是 GCN 模式来决定是否包含节点自身的特征。
1 | class Encoder(nn.Module): |
下面的代码实现了 GraphSAGE 中的 MeanAggregator
模块,使用邻居节点的特征均值来聚合节点的嵌入。这个模块的核心思想是从每个节点的邻居中采样,并计算邻居特征的平均值,从而生成聚合的节点表示。
1 | class MeanAggregator(nn.Module): |