『』组合求解器 + 深度学习 =?这篇ICLR 2020论文告诉你答案( 二 )


如图 , f(黑色)是分段恒定的 。 插值(橙色)以合理的方式连接恒定区域 。 例如 , 我们可以注意到最小值并没有变化 。
当然 , f 的域是多维的 。 这样 , 我们可以观察到 f 取相同值时输入 ω 的集合是一个多面体 。 自然地 , 在 f 的域中有许多这样的多面体 。 超参数 λ 有效地通过扰动求解器输入 ω 来使多面体偏移 。 定义了分段仿射目标的插值器 g 将多面体的偏移边界与原始边界相连 。
下图描述了这种情况 , 取值 f(y2) 的多面体边界偏移至了取值 f(y1) 处 。 这也直观地解释了为什么更倾向使用较大的 λ 。 偏移量必须足够大才能获得提供有用梯度的内插器 g 。 (详细证明过程参见原论文 。 )
首先 , 我们定义该扰动优化问题的解 , 其中扰动由超参数 λ 控制:
如果我们假设损失函数 c(ω,y) 是 y 和 ω 之间的点积 , 则我们可将插值目标定义为:
请注意 , 损失函数的线性度并不像乍一看那样有限制性 。 所有涉及边选择的问题都属于此类别 , 这类问题中损失是边权重之和 。 最短路径问题(SPP)和旅行商问题(TSP)都属于此类问题 。
在这个动画中 , 我们可以看到插值随 λ 增加的变化情况 。
算法
使用该方法 , 我们可以通过简单地通过修改反向传播来计算梯度 , 从而消除经典组合求解器和深度学习之间的断裂 。
def forward(ctx, w_): '''''' ctx: Context for backward pass w_: Estimated problem weights '''''' y_ = solver(w_) # Save context for backward pass ctx.w_ = w_ ctx.y_ = y_ return y_在前向传播中 , 我们只需给嵌入求解器提供 ω , 然后将解向前传递 。 此外 , 我们保存了 ω 和在前向传播中计算得到的解 y_ 。
def backward(ctx, grad): '''''' ctx: Context from forward pass '''''' w = ctx.w_ + lmda*grad # Calculate perturbed weights y_lmda = solver(w) return -(ctx.y_ - y_lmda)da至于反向传播 , 我们只需使用缩放系数为 λ 的反向传播梯度来扰动 ω , 并取先前解与扰动问题解之差即可 。
计算插值梯度的计算开销取决于求解器 , 额外的开销出现在前向传播和反向传播中 , 每个过程均调用了一次求解器 。
实验
我们使用包含一定组合复杂度的综合任务来验证该方法的有效性 。 在以下任务中 , 我们证明了该方法对于组合泛化的必要性 , 因为简单的监督学习方法无法泛化至没有见过的数据 。 同样 , 其目标是学习到正确的组合问题描述 。
对于魔兽争霸最短路径问题 , 训练集包含《魔兽争霸 II》地图和地图对应的最短路径作为目标 。 测试集包含没有见过的《魔兽争霸 II》地图 。 地图本身编码了 k × k 网格 。 地图被输入卷积神经网络 , 网络输出地图顶点的损失 , 然后将该损失送入求解器 。 最后 , 求解器(Dijkstra 最短路径算法)以指示矩阵的形式在地图上输出最短路径 。
自然地 , 在训练开始时 , 网络不知道如何为地图块分配正确的损失 , 但是使用该新方法后 , 我们能够学习到正确的地图块损失 , 从而获得正确的最短路径 。 下列直方图表明 , 相比于 ResNet 的传统监督训练方法 , 我们的方法泛化能力明显更好 。
MNIST 最小损失完美匹配问题的目标是 , 输出 MNIST 数字组成网格的最小损失完美匹配 。 具体而言 , 在最小损失完美匹配问题中 , 我们应该选择一些边 , 使得所有顶点都恰好被包含一次 , 并且边损失之和最小 。 网格中的每个单元都包含一个 MNIST 数字 , 该数字是图中具备垂直和水平方向邻近点的一个节点 。 垂直向下或水平向右读取两位数字 , 即可确定边损失 。
对于这个问题 , 卷积神经网络(CNN)接受 MNIST 网格图像作为输入 , 并输出被转换为边损失的顶点损失网格 。 接着将边损失提供给 Blossom V 完美匹配求解器 。