如何利用PyTorch中的Moco-V2减少计算约束( 二 )

  • 另一个挑战是 , 使用反向传播学习编码器参数是不可行的 , 因为这将需要计算队列中所有样本的梯度(这将导致大的计算图) 。
  • 为了解决这两个问题 , MoCo将键编码器实现为基于动量的查询编码器的移动平均值[1] 。 这意味着它以这种方式更新关键编码器参数:
    如何利用PyTorch中的Moco-V2减少计算约束文章插图
    其中m非常接近于1(例如 , 典型值为0.999) , 这确保我们在不同的时间从相似的编码器获得编码键 。
    损失函数-InfoNCE我们希望查询接近其所有正样本 , 远离所有负样本 。 InfoNC函数E会捕获它 。 它代表信息噪声对比估计 。 对于查询q和键k , InfoNCE损失函数是:
    如何利用PyTorch中的Moco-V2减少计算约束文章插图
    我们可以重写为:
    如何利用PyTorch中的Moco-V2减少计算约束文章插图
    当q和k的相似性增大 , q与负样本的相似性减小时 , 损失值减小
    以下是损失函数的代码:
    τ = 0.05def loss_function(q, k, queue):# N是批量大小N = q.shape[0]# C是表示的维数C = q.shape[1]# bmm代表批处理矩阵乘法# 如果mat1是b×n×m张量 , 那么mat2是b×m×p张量 ,# 然后输出一个b×n×p张量 。pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ))# 在查询和队列张量之间执行矩阵乘法neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1)# 求和denominator = neg + posreturn torch.mean(-torch.log(torch.div(pos,denominator)))
    如何利用PyTorch中的Moco-V2减少计算约束文章插图
    让我们再看看这个损失函数 , 并将它与分类交叉熵损失函数进行比较 。
    如何利用PyTorch中的Moco-V2减少计算约束文章插图
    这里pred?是数据点在第i类中的概率值预测 , true?是该点属于第i类的实际概率值(可以是模糊的 , 但大多数情况下是一个one-hot) 。
    如果你不熟悉这个话题 , 你可以看这个视频来更好地理解交叉熵 。 另外 , 请注意 , 我们经常通过softmax这样的函数将分数转换为概率值:
    我们可以把信息损失函数看作交叉熵损失 。 数据样本“q”的正确类是第r类 , 底层分类器基于softmax , 它试图在K+1类之间进行分类 。
    Info-NCE还与编码表示之间的相互信息有关;关于这一点的更多细节见[4] 。
    MoCo-v2框架现在 , 让我们把所有的东西放在一起 , 看看整个Moco-v2算法是什么样子的 。
    步骤1:我们必须得到查询和键编码器 。 最初 , 键编码器具有与查询编码器相同的参数 。 它们是彼此的复制品 。 随着训练的进行 , 键编码器将成为查询编码器的移动平均值(在这一点上进展缓慢) 。
    由于计算能力的限制 , 我们使用Resnet-18体系结构来实现 。 在通常的resnet架构之上 , 我们添加了一些密集的层 , 以使表示的维数降到25 。 这些层中的某些层稍后将充当投影 。
    # 定义我们的深度学习架构resnetq = resnet18(pretrained=False)classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(resnetq.fc.in_features, 100)),('added_relu1', nn.ReLU(inplace=True)),('fc2', nn.Linear(100, 50)),('added_relu2', nn.ReLU(inplace=True)),('fc3', nn.Linear(50, 25))]))resnetq.fc = classifierresnetk = copy.deepcopy(resnetq)# 将resnet架构迁移到设备resnetq.to(device)resnetk.to(device)步骤2:现在 , 我们已经有了编码器 , 并且假设我们已经设置了其他重要的数据结构 , 现在是时候开始训练循环并理解管道了 。
    这一步是从训练批中获取编码查询和键 。 我们用L2范数对表示进行规范化 。
    只是一个约定警告 , 所有后续步骤中的代码都将位于批处理和epoch循环中 。 我们还将张量“k”从它的梯度中分离出来 , 因为我们不需要计算图中的键编码器部分 , 因为动量更新方程会更新键编码器 。
    # 梯度零化optimizer.zero_grad()# 检索xq和xk这两个图像batchxq = sample_batched['image1']xk = sample_batched['image2']# 把它们移到设备上xq = xq.to(device)xk = xk.to(device)# 获取他们的输出q = resnetq(xq)k = resnetk(xk)k = k.detach()# 将输出规范化 , 使它们成为单位向量q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1))k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))步骤3:现在 , 我们将查询、键和队列传递给前面定义的loss函数 , 并将值存储在一个列表中 。 然后 , 像往常一样 , 对损失值调用backward函数并运行优化器 。
    # 获得损失值loss = loss_function(q, k, queue)# 把这个损失值放到epoch损失列表中epoch_losses_train.append(loss.cpu().data.item())# 反向传播loss.backward()# 运行优化器optimizer.step()