联系我们

谷歌如何打造世界上最快的AI超级计算机系统?(上)

2021-01-16

最近谷歌公司发表了一篇轰动人工智能系统界的论文[1],介绍了他们如何基于MLPerf标准去刷新深度学习训练速度的世界纪录。比如,23秒完成BERT训练!28秒完成ImageNet训练!文章洋洋洒洒,从系统,算法,编译器,应用角度全方位地阐述了谷歌如何打造基于他们自己芯片的AI超级计算机以及深度学习系统。文章一共19位作者,包含了谷歌深度学习系统团队的一些专家。本文简要讨论一下谷歌公司的这篇文章,分为上下两部分。如有问题,可以通过邮箱youy@comp.nus.edu.sg联系笔者(新加坡国立大学高性能人工智能实验室主任、壁仞科技顾问尤洋)。笔者曾在UC Berkeley读博期间在谷歌公司总部的谷歌大脑团队实习4次。


1节:介绍

 

以自然语言处理为代表的深度学习革命正处在类似于美苏60年代“太空竞赛”的境地,导致模型的大小正以爆炸性的速度增长。显然,大模型带来了极大的好处,比如OpenAI的超大模型GPT-3已经可以在轻微的人类引导下去生成一些有可读性的散文,效果明显是之前的模型无法达到的。同时OpenAI研究院也观察到,世界最佳模型所消耗的计算资源每3.5个月就要翻一番。这比当时的摩尔定律夸张多了。之前摩尔定律表明每18月单位面积的晶体管数量翻一番。由摩尔定律导致的安迪-比尔定律直接地预示了计算机产业几十年的辉煌。

所以,为了满足这些大模型的快速训练需求,我们需要更大的超级计算机。这场深度学习革命引发的超算革命给GPU带来了巨额的投资:过去几年,英伟达公司的营收见证过50%的年增长率。英伟达公司的股票也从201211月的11美元增长到了202011582美元(53)

2015年,谷歌的TPU神经网络专用加速器横空出世。谷歌宣称,每个TPU芯片比之前的神经网络芯片在功耗比,低延迟效率,以及最高性能上提升了10(Jouppi等人的论文[2],2017)。在之后的两年,谷歌基于二代TPU256芯片超计算机就展示出了完美的并行效率(Jouppi等人的论文[3],2020)。之后谷歌基于3TPU将超级计算机的规模做到了1024芯片的规模(Kumar等人的论文[4],2019)。英伟达和其它GPU供应商也在建造类似规模的AI超级计算机。微软和OpenAI甚至要打造一个一万个GPU的超级计算机(Langston的文章[5],2020)。这场类似于“美苏星球大战”的比拼将会用精准的模型去试图逼近通用人工智能。这一点是否能实现还有待观察。但是,无疑的是,这场比拼中的各个巨头对硬件投入是雄心勃勃的。

美苏太空竞赛可以用类似于“绕地飞行”或“登陆月球”等一些里程碑事件去标记成就。然而,这次AI超算竞争却在开始的时候没有具体衡量指标。所以,谷歌、英伟达、英特尔、阿里巴巴、华为等工业界巨头意识到这个问题后提出了MLPerf标准(MLPerf.org)。特别地,MLPerf训练标准吸引了众多有高性能计算能力的实体去用超级计算机以最短的时间完成神经网络的训练(Mattson等人的论文[6],2019)

事实证明,在过去的两年,MLPerf标准对深度学习社区产生了积极而深远的影响。因为刷新MLPerf标准的努力和创新产生了新的系统优化技术,更实用的代码库,更高效的编译器,以及最合适的应用层代码。

常规的TPU超级计算机只有1024个芯片。我们在上次的MLPerfv0.6版本比赛中只用了常规的TPU超级计算机。为了探索MLPerf模型并行计算的极限,我们用第三代TPU芯片组装了一个4096芯片的多组超级计算机(详见图1)。两个TPU超级计算机在网格结构的X方向上被光纤连接起来(详见图2)。这些组间的光纤链接要比普通的组内光纤链接长。我们用了128x32的二维网状结构将所有的芯片连接起来用于训练MLPerf的模型。因为第三代TPU芯片的路由表只能容纳1024个地址,所以我们用了一种稀疏的路由策略。在这种策略下,每个芯片只能看到与它同一行或同一列的芯片。这个策略足以使all-reduce操作达到峰值的通讯传输效率。

1:常规TPU超级计算与多组TPU超级计算机

24TPU超级计算机的配置,跨组链接在横向连接了两个TPU超级计算机


2节:多种编程框架

 

尽管TPU的主要前端编程框架一直是TensorFlow(Abadi等人的论文[7],2016)TPU硬件和XLA编译器其实是通用的工具:它们是可以支持其它的编程框架。因此,在这篇论文中,我们选择给TensorFlowJAX都做评测。JAX(Frostig等人的论文[8],2018)是一种面向研究的基于XLA的数值计算编程框架。这两种编程模型都需要额外的软件工程在多组超级计算机上获得高效的扩展性。但是它们最终可以获得相似的评测结果。

如图3所示,由于TensorFlowJAX在架构上的不同,它们在大规模运算上也有性能差异。首先,它们有不同分阶段方法。TensorFlowPython中嵌入了一种表达式的动态即时编程语言,然后用XLA去即时编译图的子集(TensorFlow的图可以分布在TPU加速器上和CPU)。相比之下,JAX少了一个阶段:它是一个在Python中嵌入了即时编译的XLA程序的编程框架。其中,即时编译的XLA程序用于在加速器上的静态编译性能以及加速器网络上的并行性。JAX用于动态以及加速计算。所以,TensorFlowJAX多了一个编译阶段,我们用多线程去加速了这个多出来的编译阶段。同时,JAX也需要我们更加小心地去处理Python的瓶颈。比如,把一些类似于数据读取的干扰性任务移出主线程。

3TensorFlowJAX编程模型在第三代TPU上使用的示意图

其次,TensorFlowJAX可以支持不同的分布式编程模型。JAX采用了一种多客户端的方式去做分布式编程,使得超级计算机中每个主机分别有一份同样的JAX代码(包括Python解释器)JAX程序只在两个地方进行通讯:(1)在初始阶段去设置TPU的网格连接。(2)在模型训练过程中在网络上进行XLA编译好的all-reduce操作。与之相反的是,TensorFlow采用单客户端的方式在TPU上进行编程。这种方式会给一个Python进程全局视角并允许它控制整个分布式系统。这个Python进程运行在超级计算机的一个主机上。这个主机将TensorFlow的图均分后通过RPC在网络上发给剩下的主机执行。

TensorFlowJAX在实用性和性能特点上也有不同之处。尽管TensorFlow的单客户端分布式系统能够让用户代码直接控制整体的运行负载,JAX的多客户端方法使得代码可以直接控制各个独立的计算单元。JAX在各个主机上独立地调用XLA编译器,这依赖于确定性编译技术去避免不同主机程序的不兼容性。然而,TensorFlow只编译一次并且把二进制文件分给所有主机节点。TensorFlow的这种多设备图其实能引起Amdahl定律瓶颈。原因是因为客服端进程在图构造和优化时间上的开销是跟节点数成正比的。然后,除了TPU拓扑网格链接的初始化之外,JAX的设置时间是不随着节点数增加而增加的。所以,JAX在这一点上看似有更好的可扩展性。 

(转载自壁仞科技研究院)

诚聘英才
友好链接
业务咨询及参观访问:0755-86576085    0755-86576086    地址:深圳市南山区笃学路9号
国家超级计算深圳中心(深圳云计算中心)  ©2014-2020  粤ICP备10220126号