![在Pytorch中构建流数据集](http://img.jiangsulong.com/220425/1416392201-0.jpg)
文章插图
在处理监督机器学习任务时,最重要的东西是数据——而且是大量的数据 。当面对少量数据时,特别是需要深度神经网络的任务时,该怎么办?如何创建一个快速高效的数据管道来生成更多的数据,从而在不花费数百美元在昂贵的云GPU单元上的情况下进行深度神经网络的训练?
这是我们在MAFAT雷达分类竞赛中遇到的一些问题 。我的队友hezi hershkovitz为生成更多训练数据而进行的增强,以及我们首次尝试使用数据加载器在飞行中生成这些数据 。
要解决的问题我们在比赛中使用数据管道也遇到了一些问题,主要涉及速度和效率:
它没有利用Numpy和Pandas在Python中提供的快速矢量化操作的优势
每个批次所需的信息都首先编写并存储为字典,然后使用Python for循环在getitem方法中进行访问,从而导致迭代和处理速度缓慢 。
从音轨生成"移位的"片段会导致每次检索新片段时都重新构建相同的音轨,这也会减缓管道的速度 。
管道无法处理2D或3D输入,因为我们同时使用了scalograms和spectrograms但是无法处理 。
如果我们简单地按照批处理的方式进行所有的移位和翻转,那么批处理中就会充斥着与其他示例过于相似的示例,从而使模型不能很好地泛化 。
这些低效率的核心原因是,管道是以分段作为基本单元运行,而不是在音轨上运行 。
数据格式概述在制作我们的流数据之前,先再次介绍一下数据集,MAFAT数据由多普勒雷达信号的固定长度段组成,表示为128x32 I / Q矩阵; 但是,在数据集中,有许多段属于同一磁道,即,雷达信号持续时间较长,一条磁道中有1到43个段 。
![在Pytorch中构建流数据集](http://img.jiangsulong.com/220425/1416393C0-1.jpg)
文章插图
上面的图像来自hezi hershkovitz 的文章,并显示了一个完整的跟踪训练数据集时,结合所有的片段 。红色的矩形是包含在这条轨迹中的单独的部分 。白点是"多普勒脉冲",代表被跟踪物体的质心 。
借助"多普勒脉冲"白点,我们可以很容易地看到,航迹是由相邻的段组成的,即段id 1942之后是1943,然后是1944,等等 。
片段相邻的情况下允许我们使用移位来创建"新的"样本 。
![在Pytorch中构建流数据集](http://img.jiangsulong.com/220425/1416395227-2.jpg)
文章插图
但是,由于每个音轨由不同数量的片段组成,因此从任何给定音轨生成的增补数目都会不同,这使我们无法使用常规的Pytorch Dataset 类 。这里就需要依靠Pytorch中的IterableDataset 类从每个音轨生成数据流 。
数据流管道设计这三个对象的高级目标是创建一个_Segment对象流,它能够足够灵活地处理音轨和段,并且在代码中提供一致的语义:
class _Segment(Dict, ABC):segment_id: Union[int, str]output_array: np.ndarraydoppler_burst: np.ndarraytarget_type: np.ndarraysegment_count: int
为此,我们创建了:一个配置类,它将为一个特定的实验保存所有必要的超参数和环境变量——这实际上只是一个具有预定义键的简单字典 。
一个DataDict类,它处理原始片段的加载,验证每一条轨迹,创建子轨迹以防止数据泄漏,并将数据转换为正确的格式,例如2D或3D,并为扩展做好准备
StreamingDataset类,是Pytorch IterableDataset的子类,处理模型的扩充和流段 。
config = Config(file_path=PATH_DATA,num_tracks=3,valratio=6,get_shifts=True,output_data_type='spectrogram',get_horizontal_flip=True,get_vertical_flip=True,mother_wavelet='cgau1',wavelet_scale=3,batch_size=50,tracks_in_memory=25,include_doppler=True,shift_segment=2)dataset = DataDict(config=config)train_dataset = StreamingDataset(dataset.train_data, config, shuffle=True)train_loader = DataLoader(train_dataset,batch_size=config['batch_size'])
DataDict实现在DataDict中将片段处理为音轨,然后再处理为片段,为加速代码提供了很好的机会,特别是在数据验证、重新分割和轨创建都可以向量化的情况下 。我们使用了Numpy和Pandas中的一堆技巧和简洁的特性,大量使用了布尔矩阵来进行验证,并将scalogram/spectrogram 图转换应用到音轨中连接的片段上 。代码太长,但你可以去最后的源代码地址中查看一下DataDict create
推荐阅读
- 如何使用 Python 来自动交易加密货币
- 德国生化僵尸实验 苏联恐怖实验
- 慈禧太后的坟墓是谁挖掘出来的 慈禧墓中遗失的7件绝世珍宝
- 世界飞的最高的鸟 空中最大的鸟类是什么
- 古代的一两银子是现在的多少人民币 古时候1两银子相当于现在多少钱
- 中国的顶级豪宅有哪些 古代豪宅叫什么
- 日本人对中国做了些什么 日本在中国做的残忍的事情
- 鼠尾草在中国叫什么,迷迭香在菜市场叫什么
- 中药槐米和槐花的区别,槐花和槐米的药用价值样吗
- 网页如何唤起应用程序?