教你如何理解并使用常微分方程 常微分方程


这是一篇神奇的论文,以前一层一层叠加的神经网络似乎突然变得连续了,反向传播也似乎不再需要一点一点往前传、一层一层更新参数了 。
在最近结束的 NeruIPS 2018 中,来自多伦多大学的陈天琦等研究者成为最佳论文的获得者 。他们提出了一种名为神经常微分方程的模型,这是新一类的深度神经网络 。神经常微分方程不拘于对已有架构的修修补补,它完全从另外一个角度考虑如何以连续的方式借助神经网络对数据建模 。在陈天琦的讲解下,机器之心将向各位读者介绍这一令人兴奋的神经网络新家族 。
在与机器之心的访谈中,陈天琦的导师 David Duvenaud 教授谈起这位学生也是赞不绝口 。Duvenaud 教授认为陈天琦不仅是位理解能力超强的学生,钻研起问题来也相当认真透彻 。他说:「天琦很喜欢提出新想法,他有时会在我提出建议一周后再反馈:『老师你之前建议的方法不太合理 。但是我研究出另外一套合理的方法,结果我也做出来了 。』」Ducenaud 教授评价道,现如今人工智能热度有增无减,教授能找到优秀博士生基本如同「鸡生蛋还是蛋生鸡」的问题,顶尖学校的教授通常能快速地招纳到博士生,「我很幸运地能在事业起步阶段就遇到陈天琦如此优秀的学生 。」
本文主要介绍神经常微分方程背后的细想与直观理解,很多延伸的概念并没有详细解释,例如大大降低计算复杂度的连续型流模型和官方 PyTorch 代码实现等 。这一篇文章重点对比了神经常微分方程(ODEnet)与残差网络,我们不仅能通过这一部分了解如何从熟悉的 ResNet 演化到 ODEnet,同时还能还有新模型的前向传播过程和特点 。
其次文章比较关注 ODEnet 的反向传播过程,即如何通过解常微分方程直接把梯度求出来 。这一部分与传统的反向传播有很多不同,因此先理解反向传播再看源码可能是更好的选择 。值得注意的是,ODEnet 的反传只有常数级的内存占用成本 。
  • ODEnet 的 PyTorch 实现地址:https://github.com/rtqichen/torchdiffeq
  • ODEnet 论文地址:https://arxiv.org/abs/1806.07366
如下展示了文章的主要结构:
  • 常微分方程
  • 从残差网络到微分方程
  • 从微分方程到残差网络
  • 网络对比
  • 神经常微分方程
  • 反向传播
  • 反向传播怎么做
  • 连续型的归一化流
  • 变量代换定理
常微分方程
其实初读这篇论文,还是有一些疑惑的,因为很多概念都不是我们所熟知的 。因此如果想要了解这个模型,那么同学们,你们首先需要回忆高数上的微分方程 。有了这样的概念后,我们就能愉快地连续化神经网络层级,并构建完整的神经常微分方程 。
常微分方程即只包含单个自变量 x、未知函数 f(x) 和未知函数的导数 f'(x) 的等式,所以说 f'(x) = 2x 也算一个常微分方程 。但更常见的可以表示为 df(x)/dx = g(f(x), x),其中 g(f(x), x) 表示由 f(x) 和 x 组成的某个表达式,这个式子是扩展一般神经网络的关键,我们在后面会讨论这个式子怎么就连续化了神经网络层级 。
一般对于常微分方程,我们希望解出未知的 f(x),例如 f'(x) = 2x 的通解为 f(x)=x^2 +C,其中 C 表示任意常数 。而在工程中更常用数值解,即给定一个初值 f(x_0),我们希望解出末值 f(x_1),这样并不需要解出完整的 f(x),只需要一步步逼近它就行了 。
现在回过头来讨论我们熟悉的神经网络,本质上不论是全连接、循环还是卷积网络,它们都类似于一个非常复杂的复合函数,复合的次数就等于层级的深度 。例如两层全连接网络可以表示为 Y=f(f(X, θ1), θ2),因此每一个神经网络层级都类似于万能函数逼近器 。
因为整体是复合函数,所以很容易接受复合函数的求导方法:链式法则,并将梯度从最外一层的函数一点点先向里面层级的函数传递,并且每传到一层函数,就可以更新该层的参数 θ 。现在问题是,我们前向传播过后需要保留所有层的激活值,并在沿计算路径反传梯度时利用这些激活值 。这对内存的占用非常大,因此也就限制了深度模型的训练过程 。
神经常微分方程走了另一条道路,它使用神经网络参数化隐藏状态的导数,而不是如往常那样直接参数化隐藏状态 。这里参数化隐藏状态的导数就类似构建了连续性的层级与参数,而不再是离散的层级 。因此参数也是一个连续的空间,我们不需要再分层传播梯度与更新参数 。总而言之,神经微分方程在前向传播过程中不储存任何中间结果,因此它只需要近似常数级的内存成本 。