CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!

点击上方关注 , All in AI中国
人类学习的一个基本特征是我们可以同时学到很多东西 。 机器学习中的等效思想被称为多任务学习(MTL) , 它在实践中变得越来越有用 , 特别是对于强化学习和自然语言处理 。 事实上 , 即使在标准的单任务情况下 , 也可以设计额外的辅助任务并将其包含在优化过程中以帮助学习 。
本文通过展示如何在图像分类基准中解决简单的多任务问题来介绍该领域 。 重点是TensorFlow(Head API)的一个实验组件 , 它通过将神经网络的共享组件与特定任务组件解耦 , 帮助设计MTL的自定义估算器 。 在这个过程中 , 我们有机会讨论TensorFlow核心的其他功能 , 包括tf.data , tf.image和自定义估算器 。
本教程的代码作为完全包含的Colab笔记本提供 , 随时可以测试和实验!
()
内容一目了然
为了使教程更有趣 , 我们通过重新实现2014年论文的一部分(通过深度多任务学习进行面部特征点检测)来考虑一个现实的用例 。 问题很简单:给我们一个面部图像 , 我们需要定位一系列特征点 , 即图像上的兴趣点(鼻子、左眼、嘴巴......)和标签 , 包括人的年龄和性别 。 每个界标/标签构成图像上的单独任务 , 并且任务之间明显相关(即 , 想预测左眼的位置 , 需要先知道右边的位置) 。
CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!文章插图
【CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!】来自数据集的示例图像(源) 。 绿点是地标 , 每个图像还与一些其他标签相关联 , 包括年龄和性别 。
我们将实现分为三个部分:(i)加载图像(使用tf.data和tf.image); (ii)从论文中实施卷积网络(使用TF的自定义估计量); (iii)使用Head API添加MTL逻辑 。
第0步 - 加载数据集
下载数据集()后 , 快速检查 , 发现图像分为三个不同的文件夹(AFLW , lfw_5590和net_7876) 。 通过不同的文本文件提供训练和测试分割 , 每行对应一个图像和标签的路径:
CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!文章插图
来自训练数据集的第一个图像和标签 。 蓝色数字是图像位置(从左上角开始) , 红色数字是类别(见下文) 。
为简单起见 , 我们将使用Pandas加载文本文件并调整Unix标准的路径URL , 例如:对于训练部分:
CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!文章插图
在Pandas和scikit-learn中加载数据
由于文本文件不是很大 , 在这种情况下使用Pandas稍微容易一些 , 并且提供了一点灵活性 。 但是 , 对于较大的文件 , 更好的选择是直接使用tf.data对象TextLineDataset 。
第1步 - 使用tf.data和Dataset对象
现在有了数据 , 我们可以使用tf.data加载它以使其估算好!在最简单的情况下 , 我们可以通过Pandas的DataFrame进行切片 , 就可以获取我们的数据:
CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!文章插图
从Pandas的DataFrame加载tf.data中的数据
以前 , 将tf.data与Estimators一起使用的一个主要问题是调试数据集相当复杂 , 必须通过tf.Session对象 。 但是 , 从最新版本开始 , 即使在使用估算器时 , 也可以通过启用即时执行来调试数据集 。 例如 , 我们可以使用数据集构建8个元素的批次 , 获取第一批 , 并在屏幕上输出所有内容:
CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!文章插图
在即时执行中调试数据集对象
现在是从路径开始加载图像的时候了!通常这不是一件容易的事 , 因为图像可以有许多不同的扩展、大小 , 有些可以是黑白 , 等等 。 幸运的是 , 我们可以从TF教程中获取灵感来构建一个简单的函数来封装所有这些逻辑 , 利用tf.image模块中的工具:
CNN|一步步教你使用Head API在TensorFlow中进行多任务学习!文章插图
使用tf.image模块解析图像
该函数负责解决大多数解析问题:

  1. 'channels'参数允许在一行中加载彩色和黑白图像;
  2. 我们将所有图像调整为我们想要的格式(40x40 , 根据原始文件);
  3. 在第8行 , 我们还标准化了我们的标签 , 以表示0和1之间的相对位置 , 而不是绝对的位置(因为我们调整了所有图像的大小 , 图像可能会有不同的形状) 。
我们可以使用其内部的“map”函数将解析函数应用于数据集的每个元素:将它与一些用于测试的额外逻辑放在一起 , 我们获得最终的加载函数: