小胖有技能|AssemblyAI 在 PyTorch 中建立端到端的语音识别模型,利用( 四 )


小胖有技能|AssemblyAI 在 PyTorch 中建立端到端的语音识别模型,利用
文章图片
选择合适的优化器和调度器–具有超融合的AdamW
优化器和学习率调度器在使模型收敛到最佳点方面起着非常重要的作用 。 选择合适的的优化器和调度器还可以节省计算时间 , 并有助于你的模型更好应用到实际案例中 。
对于我们的模型 , 我们将使用AdamW和一个周期学习率调度器 。 Adam是一种广泛使用的优化器 , 可帮助你的模型更快地收敛 , 节省计算时间 , 但由于没有推广性 , 和随机梯度下降(SGD)一样臭名昭著 。
AdamW最初是在“去耦权重衰减正则化”中引入的 , 被认为是对Adam的“修复” 。 该论文指出 , 原始的Adam算法权重衰减的实现上存在错误 , AdamW试图解决该问题 。 这个修复程序有助于解决Adam的推广问题 。
单周期学习率调度算法最早是在《超收敛:大学习率下神经网络的快速训练》一文中引入的 。 本文表明 , 你可以使用一个简单的技巧 , 在保持其可推广能力的同时 , 将神经网络的训练速度提高一个数量级 。
开始时学习率很低 , 逐渐上升到一个很大的最大学习率 , 然后线性衰减到最初开始时的位置 。
小胖有技能|AssemblyAI 在 PyTorch 中建立端到端的语音识别模型,利用
文章图片
最大学习率比最低学习率要高很多 , 你可以获得一些正则化好处 , 如果数据量较小 , 可以帮助你的模型更好地推广 。
使用PyTorch , 这两种方法已经成为软件包的一部分 。
optimizer=optim.AdamW(model.parameters,hparams['learning_rate'])scheduler=optim.lr_scheduler.OneCycleLR(optimizer,max_lr=hparams['learning_rate'],steps_per_epoch=int(len(train_loader)),epochs=hparams['epochs'],anneal_strategy='linear')
小胖有技能|AssemblyAI 在 PyTorch 中建立端到端的语音识别模型,利用
文章图片
CTC损失功能–将音频与文本对齐
我们的模型将接受训练 , 预测输入到模型中的声谱图中每一帧(即时间步长)字母表中所有字符的概率分布 。
小胖有技能|AssemblyAI 在 PyTorch 中建立端到端的语音识别模型,利用
文章图片
传统的语音识别模型将要求你在训练之前将文本与音频对齐 , 并且将训练模型来预测特定帧处的特定标签 。
CTC损失功能的创新之处在于它允许我们可以跳过这一步 。 我们的模型将在训练过程中学习对齐文本本身 。 关键在于CTC引入的“空白”标签 , 该标签使模型能够表明某个音频帧没有产生字符 。 你可以在这篇出色的文章中看到有关CTC及其工作原理的更详细说明 。
PyTorch还内置了CTC损失功能 。
criterion=nn.CTCLoss(blank=28).to(device)
小胖有技能|AssemblyAI 在 PyTorch 中建立端到端的语音识别模型,利用
文章图片
语音模型评估
在评估语音识别模型时 , 行业标准使用的是单词错误率(WER)作为度量标准 。 错误率这个词的作用就像它说的那样——它获取你的模型输出的转录和真实的转录 , 并测量它们之间的误差 。
你可以在此处查看它是如何实现 。 另一个有用的度量标准称为字符错误率(CER) 。 CER测量模型输出和真实标签之间的字符误差 。 这些指标有助于衡量模型的性能 。
在本教程中 , 我们使用“贪婪”解码方法将模型的输出处理为字符 , 这些字符可组合创建文本 。 “贪婪”解码器接收模型输出 , 该输出是字符的最大概率矩阵 , 对于每个时间步长(频谱图帧) , 它选择概率最高的标签 。 如果标签是空白标签 , 则将其从最终的文本中删除 。
defGreedyDecoder(output,labels,label_lengths,blank_label=28,collapse_repeated=True):arg_maxes=torch.argmax(output,dim=2)decodes=targets=fori,argsinenumerate(arg_maxes):decode=targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist))forj,indexinenumerate(args):ifindex!=blank_label:ifcollapse_repeatedandj!=0andindex==args[j-1]:continuedecode.append(index.item)decodes.append(text_transform.int_to_text(decode))returndecodes,targets