NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法( 二 )


该学习范式自然地将解释梯度策略优化、熵最大化、世界模型优化、置信度优化四部分整合到了统一的框架下 。 实验证明 , 该方法显著提高了基于模型规划算法的样本利用效率 。 在以视觉图像为输入的机器人控制任务公开数据集上 , 该方法超越了Dreamer , 取得了该领域中最高的性能 。
二、方法介绍

NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片
该图显示了总体的算法框架 。 一方面 , 作者们利用解析梯度方法对策略函数进行优化(PolicyImprovement) , 使得其在世界模型中的虚拟轨迹上获得最高的预测收益 。 另一方面 , 作者们最大化了虚拟轨迹和真实轨迹之间的互信息(MutualInformation) , 使得虚拟轨迹尽可能真实 。 因此 , 总体的优化目标函数为:

NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片
其中 , 表示根据实际执行动作从环境收集到的真实路径 , 表示相同的执行动作在世界模型中得到的虚拟路径 。 表示在解析梯度算法优化过程中 , 将世界模型带入策略函数得到的端到端的虚拟路径 。 分别表示策略网络 , 值网络 , 世界模型的参数 。 SVG表示stochasticvaluegradients , 一种经典的解析梯度算法 。 TD表示值函数的TDerror优化 。
传统的基于模型的强化学习方法中通常也会使用真实轨迹来优化世界模型预测误差 , 但这与BIRD框架有很大不同 。 在复杂的问题中 , 世界模型即使按照真实轨迹优化 , 得到的预测误差也往往无法忽略不计 。 用一个不精确的世界模型预测多步会产生很大的累积误差 , 使生成的轨迹与实际轨迹之间存在很大的差距 。
该问题在解析梯度算法中进一步恶化 , 因为解析梯度算法直接沿着世界模型求策略函数的梯度 , 需要世界模型在当前策略函数的邻域内也要表现良好 , 而邻域的数据往往超出了训练集的覆盖范围 , 对世界模型的泛化能力提出了更高的要求 。
为了解决这个问题 , 我们的方法从世界模型和策略函数两方面优化互信息 , 不仅能像传统方法那样优化世界模型的预测误差 , 也会使得策略优化对世界模型的置信度敏感 。 也就是说 , BIRD一方面会提高世界模型准确度 , 准确度高的地方会加大力度优化策略 , 另一方面在世界模型精度不够的地方 , 会比较保守的去降低策略的优化强度 。
1.世界模型中的策略优化
本文采用了一种经典的解析梯度算法SVG来优化策略函数 , 目标是最大化当前策略在世界模型中的预测收益 , 基于端到端的策略梯度进行优化 。 目标函数如下:

NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片
其中 , 是指数平均的值估计 , 用于平衡bias和variance , 其公式如下:

NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片
对于值函数的优化 , 本文采取了TD更新:
NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片
2.虚拟路径和真实路径的互信息优化
本文对互信息优化公式进行了展开 , 并分别对世界模型参数和策略函数参数进行求导 , 得到了三项:模型预测误差最小化、策略熵最大化、基于置信度的策略优化 。

NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片

NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片

NeurIPS 2020 | 清华联合密歇根大学: 兼顾想象与现实的基于模型强化学习算法
文章图片
第一项模型预测误差最小化 , 与普通基于模型的强化学习算法中的世界模型构建部分完全相同 , 本文采用了常见的似然函数 , 根据利用采样得到的真实样本对此项进行优化 。
第二项策略熵最大化 , 直接对策略的熵进行最大化 , 提高策略的多样性 。 直观上看 , 相当于提高了可微规划过程中的策略搜索空间 。
第三项基于置信度的策略优化 , 鼓励在模型预测置信度高的地方增大学习力度 , 在模型预测置信度低的地方降低学习强度 。 直观上看 , 相当于提高了策略搜索的质量 。
第二项和第三项结合起来看 , 本文的新型策略优化目标 , 相当于一边扩大了搜索的可能性 , 一边保证搜到的新数据能按照置信度去优化 , 从而避免了被高熵策略带来的离群值破坏优化 。 本文从对互信息的优化出发 , 推导出这三项 , 很自然地将模型误差、熵、置信度统一到一个完整的框架内 。