Skip to content

Commit f88def2

Browse files
authored
Merge pull request tensorflow#2690 from tensorflow/tf-data
Changing tf.contrib.data to tf.data for release of tf 1.4
2 parents 4cfa0d3 + ae5adb5 commit f88def2

File tree

5 files changed

+49
-47
lines changed

5 files changed

+49
-47
lines changed

official/mnist/mnist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
56-
"""A simple input_fn using the contrib.data input pipeline."""
56+
"""A simple input_fn using the tf.data input pipeline."""
5757

5858
def example_parser(serialized_example):
5959
"""Parses a single tf.Example into image and label tensors."""
@@ -71,8 +71,9 @@ def example_parser(serialized_example):
7171
label = tf.cast(features['label'], tf.int32)
7272
return image, tf.one_hot(label, 10)
7373

74-
dataset = tf.contrib.data.TFRecordDataset([filename])
74+
dataset = tf.data.TFRecordDataset([filename])
7575

76+
# Apply dataset transformations
7677
if is_training:
7778
# When choosing shuffle buffer sizes, larger sizes result in better
7879
# randomness, while smaller sizes have better performance. Because MNIST is
@@ -84,8 +85,7 @@ def example_parser(serialized_example):
8485
dataset = dataset.repeat(num_epochs)
8586

8687
# Map example_parser over dataset, and batch results by up to batch_size
87-
dataset = dataset.map(
88-
example_parser, num_threads=1, output_buffer_size=batch_size)
88+
dataset = dataset.map(example_parser).prefetch(batch_size)
8989
dataset = dataset.batch(batch_size)
9090
iterator = dataset.make_one_shot_iterator()
9191
images, labels = iterator.get_next()

official/resnet/cifar10_main.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,11 @@
7171
'validation': 10000,
7272
}
7373

74-
_SHUFFLE_BUFFER = 20000
75-
7674

7775
def record_dataset(filenames):
7876
"""Returns an input pipeline Dataset from `filenames`."""
7977
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
80-
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
78+
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
8179

8280

8381
def get_filenames(is_training, data_dir):
@@ -97,74 +95,77 @@ def get_filenames(is_training, data_dir):
9795
return [os.path.join(data_dir, 'test_batch.bin')]
9896

9997

100-
def dataset_parser(value):
101-
"""Parse a CIFAR-10 record from value."""
98+
def parse_record(raw_record):
99+
"""Parse CIFAR-10 image and label from a raw record."""
102100
# Every record consists of a label followed by the image, with a fixed number
103101
# of bytes for each.
104102
label_bytes = 1
105103
image_bytes = _HEIGHT * _WIDTH * _DEPTH
106104
record_bytes = label_bytes + image_bytes
107105

108-
# Convert from a string to a vector of uint8 that is record_bytes long.
109-
raw_record = tf.decode_raw(value, tf.uint8)
106+
# Convert bytes to a vector of uint8 that is record_bytes long.
107+
record_vector = tf.decode_raw(raw_record, tf.uint8)
110108

111-
# The first byte represents the label, which we convert from uint8 to int32.
112-
label = tf.cast(raw_record[0], tf.int32)
109+
# The first byte represents the label, which we convert from uint8 to int32
110+
# and then to one-hot.
111+
label = tf.cast(record_vector[0], tf.int32)
112+
label = tf.one_hot(label, _NUM_CLASSES)
113113

114114
# The remaining bytes after the label represent the image, which we reshape
115115
# from [depth * height * width] to [depth, height, width].
116-
depth_major = tf.reshape(raw_record[label_bytes:record_bytes],
117-
[_DEPTH, _HEIGHT, _WIDTH])
116+
depth_major = tf.reshape(
117+
record_vector[label_bytes:record_bytes], [_DEPTH, _HEIGHT, _WIDTH])
118118

119119
# Convert from [depth, height, width] to [height, width, depth], and cast as
120120
# float32.
121121
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
122122

123-
return image, tf.one_hot(label, _NUM_CLASSES)
123+
return image, label
124124

125125

126-
def train_preprocess_fn(image, label):
127-
"""Preprocess a single training image of layout [height, width, depth]."""
128-
# Resize the image to add four extra pixels on each side.
129-
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
126+
def preprocess_image(image, is_training):
127+
"""Preprocess a single image of layout [height, width, depth]."""
128+
if is_training:
129+
# Resize the image to add four extra pixels on each side.
130+
image = tf.image.resize_image_with_crop_or_pad(
131+
image, _HEIGHT + 8, _WIDTH + 8)
130132

131-
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
132-
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
133+
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
134+
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
133135

134-
# Randomly flip the image horizontally.
135-
image = tf.image.random_flip_left_right(image)
136+
# Randomly flip the image horizontally.
137+
image = tf.image.random_flip_left_right(image)
136138

137-
return image, label
139+
# Subtract off the mean and divide by the variance of the pixels.
140+
image = tf.image.per_image_standardization(image)
141+
return image
138142

139143

140144
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
141-
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
145+
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
142146
143147
Args:
144148
is_training: A boolean denoting whether the input is for training.
149+
data_dir: The directory containing the input data.
150+
batch_size: The number of samples per batch.
145151
num_epochs: The number of epochs to repeat the dataset.
146152
147153
Returns:
148154
A tuple of images and labels.
149155
"""
150156
dataset = record_dataset(get_filenames(is_training, data_dir))
151-
dataset = dataset.map(dataset_parser, num_threads=1,
152-
output_buffer_size=2 * batch_size)
153157

154-
# For training, preprocess the image and shuffle.
155158
if is_training:
156-
dataset = dataset.map(train_preprocess_fn, num_threads=1,
157-
output_buffer_size=2 * batch_size)
158-
159159
# When choosing shuffle buffer sizes, larger sizes result in better
160-
# randomness, while smaller sizes have better performance.
161-
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
160+
# randomness, while smaller sizes have better performance. Because CIFAR-10
161+
# is a relatively small dataset, we choose to shuffle the full epoch.
162+
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
162163

163-
# Subtract off the mean and divide by the variance of the pixels.
164+
dataset = dataset.map(parse_record)
164165
dataset = dataset.map(
165-
lambda image, label: (tf.image.per_image_standardization(image), label),
166-
num_threads=1,
167-
output_buffer_size=2 * batch_size)
166+
lambda image, label: (preprocess_image(image, is_training), label))
167+
168+
dataset = dataset.prefetch(2 * batch_size)
168169

169170
# We call repeat after shuffling, rather than before, to prevent separate
170171
# epochs from blending together.

official/resnet/cifar10_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_dataset_input_fn(self):
4444
data_file.close()
4545

4646
fake_dataset = cifar10_main.record_dataset(filename)
47-
fake_dataset = fake_dataset.map(cifar10_main.dataset_parser)
47+
fake_dataset = fake_dataset.map(cifar10_main.parse_record)
4848
image, label = fake_dataset.make_one_shot_iterator().get_next()
4949

5050
self.assertEqual(label.get_shape().as_list(), [10])

official/resnet/imagenet_main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,23 +134,23 @@ def dataset_parser(value, is_training):
134134

135135
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
136136
"""Input function which provides batches for train or eval."""
137-
dataset = tf.contrib.data.Dataset.from_tensor_slices(
137+
dataset = tf.data.Dataset.from_tensor_slices(
138138
filenames(is_training, data_dir))
139139

140140
if is_training:
141141
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
142142

143-
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
144-
145-
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
146-
num_threads=5,
147-
output_buffer_size=batch_size)
143+
dataset = dataset.flat_map(tf.data.TFRecordDataset)
148144

149145
if is_training:
150146
# When choosing shuffle buffer sizes, larger sizes result in better
151147
# randomness, while smaller sizes have better performance.
152148
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
153149

150+
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
151+
num_parallel_calls=5)
152+
dataset = dataset.prefetch(batch_size)
153+
154154
# We call repeat after shuffling, rather than before, to prevent separate
155155
# epochs from blending together.
156156
dataset = dataset.repeat(num_epochs)

official/wide_deep/wide_deep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,13 @@ def parse_csv(value):
178178
return features, tf.equal(labels, '>50K')
179179

180180
# Extract lines from input files using the Dataset API.
181-
dataset = tf.contrib.data.TextLineDataset(data_file)
182-
dataset = dataset.map(parse_csv, num_threads=5)
181+
dataset = tf.data.TextLineDataset(data_file)
183182

184183
if shuffle:
185184
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
186185

186+
dataset = dataset.map(parse_csv, num_parallel_calls=5)
187+
187188
# We call repeat after shuffling, rather than before, to prevent separate
188189
# epochs from blending together.
189190
dataset = dataset.repeat(num_epochs)

0 commit comments

Comments
 (0)