@@ -153,18 +153,27 @@ def _resnet_model_fn(features, labels, mode, params):
153
153
154
154
learning_rate = tf .train .piecewise_constant (tf .train .get_global_step (),
155
155
boundaries , staged_lr )
156
- # Create a nicely-named tensor for logging
157
- learning_rate = tf .identity (learning_rate , name = 'learning_rate' )
156
+
157
+ loss = tf .reduce_mean (tower_losses , name = 'loss' )
158
+
159
+ examples_sec_hook = cifar10_utils .ExamplesPerSecondHook (
160
+ params .train_batch_size , every_n_steps = 10 )
161
+
162
+ tensors_to_log = {'learning_rate' : learning_rate , 'loss' : loss }
163
+
164
+ logging_hook = tf .train .LoggingTensorHook (
165
+ tensors = tensors_to_log , every_n_iter = 100 )
166
+
167
+ train_hooks = [logging_hook , examples_sec_hook ]
158
168
159
169
optimizer = tf .train .MomentumOptimizer (
160
170
learning_rate = learning_rate , momentum = momentum )
161
171
162
- chief_hooks = []
163
172
if params .sync :
164
173
optimizer = tf .train .SyncReplicasOptimizer (
165
174
optimizer , replicas_to_aggregate = num_workers )
166
- sync_replicas_hook = optimizer .make_session_run_hook (True )
167
- chief_hooks .append (sync_replicas_hook )
175
+ sync_replicas_hook = optimizer .make_session_run_hook (params . is_chief )
176
+ train_hooks .append (sync_replicas_hook )
168
177
169
178
# Create single grouped train op
170
179
train_op = [
@@ -185,14 +194,13 @@ def _resnet_model_fn(features, labels, mode, params):
185
194
'accuracy' :
186
195
tf .metrics .accuracy (stacked_labels , predictions ['classes' ])
187
196
}
188
- loss = tf .reduce_mean (tower_losses , name = 'loss' )
189
197
190
198
return tf .estimator .EstimatorSpec (
191
199
mode = mode ,
192
200
predictions = predictions ,
193
201
loss = loss ,
194
202
train_op = train_op ,
195
- training_chief_hooks = chief_hooks ,
203
+ training_hooks = train_hooks ,
196
204
eval_metric_ops = metrics )
197
205
198
206
return _resnet_model_fn
@@ -336,32 +344,20 @@ def _experiment_fn(run_config, hparams):
336
344
337
345
train_steps = hparams .train_steps
338
346
eval_steps = num_eval_examples // hparams .eval_batch_size
339
- examples_sec_hook = cifar10_utils .ExamplesPerSecondHook (
340
- hparams .train_batch_size , every_n_steps = 10 )
341
-
342
- tensors_to_log = {'learning_rate' : 'learning_rate' , 'loss' : 'loss' }
343
-
344
- logging_hook = tf .train .LoggingTensorHook (
345
- tensors = tensors_to_log , every_n_iter = 100 )
346
-
347
- hooks = [logging_hook , examples_sec_hook ]
348
-
347
+
349
348
classifier = tf .estimator .Estimator (
350
349
model_fn = get_model_fn (num_gpus , variable_strategy ,
351
350
run_config .num_worker_replicas or 1 ),
352
351
config = run_config ,
353
352
params = hparams )
354
353
355
354
# Create experiment.
356
- experiment = tf .contrib .learn .Experiment (
355
+ return tf .contrib .learn .Experiment (
357
356
classifier ,
358
357
train_input_fn = train_input_fn ,
359
358
eval_input_fn = eval_input_fn ,
360
359
train_steps = train_steps ,
361
360
eval_steps = eval_steps )
362
- # Adding hooks to be used by the estimator on training modes
363
- experiment .extend_train_hooks (hooks )
364
- return experiment
365
361
366
362
return _experiment_fn
367
363
@@ -386,7 +382,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
386
382
get_experiment_fn (data_dir , num_gpus , variable_strategy ,
387
383
use_distortion_for_training ),
388
384
run_config = config ,
389
- hparams = tf .contrib .training .HParams (** hparams ))
385
+ hparams = tf .contrib .training .HParams (
386
+ is_chief = config .is_chief ,
387
+ ** hparams ))
390
388
391
389
392
390
if __name__ == '__main__' :
0 commit comments