业界 | OpenAI提出Reptile:可扩展的元学习算法
选自OpenAI Blog
作者: ALEX NICHOL & JOHN SCHULMAN
机器之心编译
近日,OpenAI 发布了简单元学习算法 Reptile,该算法对一项任务进行重复采样、执行随机梯度下降、更新初始参数直到习得最终参数。该方法的性能可与 MAML(一种广泛应用的元学习算法)媲美,且比后者更易实现,计算效率更高。
元学习是学习如何学习的过程。元学习算法会学习任务的一个分布,每项任务都是学习问题,并输出快速学习器,学习器可从少量样本中学习并进行泛化。一个得到充分研究的元学习问题是 few-shot 分类,其中每项任务都是分类问题,学习器只能看到 1-5 个输入-输出样本(每个类别),之后学习器必须对新输入进行分类。下面,你可以尝试 OpenAI 的 1-shot 分类交互 Demo,其使用了 Reptile。
点击「Edit All」按钮,绘制三种不同的形状或符号,然后在后侧的输入区域绘制其中一个形状,就可以看到 Reptile 的分类效果。前三个图是标注样本:每个定义一类。最后的图表示未知样本,Reptile 输出其属于每个类别的概率。(请点击原文链接体验交互)
Reptile 的工作原理
和 MAML 类似,Reptile 会学习神经网络的参数初始化方法,以使神经网络可使用少量新任务数据进行调整。但是 MAML 通过梯度下降算法的计算图来展开微分计算过程,而 Reptile 在每个任务中执行标准形式的随机梯度下降(SGD):它不用展开计算图或计算任意二阶导数。因此 Reptile 比 MAML 所需的计算量和内存都更少。伪代码如下:
最后一步也可以把 Φ?W 作为梯度,将其插入如 Adam 等更复杂的优化器。
很令人震惊,该方法运行效果很好。如果 k=1,该算法对应「联合训练」(joint training):在多项任务上执行 SGD。尽管联合训练在很多情况下可以学到有用的初始化,但在 zero-shot 学习不可能出现的情况下(如输出标签是随机排列的)它能学习的很少。Reptile 要求 k>1,更新依赖于损失函数的高阶导数。正如 OpenAI 在论文中展示的那样,k>1 时 Reptile 的行为与 k=1(联合训练)时截然不同。
为了分析 Reptile 的工作原理,OpenAI 使用泰勒级数逼近更新。Reptile 更新最大化同一任务中不同小批量的梯度内积,以改善泛化效果。该发现可能在元学习之外也有影响,如解释 SGD 的泛化性能。OpenAI 的分析结果表明 Reptile 和 MAML 可执行类似的更新,包括具备不同权重的相同两个项。
在 OpenAI 的实验中,他们展示了 Reptile 和 MAML 在 Omniglot 和 Mini-ImageNet 基准上执行 few-shot 分类任务时具备类似的性能。Reptile 收敛速度更快,因为其更新具备更低的方差。OpenAI 关于 Reptile 的分析表明,我们可以使用不同的 SGD 梯度组合获取大量不同的算法。在下图中,假设我们在不同任务中使用不同批量大小的 SGD 执行 K 个更新步,产生 g_1,g_2,…,g_k k 个梯度。下图展示了在 Omniglot 上的学习曲线,且它由梯度的和作为元梯度而绘制出。g_2 对应一阶 MAML,即原版 MAML 论文提出的算法。由于方差缩减,使用更多的梯度会导致更快的学习或收敛。注意仅使用 g_1(对应 k=1)如预测那样在这个任务中没有什么提升,因为我们无法改进 zero-shot 的性能。
实现
实现的 GitHub 地址:http://github.com/openai/supervised-reptile
该实现应用 TensorFlow 进行相关的计算,代码可在 Omniglot 和 Mini-ImageNet 上复现。此外,OpenAI 也发布了一个更小的基于 JavaScript 的实现(http://github.com/openai/supervised-reptile/tree/master/web),其对使用 TensorFlow 预训练的模型进行了调整——以上 demo 就是基于此实现的。
最后,下面是一个 few-shot 回归的简单示例,预测 10(x,y) 对的随机正弦波。该示例基于 PyTorch:
import
numpy
as
np
import
torch
from
torch
import
nn
,
autograd
as
ag
import
matplotlib
.
pyplot
as
plt
from
copy
import
deepcopy
seed
=
0
plot
=
True
innerstepsize
=
0.02
# stepsize in inner SGD
innerepochs
=
1
# number of epochs of each inner SGD
outerstepsize0
=
0.1
# stepsize of outer optimization, i.e., meta-optimization
niterations
=
30000
# number of outer updates; each iteration we sample one task and update on it
rng
=
np
.
random
.
RandomState
(
seed
)
torch
.
manual_seed
(
seed
)
# Define task distribution
x_all
=
np
.
linspace
(-
5
,
5
,
50
)[:,
None
]
# All of the x points
ntrain
=
10
# Size of training minibatches
def
gen_task
():
"Generate classification problem"
phase
=
rng
.
uniform
(
low
=
0
,
high
=
2
*
np
.
pi
)
ampl
=
rng
.
uniform
(
0.1
,
5
)
f_randomsine
=
lambda
x
:
np
.
sin
(
x
+
phase
)
*
ampl
return
f_randomsine
# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
64
),
nn
.
Tanh
(),
nn
.
Linear
(
64
,
64
),
nn
.
Tanh
(),
nn
.
Linear
(
64
,
1
),
)
def
totorch
(
x
):
return
ag
.
Variable
(
torch
.
Tensor
(
x
))
def
train_on_batch
(
x
,
y
):
x
=
totorch
(
x
)
y
=
totorch
(
y
)
model
.
zero_grad
()
ypred
=
model
(
x
)
loss
=
(
ypred
-
y
).
pow
(
2
).
mean
()
loss
.
backward
()
for
param
in
model
.
parameters
():
param
.
data
-=
innerstepsize
*
param
.
grad
.
data
def
predict
(
x
):
x
=
totorch
(
x
)
return
model
(
x
).
data
.
numpy
()
# Choose a fixed task and minibatch for visualization
f_plot
=
gen_task
()
xtrain_plot
=
x_all
[
rng
.
choice
(
len
(
x_all
),
size
=
ntrain
)]
# Reptile training loop
for
iteration
in
range
(
niterations
):
weights_before
=
deepcopy
(
model
.
state_dict
())
# Generate task
f
=
gen_task
()
y_all
=
f
(
x_all
)
# Do SGD on this task
inds
=
rng
.
permutation
(
len
(
x_all
))
for
_
in
range
(
innerepochs
):
for
start
in
range
(
0
,
len
(
x_all
),
ntrain
):
mbinds
=
inds
[
start
:
start
+
ntrain
]
train_on_batch
(
x_all
[
mbinds
],
y_all
[
mbinds
])
# Interpolate between current weights and trained weights from this task
# I.e. (weights_before - weights_after) is the meta-gradient
weights_after
=
model
.
state_dict
()
outerstepsize
=
outerstepsize0
*
(
1
-
iteration
/
niterations
)
# linear schedule
model
.
load_state_dict
({
name
:
weights_before
[
name
]
+
(
weights_after
[
name
]
-
weights_before
[
name
])
*
outerstepsize
for
name
in
weights_before
})
# Periodically plot the results on a particular task and minibatch
if
plot
and
iteration
==
0
or
(
iteration
+
1
)
%
1000
==
0
:
plt
.
cla
()
f
=
f_plot
weights_before
=
deepcopy
(
model
.
state_dict
())
# save snapshot before evaluation
plt
.
plot
(
x_all
,
predict
(
x_all
),
label
=
"pred after 0"
,
color
=(
0
,
0
,
1
))
for
inneriter
in
range
(
32
):
train_on_batch
(
xtrain_plot
,
f
(
xtrain_plot
))
if
(
inneriter
+
1
)
%
8
==
0
:
frac
=
(
inneriter
+
1
)
/
32
plt
.
plot
(
x_all
,
predict
(
x_all
),
label
=
"pred after %i"
%(
inneriter
+
1
),
color
=(
frac
,
0
,
1
-
frac
))
plt
.
plot
(
x_all
,
f
(
x_all
),
label
=
"true"
,
color
=(
0
,
1
,
0
))
lossval
=
np
.
square
(
predict
(
x_all
)
-
f
(
x_all
)).
mean
()
plt
.
plot
(
xtrain_plot
,
f
(
xtrain_plot
),
"x"
,
label
=
"train"
,
color
=
"k"
)
plt
.
ylim
(-
4
,
4
)
plt
.
legend
(
loc
=
"lower right"
)
plt
.
pause
(
0.01
)
model
.
load_state_dict
(
weights_before
)
# restore from snapshot
(
f
"-----------------------------"
)
(
f
"iteration {iteration+1}"
)
(
f
"loss on plotted curve {lossval:.3f}"
)
# would be better to average loss ove
论文:Reptile: a Scalable Metalearning Algorithm
地址:http://d4mucfpksywv.cloudfront.net/research-covers/reptile/reptile_update.pdf
摘要:本论文讨论了元学习问题,即存在任务的一个分布,我们希望找到能在该分布所采样的任务(模型未见过的任务)中快速学习的智能体。我们提出了一种简单元学习算法 Reptile,它会学习一种能在新任务中快速精调的参数初始化方法。Reptile 会重复采样一个任务,并在该任务上执行训练,且将初始化朝该任务的已训练权重方向移动。Reptile 不像同样学习初始化的 MAML,它并不要求在优化过程中是可微的,因此它更适合于需要很多更新步的优化问题。我们的研究发现,Reptile 在一些有具备完整基准的 few-shot 分类任务上表现良好。此外,我们还提供了一些理论性分析,以帮助理解 Reptile 的工作原理。
原文链接:http://blog.openai.com/reptile/
本文为机器之心编译, 转载请联系本公众号获得授权
?------------------------------------------------
加入机器之心(全职记者/实习生):hr@jiqizhixin.com
投稿或寻求报道:editor@jiqizhixin.com
广告&商务合作:bd@jiqizhixin.com
- 两个业界首创 华为Mate RS保时捷设计重新定义万元机
- 这家百亿房企,为何提出“5年冲3000亿”命题?
- 不论是三摄还是五摄手机,这款新机拍照技术在业界大为赞赏
- 破天荒!俄罗斯竟提出这种要求,一下子让中国陷入进退两难境地
- 万科独董提出新词:宝能已经开始经济领域的“颜色革命”
- 【公考干货】申论怎么写?这种方法很有用!
- 联发科 VS 高通,同日赴深抢单,谁会成大赢家?
- 中科大提出二叠纪末生命灭绝假说
- 房产税成业界讨论热点,看国外是如何征收的?
- 手机号13、15、17、18开头的注意了!