88from caicloud .clever .tensorflow import dist_base
99from caicloud .clever .tensorflow import model_exporter
1010
11+ tf .app .flags .DEFINE_string ("export_dir" ,
12+ "/tmp/mnist/saved_model" ,
13+ "model export directory path." )
14+ tf .app .flags .DEFINE_string ("checkpoint_dir" ,
15+ "" ,
16+ "model checkpoint directory path." )
17+ FLAGS = tf .app .flags .FLAGS
18+
19+ _train_op = None
20+
1121def model_fn (sync , num_replicas ):
1222 """TensorFlow 模型定义函数。
1323
@@ -18,13 +28,33 @@ def model_fn(sync, num_replicas):
1828 - `sync`:当前是否采用参数同步更新模式。
1929 - `num_replicas`:分布式 TensorFlow 的计算节点(worker)个数。
2030 """
31+ global _train_op
2132
2233 # TODO:添加业务模型定义操作。
23-
34+ # global_step = ...
35+ # _train_op = ...
36+
37+ # 添加模型评估配置:
38+ # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
39+ # def accuracy_evalute_fn(session):
40+ # return session.run(accuracy, ...)
41+ # model_metric_ops = {
42+ # "accuracy": accuracy_evalute_fn
43+ # }
44+
45+ # 定义模型导出配置
46+ # model_export_spec = model_exporter.ModelExportSpec(
47+ # export_dir=FLAGS.export_dir,
48+ # input_tensors={"image": _input_images},
49+ # output_tensors={"logits": logits})
50+
2451 # model_fn 函数需要返回 ModelFnHandler 对象告知 TaaS 平台所构建的模型的一些信息,
2552 # 例如 global_step、优化器 Optimizer、模型评估指标以及模型导出的相关配置等等。
2653 # 详细信息请参考 docs.caicloud.io。
27- return dist_base .ModelFnHandler ()
54+ return dist_base .ModelFnHandler (
55+ global_step = global_step ,
56+ model_metric_ops = model_metric_ops ,
57+ model_export_spec = model_export_spec )
2858
2959def train_fn (session , num_global_step ):
3060 """模型训练的每一轮操作。
@@ -38,21 +68,41 @@ def train_fn(session, num_global_step):
3868
3969 # TODO:添加业务模型训练操作。
4070
41- # train_fn 函数返回一个 bool 值,用于告知 TaaS 平台是否要提前终止模型训练。返回 True,
42- # 表示终止训练;否则,TaaS 将继续下一轮训练。
43- # 例如,为了防止训练模型过拟合,在训练过程中定时使用验证数据评测模型效果。如果发现模型
44- # 在训练数据集上的效果有优化,而在验证数据集上的效果却开始劣化,则说明模型可能出现了过
45- # 拟合,此时我们就可以通过返回 True 来告知 TaaS 平台提前终止模型训练。
71+ # train_fn 函数返回一个 bool 值,用于告知 TaaS 平台是否要提前终止模型训练。
72+ # 返回 True,表示终止训练;否则,TaaS 将继续下一轮训练。
73+ # 例如,为了防止训练模型过拟合,在训练过程中定时使用验证数据评测模型效果。当模型效果
74+ # 达到预期效果,便可以通过返回 True 来结束模型训练。
4675 return False
4776
4877def gen_init_fn ():
4978 """获取自定义初始化函数。
5079
5180 有些情况下,我需要从某个事先训练好的 checkpoint 文件中加载模型的参数。此时,我们需要自
5281 己实现使用 tf.Saver() 从该 checkpoint 中加载模型参数进行自定义初始化的函数。
82+
83+ 注:如果不需要自定义初始化,可以不提供 gen_init_fn 实现,或者 gen_init_fn 返回 None。
5384 """
54- return None
5585
86+ # TODO: 添加自己的处理逻辑
87+
88+ # 定义 tf.train.Saver 会修改 TensorFlow 的 Graph 结构,
89+ # 而当 Base 框架调用自定义初始化函数 init_from_checkpoint 的时候,
90+ # TensorFlow 模型的 Graph 结构已经变成 finalized,不再允许修改 Graph 结构。
91+ # 所以,这个定义必须放在 init_from_checkpoint 函数外面。
92+ saver = tf .train .Saver (tf .trainable_variables ())
93+
94+ def init_from_checkpoint (scaffold , sess ):
95+ """执行自定义初始化的函数。
96+
97+ TaaS 平台会优先从设置的日志保存路径中获取最新的 checkpoint 来 restore 模型参数,
98+ 如果日志保存路径中找不到 checkpoint 文件,才会调用本函数来进行模型初始化。
99+
100+ 本函数必须接收两个参数:
101+ - scafford: tf.train.Scaffold 对象;
102+ - sess: tf.Session 对象。
103+ """
104+ saver .restore (sess , checkpoint_path )
105+ return init_from_checkpoint
56106
57107def after_train_hook (session ):
58108 """模型训练操作。
0 commit comments