人走茶凉|API统一、干净,新型EagerPy实现多框架无缝衔接( 二 )


此外 , 能够编写自动与所有支持的框架一起运行的代码 , 不仅需要语法 , 还需要语义统一 。 为了确保这一点 , EagerPy 附带了一个庞大的测试套件 , 该套件可以验证不同框架特定子类之间的一致性 。 它会在所有 pull-request 上自动运行 , 并且需要通过之后才能合并新代码 。
测试套件还可以作为所支持的操作和参数组合的最终参考 。 这样就可以避免文档和实现之间出现不一致 , 并在实践中引出测试驱动开发过程 。
原始性能
没有 EagerPy , 想要与不同深度学习框架进行交互的代码必须经过 NumPy 实现 。 这需要在 CPU(NumPy)和 GPU(PyTorch、TensorFlow 和 JAX)之间进行高成本的内存复制 , 反之亦然 。
此外 , 许多计算仅在 CPU 上执行 , 为了避免这种情况 , EagerPy 仅保留对原始框架特定张量的引用(例如 GPU 上的 PyTorch 张量) , 并将所有的操作委托给相应的框架 。 这几乎不产生任何的计算开销 。
完全可链接的 API
求和或平方之类的许多运算都要采用张量并返回一个张量 。 通常情况下 , 这些运算按顺序被调用 。 例如使用平方、求和和开平方根以计算 L2 范数 。
在 EagerPy 中 , 所有运算都成为了张量对象(tensor object)上可用的方法 。 这样就可以按照它们的自然顺序(x.square().sum().sqrt())来链接操作 。 相反 , 例如 , NumPy 需要相反的操作顺序 , 即 np.sqrt(np.square(x).sum()) 。
类型检查
在 Python3.5 中 , Python 语法的扩展已经实现了对类型注释的支持(van Rossum 等人 , 2015 年) 。 即使具有类型注释 , Python 仍然是一种动态类型化的编程语言 , 并且当前在运行时会忽略所有类型注释 。 但是 , 我们可以在运行代码之前通过静态代码分析器检查这些类型注释 。
EagerPy 带有所有参数和返回值的全面类型注释 , 并使用 Mypy(Lehtosalo 等人 , 2016 年)对这些注释进行检查 。 这有助于我们捕获 EagerPy 中的漏洞 , 否则这些漏洞将一直不会被发现 。
EagerPy 用户可以通过键入自己代码的注释 , 并根据 EagerPy 的函数签名(function signature)自动检查代码来进一步优化 。 这一点很关键 , 因为 TensorFlow、NumPy 和 JAX 当前自身不提供类型注释 。
EagerPy 的代码实例解析
如下代码 1 为一个通用 EagerPy 范数函数 , 它可以通过任何框架中的原生张量被调用 , 并且返回的范数依然作为同一个框架中的原生张量 。
人走茶凉|API统一、干净,新型EagerPy实现多框架无缝衔接代码 1:框架无关的范数函数 。
EagerPy 和原生张量之间的转换
原生张量可以是 PyTorch GPU 或 CPU 张量 , 如下代码 2 所示:
人走茶凉|API统一、干净,新型EagerPy实现多框架无缝衔接代码 2:原生 PyTorch 张量 。
可以是 TensorFlow 张量 , 如下代码 3 所示:
人走茶凉|API统一、干净,新型EagerPy实现多框架无缝衔接代码 3:原生 TensorFlow 张量 。
可以是 JAX 数组 , 如下代码 4 所示:
人走茶凉|API统一、干净,新型EagerPy实现多框架无缝衔接代码 4:原生 JAX 数组 。
可以是 NumPy 数组 , 如下代码 5 所示:
人走茶凉|API统一、干净,新型EagerPy实现多框架无缝衔接代码 5:原生 NumPy 数组 。