引入大语言模型、首个支持国产算力生态,全新开源RL框架RLLTE来了

近年来,强化学习的研究热度不断攀升,在智能制造、自动驾驶、大语言模型等多个领域取得了耀眼成绩,展示出巨大的研究潜力 。然而,强化学习算法高效、可靠的工程实现仍是长期存在的问题 。由于其算法结构的复杂性,微小的代码差异就可能严重影响实际性能 。
为了解决这一问题,科研和开发者社区先后提出了多个强化学习框架,如强调稳定性和可靠性的 Stable-Baselines3、模块化设计的 Tianshou 以及单文件实现算法的 CleanRL , 为强化学习的学术研究和应用开发做出了积极贡献 。然而 , 大部分项目的活跃周期较短,未建立合理的长期演进计划,并且代码风格迥异 , 限制了社区的开源协作 。它们也没有构建完整的项目生态,仅专注于模型训练,而忽略了评估、部署等现实需求 。同时 , 这些项目缺乏完备的测试数据,导致复现成本极高,阻碍了后续研究的进行 。
为了解决以上问题,来自香港理工大学、宁波东方理工大学(暂名)、普渡大学和大疆科技的研究者和算法团队,受到电信中「长期演进技术(LTE)」的启发,发布了名为 RLLTE 的开源强化学习框架,旨在为促进强化学习研究和应用提供开发组件和工程标准 。RLLTE 不仅提供高质量的算法实现,还可作为工具库用于新算法的开发 。

引入大语言模型、首个支持国产算力生态,全新开源RL框架RLLTE来了

文章插图
论文链接:https://arxiv.org/pdf/2309.16382.pdf
代码 / Demo 链接:https://Github.com/RLE-Foundation/rllte
官方网站:https://docs.rllte.dev/

引入大语言模型、首个支持国产算力生态,全新开源RL框架RLLTE来了

文章插图
RLLTE 从「探索 - 利用」的角度出发对强化学习算法进行完全解耦,将它们分解为若干最小基元,例如用于处理观测的编码器(Encoder),以及用于经验存储和采样的存储器(Storage) 。RLLTE 为每一种基元提供了丰富的模组供开发者选择,使得开发者可以以「搭积木」的方式进行强化学习算法的构建 。RLLTE 框架的主要功能和亮点如下:
极致模块化:RLLTE 的核心设计思想是像 PyTorch 一样为强化学习算法开发提供标准、便捷、即插即用的开发组件,而并非专注于提供具体的算法实现 。因此 , RLLTE 实现的算法中每个组件都是可替换的,并且支持用户使用自定义的模块 。这一解耦过程有助于算法可解释性的研究和更深层次的改进探索 。
长期演进:RLLTE 作为一个长期演进的框架,将会持续更新强化学习中的先进算法和工具 。为了保持项目的体量和高质量 , RLLTE 只更新通用的算法,抑或是在采样效率或者泛化能力方面做出的改进 , 并且要求这些算法必须在广受认可的基线上取得杰出的性能 。
数据增强:近年来的大量研究将数据增强技巧引入强化学习算法,以实现采样效率和泛化能力的显著提升 。RLLTE 默认支持数据增强操作,并提供大量观测(Observations)增强和内在奖励(Intrinsic Rewards)模组供开发者选择 。
丰富的项目生态:RLLTE 同时考虑学术界和工业界的需求 , 构建了丰富的项目生态 。开发者可以在一个框架中实现任务设计、模型训练、评估以及部署 。并且,RLLTE 还尝试将大语言模型引入该框架,以降低用户学习成本,加速强化学习的应用构建 。
完备的基线数据:现有的强化学习框架通常只在有限的任务上进行算法测试,缺乏完备的测试数据(训练曲线和测试分数等) 。由于强化学习训练的高算力消耗,这是可以理解的,但阻碍了后续的研究进行 。为了解决这一问题,RLLTE 依托 Hugging Face 平台建立了数据仓库,并对内置算法在广受认可的基线上进行测试以提供完整的训练数据 。
多硬件支持:在当前全球算力需求激增的背景下,RLLTE 被设计为支持多种算力设备以保证灵活性和可拓展性 。当前,框架支持使用 NVIDIA GPU 和 HUAWEI NPU 进行训练,并支持在 NVIDIA TensorRT 以及 HUAWEI CANN 架构下进行推理端部署 。RLTLE 也是首个支持国产算力生态的强化学习框架 。

引入大语言模型、首个支持国产算力生态,全新开源RL框架RLLTE来了

文章插图
RLLTE 框架介绍
RLLTE 框架主要包含三个层级:核心层(Core)、应用层(Application)以及工具层(Tool) 。
核心层从「探索 - 利用」的角度对 RL 算法进行完全解耦,并将其拆分成以下 6 大基元:

引入大语言模型、首个支持国产算力生态,全新开源RL框架RLLTE来了

文章插图
RLLTE 为每一类基元提供了大量模组供开发者选择 。例如,storage 模块中提供了 VanillaReplayStorage 和 DictReplayStorage 分别用于存储普通格式和字典格式的观测数据 。


推荐阅读