一个简单的更改就可以让PyTorch读取表格数据的速度提高20倍


一个简单的更改就可以让PyTorch读取表格数据的速度提高20倍
本文插图

来源:DeepHub IMBA
本文约3000字 , 建议阅读5分钟
我在PyTorch中对表格的数据加载器进行的简单更改如何将训练速度提高了20倍以上 , 而循环没有任何变化!
深度学习:需要速度
在训练深度学习模型时 , 性能至关重要 。 数据集可能非常庞大 , 而低效的训练方法意味着迭代速度变慢 , 超参数优化的时间更少 , 部署周期更长以及计算成本更高 。
由于有许多潜在的问题要探索 , 很难证明花太多时间来进行加速工作是合理的 。 但是幸运的是 , 有一些简单的加速方法!
我将向您展示我在PyTorch中对表格的数据加载器进行的简单更改如何将训练速度提高了20倍以上 , 而循环没有任何变化!这只是PyTorch标准数据加载器的简单替代品 。 对于我正在训练的模型 , 可以16分钟的迭代时间 , 减少到40秒!
【一个简单的更改就可以让PyTorch读取表格数据的速度提高20倍】所有这些都无需安装任何新软件包 , 不用进行任何底层代码或任何超参数的更改 。
一个简单的更改就可以让PyTorch读取表格数据的速度提高20倍
本文插图

研究/产业裂痕
在监督学习中 , 对Arxiv-Sanity的快速浏览告诉我们 , 当前最热门的研究论文都是关于图像(无论是分类还是生成GAN)或文本(主要是BERT的变体) 。 深度学习在传统机器学习效果不好的这些领域非常有用 , 但是这需要专业知识和大量研究预算才能很好地执行 。
许多公司拥有的许多数据已经以很好的表格格式保存在数据库中 。 一些数据包括用于终生价值估算的客户详细信息 , 优化和财务的时间序列数据 。
表格数据有何特别之处?
那么 , 为什么研究与产业之间的裂痕对我们来说是一个问题呢?好吧 , 最新的文本/视觉研究人员的需求与那些在表格数据集上进行监督学习的人的需求截然不同 。
以表格形式显示数据(即数据库表 , Pandas DataFrame , NumPy Array或PyTorch Tensor)可以通过以下几种方式简化操作:

  • 可以通过切片从连续的内存块中获取训练批次 。
  • 无需按样本进行预处理 , 从而使我们能够充分利用大批量培训来提高速度(请记住要提高学习率 , 所以我们不会过拟合!)
  • 如果您的数据集足够小 , 则可以一次将其全部加载到GPU上 。 (虽然在技术上也可以使用文本/视觉数据 , 但数据集往往更大 , 并且某些预处理步骤更容易在CPU上完成) 。
对于表格数据而不是文本/视觉数据 , 这些优化是可能的 , 他们存在两个主要区别:模型和数据 。
模型:视觉研究倾向于使用大型深层卷积神经网络(CNN);文本倾向于使用大型递归神经网络(RNN)或转换器;但是在表格数据上 , 完全连接的深度神经网络(FCDNN)可以很好地完成工作 。 尽管并非总是如此 , 但与表格数据中变量之间的交互作用相比 , 一般而言 , 视觉和文本模型需要更多的参数来学习更多的细微差别的表示 , 因此向前和向后传递可能需要更长的时间 。
数据:视觉数据倾向于将数据保存为充满图像的嵌套文件夹 , 这可能需要大量的预处理(裁剪 , 缩放 , 旋转等) 。 文本数据可以是大文件或其他文本流 。 通常 , 这两种方法都将保存在磁盘上 , 并从磁盘上批量加载 。 这不是问题 , 因为瓶颈不是磁盘的读写速度 , 而是预处理或向后传递 。 另一方面 , 表格数据具有很好的特性 , 可以轻松地以数组或张量的形式加载到连续的内存块中 。 表格数据的预处理往往是预先在数据库中单独进行 , 或者作为数据集上的矢量化操作进行 。
一个简单的更改就可以让PyTorch读取表格数据的速度提高20倍


推荐阅读