正如作者而言,训练一个连续层级网络的主要技术难点在于令梯度穿过 ODE Solver 的反向传播 。其实如果令梯度沿着前向传播的计算路径反传回去是非常直观的,但是内存占用会比较大而且数值误差也不能控制 。作者的解决方案是将前向传播的 ODE Solver 视为一个黑箱操作,梯度很难或根本不需要传递进去,只需要「绕过」就行了 。
作者采用了一种名为 adjoint method 的梯度计算方法来「绕过」前向传播中的 ODE Solver,即模型在反传中通过第二个增广 ODE Solver 算出梯度,其可以逼近按计算路径从 ODE Solver 传递回的梯度,因此可用于进一步的参数更新 。这种方法如上图 c 所示不仅在计算和内存非常有优势,同时还能精确地控制数值误差 。
具体而言,若我们的损失函数为 L(),且它的输入为 ODE Solver 的输出:
文章插图
我们第一步需要求 L 对 z(t) 的导数,或者说模型损失的变化如何取决于隐藏状态 z(t) 的变化 。其中损失函数 L 对 z(t_1) 的导数可以为整个模型的梯度计算提供入口 。作者将这一个导数称为 adjoint a(t) = -dL/z(t),它其实就相当于隐藏层的梯度 。
在基于链式法则的传统反向传播中,我们需要从后一层对前一层求导以传递梯度 。而在连续化的 ODEnet 中,我们需要将前面求出的 a(t) 对连续的 t 进行求导,由于 a(t) 是损失 L 对隐藏状态 z(t) 的导数,这就和传统链式法则中的传播概念基本一致 。下式展示了 a(t) 的导数,它能将梯度沿着连续的 t 向前传,附录 B.1 介绍了该式具体的推导过程 。
文章插图
在获取每一个隐藏状态的梯度后,我们可以再求它们对参数的导数,并更新参数 。同样在 ODEnet 中,获取隐藏状态的梯度后,再对参数求导并积分后就能得到损失对参数的导数,这里之所以需要求积分是因为「层级」t 是连续的 。这一个方程式可以表示为:
文章插图
综上,我们对 ODEnet 的反传过程主要可以直观理解为三步骤,即首先求出梯度入口伴随 a(t_1),再求 a(t) 的变化率 da(t)/dt,这样就能求出不同时刻的 a(t) 。最后借助 a(t) 与 z(t),我们可以求出损失对参数的梯度,并更新参数 。当然这里只是简要的直观理解,更完整的反传过程展示在原论文的算法 1 。
反向传播怎么做
在算法 1 中,陈天琦等研究者展示了如何借助另一个 OED Solver 一次性求出反向传播的各种梯度和更新量 。要理解算法 1,首先我们要熟悉 ODESolver 的表达方式 。例如在 ODEnet 的前向传播中,求解过程可以表示为 ODEsolver(z(t_0), f, t_0, t_1, θ),我们可以理解为从 t_0 时刻开始令 z(t_0) 以变化率 f 进行演化,这种演化即 f 在 t 上的积分,ODESolver 的目标是通过积分求得 z(t_1) 。
同样我们能以这种方式理解算法 1,我们的目的是利用 ODESolver 从 z(t_1) 求出 z(t_0)、从 a(t_1) 按照方程 4 积出 a(t_0)、从 0 按照方程 5 积出 dL/dθ 。最后我们只需要使用 dL/dθ 更新神经网络 f(z(t), t, θ) 就完成了整个反向传播过程 。
文章插图
如上所示,若初始给定参数θ、前向初始时刻 t_0 和终止时刻 t_1、终止状态 z(t_1) 和梯度入口 ?L/?z(t_1) 。接下来我们可以将三个积分都并在一起以一次性解出所有量,因此我们可以定义初始状态 s_0,它们是解常微分方程的初值 。
注意第一个初值 z(t_1),其实在前向传播中,从 z(t_0) 到 z(t_1) 都已经算过一遍了,但是模型并不会保留计算结果,因此也就只有常数级的内存成本 。此外,在算 a(t) 时需要知道对应的 z(t),例如 ?L/?z(t_0) 就要求知道 z(t_0) 的值 。如果我们不能保存中间状态的话,那么也可以从 z(t_1) 到 z(t_0) 反向再算一遍中间状态 。这个计算过程和前向过程基本一致,即从 z(t_1) 开始以变化率 f 进行演化而推出 z(t_0) 。
定义 s_0 后,我们需要确定初始状态都是怎样「演化」到终止状态的,定义这些演化的即前面方程 (3)、(4) 和 (5) 的被积函数,也就是算法 1 中 aug_dynamics() 函数所定义的 。
其中 f(z(t), t, θ) 从 t_1 到 t_0 积出来为 z(t_0),这第一个常微分方程是为了给第二个提供条件 。而-a(t)*?L/?z(t) 从 t_1 到 t_0 积出来为 a(t_0),它类似于传统神经网络中损失函数对第一个隐藏层的导数,整个 a(t) 就相当于隐藏层的梯度 。只有获取积分路径中所有隐藏层的梯度,我们才有可能进一步解出损失函数对参数的梯度 。
- 教你正确区分巴沙鱼和多利鱼 多利鱼图片
- windows如何快速移动文件 windows文件夹的移动
- 喷墨打印机打印重影如何解决 打印机打印重影怎么解决
- 我的电脑图标不见了找回方法 电脑上图标不见了如何找回
- iphone如何设置来电拒接 iphone手机来电如何拒接
- ID卡和IC卡的区别 如何区别ic卡与id卡
- excel如何对文本进行排序 excel中怎么给文本排序
- 如何取消苹果系统更新提醒 如何关闭苹果手机的系统更新提醒
- 教你用百花鱼做酸爽的酸菜鱼 百花鱼怎么做
- 三十分钟教你快速入门隶书书法 初学隶书该做什么