|
39 | 39 | from __future__ import division
|
40 | 40 | from __future__ import print_function
|
41 | 41 |
|
| 42 | +import argparse |
42 | 43 | from datetime import datetime
|
43 | 44 | import os.path
|
44 | 45 | import re
|
|
49 | 50 | import tensorflow as tf
|
50 | 51 | import cifar10
|
51 | 52 |
|
52 |
| -FLAGS = tf.app.flags.FLAGS |
| 53 | +parser = argparse.ArgumentParser() |
53 | 54 |
|
54 |
| -tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', |
55 |
| - """Directory where to write event logs """ |
56 |
| - """and checkpoint.""") |
57 |
| -tf.app.flags.DEFINE_integer('max_steps', 1000000, |
58 |
| - """Number of batches to run.""") |
59 |
| -tf.app.flags.DEFINE_integer('num_gpus', 1, |
60 |
| - """How many GPUs to use.""") |
61 |
| -tf.app.flags.DEFINE_boolean('log_device_placement', False, |
62 |
| - """Whether to log device placement.""") |
| 55 | +parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train', help='Directory where to write event logs and checkpoint.') |
| 56 | + |
| 57 | +parser.add_argument('--max_steps', type=int, default=1000000, help='Number of batches to run.') |
| 58 | + |
| 59 | +parser.add_argument('--num_gpus', type=int, default=1, help='How many GPUs to use.') |
| 60 | + |
| 61 | +parser.add_argument('--log_device_placement', type=bool, default=False, help='Whether to log device placement.') |
63 | 62 |
|
64 | 63 |
|
65 | 64 | def tower_loss(scope, images, labels):
|
@@ -274,4 +273,5 @@ def main(argv=None): # pylint: disable=unused-argument
|
274 | 273 |
|
275 | 274 |
|
276 | 275 | if __name__ == '__main__':
|
| 276 | + FLAGS = parser.parse_args() |
277 | 277 | tf.app.run()
|
0 commit comments