|训练超大规模图模型,PyTorchBigGraph如何做到?


编辑:Panda
Facebook 提出了一种可高效训练包含数十亿节点和数万亿边的图模型的框架 BigGraph 并开源了其 PyTorch 实现 。 本文将解读它的创新之处 , 解析它能从大规模图网络高效提取知识的原因 。
|训练超大规模图模型,PyTorchBigGraph如何做到?
本文插图
图(graph)是机器学习应用中最基本的数据结构之一 。 具体来说 , 图嵌入方法是一种无监督学习方法 , 可使用本地图结构来学习节点的表征 。 社交媒体预测、物联网模式检测或药物序列建模等主流场景中的训练数据可以很自然地表征为图结构 。 其中每一种场景都可以轻松得到具有数十亿相连节点的图 。 图结构非常丰富且具有与生俱来的导向能力 , 因此非常适合机器学习模型 。 尽管如此 , 图结构却非常复杂 , 难以进行大规模扩展应用 。 也因此 , 现代深度学习框架对大规模图数据结构的支持仍非常有限 。
Facebook 推出过一个框架 PyTorch BigGraph:https://github.com/facebookresearch/PyTorch-BigGraph , 它能更快更轻松地为 PyTorch 模型中的超大图结构生成图嵌入 。
某种程度上讲 , 图结构可视为有标注训练数据集的一种替代 , 因为节点之间的连接可用于推理特定的关系 。 这种方法遵照无监督图嵌入方法的模式 , 它可以学习图中每个节点的向量表征 , 其具体做法是优化节点对的嵌入 , 使得之间有边相连的节点对的嵌入比无边相连的节点对的嵌入更近 。 这类似于在文本上训练的 word2vec 的词嵌入的工作方式 。
|训练超大规模图模型,PyTorchBigGraph如何做到?
本文插图
当应用于大型图结构时 , 大多数图嵌入方法的结果都相当局限 。 举个例子 , 如果一个模型有 20 亿个节点 , 每个节点有 100 个嵌入参数(用浮点数表示) , 则光是存储这些参数就需要 800 GB 内存 , 因此很多标准方法都超过了典型商用服务器的内存容量 。 这是深度学习模型面临的一大挑战 , 也是 Facebook 开发 BigGraph 框架的原因 。
PyTorch BigGraph
PyTorch BigGraph(PBG)的目标是扩展图嵌入模型 , 使其有能力处理包含数十亿节点和数万亿边的图 。 PBG 为什么有能力做到这一点?因为它使用了四大基本构建模块:

  1. 图分区 , 这让模型不必完全载入到内存中 。
  2. 在每台机器上的多线程计算
  3. 在多台机器上的分布式执行(可选) , 所有操作都在图上不相连的部分进行
  4. 分批负采样 , 当每条边 100 个负例时 , 可实现每台机器每秒处理超过 100 万条边 。
通过将图结构分区为随机划分的 P 个分区 , 使得可将两个分区放入内存中 , PBG 解决了传统图嵌入方法的一些短板 。 举个例子 , 如果一条边的起点在分区 p1 , 终点在分区 p2 , 则它会被放入 bucket (p1, p2) 。 然后 , 在同一模型中 , 根据源节点和目标节点将这些图节点划分到 P2 bucket 。 完成节点和边的分区之后 , 可以每次在一个 bucket 内执行训练 。 bucket (p1, p2) 的训练仅需要将分区 p1 和 p2 的嵌入存储到内存中 。 PBG 结构能保证 bucket 至少有一个之前已训练的嵌入分区 。
|训练超大规模图模型,PyTorchBigGraph如何做到?
本文插图
PBG 的另一大创新是训练机制的并行化和分布式 。 PBG 使用 PyTorch 自带的并行化机制实现了一种分布式训练模型 , 这用到了前面描述的模块分区结构 。 在这个模型中 , 各个机器会协调在不相交的 bucket 上进行训练 。 这会用到一个锁服务器(lock server) , 其负责将 bucket 分派给工作器(worker) , 从而尽可能地减少不同机器之间的通信 。 每台机器都可以使用不同的 bucket 并行地训练模型 。