文章插图
?首先是 GPU embedding 。我们先来回顾一下传统的推荐系统 GPU 训练流程,我们会把具体的模型放在 GPU worker 上,embedding 存在远端 PS 上 。每个迭代步会先从远端 PS 拉取参数,之后在 GPU 上进行模型的前向和反向计算,把梯度传回给 PS,在 PS 上进行参数更新 。
图中绿色的部分是在 GPU 上进行的操作,红色的部分是网络或者 CPU 上进行的 。可以看到虽然 GPU 是系统中最昂贵的部分,很多操作却都没有放在 GPU 上 。?
文章插图
传统流程并没有充分利用好 GPU 。同时,从硬件层面来说,GPU 单卡显存越来越大,dense 部分模型远远没有充分利用 GPU;在英伟达的不断优化下,NV link 以及 GPU direct RDMA 还让卡间通信速度越来越快 。
文章插图
GPU embedding 是一个非常简单的方案 。他直接把 embedding 切分放在 GPU 上——比如单机上有 8 张卡,我们把 embedding 直接切分为 8 份,每份放在一张卡上——从而保证所有的操作全都留在卡上 。GPU 的利用效率就会有明显提升,训练速度也会有质的飞跃 。如果担心 GPU 上面的显存空间不足,TorchRec 还做了 UVM 的支持,可以提前划分一部分主机上的内存作为显存的补充,从而提升单机内部能放下的 embedding 大小 。
文章插图
除去 GPU embedding 以外,TorchRec 还实现了非常优秀的 GPU kernel 。这些 kernel 充分利用了最新的硬件特性和 CUDA feature 。
文章插图
举例来说,假如果要实现一个 embedding lookup kernel,也就是要从一个大的 embedding 里面找到一堆 ID 对应的 embedding vector,那么普通的实现里,会给每个 GPU thread 分配一个 ID,让他们分别去找对应的 embedding 。这个时候我们要考虑到,GPU 底层是按 warp 进行调度的,一个 warp 里的 32 个 thread 会一起进行显存读写 。这意味着,在上述样流程里,虽然在读取 ID 时连续地访问了显存,但后续的拷贝变成了一个随机读写的状态 。对于硬件来说,随机读写无法充分利用显存带宽,运行效率也就不够高 。
文章插图
TorchRec 则是在每个 thread 读到 ID 后,利用 shuffle_sync 这样的 warp primitive,将 ID 广播至 warp 内的所有thread 上,从而让一个 wrap 里 32 个 thread 去同时处理同一个 embedding,从而可以进行连续的内存读写,使得显存的带宽利用效率有明显的提升,让 kernel 的速度得到数倍提升 。
文章插图
这个表是官方测试的 embedding lookup 性能提升 。这里 Fused EBC 是优化后的kernel,可以看到,不同的设置情况下 TorchRec 相较于原生的 PyTorch 有数十倍的性能提升 。在 TorchRec 的基础之上,我们发现对于 embedding 比较小的情况(小于128),可能有半数甚至更多的 thread 空闲,所以进一步把 warp 内的 thread 分组,让他们同时去处理多条 embedding 。
文章插图
文章插图
在我们的改进下,小 embedding dim 上 kernel 又有了 10% 到 30% 的提升 。这一优化也已经合入官方 repo 。要特别指出的是,TorchRec 的 kernel 放在了 FBGEMM 库里,有兴趣朋友可以去看一看 。
文章插图
最后想介绍一下 TorchRec 的 embedding 划分机制 。前面提到,GPU embedding 就是把 embedding 切分一下放在卡上,那么怎么分就成了一个需要考虑的问题 。传统来说有两种划分思路,Row wise 和 Column wise 。Row wise 是指假如有 2 万个 feature,0 号到第 10000 号放在卡 1 上,10000 号到 20000 号放在卡 2 上,这样我们在训练的时候,如果 ID 对应卡 1,我们就从卡 1 上拿,对应卡 2,就从卡 2 上拿 。Row wise 的问题在于,因为我们不清楚前 10000 号的通信量和后 10000 号的是不是差距很大,通信都是不均衡的,无法充分利用网络硬件 。
推荐阅读
- 基于SQL的数据可视化和数据挖掘
- 微信只需以下几步,就可以查询我们交了多久社保了!
- 视频号将推创作分成计划、微信支付介绍“微信刷掌”、“问一问”搜索功能将上线...微信公开课信息量很大
- 微信怎样撤回发送很久的消息,这个小技巧,不能不知道!
- 微信拉黑和微信删除的区别,千万别再用错了
- 从零开发一套基于React的加载动画库
- 请问微信不可以购买虚拟产品是什么意思 虚拟产品可以退款吗
- 玩了这么久的微信,还不知道青少年模式有哪些限制,你就OUT了!
- PyTorch张量的四种乘法运算
- 语音识别系列之基于脉冲神经网络的语音唤醒