71
71
'validation' : 10000 ,
72
72
}
73
73
74
- _SHUFFLE_BUFFER = 20000
75
-
76
74
77
75
def record_dataset (filenames ):
78
76
"""Returns an input pipeline Dataset from `filenames`."""
79
77
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
80
- return tf .contrib . data .FixedLengthRecordDataset (filenames , record_bytes )
78
+ return tf .data .FixedLengthRecordDataset (filenames , record_bytes )
81
79
82
80
83
81
def get_filenames (is_training , data_dir ):
@@ -97,74 +95,77 @@ def get_filenames(is_training, data_dir):
97
95
return [os .path .join (data_dir , 'test_batch.bin' )]
98
96
99
97
100
- def dataset_parser ( value ):
101
- """Parse a CIFAR-10 record from value ."""
98
+ def parse_record ( raw_record ):
99
+ """Parse CIFAR-10 image and label from a raw record ."""
102
100
# Every record consists of a label followed by the image, with a fixed number
103
101
# of bytes for each.
104
102
label_bytes = 1
105
103
image_bytes = _HEIGHT * _WIDTH * _DEPTH
106
104
record_bytes = label_bytes + image_bytes
107
105
108
- # Convert from a string to a vector of uint8 that is record_bytes long.
109
- raw_record = tf .decode_raw (value , tf .uint8 )
106
+ # Convert bytes to a vector of uint8 that is record_bytes long.
107
+ record_vector = tf .decode_raw (raw_record , tf .uint8 )
110
108
111
- # The first byte represents the label, which we convert from uint8 to int32.
112
- label = tf .cast (raw_record [0 ], tf .int32 )
109
+ # The first byte represents the label, which we convert from uint8 to int32
110
+ # and then to one-hot.
111
+ label = tf .cast (record_vector [0 ], tf .int32 )
112
+ label = tf .one_hot (label , _NUM_CLASSES )
113
113
114
114
# The remaining bytes after the label represent the image, which we reshape
115
115
# from [depth * height * width] to [depth, height, width].
116
- depth_major = tf .reshape (raw_record [ label_bytes : record_bytes ],
117
- [_DEPTH , _HEIGHT , _WIDTH ])
116
+ depth_major = tf .reshape (
117
+ record_vector [ label_bytes : record_bytes ], [_DEPTH , _HEIGHT , _WIDTH ])
118
118
119
119
# Convert from [depth, height, width] to [height, width, depth], and cast as
120
120
# float32.
121
121
image = tf .cast (tf .transpose (depth_major , [1 , 2 , 0 ]), tf .float32 )
122
122
123
- return image , tf . one_hot ( label , _NUM_CLASSES )
123
+ return image , label
124
124
125
125
126
- def train_preprocess_fn (image , label ):
127
- """Preprocess a single training image of layout [height, width, depth]."""
128
- # Resize the image to add four extra pixels on each side.
129
- image = tf .image .resize_image_with_crop_or_pad (image , _HEIGHT + 8 , _WIDTH + 8 )
126
+ def preprocess_image (image , is_training ):
127
+ """Preprocess a single image of layout [height, width, depth]."""
128
+ if is_training :
129
+ # Resize the image to add four extra pixels on each side.
130
+ image = tf .image .resize_image_with_crop_or_pad (
131
+ image , _HEIGHT + 8 , _WIDTH + 8 )
130
132
131
- # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
132
- image = tf .random_crop (image , [_HEIGHT , _WIDTH , _DEPTH ])
133
+ # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
134
+ image = tf .random_crop (image , [_HEIGHT , _WIDTH , _DEPTH ])
133
135
134
- # Randomly flip the image horizontally.
135
- image = tf .image .random_flip_left_right (image )
136
+ # Randomly flip the image horizontally.
137
+ image = tf .image .random_flip_left_right (image )
136
138
137
- return image , label
139
+ # Subtract off the mean and divide by the variance of the pixels.
140
+ image = tf .image .per_image_standardization (image )
141
+ return image
138
142
139
143
140
144
def input_fn (is_training , data_dir , batch_size , num_epochs = 1 ):
141
- """Input_fn using the contrib .data input pipeline for CIFAR-10 dataset.
145
+ """Input_fn using the tf .data input pipeline for CIFAR-10 dataset.
142
146
143
147
Args:
144
148
is_training: A boolean denoting whether the input is for training.
149
+ data_dir: The directory containing the input data.
150
+ batch_size: The number of samples per batch.
145
151
num_epochs: The number of epochs to repeat the dataset.
146
152
147
153
Returns:
148
154
A tuple of images and labels.
149
155
"""
150
156
dataset = record_dataset (get_filenames (is_training , data_dir ))
151
- dataset = dataset .map (dataset_parser , num_threads = 1 ,
152
- output_buffer_size = 2 * batch_size )
153
157
154
- # For training, preprocess the image and shuffle.
155
158
if is_training :
156
- dataset = dataset .map (train_preprocess_fn , num_threads = 1 ,
157
- output_buffer_size = 2 * batch_size )
158
-
159
159
# When choosing shuffle buffer sizes, larger sizes result in better
160
- # randomness, while smaller sizes have better performance.
161
- dataset = dataset .shuffle (buffer_size = _SHUFFLE_BUFFER )
160
+ # randomness, while smaller sizes have better performance. Because CIFAR-10
161
+ # is a relatively small dataset, we choose to shuffle the full epoch.
162
+ dataset = dataset .shuffle (buffer_size = _NUM_IMAGES ['train' ])
162
163
163
- # Subtract off the mean and divide by the variance of the pixels.
164
+ dataset = dataset . map ( parse_record )
164
165
dataset = dataset .map (
165
- lambda image , label : (tf . image . per_image_standardization (image ), label ),
166
- num_threads = 1 ,
167
- output_buffer_size = 2 * batch_size )
166
+ lambda image , label : (preprocess_image (image , is_training ), label ))
167
+
168
+ dataset = dataset . prefetch ( 2 * batch_size )
168
169
169
170
# We call repeat after shuffling, rather than before, to prevent separate
170
171
# epochs from blending together.
0 commit comments