论文地址:1706.02216

网站:GraphSAGE

代码仓库:williamleif/graphsage-simple: Simple reference implementation of GraphSAGE.

介绍

GraphSAGE(Graph Sample and Aggregate)是一种图神经网络模型,旨在有效地处理大规模图数据。它通过采样和聚合节点的邻居信息,生成节点的嵌入向量,能够进行图结构的任务,如节点分类、节点嵌入、链路预测等。

1. 背景:

传统的图神经网络(GNN)模型,如GCN(Graph Convolutional Network),通常需要处理整个图的邻居信息来生成节点嵌入。这种全图卷积的方式在小型图上表现良好,但随着图规模增大,计算和存储成本急剧增加,尤其是在处理社交网络、生物网络等大规模图时,直接使用所有邻居节点的信息变得不可行。因此,GraphSAGE模型应运而生,旨在解决大规模图数据中计算和存储效率的问题。

2. 动机:

GraphSAGE的动机是解决以下问题:

  • 大规模图上的计算效率:随着图规模的增大,传统GNN需要处理整个图的数据,导致内存和计算资源的消耗难以接受。
  • 高效地生成节点嵌入:许多实际任务需要学习节点的低维表示,而这些任务的需求往往是图不断增长的,GraphSAGE希望在无需重新训练模型的情况下,能够灵活地生成新节点的嵌入。
  • 采样与泛化能力:GraphSAGE通过局部采样机制,避免处理整个图的邻居节点,从而提升了模型的泛化能力,特别是能够处理动态图或部分未知结构的图。

3. 创新点:

GraphSAGE的创新点主要体现在以下几个方面:

  • 采样邻居节点:与GCN不同,GraphSAGE不是使用全局图卷积操作,而是为每个节点采样固定数量的邻居节点。通过局部采样的方式,可以有效地减少计算量,特别是当图的节点数和边数非常大时,这种方法可以显著降低内存和计算资源的消耗。
  • 聚合函数设计:GraphSAGE通过不同的聚合函数(如均值、池化、LSTM等)对采样的邻居节点进行信息汇总,灵活地捕捉图的局部结构信息。不同的聚合函数能够根据任务需求调整节点嵌入的表达能力。
  • 节点归纳学习:相比于传统的GNN模型需要全图信息,GraphSAGE采用的是归纳学习方式,不需要提前知道整个图的结构,也能够在训练好的模型基础上为新加入的节点生成嵌入,这使得模型在处理动态或部分未知图结构时具有很强的泛化能力。

4. 效果:

GraphSAGE在多个大规模图任务上取得了较好的效果,特别是在以下方面:

  • 计算效率:通过局部采样,GraphSAGE极大地减少了计算开销,使得在大型图数据上进行节点嵌入学习成为可能。相比于传统的GCN,它能够处理上亿节点规模的图。
  • 嵌入生成能力:GraphSAGE能够在不重新训练模型的情况下,为新加入的节点生成嵌入,这对动态图和实时应用非常有利。
  • 实验表现:在标准数据集(如Reddit、PPI等)上,GraphSAGE在节点分类和链路预测任务中展现了优异的表现,与其他模型相比,它能够在保持较高准确率的同时显著提高计算效率。

方法实现

算法

GraphSAGE 提出了一种归纳式方法,用于在图中生成节点嵌入,具体通过从节点的局部邻域聚合特征来实现。这一过程依赖于学习得到的聚合函数,记为 ,其中 代表搜索深度。每次聚合都会结合邻居节点的信息,生成特定深度的节点表示。

image-20241019230957810

算法概览

  1. 输入:

    • ,其中 是节点集, 是边集。
    • 每个节点的输入特征
    • 搜索深度
    • 权重矩阵
    • 激活函数
    • 可微分的聚合函数
    • 邻居选择函数 ,返回节点 的邻居节点集合。
  2. 输出:

    • 每个节点 的向量表示

算法步骤

  1. 初始化每个节点的特征表示:
  2. 对于每一个深度
    1. 对于每个节点 ,从其邻居中聚合特征:
    2. 更新节点 的表示,结合节点自身和邻居的信息:
    3. 对表示进行归一化:
  3. 输出每个节点的最终嵌入表示:

算法流程

在每一次迭代(即搜索深度)中,节点通过聚合其局部邻域中的信息逐步构建其嵌入表示。随着算法的迭代,节点逐渐获得来自更远邻域的信息,这样最终生成的嵌入不仅包含了节点自身的特征,还包含了邻居节点的特征信息。

GraphSAGE(Graph Sample and Aggregate)的具体实现主要包括以下几个关键步骤:邻居采样、聚合函数、节点更新、损失函数与训练过程。下面详细介绍其实现流程。

1. 邻居采样(Neighbor Sampling)

GraphSAGE通过随机采样的方式,选择每个节点的固定数量的邻居节点进行消息传递和聚合。这样可以有效减少大规模图上计算的复杂度和内存需求。对于每一层卷积,GraphSAGE只采样每个节点的部分邻居,从而避免了全图卷积带来的效率瓶颈。

  • 具体操作:对于每个节点 ,在每一层中,GraphSAGE会随机采样固定数量的邻居节点(设为 个),并使用这些邻居节点来参与聚合运算。这种局部采样方法可以有效减少计算量,并使得模型能够扩展到大规模图。

2. 聚合函数(Aggregation Function)

采样到邻居节点后,GraphSAGE的核心是如何聚合这些邻居节点的信息。GraphSAGE提供了几种不同的聚合函数选择,这些聚合方式可以捕捉邻居节点的局部特征并传递给目标节点。常用的聚合函数包括以下几种:

  • Mean Aggregator:这是最简单的一种聚合方式,直接取邻居节点的特征的均值。公式如下:

    其中, 是节点 在第 层的表示, 是节点 的邻居节点集合, 是可学习的权重矩阵, 是非线性激活函数(如ReLU)。

  • Pooling Aggregator:对于每个邻居节点,首先通过一个多层感知机(MLP)将其映射到高维特征空间,然后对所有邻居节点的表示进行池化操作(如最大池化或平均池化)。公式如下:

  • LSTM Aggregator:LSTM聚合器使用循环神经网络(RNN)中的LSTM单元来处理邻居节点信息。它将邻居节点序列输入LSTM,最后将LSTM的输出作为聚合结果。公式如下:

Mean Aggregator:适合处理大规模数据,简单且高效,适合对节点之间关系相对简单的图。

Pooling Aggregator:通过引入非线性变换增加了模型的表达能力,适合在需要较复杂的特征表达时使用。

LSTM Aggregator:适合在节点间存在序列关系或依赖时使用,但由于计算复杂度较高,可能不适合非常大规模的图数据。

3. 节点更新(Node Update)

在每一层中,目标节点 的新表示 是通过将它自己的特征 和聚合后的邻居特征结合起来生成的。具体更新规则通常为:

其中, 是可学习的权重矩阵, 是邻居节点特征的聚合函数。这个过程会迭代多层,每一层都会进一步整合节点的局部和邻居信息。

4. 归纳学习(Inductive Learning)

GraphSAGE的一个显著优点是其归纳学习能力。传统的图神经网络(如GCN)是基于整个图的全局信息来学习节点表示的,而GraphSAGE通过局部采样和聚合邻居信息,实现了在不使用全图信息的情况下进行学习和推断。因此,GraphSAGE能够处理动态图或新增节点,即使这些节点在训练阶段没有出现,模型也能生成合理的嵌入。

5. 损失函数(Loss Function)

GraphSAGE通常用于节点分类、节点嵌入等任务。其常用的损失函数包括:

  • 监督学习:在节点分类任务中,使用标准的交叉熵损失函数。假设任务是对节点进行分类,目标是最小化以下损失:

    其中, 是节点 的真实标签, 是模型预测的概率分布。

  • 无监督学习:GraphSAGE也可以在无监督的设置下用于学习节点嵌入。通常使用负采样(negative sampling)的方法,通过最大化正负样本之间的对比损失来学习节点表示。

6. 训练过程(Training Procedure)

GraphSAGE的训练流程与传统的神经网络类似,主要步骤如下:

  • 前向传播:对每个节点,采样邻居节点,并通过聚合函数生成新的节点表示。
  • 反向传播:通过标准的梯度下降方法,计算损失并更新模型参数。
  • 优化:使用优化器(如Adam或SGD)更新模型的可学习参数。

训练过程中,GraphSAGE通过小批量采样(mini-batch)的方式进行迭代,以避免处理整个图结构,从而提高了效率。每一轮迭代更新节点表示后,模型会逐渐捕捉到更深层次的图结构信息。

代码模板

模型代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class GraphSage(nn.Module):
"""
原始 GraphSAGE 模型
实现了论文 "Inductive Representation Learning on Large Graphs"
"""

def __init__(self, num_classes, enc):
super(GraphSage, self).__init__()
self.enc = enc # Encoder 对象,负责节点嵌入
self.xent = nn.CrossEntropyLoss() # 使用交叉熵损失进行分类
self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim)) # 分类层的权重
init.xavier_uniform_(self.weight) # 初始化权重

def forward(self, nodes):
"""
前向传播函数:生成给定节点的分类分数
:param nodes: 节点列表
:return: 分类分数
"""
embeds = self.enc(nodes) # 获取节点嵌入
scores = self.weight.mm(embeds) # 线性变换生成分数
return scores.t() # 返回转置后的结果,形状为 (batch_size, num_classes)

def to_prob(self, nodes):
"""
将输出转换为概率值
:param nodes: 节点列表
:return: 节点所属类别的概率
"""
pos_scores = torch.sigmoid(self.forward(nodes)) # 应用 Sigmoid 函数将分数转换为概率
return pos_scores

def loss(self, nodes, labels):
"""
计算给定节点的损失
:param nodes: 节点列表
:param labels: 对应的真实标签
:return: 交叉熵损失
"""
scores = self.forward(nodes) # 获取分类分数
return self.xent(scores, labels.squeeze()) # 计算交叉熵损失

下面的代码定义了 GraphSAGE 模型中的 Encoder 模块,它负责从节点及其邻居的特征中生成嵌入表示。Encoder 通过使用一个 聚合器(Aggregator) 来聚合节点的邻居特征,并根据是否是 GCN 模式来决定是否包含节点自身的特征。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Encoder(nn.Module):
"""
GraphSAGE 编码器模块,使用卷积式 GraphSAGE 方法
"""

def __init__(self, features, feature_dim, embed_dim, adj_lists, aggregator, num_sample=10, base_model=None,
gcn=False, cuda=False, feature_transform=False):
super(Encoder, self).__init__()
self.features = features # 节点特征函数
self.feat_dim = feature_dim # 输入特征维度
self.adj_lists = adj_lists # 节点的邻居表
self.aggregator = aggregator # 聚合器对象
self.num_sample = num_sample # 每个节点采样的邻居数量
self.gcn = gcn # 是否为 GCN 模式
self.embed_dim = embed_dim # 嵌入维度
self.cuda = cuda # 是否使用 GPU
self.aggregator.cuda = cuda # 聚合器是否使用 GPU
self.weight = nn.Parameter(
torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim)) # 权重矩阵
init.xavier_uniform_(self.weight) # 使用 Xavier 初始化

def forward(self, nodes):
"""
生成一批节点的嵌入表示
:param nodes: 节点列表
:return: 节点的嵌入表示
"""
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)]
for node in nodes],self.num_sample) # 聚合邻居特征

if isinstance(nodes, list):
index = torch.LongTensor(nodes)
else:
index = nodes

if not self.gcn:
if self.cuda:
self_feats = self.features(index).cuda() # 获取节点自身特征
else:
self_feats = self.features(index)
combined = torch.cat((self_feats, neigh_feats), dim=1) # 将节点自身特征与邻居特征拼接
else:
combined = neigh_feats # 如果是 GCN 模式,则不拼接自身特征

combined = F.relu(self.weight.mm(combined.t())) # 通过线性层并应用 ReLU 激活
return combined # 返回嵌入表示

下面的代码实现了 GraphSAGE 中的 MeanAggregator 模块,使用邻居节点的特征均值来聚合节点的嵌入。这个模块的核心思想是从每个节点的邻居中采样,并计算邻居特征的平均值,从而生成聚合的节点表示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class MeanAggregator(nn.Module):
"""
使用邻居节点嵌入的平均值来聚合节点的嵌入
"""

def __init__(self, features, cuda=False, gcn=False):
"""
初始化聚合器
:param features: 用于映射节点ID到特征的函数
:param cuda: 是否使用 GPU
:param gcn: 是否使用 GraphSAGE 的 GCN 变体(带自环)
"""
super(MeanAggregator, self).__init__()
self.features = features # 节点特征
self.cuda = cuda # 是否使用 GPU
self.gcn = gcn # 是否添加自环(类似 GCN)

def forward(self, nodes, to_neighs, num_sample=10):
"""
进行聚合操作
:param nodes: 批次中的节点列表
:param to_neighs: 每个节点的邻居集合
:param num_sample: 每个节点采样的邻居数量(如果为 None 则不进行采样)
:return: 聚合后的特征
"""
_set = set
if num_sample is not None:
_sample = random.sample
samp_neighs = [_set(_sample(to_neigh, num_sample))
if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
else:
samp_neighs = to_neighs # 不进行采样时,使用所有邻居

if self.gcn:
samp_neighs = [samp_neigh.union(set([int(nodes[i])])) for i, samp_neigh in
enumerate(samp_neighs)] # 为每个节点添加自环

unique_nodes_list = list(set.union(*samp_neighs)) # 获取所有唯一节点的列表
unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)} # 创建节点映射表

mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) # 初始化掩码矩阵
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
mask[row_indices, column_indices] = 1 # 构建掩码

if self.cuda:
mask = mask.cuda()

num_neigh = mask.sum(1, keepdim=True) # 计算邻居数量
mask = mask.div(num_neigh) # 正则化掩码

if self.cuda:
embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda()) # 获取节点特征
else:
embed_matrix = self.features(torch.LongTensor(unique_nodes_list))

to_feats = mask.mm(embed_matrix) # 聚合特征
return to_feats # 返回聚合后的特征