近几年,神经网络在自然语言、图像、语音等数据上都取得了显著的突破,将模型性能带到了一个前所未有的高度,但如何在图数据上训练仍然是一个可研究的点。
传统神经网络输入的数据通常每个sample之间都不存在关系,而图数据更加复杂,每个节点之间存在联系,也更符合真实世界中的数据存储方式。真实世界的物体通常根据它们与其他事物的联系来定义的,一组对象以及它们之间的联系可以很自然地表示为一个图(graph),基于图数据的神经网络也称为Graph Neural Network(GNN)。
图神经网络的发展逐渐受到更多关注,在推理、常识等方面也取得很多成就,来自Google的研究员们最近发表了一篇博客,介绍了图神经网络的发展历程,还对现代图神经网络进行了探讨和解释。
一个图由顶点和边组成,在人的脑海中,可以很自然地把社交网络等数据表示为图,那如何把图像和文本表示为图你想过吗?
通常认为图像是带有通道(channels)的矩形网格,将它们表示为例如244x244x3的三维矩阵。
另一种看待图像的方式是有规则结构的图像,其中每个像素代表一个节点,并通过边缘连接到相邻的像素。每个非边界像素恰好有8个相邻节点,并且存储在每个节点上的信息是表示像素 RGB 值的三维向量。
可视化图的连通性的一种方法是邻接矩阵。对这些节点进行排序,在一个5x5的图像中有25个像素,构造一个矩阵,如果两个节点之间存在一条边那么在邻接矩阵中就存在一个入口。
对于文本来说,可以将索引与每个字符、单词或标记相关联,并将文表示为一个有向图,其中每个字符或索引都是一个节点,并通过一条边连接到后面的节点。
但文本和图像在实际使用上通常不采用这种编码方式,用图来表示是比较多余的一步操作,因为所有图像和文本都具有非常规则的结构。例如,图像的邻接矩阵中通常有一条带状结构,因为所有的节点或像素都连接包含在在一个网格结构中。文本的邻接矩阵只包括一条对角线,因为每个单词只连接到前一个单词和下一个单词。
在使用神经网络表示图任务时,一个最重要的表示就是它的连通性,一个比较好的选择就是邻接矩阵,但如前文所说,邻接矩阵过于稀疏,空间利用率不高;另一个问题就是同一个图的邻接矩阵有多种表示方法,神经网络无法保证这些邻接矩阵的输出结果都相同,也就是说不存在置换不变性(permutation invariant)。
并且不同形状的图可能也包含相同的邻接矩阵。
一种优雅且高效来表示稀疏矩阵的方法是邻接列表。它们将节点之间的边的连通性描述为邻接列表第k个条目中的元组(i,j)。由于边的数量远低于邻接矩阵的条目数量,因此可以避免了在图的断开部分(不含边)进行计算和存储。
既然图的描述是以排列不变的矩阵格式,那图神经网络(GNNs)就可以用来解决图预测任务。GNN是对图的所有属性(节点、边、全局上下文)的可优化变换,它可以保持图的对称性(置换不变性)。GNN采用“图形输入,图形输出”架构,这意味着这些模型类型接受图作为输入,将信息加载到其节点、边和全局上下文,并逐步转换这些embedding,而不更改输入图形的连通性。
最简单的GNN模型架构还没有使用图形的连通性,在图的每个组件上使用一个单独的多层感知器(MLP)(其他可微模型都可以)就可以称之为GNN层。
对于每个节点向量,使用MLP并返回一个可学习的节点向量。对每一条边也做同样的事情,学习每一条边的embedding,也对全局上下文向量做同样的事情,学习整个图的单个embedding。
与神经网络模块或层一样,我们可以将这些GNN层堆叠在一起。
由于GNN不会更新输入图的连通性,因此可以使用与输入图相同的邻接列表和相同数量的特征向量来描述GNN的输出图。
构建了一个简单的GNN后,下一步就是考虑如何在上面描述的任务中进行预测。
首先考虑二分类的情况,这个框架也可以很容易地扩展到多分类或回归情况。如果任务是在图节点上进行二分类预测,并且图已经包含节点信息,那么对于每个节点embedding应用线性分类器即可。
实际情况可能更复杂,例如图形中的信息可能存储在边中,而且节点中没有信息,但仍然需要对节点进行预测。所以就需要一种从边收集信息并将其提供给节点进行预测的方法。
可以通过Pooling来实现这一点。Pooling分两步进行:对于要池化的每个item,收集它们的每个embedding并将它们连接到一个矩阵中,通常通过求和操作聚合收集的embedding。
更复杂地,可以通过在 GNN 层内使用池化来进行更复杂的预测,以使学习到的embedding更了解图的连通性。可以使用消息传递(Message Passing)来做到这一点,其中相邻节点或边缘交换信息并影响彼此更新的embedding。
消息传递包含三个步骤:
1、对于图中的每个节点,收集所有相邻节点embedding(或消息)。
2、通过聚合函数(如sum)聚合所有消息。
3、所有汇集的消息都通过一个更新函数传递,通常是一个学习的神经网络。
这些步骤是利用图的连接性的关键,还可以在GNN层中构建更复杂的消息传递变体,以产生更高表达能力的GNN模型。
本质上,消息传递和卷积是聚合和处理元素的邻居信息以更新元素值的操作。在图中,元素是节点,在图像中,元素是像素。然而,图中相邻节点的数量可以是可变的,这与图像中每个像素都有一定数量的相邻元素不同。通过将传递给GNN层的消息堆叠在一起,节点最终可以合并整个图形中的信息。
节点学习完embedding后的下一步就是边。在真实场景中,数据集并不总是包含所有类型的信息(节点、边缘和全局上下文),当用户想要对节点进行预测,但提供的数据集只有边信息时,在上面展示了如何使用池将信息从边路由到节点,但也仅局限在模型的最后一步预测中。除此之外,还可以使用消息传递在GNN层内的节点和边之间共享信息。
可以采用与之前使用相邻节点信息相同的方式合并来自相邻边缘的信息,首先合并边缘信息,使用更新函数对其进行转换并存储。
但存储在图中的节点和边信息不一定具有相同的大小或形状,因此目前还没有一种明确有效的方法来组合他们,一种比较好的方法是学习从边空间到节点空间的线性映射,反之亦然。或者,可以在update函数之前将它们concatenate在一起。
最后一步就是获取全局的节点、边表示。
之前所描述的网络存在一个缺陷:即使多次应用消息传递,在图中彼此不直接连接的节点可能永远无法有效地将信息传递给彼此。对于一个节点,如果有k层网络,那么信息最多传播k步。
对于预测任务依赖于相距很远的节点或节点组的情况,这可能是一个问题。一种解决方案是让所有节点都能够相互传递信息。但不幸的是,对于大型的图来说,所需要的计算成本相当高,但在小图形中已经可以有所应用。
这个问题的一个解决方案是使用图(U)的全局表示,它有时被称为主节点或上下文向量。该全局上下文向量连接到网络中的所有其他节点和边,并可以作为它们之间传递信息的桥梁,为整个图形建立表示。这可以创建一个比其他方法更丰富、更复杂的图形表示。
从这方面来看,所有的图形的属性都已经学习到了对应的表示,因此可以通过调整感兴趣的属性相对于其余属性的信息在池中利用它们。例如对于一个节点,可以考虑来自相邻节点、连接边和全局信息的信息。为了将新节点嵌入到所有这些可能的信息源上,还可以简单地将它们连接起来。此外,还可以通过线性映射将它们映射到同一空间,并应用特征调节层(feature-wise modulation layer)。
通过上述流程,相信大家已经对简单的GNN如何发展为sota模型有了了解。在获取图的节点、边表示后,就可以为之后的任务再单独设计网络,GNN为神经网络提供了一种处理图数据的方式。
在原文博客中,还包括一些GNN的真实案例和数据集,并了解GNN在其中的具体作用,想了解更多内容可以访问参考链接进行阅读。