TensorFlow在美团推荐系统中的分布式训练优化实践( 五 )


3.4 延迟优化
这部分优化,也是分布式计算的经典优化方向 。整个流程链路上那些可以精简、合并、重叠需要不断去挖掘 。对于机器学习系统来说,相比其它的系统,还可以用一些近似的算法来做这部分工作,从而获得较大的性能提升 。下面介绍我们在两个这方面做的一些优化实践 。
3.4.1 稀疏域参数聚合
在启用HashTable存储稀疏参数后,对应的,一些配套参数也需要替换为HashTable实现,这样整个计算图中会出现多张HashTable以及大量的相关算子 。在实践中,我们发现需要尽量降低Lookup/Insert等算子的个数,一方面降低PS的负载,一方面降低RPC QPS 。因此,针对稀疏模型的常见用法,我们进行了相关的聚合工作 。
以Adam优化器为例,需要创建两个slot,以保存优化中的动量信息,它的Shape与Embedding相同 。在原生优化器中,这两个Variable是单独创建的,并在反向梯度更新的时候会去读写 。同理,使用HashTable方案时,我们需要同时创建两张单独的HashTable用来训练m、v参数 。那么在前向,反向中需要分别对Embedding、 m、v进行一次Lookup和一次Insert,总共需要三次Lookup和三次Insert 。
这里一个优化点就是将Embedding、 m、v,以及低频过滤的计数器(见下图14的Counting HashTable)聚合到一起,作为HashTable的Value,这样对稀疏参数的相关操作就可以聚合执行,大大减少了稀疏参数操作频次,降低了PS的压力 。

TensorFlow在美团推荐系统中的分布式训练优化实践

文章插图
图14 基于HashTable的参数融合策略
该特性属于一个普适型优化,开启聚合功能后,训练速度有了显著的提高,性能提升幅度随着模型和Worker规模的变化,效果总是正向的 。在美团内部真实业务模型上,聚合之后性能相比非聚合方式能提升了45%左右 。
3.4.2 Embedding流水线优化
流水线,在工业生产中,指每一个生产单位只专注处理某个片段的工作,以提高工作效率及产量的一种生产方式 。在计算机领域内,更为大家熟知的是,流水线代表一种多任务之间Overlap执行的并行化技术 。例如在典型的RISC处理器中,用户的程序由大量指令构成,而一条指令的执行又可以大致分为:取指、译码、执行、访存、写回等环节 。这些环节会利用到指令Cache、数据Cache、寄存器、ALU等多种不同的硬件单元,在每一个指令周期内,这5个环节的硬件单元会并行执行,得以更加充分的利用硬件能力,以此提高整个处理器的指令吞吐性能 。处理器的指令流水线是一套复杂而系统的底层技术,但其中的思想在分布式深度学习框架中也被大量的使用,例如:
如果将分布式训练简单的抽象为计算和通信两个过程,绝大多数主流的深度学习框架都支持在执行计算图DAG时,通信和计算的Overlap 。如果将深度模型训练简单的分为前向和反向,在单步内,由于两者的强依赖性,无法做到有效并行,字节BytePS[8]中引入的通信调度打破了step iteration间的屏障,上一轮的部分参数更新完毕后,即可提前开始下轮的前向计算,增强了整体视角下前反向的Overlap 。百度AIBox[9]为了解决CTR场景GPU训练时,参数位于主存,但计算位于GPU的问题,巧妙调度不同硬件设备,搭建起了主要利用CPU/主存/网卡的参数预准备阶段和主要利用GPU/NVLink的网络计算阶段,通过两个阶段的Overlap达到更高的训练吞吐 。
我们看到,在深度学习框架设计上,通过分析场景,可以从不同的视角发掘可并行的阶段,来提高整体的训练吞吐 。
对于大规模稀疏模型训练时,核心模型流程是:先执行稀疏参数的Embedding,然后执行稠密部分子网络 。其中稀疏参数Embedding在远端PS上执行,主要耗费网络资源,而稠密部分子网络在本地Worker执行,主要耗费计算资源 。这两部分占了整个流程的大部分时间,在美团某实际业务模型上分别耗时占比:40%+、50%+ 。
那我们是否可以提前执行稀疏参数的Embedding,来做到通信和计算的Overlap,隐藏掉这部分时间呢?从系统实现上肯定是可行的,但从算法上讲,这样做会引入参数Staleness的问题,可能会导致模型精度受到影响 。但在实际的生产场景中,大规模异步训练时本身就会带来几十到几百个步的滞后性问题 。经过我们测试,提前获取一两步的稀疏参数,模型精度并未受到影响 。
在具体实现上,我们把整个计算图拆分为Embedding Graph(EG)和Main Graph(MG)两张子图,两者异步独立执行,做到拆分流程的Overlap(整个拆分过程,可以做到对用户透明) 。EG主要覆盖从样本中抽取Embedding Key,查询组装Embedding向量,Embedding向量更新等环节 。MG主要包含稠密部分子网络计算、梯度计算、稠密参数部分更新等环节 。