Skip to content

Commit d34cf18

Browse files
committed
sort of works...still a namespace problem in _train
1 parent 16454bd commit d34cf18

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

tutorials/image/cifar10/cifar10.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,20 @@
4040
import sys
4141
import tarfile
4242

43+
import argparse
4344
from six.moves import urllib
4445
import tensorflow as tf
4546

4647
import cifar10_input
4748

48-
FLAGS = tf.app.flags.FLAGS
49+
parser = argparse.ArgumentParser()
4950

5051
# Basic model parameters.
51-
tf.app.flags.DEFINE_integer('batch_size', 128,
52-
"""Number of images to process in a batch.""")
53-
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
54-
"""Path to the CIFAR-10 data directory.""")
55-
tf.app.flags.DEFINE_boolean('use_fp16', False,
56-
"""Train the model using fp16.""")
52+
parser.add_argument('--batch_size', type=int, default=128, help='Number of images to process in a batch.')
53+
parser.add_argument('--data_dir', type=str, default='/tmp/cifar10_data', help='Path to the CIFAR-10 data directory.')
54+
parser.add_argument('--use_fp16', type=bool, default=False, help='Train the model using fp16.')
5755

56+
FLAGS = parser.parse_args()
5857
# Global constants describing the CIFAR-10 data set.
5958
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
6059
NUM_CLASSES = cifar10_input.NUM_CLASSES

tutorials/image/cifar10/cifar10_multi_gpu_train.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from __future__ import division
4040
from __future__ import print_function
4141

42+
import argparse
4243
from datetime import datetime
4344
import os.path
4445
import re
@@ -49,17 +50,15 @@
4950
import tensorflow as tf
5051
import cifar10
5152

52-
FLAGS = tf.app.flags.FLAGS
53+
parser = argparse.ArgumentParser()
5354

54-
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
55-
"""Directory where to write event logs """
56-
"""and checkpoint.""")
57-
tf.app.flags.DEFINE_integer('max_steps', 1000000,
58-
"""Number of batches to run.""")
59-
tf.app.flags.DEFINE_integer('num_gpus', 1,
60-
"""How many GPUs to use.""")
61-
tf.app.flags.DEFINE_boolean('log_device_placement', False,
62-
"""Whether to log device placement.""")
55+
parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train', help='Directory where to write event logs and checkpoint.')
56+
57+
parser.add_argument('--max_steps', type=int, default=1000000, help='Number of batches to run.')
58+
59+
parser.add_argument('--num_gpus', type=int, default=1, help='How many GPUs to use.')
60+
61+
parser.add_argument('--log_device_placement', type=bool, default=False, help='Whether to log device placement.')
6362

6463

6564
def tower_loss(scope, images, labels):
@@ -274,4 +273,5 @@ def main(argv=None): # pylint: disable=unused-argument
274273

275274

276275
if __name__ == '__main__':
276+
FLAGS = parser.parse_args()
277277
tf.app.run()

0 commit comments

Comments
 (0)