11import tensorflow as tf
2- from ..layers .utils import combined_dnn_input
32
4- def input_fn_pandas (df , features , label = None , batch_size = 256 , num_epochs = 1 , shuffle = False , queue_capacity = 2560 ,
3+
4+ def input_fn_pandas (df , features , label = None , batch_size = 256 , num_epochs = 1 , shuffle = False , queue_capacity_factor = 10 ,
55 num_threads = 1 ):
6- """
7-
8- :param df:
9- :param features:
10- :param label:
11- :param batch_size:
12- :param num_epochs:
13- :param shuffle:
14- :param queue_capacity:
15- :param num_threads:
16- :return:
17- """
186 if label is not None :
197 y = df [label ]
208 else :
219 y = None
2210 if tf .__version__ >= "2.0.0" :
2311 return tf .compat .v1 .estimator .inputs .pandas_input_fn (df [features ], y , batch_size = batch_size ,
2412 num_epochs = num_epochs ,
25- shuffle = shuffle , queue_capacity = queue_capacity ,
13+ shuffle = shuffle ,
14+ queue_capacity = batch_size * queue_capacity_factor ,
2615 num_threads = num_threads )
2716
2817 return tf .estimator .inputs .pandas_input_fn (df [features ], y , batch_size = batch_size , num_epochs = num_epochs ,
29- shuffle = shuffle , queue_capacity = queue_capacity , num_threads = num_threads )
18+ shuffle = shuffle , queue_capacity = batch_size * queue_capacity_factor ,
19+ num_threads = num_threads )
3020
3121
32- def input_fn_tfrecord (filenames , feature_description , label = None , batch_size = 256 , num_epochs = 1 , shuffle = False ,
33- num_parallel_calls = 10 ):
22+ def input_fn_tfrecord (filenames , feature_description , label = None , batch_size = 256 , num_epochs = 1 , num_parallel_calls = 8 ,
23+ shuffle_factor = 10 , prefetch_factor = 1 ,
24+ ):
3425 def _parse_examples (serial_exmp ):
3526 features = tf .parse_single_example (serial_exmp , features = feature_description )
3627 if label is not None :
@@ -40,16 +31,17 @@ def _parse_examples(serial_exmp):
4031
4132 def input_fn ():
4233 dataset = tf .data .TFRecordDataset (filenames )
43- dataset = dataset .map (_parse_examples , num_parallel_calls = num_parallel_calls ).prefetch (
44- buffer_size = batch_size * 10 )
45- if shuffle :
46- dataset = dataset .shuffle (buffer_size = batch_size * 10 )
34+ dataset = dataset .map (_parse_examples , num_parallel_calls = num_parallel_calls )
35+ if shuffle_factor > 0 :
36+ dataset = dataset .shuffle (buffer_size = batch_size * shuffle_factor )
4737
4838 dataset = dataset .repeat (num_epochs ).batch (batch_size )
39+
40+ if prefetch_factor > 0 :
41+ dataset = dataset .prefetch (buffer_size = batch_size * prefetch_factor )
42+
4943 iterator = dataset .make_one_shot_iterator ()
5044
5145 return iterator .get_next ()
5246
5347 return input_fn
54-
55-
0 commit comments