Skip to content

Commit 0e09477

Browse files
authored
Merge pull request tensorflow#2443 from tensorflow/flags
Flags
2 parents 1630da3 + 8815a86 commit 0e09477

File tree

4 files changed

+60
-40
lines changed

4 files changed

+60
-40
lines changed

tutorials/image/cifar10/cifar10.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from __future__ import division
3636
from __future__ import print_function
3737

38+
import argparse
3839
import os
3940
import re
4041
import sys
@@ -45,15 +46,19 @@
4546

4647
import cifar10_input
4748

48-
FLAGS = tf.app.flags.FLAGS
49+
parser = argparse.ArgumentParser()
4950

5051
# Basic model parameters.
51-
tf.app.flags.DEFINE_integer('batch_size', 128,
52-
"""Number of images to process in a batch.""")
53-
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
54-
"""Path to the CIFAR-10 data directory.""")
55-
tf.app.flags.DEFINE_boolean('use_fp16', False,
56-
"""Train the model using fp16.""")
52+
parser.add_argument('--batch_size', type=int, default=128,
53+
help='Number of images to process in a batch.')
54+
55+
parser.add_argument('--data_dir', type=str, default='/tmp/cifar10_data',
56+
help='Path to the CIFAR-10 data directory.')
57+
58+
parser.add_argument('--use_fp16', type=bool, default=False,
59+
help='Train the model using fp16.')
60+
61+
FLAGS = parser.parse_args()
5762

5863
# Global constants describing the CIFAR-10 data set.
5964
IMAGE_SIZE = cifar10_input.IMAGE_SIZE

tutorials/image/cifar10/cifar10_eval.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from __future__ import division
3535
from __future__ import print_function
3636

37+
import argparse
3738
from datetime import datetime
3839
import math
3940
import time
@@ -43,20 +44,25 @@
4344

4445
import cifar10
4546

46-
FLAGS = tf.app.flags.FLAGS
47+
parser = cifar10.parser
4748

48-
tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
49-
"""Directory where to write event logs.""")
50-
tf.app.flags.DEFINE_string('eval_data', 'test',
51-
"""Either 'test' or 'train_eval'.""")
52-
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
53-
"""Directory where to read model checkpoints.""")
54-
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
55-
"""How often to run the eval.""")
56-
tf.app.flags.DEFINE_integer('num_examples', 10000,
57-
"""Number of examples to run.""")
58-
tf.app.flags.DEFINE_boolean('run_once', False,
59-
"""Whether to run eval only once.""")
49+
parser.add_argument('--eval_dir', type=str, default='/tmp/cifar10_eval',
50+
help='Directory where to write event logs.')
51+
52+
parser.add_argument('--eval_data', type=str, default='test',
53+
help='Either `test` or `train_eval`.')
54+
55+
parser.add_argument('--checkpoint_dir', type=str, default='/tmp/cifar10_train',
56+
help='Directory where to read model checkpoints.')
57+
58+
parser.add_argument('--eval_interval_secs', type=int, default=60*5,
59+
help='How often to run the eval.')
60+
61+
parser.add_argument('--num_examples', type=int, default=10000,
62+
help='Number of examples to run.')
63+
64+
parser.add_argument('--run_once', type=bool, default=False,
65+
help='Whether to run eval only once.')
6066

6167

6268
def eval_once(saver, summary_writer, top_k_op, summary_op):
@@ -154,4 +160,5 @@ def main(argv=None): # pylint: disable=unused-argument
154160

155161

156162
if __name__ == '__main__':
163+
FLAGS = parser.parse_args()
157164
tf.app.run()

tutorials/image/cifar10/cifar10_multi_gpu_train.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from __future__ import division
4040
from __future__ import print_function
4141

42+
import argparse
4243
from datetime import datetime
4344
import os.path
4445
import re
@@ -49,17 +50,19 @@
4950
import tensorflow as tf
5051
import cifar10
5152

52-
FLAGS = tf.app.flags.FLAGS
53+
parser = cifar10.parser
5354

54-
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
55-
"""Directory where to write event logs """
56-
"""and checkpoint.""")
57-
tf.app.flags.DEFINE_integer('max_steps', 1000000,
58-
"""Number of batches to run.""")
59-
tf.app.flags.DEFINE_integer('num_gpus', 1,
60-
"""How many GPUs to use.""")
61-
tf.app.flags.DEFINE_boolean('log_device_placement', False,
62-
"""Whether to log device placement.""")
55+
parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train',
56+
help='Directory where to write event logs and checkpoint.')
57+
58+
parser.add_argument('--max_steps', type=int, default=1000000,
59+
help='Number of batches to run.')
60+
61+
parser.add_argument('--num_gpus', type=int, default=1,
62+
help='How many GPUs to use.')
63+
64+
parser.add_argument('--log_device_placement', type=bool, default=False,
65+
help='Whether to log device placement.')
6366

6467

6568
def tower_loss(scope, images, labels):
@@ -274,4 +277,5 @@ def main(argv=None): # pylint: disable=unused-argument
274277

275278

276279
if __name__ == '__main__':
280+
FLAGS = parser.parse_args()
277281
tf.app.run()

tutorials/image/cifar10/cifar10_train.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,27 @@
3636
from __future__ import division
3737
from __future__ import print_function
3838

39+
import argparse
3940
from datetime import datetime
4041
import time
4142

4243
import tensorflow as tf
4344

4445
import cifar10
4546

46-
FLAGS = tf.app.flags.FLAGS
47+
parser = cifar10.parser
4748

48-
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
49-
"""Directory where to write event logs """
50-
"""and checkpoint.""")
51-
tf.app.flags.DEFINE_integer('max_steps', 1000000,
52-
"""Number of batches to run.""")
53-
tf.app.flags.DEFINE_boolean('log_device_placement', False,
54-
"""Whether to log device placement.""")
55-
tf.app.flags.DEFINE_integer('log_frequency', 10,
56-
"""How often to log results to the console.""")
49+
parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train',
50+
help='Directory where to write event logs and checkpoint.')
51+
52+
parser.add_argument('--max_steps', type=int, default=1000000,
53+
help='Number of batches to run.')
54+
55+
parser.add_argument('--log_device_placement', type=bool, default=False,
56+
help='Whether to log device placement.')
57+
58+
parser.add_argument('--log_frequency', type=int, default=10,
59+
help='How often to log results to the console.')
5760

5861

5962
def train():
@@ -124,4 +127,5 @@ def main(argv=None): # pylint: disable=unused-argument
124127

125128

126129
if __name__ == '__main__':
130+
FLAGS = parser.parse_args()
127131
tf.app.run()

0 commit comments

Comments
 (0)