Skip to content

Commit dd9a81c

Browse files
authored
Merge pull request tensorflow#2343 from elibixby/fixsyncreplicausage
Fix incorrect SyncReplicasOptimizer usage
2 parents 78bb025 + 28d37e7 commit dd9a81c

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

tutorials/image/cifar10_estimator/cifar10_main.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,27 @@ def _resnet_model_fn(features, labels, mode, params):
153153

154154
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
155155
boundaries, staged_lr)
156-
# Create a nicely-named tensor for logging
157-
learning_rate = tf.identity(learning_rate, name='learning_rate')
156+
157+
loss = tf.reduce_mean(tower_losses, name='loss')
158+
159+
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
160+
params.train_batch_size, every_n_steps=10)
161+
162+
tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}
163+
164+
logging_hook = tf.train.LoggingTensorHook(
165+
tensors=tensors_to_log, every_n_iter=100)
166+
167+
train_hooks = [logging_hook, examples_sec_hook]
158168

159169
optimizer = tf.train.MomentumOptimizer(
160170
learning_rate=learning_rate, momentum=momentum)
161171

162-
chief_hooks = []
163172
if params.sync:
164173
optimizer = tf.train.SyncReplicasOptimizer(
165174
optimizer, replicas_to_aggregate=num_workers)
166-
sync_replicas_hook = optimizer.make_session_run_hook(True)
167-
chief_hooks.append(sync_replicas_hook)
175+
sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief)
176+
train_hooks.append(sync_replicas_hook)
168177

169178
# Create single grouped train op
170179
train_op = [
@@ -185,14 +194,13 @@ def _resnet_model_fn(features, labels, mode, params):
185194
'accuracy':
186195
tf.metrics.accuracy(stacked_labels, predictions['classes'])
187196
}
188-
loss = tf.reduce_mean(tower_losses, name='loss')
189197

190198
return tf.estimator.EstimatorSpec(
191199
mode=mode,
192200
predictions=predictions,
193201
loss=loss,
194202
train_op=train_op,
195-
training_chief_hooks=chief_hooks,
203+
training_hooks=train_hooks,
196204
eval_metric_ops=metrics)
197205

198206
return _resnet_model_fn
@@ -336,32 +344,20 @@ def _experiment_fn(run_config, hparams):
336344

337345
train_steps = hparams.train_steps
338346
eval_steps = num_eval_examples // hparams.eval_batch_size
339-
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
340-
hparams.train_batch_size, every_n_steps=10)
341-
342-
tensors_to_log = {'learning_rate': 'learning_rate', 'loss': 'loss'}
343-
344-
logging_hook = tf.train.LoggingTensorHook(
345-
tensors=tensors_to_log, every_n_iter=100)
346-
347-
hooks = [logging_hook, examples_sec_hook]
348-
347+
349348
classifier = tf.estimator.Estimator(
350349
model_fn=get_model_fn(num_gpus, variable_strategy,
351350
run_config.num_worker_replicas or 1),
352351
config=run_config,
353352
params=hparams)
354353

355354
# Create experiment.
356-
experiment = tf.contrib.learn.Experiment(
355+
return tf.contrib.learn.Experiment(
357356
classifier,
358357
train_input_fn=train_input_fn,
359358
eval_input_fn=eval_input_fn,
360359
train_steps=train_steps,
361360
eval_steps=eval_steps)
362-
# Adding hooks to be used by the estimator on training modes
363-
experiment.extend_train_hooks(hooks)
364-
return experiment
365361

366362
return _experiment_fn
367363

@@ -386,7 +382,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
386382
get_experiment_fn(data_dir, num_gpus, variable_strategy,
387383
use_distortion_for_training),
388384
run_config=config,
389-
hparams=tf.contrib.training.HParams(**hparams))
385+
hparams=tf.contrib.training.HParams(
386+
is_chief=config.is_chief,
387+
**hparams))
390388

391389

392390
if __name__ == '__main__':

0 commit comments

Comments
 (0)