快速开启你的第一个项目:TensorFlow项目架构模板
选自GitHub
作者:Mahmoud Gemy
机器之心编译
参与:黄小天、李泽南
作为最为流行的深度学习资源库,TensorFlow 是帮助深度学习新方法走向实现的强大工具。它为大多数深度学习领域中使用的常用语言提供了大量应用程序接口。对于开发者和研究人员来说,在开启新的项目前首先面临的问题是:如何构建一个简单明了的结构,本文或许可以为你带来帮助。
项目链接:http://github.com/Mrgemy95/Tensorflow-Project-Template
TensorFlow 项目模板
简洁而精密的结构对于深度学习项目来说是必不可少的,在经过多次练习和 TensorFlow 项目开发之后,本文作者提出了一个结合简便性、优化文件结构和良好 OOP 设计的 TensorFlow 项目模板。该模板可以帮助你快速启动自己的 TensorFlow 项目,直接从实现自己的核心思想开始。
这个简单的模板可以帮助你直接从构建模型、训练等任务开始工作。
目录
概述
详述
项目架构
文件夹结构
主要组件
模型
训练器
数据加载器
记录器
配置
Main
未来工作
概述
简言之,本文介绍的是这一模板的使用方法,例如,如果你希望实现 VGG 模型,那么你应该:
在模型文件夹中创建一个名为 VGG 的类,由它继承「base_model」类
class
VGGModel
(
BaseModel
):
def
__init__
(
self
,
config
):
super
(
VGGModel
,
self
).
__init__
(
config
)
#call the build_model and init_saver functions.
self
.
build_model
()
self
.
init_saver
()
覆写这两个函数 "build_model",在其中执行你的 VGG 模型;以及定义 TensorFlow 保存的「init_saver」,随后在 initalizer 中调用它们。
def
build_model
(
self
):
# here you build the tensorflow graph of any model you want and also define the loss.
pass
def
init_saver
(
self
):
#here you initalize the tensorflow saver that will be used in saving the checkpoints.
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
self
.
config
.
max_to_keep
)
在 trainers 文件夹中创建 VGG 训练器,继承「base_train」类。
class
VGGTrainer
(
BaseTrain
):
def
__init__
(
self
,
sess
,
model
,
data
,
config
,
logger
):
super
(
VGGTrainer
,
self
).
__init__
(
sess
,
model
,
data
,
config
,
logger
)
覆写这两个函数「train_step」、「train_epoch」,在其中写入训练过程的逻辑。
def
train_epoch
(
self
):
"""
implement the logic of epoch:
-loop ever the number of iteration in the config and call teh train step
-add any summaries you want using the sammary
"""
pass
def
train_step
(
self
):
"""
implement the logic of the train step
- run the tensorflow session
- return any metrics you need to summarize
"""
pass
在主文件中创建会话,创建以下对象:「Model」、「Logger」、「Data_Generator」、「Trainer」与配置:
sess
=
tf
.
Session
()
# create instance of the model you want
model
=
VGGModel
(
config
)
# create your data generator
data
=
DataGenerator
(
config
)
# create tensorboard logger
logger
=
Logger
(
sess
,
config
)
向所有这些对象传递训练器对象,通过调用「trainer.train()」开始训练。
trainer
=
VGGTrainer
(
sess
,
model
,
data
,
config
,
logger
)
# here you train your model
trainer
.
train
()
你会看到模板文件、一个示例模型和训练文件夹,向你展示如何快速开始你的第一个模型。
详述
模型架构
文件夹结构
├──
base
│
├──
base_model
.
py
-
this file contains the abstract
class
of the model
.
│
└──
ease_train
.
py
-
this file contains the abstract
class
of the trainer
.
│
│
├──
model
-
This
folder contains any model of your project
.
│
└──
example_model
.
py
│
│
├──
trainer
-
this folder contains trainers of your project
.
│
└──
example_trainer
.
py
│
├──
mains
-
here
"s the main/s of your project (you may need more than one main.
│
│
├── data _loader
│ └── data_generator.py - here"
s the data_generator that responsible
for
all data handling
.
│
└──
utils
├──
logger
.
py
└──
any_other_utils_you_need
主要组件
模型
基础模型
基础模型是一个必须由你所创建的模型继承的抽象类,其背后的思路是:绝大多数模型之间都有很多东西是可以共享的。基础模型包含:
Save-此函数可保存 checkpoint 至桌面。
Load-此函数可加载桌面上的 checkpoint。
Cur-epoch、Global_step counters-这些变量会跟踪训练 epoch 和全局步。
Init_Saver-一个抽象函数,用于初始化保存和加载 checkpoint 的操作,注意:请在要实现的模型中覆盖此函数。
Build_model-是一个定义模型的抽象函数,注意:请在要实现的模型中覆盖此函数。
你的模型
以下是你在模型中执行的地方。因此,你应该:
创建你的模型类并继承 base_model 类。
覆写 "build_model",在其中写入你想要的 tensorflow 模型。
覆写"init_save",在其中你创建 tensorflow 保存器,以用它保存和加载检查点。
在 initalizer 中调用"build_model" 和 "init_saver"
训练器
基础训练器
基础训练器(Base trainer)是一个只包装训练过程的抽象的类。
你的训练器
以下是你应该在训练器中执行的。
创建你的训练器类,并继承 base_trainer 类。
覆写这两个函数,在其中你执行每一步和每一 epoch 的训练过程。
数据加载器
这些类负责所有的数据操作和处理,并提供一个可被训练器使用的易用接口。
记录器(Logger)
这个类负责 tensorboard 总结。在你的训练器中创建一个有关所有你想要的 tensorflow 变量的词典,并将其传递给 logger.summarize()。
配置
我使用 Json 作为配置方法,接着解析它,因此写入所有你想要的配置,然后用"utils/config/process_config"解析它,并把这个配置对象传递给所有其他对象。
Main
以下是你整合的所有之前的部分。
1. 解析配置文件。
2. 创建一个 TensorFlow 会话。
3. 创建 "Model"、"Data_Generator" 和 "Logger"实例,并解析所有它们的配置。
4. 创建一个"Trainer"实例,并把之前所有的对象传递给它。
5. 现在你可通过调用"Trainer.train()"训练你的模型。
未来工作
未来,该项目计划通过新的 TensorFlow 数据集 API 替代数据加载器。
本文为机器之心编译, 转载请联系本公众号获得授权
?------------------------------------------------
加入机器之心(全职记者/实习生):hr@jiqizhixin.com
投稿或寻求报道:editor@jiqizhixin.com
广告&商务合作:bd@jiqizhixin.com
- 过年了,教你2分钟快速泡发木耳、香菇、腐竹等,简单实用!
- 这三个表现,深爱你的男人才有,一般人做不到!
- 王者荣耀:拿什么拯救你的嬴政?大兄弟醒醒,你的真心没得救了!
- 心理测试:2018年你的爱情如何?测测你的桃花运程吧!
- 我愤怒的对她说:他爱的不是你,而是你的身体
- 别让你的腰承受无法言说之痛!你知道腰肌劳损是腰间盘的前兆吗?
- 学会这14款酱汁,没有难倒你的菜!
- 十二星座代表萌犬,跟你的性格蛮搭的
- Moncler开启人海战术 将启用八名设计师组成创意团队
- 男女相处,女人这处地方“货要真”,男人给你的爱“价才实”!