easy_rec/python/protos/variational_dropout.proto
@@ -1061,6 +1184,51 @@ AutoInt
+
+
easy_rec/python/protos/cmbf.proto Top
+
+
+
+
+ CMBF
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ config
+ CMBFTower
+ required
+
+
+
+
+ final_dnn
+ DNN
+ required
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/collaborative_metric_learning.proto Top
@@ -1252,7 +1420,15 @@ DatasetConfig
shard
bool
optional
- shard dataset to 1/num_workers in distribute mode Default: false
+ shard dataset to 1/num_workers in distribute mode
+this param is not used anymore Default: false
+
+
+
+ file_shard
+ bool
+ optional
+ shard by file, not by sample, valid only for CSVInput Default: false
@@ -1370,6 +1546,23 @@ DatasetConfig
n data for one feature in tfrecord
+
+ with_header
+ bool
+ optional
+ for csv files, may optionally with an header
+in that case, input_name must match header name,
+and the number and the order of input_fields
+may not be the same as that in csv files. Default: false
+
+
+
+ feature_fields
+ string
+ repeated
+
+
+
negative_sampler
NegativeSampler
@@ -1398,6 +1591,20 @@ DatasetConfig
+
+ negative_sampler_in_memory
+ NegativeSamplerInMemory
+ optional
+
+
+
+
+ eval_batch_size
+ uint32
+ optional
+ Default: 4096
+
+
@@ -1540,6 +1747,13 @@ HardNegativeSampler
Default: 0
+
+ field_delimiter
+ string
+ optional
+ only works on DataScience/Local Default:
+
+
@@ -1638,6 +1852,13 @@ HardNegativeSamplerV2
Default: 0
+
+ field_delimiter
+ string
+ optional
+ only works on DataScience/Local Default:
+
+
@@ -1698,6 +1919,13 @@ NegativeSampler
Default: 0
+
+ field_delimiter
+ string
+ optional
+ only works on DataScience/Local Default:
+
+
@@ -1705,8 +1933,8 @@ NegativeSampler
- NegativeSamplerV2
- Weighted Random Sampling ItemID not with Edge
+ NegativeSamplerInMemory
+
@@ -1716,16 +1944,83 @@ NegativeSamplerV2
- user_input_path
+ input_path
string
required
- user data path
-userid weight
+ sample data path
+itemid weight attrs
- item_input_path
- string
+ num_sample
+ uint32
+ required
+ number of negative sample
+
+
+
+ attr_fields
+ string
+ repeated
+ field names of attrs in train data or eval data
+
+
+
+ item_id_field
+ string
+ required
+ field name of item_id in train data or eval data
+
+
+
+ attr_delimiter
+ string
+ optional
+ Default: :
+
+
+
+ num_eval_sample
+ uint32
+ optional
+ Default: 0
+
+
+
+ field_delimiter
+ string
+ optional
+ only works on DataScience/Local Default:
+
+
+
+
+
+
+
+
+
+ NegativeSamplerV2
+ Weighted Random Sampling ItemID not with Edge
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ user_input_path
+ string
+ required
+ user data path
+userid weight
+
+
+
+ item_input_path
+ string
required
item data path
itemid weight attrs
@@ -1781,6 +2076,13 @@ NegativeSamplerV2
Default: 0
+
+ field_delimiter
+ string
+ optional
+ only works on DataScience/Local Default:
+
+
@@ -1848,7 +2150,8 @@
CSVInput
10
- csv format input, could be used in local or hdfs
+ csv format input, could be used in local or hdfs
+support .gz compression(but not .tar.gz files)
@@ -1901,7 +2204,13 @@
OdpsRTPInput
- 6
+ 601
+
+
+
+
+ OdpsRTPInputV2
+ 602
@@ -1930,6 +2239,30 @@
+
+ HiveInput
+ 16
+
+
+
+
+ HiveRTPInput
+ 17
+
+
+
+
+ HiveParquetInput
+ 18
+
+
+
+
+ CriteoInput
+ 1001
+
+
+
@@ -1945,6 +2278,44 @@ easy_rec/python/protos/data_so
+
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ category_path
+ string
+ repeated
+ support gfile.Glob
+
+
+
+ dense_path
+ string
+ repeated
+
+
+
+
+ label_path
+ string
+ repeated
+
+
+
+
+
+
+
+
+
+
DatahubServer
@@ -1970,7 +2341,7 @@ DatahubServer
- region
+ endpoint
string
required
@@ -1991,17 +2362,19 @@ DatahubServer
- shard_num
- uint32
- required
-
+ offset_info
+ string
+ optional
+ in json format: {"0":{"cursor": ""}, "1":{"cursor":""}}
- life_cycle
- uint32
- required
-
+ offset_time
+ string
+ optional
+ offset_time could be two formats:
+1: %Y%m%d %H:%M:%S "20220508 12:00:00"
+2: %s "1651982400"
@@ -2043,17 +2416,33 @@ KafkaServer
- partitions
- uint32
- required
-
+ offset_info
+ string
+ optional
+ in json format: {'0':10, '1':20}
- offset
- uint32
+ offset_time
+ string
+ optional
+ offset_time could be two formats:
+1: %Y%m%d %H:%M:%S '20220508 12:00:00'
+2: %s '1651982400'
+
+
+
+ config_global
+ string
repeated
-
+ kafka global config, such as: fetch.max.bytes=1024
+
+
+
+ config_topic
+ string
+ repeated
+ kafka topic config, such as: max.partition.fetch.bytes=1024
@@ -2087,6 +2476,20 @@ DBMTL
+
+ bottom_cmbf
+ CMBFTower
+ optional
+ shared bottom cmbf layer
+
+
+
+ bottom_uniter
+ UniterTower
+ optional
+ shared bottom uniter layer
+
+
bottom_dnn
DNN
@@ -2300,6 +2703,79 @@ DeepFM
+
+
easy_rec/python/protos/dlrm.proto Top
+
+
+
+
+ DLRM
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ top_dnn
+ DNN
+ required
+
+
+
+
+ bot_dnn
+ DNN
+ required
+
+
+
+
+ arch_interaction_op
+ string
+ optional
+ options are: dot and cat Default: dot
+
+
+
+ arch_interaction_itself
+ bool
+ optional
+ whether a feature will interact with itself Default: false
+
+
+
+ arch_with_dense_feature
+ bool
+ optional
+ whether to include dense features after interaction Default: false
+
+
+
+ l2_regularization
+ float
+ optional
+ Default: 1e-05
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/dnn.proto Top
@@ -2511,6 +2987,20 @@ DSSM
add a layer for scaling the similarity Default: true
+
+ item_id
+ string
+ optional
+
+
+
+
+ ignore_in_batch_neg_sam
+ bool
+ required
+ Default: false
+
+
@@ -2558,12 +3048,19 @@ DSSMTower
-
easy_rec/python/protos/eas_serving.proto Top
+
easy_rec/python/protos/easy_rec_model.proto Top
- Config
+ DummyModel
+ for input performance test
+
+
+
+
+
+ EasyRecModel
@@ -2574,254 +3071,230 @@ Config
- column_delim
- string
-
- 例如输入特征为"1005,109;0;93eaba74",此时分号分割的为column,
-逗号分割的为每个column的多个feature, 下划线分割为feature名字和对应的value。
-
-
-
- feature_delim
+ model_class
string
-
+ required
- hash
- string
-
- 指定字符串hash分桶的算法,支持HarmHash(对应于tf.strings.to_hash_bucket_fast())
-和SipHash(对应于tf.strings.to_hash_bucket_strong())两种字符串hash分桶算法
+ feature_groups
+ FeatureGroupConfig
+ repeated
+ actually input layers, each layer produce a group of feature
- embeddings
- Config.EmbeddingsEntry
- repeated
- embedding_name to embedding
+ dummy
+ DummyModel
+ optional
+
- embedding_max_norm
- Config.EmbeddingMaxNormEntry
- repeated
- 指定embedding lookup的结果的最大L2-norm
+ wide_and_deep
+ WideAndDeep
+ optional
+
- embedding_combiner
- Config.EmbeddingCombinerEntry
- repeated
- 指定embedding的combiner策略,支持sum, mean和sqrtn
+ deepfm
+ DeepFM
+ optional
+
- model
- Model
-
+ multi_tower
+ MultiTower
+ optional
-
-
-
-
-
-
-
- Config.EmbeddingCombinerEntry
-
-
-
-
-
- Field Type Label Description
-
-
-
- key
- string
-
+ fm
+ FM
+ optional
- value
- string
-
+ dcn
+ DCN
+ optional
-
-
-
-
-
-
-
- Config.EmbeddingMaxNormEntry
-
-
-
-
-
- Field Type Label Description
-
-
-
- key
- string
-
+ autoint
+ AutoInt
+ optional
- value
- float
-
+ dlrm
+ DLRM
+ optional
-
-
-
-
-
-
-
- Config.EmbeddingsEntry
-
-
+
+ cmbf
+ CMBF
+ optional
+
+
-
+
+ mind
+ MIND
+ optional
+
+
+
+ dropoutnet
+ DropoutNet
+ optional
+
+
+
+ metric_learning
+ CoMetricLearningI2I
+ optional
+
+
+
+ mmoe
+ MMoE
+ optional
+
+
+
+ esmm
+ ESMM
+ optional
+
+
- Embedding
-
+
+ dbmtl
+ DBMTL
+ optional
+
+
+
+ simple_multi_task
+ SimpleMultiTask
+ optional
+
+
-
-
-
-
-
-
- EmbeddingPart
-
-
+
+ embedding_regularization
+ float
+ optional
+ implemented in easy_rec/python/model/easy_rec_estimator
+add regularization to all variables with "embedding_weights:"
+in name Default: 0
+
-
-
- Field Type Label Description
-
-
+
+ loss_type
+ LossType
+ optional
+ Default: CLASSIFICATION
+
- embedding_part_path
- string
-
- 指定EmbeddingPartData(*.pb)所在的路径
+ num_class
+ uint32
+ optional
+ Default: 1
- partition_id
- int32
-
- 指定该embedding part所属第几个part
+ ev_params
+ EVParams
+ optional
+
- shape
- int64
+ kd
+ KD
repeated
- 指定该embedding part的shape(可以从EmbeddingPartData中读取)
+
- deploy_strategy
+ restore_filters
string
-
- embedding part的部署策略, 支持本地部署(local)和远程部署(remote)
+ repeated
+ filter variables matching any pattern in restore_filters
+common filters are Adam, Momentum, etc.
-
-
-
-
-
-
-
- EmbeddingPartData
-
-
-
-
-
- Field Type Label Description
-
-
-
- shape
- int64
- repeated
- Shape of the embedding
+ variational_dropout
+ VariationalDropoutLayer
+ optional
+
- data
- float
+ losses
+ Loss
repeated
- Data
+
@@ -2831,8 +3304,8 @@ EmbeddingPartData
- Model
-
+ KD
+ for knowledge distillation
@@ -2842,24 +3315,59 @@ Model
- model_path
+ loss_name
string
-
- 指定模型所在路径,便于加载模型
+ optional
+
- model_signature_name
+ pred_name
string
-
- 指定模型的sinature的名字
+ required
+
- model_inputs
- ModelInput
- repeated
- model input description
+ pred_is_logits
+ bool
+ optional
+ default to be logits Default: true
+
+
+
+ soft_label_name
+ string
+ required
+ for CROSS_ENTROPY_LOSS, soft_label must be logits instead of probs
+
+
+
+ label_is_logits
+ bool
+ optional
+ default to be logits Default: true
+
+
+
+ loss_type
+ LossType
+ required
+ currently only support CROSS_ENTROPY_LOSS and L2_LOSS
+
+
+
+ loss_weight
+ float
+ optional
+ Default: 1
+
+
+
+ temperature
+ float
+ optional
+ only for loss_type == CROSS_ENTROPY_LOSS Default: 1
@@ -2869,7 +3377,21 @@ Model
-
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/esmm.proto Top
+
+
+
+
+ ESMM
@@ -2880,31 +3402,31 @@
- feature_name
- string
-
+ groups
+ Tower
+ repeated
- embedding_name
- string
-
+ ctr_tower
+ TaskTower
+ required
- placeholder_name
- string
-
+ cvr_tower
+ TaskTower
+ required
- weight_name
- string
-
-
+ l2_regularization
+ float
+ required
+ Default: 0.0001
@@ -2923,19 +3445,43 @@
-
easy_rec/python/protos/easy_rec_model.proto Top
+
easy_rec/python/protos/eval.proto Top
- DummyModel
- for input performance test
+ AUC
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ num_thresholds
+ uint32
+ optional
+ Default: 200
+
+
+
- EasyRecModel
+
+ Accuracy
+
+
+
+
+
+
+ AvgPrecisionAtTopK
@@ -2946,201 +3492,170 @@ EasyRecModel
- model_class
- string
- required
-
+ topk
+ uint32
+ optional
+ Default: 5
-
- feature_groups
- FeatureGroupConfig
- repeated
- actually input layers, each layer produce a group of feature
-
+
+
+
+
+
+
+
+ EvalConfig
+ Message for configuring EasyRecModel evaluation jobs (eval.py).
+
+
+
+
+ Field Type Label Description
+
+
- dummy
- DummyModel
+ num_examples
+ uint32
optional
-
+ Number of examples to process of evaluation. Default: 0
- wide_and_deep
- WideAndDeep
+ eval_interval_secs
+ uint32
optional
-
+ How often to run evaluation. Default: 300
- deepfm
- DeepFM
+ max_evals
+ uint32
optional
-
+ Maximum number of times to run evaluation. If set to 0, will run forever. Default: 0
- multi_tower
- MultiTower
+ save_graph
+ bool
optional
-
+ Whether the TensorFlow graph used for evaluation should be saved to disk. Default: false
- fm
- FM
- optional
-
+ metrics_set
+ EvalMetrics
+ repeated
+ Type of metrics to use for evaluation.
+possible values:
- dcn
- DCN
+ eval_online
+ bool
optional
-
+ Evaluation online with batch forward data of training Default: false
+
+
+
+
+
+
+
+ EvalMetrics
+
+
+
+
+
+ Field Type Label Description
+
+
+
- autoint
- AutoInt
+ auc
+ AUC
optional
- dssm
- DSSM
+ recall_at_topk
+ RecallAtTopK
optional
- mind
- MIND
+ mean_absolute_error
+ MeanAbsoluteError
optional
- dropoutnet
- DropoutNet
+ mean_squared_error
+ MeanSquaredError
optional
- metric_learning
- CoMetricLearningI2I
+ accuracy
+ Accuracy
optional
- mmoe
- MMoE
+ max_f1
+ Max_F1
optional
- esmm
- ESMM
+ root_mean_squared_error
+ RootMeanSquaredError
optional
- dbmtl
- DBMTL
+ gauc
+ GAUC
optional
- simple_multi_task
- SimpleMultiTask
+ session_auc
+ SessionAUC
optional
- ple
- PLE
+ recall
+ Recall
optional
- rocket_launching
- RocketLaunching
+ precision
+ Precision
optional
- seq_att_groups
- SeqAttGroupConfig
- repeated
-
-
-
-
- embedding_regularization
- float
- optional
- implemented in easy_rec/python/model/easy_rec_estimator
-add regularization to all variables with "embedding_weights:"
-in name Default: 0
-
-
-
- loss_type
- LossType
- optional
- Default: CLASSIFICATION
-
-
-
- num_class
- uint32
- optional
- Default: 1
-
-
-
- use_embedding_variable
- bool
- optional
- Default: false
-
-
-
- kd
- KD
- repeated
-
-
-
-
- restore_filters
- string
- repeated
- filter variables matching any pattern in restore_filters
-common filters are Adam, Momentum, etc.
-
-
-
- variational_dropout
- VariationalDropoutLayer
- optional
-
-
-
-
- losses
- Loss
- repeated
+ precision_at_topk
+ AvgPrecisionAtTopK
+ optional
@@ -3151,8 +3666,8 @@ EasyRecModel
- KD
- for knowledge distillation
+ GAUC
+
@@ -3162,59 +3677,20 @@ KD
- loss_name
- string
- optional
-
-
-
-
- pred_name
+ uid_field
string
required
-
-
-
-
- pred_is_logits
- bool
- optional
- default to be logits Default: true
+ uid field name
- soft_label_name
+ reduction
string
- required
- for CROSS_ENTROPY_LOSS, soft_label must be logits instead of probs
-
-
-
- label_is_logits
- bool
- optional
- default to be logits Default: true
-
-
-
- loss_type
- LossType
- required
- currently only support CROSS_ENTROPY_LOSS and L2_LOSS
-
-
-
- loss_weight
- float
optional
- Default: 1
-
-
-
- temperature
- float
- optional
- only for loss_type == CROSS_ENTROPY_LOSS Default: 1
+ reduction method for auc of different users
+* "mean": simple mean of different users
+* "mean_by_sample_num": weighted mean with sample num of different users
+* "mean_by_positive_num": weighted mean with positive sample num of different users Default: mean
@@ -3224,80 +3700,42 @@ KD
+ Max_F1
+
-
-
-
-
-
easy_rec/python/protos/esmm.proto Top
-
-
-
-
- ESMM
+ MeanAbsoluteError
-
-
- Field Type Label Description
-
-
-
-
- groups
- Tower
- repeated
-
-
-
-
- ctr_tower
- TaskTower
- required
-
-
-
- cvr_tower
- TaskTower
- required
-
-
-
- l2_regularization
- float
- required
- Default: 0.0001
-
-
-
+ MeanSquaredError
+
+ Precision
+
+ Recall
+
-
-
easy_rec/python/protos/eval.proto Top
-
-
- AUC
+ RecallAtTopK
@@ -3308,10 +3746,10 @@ AUC
- num_thresholds
+ topk
uint32
optional
- Default: 200
+ Default: 5
@@ -3321,14 +3759,14 @@ AUC
- Accuracy
+ RootMeanSquaredError
- AvgPrecisionAtTopK
+ SessionAUC
@@ -3339,10 +3777,20 @@ AvgPrecisionAtTopK
- topk
- uint32
+ session_id_field
+ string
+ required
+ session id field name
+
+
+
+ reduction
+ string
optional
- Default: 5
+ reduction: reduction method for auc of different sessions
+* "mean": simple mean of different sessions
+* "mean_by_sample_num": weighted mean with sample num of different sessions
+* "mean_by_positive_num": weighted mean with positive sample num of different sessions Default: mean
@@ -3352,8 +3800,22 @@ AvgPrecisionAtTopK
- EvalConfig
- Message for configuring EasyRecModel evaluation jobs (eval.py).
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/export.proto Top
+
+
+
+
+ ExportConfig
+ Message for configuring exporting models.
@@ -3363,147 +3825,123 @@ EvalConfig
- num_examples
- uint32
+ batch_size
+ int32
optional
- Number of examples to process of evaluation. Default: 0
+ batch size used for exported model, -1 indicates batch_size is None
+which is only supported by classification model right now, while
+other models support static batch_size Default: -1
- eval_interval_secs
- uint32
+ exporter_type
+ string
optional
- How often to run evaluation. Default: 300
+ type of exporter [final | latest | best | none] when train_and_evaluation
+final: performs a single export in the end of training
+latest: regularly exports the serving graph and checkpoints
+best: export the best model according to best_exporter_metric
+none: do not perform export Default: final
- max_evals
- uint32
+ best_exporter_metric
+ string
optional
- Maximum number of times to run evaluation. If set to 0, will run forever. Default: 0
+ the metric used to determine the best checkpoint Default: auc
- save_graph
+ metric_bigger
bool
optional
- Whether the TensorFlow graph used for evaluation should be saved to disk. Default: false
+ metric value the bigger the best Default: true
- metrics_set
- EvalMetrics
- repeated
- Type of metrics to use for evaluation.
-possible values:
+ enable_early_stop
+ bool
+ optional
+ enable early stop Default: false
- eval_online
- bool
+ early_stop_func
+ string
optional
- Evaluation online with batch forward data of training Default: false
-
-
-
-
-
-
-
-
-
- EvalMetrics
-
-
-
-
-
- Field Type Label Description
-
-
-
-
- auc
- AUC
- optional
-
-
-
-
- recall_at_topk
- RecallAtTopK
- optional
-
+ custom early stop function, format:
+ early_stop_func(eval_results, early_stop_params)
+return True if should stop
- mean_absolute_error
- MeanAbsoluteError
+ early_stop_params
+ string
optional
-
+ custom early stop parameters
- mean_squared_error
- MeanSquaredError
+ max_check_steps
+ int32
optional
-
+ early stop max check steps Default: 10000
- accuracy
- Accuracy
+ multi_placeholder
+ bool
optional
-
+ each feature has a placeholder Default: true
- max_f1
- Max_F1
+ exports_to_keep
+ int32
optional
-
+ export to keep, only for exporter_type in [best, latest] Default: 1
- root_mean_squared_error
- RootMeanSquaredError
+ multi_value_fields
+ MultiValueFields
optional
-
+ multi value field list
- gauc
- GAUC
+ placeholder_named_by_input
+ bool
optional
-
+ is placeholder named by input Default: false
- session_auc
- SessionAUC
+ filter_inputs
+ bool
optional
-
+ filter out inputs, only keep effective ones Default: true
- recall
- Recall
+ export_features
+ bool
optional
-
+ export the original feature values as string Default: false
- precision
- Precision
+ export_rtp_outputs
+ bool
optional
-
+ export the outputs required by RTP Default: false
- precision_at_topk
- AvgPrecisionAtTopK
- optional
-
+ asset_files
+ string
+ repeated
+ export asset files
@@ -3513,7 +3951,7 @@ EvalMetrics
- GAUC
+ MultiValueFields
@@ -3524,20 +3962,10 @@ GAUC
- uid_field
- string
- required
- uid field name
-
-
-
- reduction
+ input_name
string
- optional
- reduction method for auc of different users
-* "mean": simple mean of different users
-* "mean_by_sample_num": weighted mean with sample num of different users
-* "mean_by_positive_num": weighted mean with positive sample num of different users Default: mean
+ repeated
+
@@ -3547,42 +3975,28 @@ GAUC
- Max_F1
-
-
-
-
-
-
- MeanAbsoluteError
-
-
-
-
-
-
- MeanSquaredError
-
- Precision
-
+
+
easy_rec/python/protos/feature_config.proto Top
+
+
- Recall
+ AttentionCombiner
- RecallAtTopK
+ EVParams
@@ -3593,10 +4007,17 @@ RecallAtTopK
- topk
- uint32
+ filter_freq
+ uint64
optional
- Default: 5
+ Default: 0
+
+
+
+ steps_to_live
+ uint64
+ optional
+ Default: 0
@@ -3606,14 +4027,7 @@ RecallAtTopK
- RootMeanSquaredError
-
-
-
-
-
-
- SessionAUC
+ FeatureConfig
@@ -3624,299 +4038,92 @@ SessionAUC
- session_id_field
+ feature_name
string
- required
- session id field name
+ optional
+
- reduction
+ input_names
string
- optional
- reduction: reduction method for auc of different sessions
-* "mean": simple mean of different sessions
-* "mean_by_sample_num": weighted mean with sample num of different sessions
-* "mean_by_positive_num": weighted mean with positive sample num of different sessions Default: mean
+ repeated
+ input field names: must be included in DatasetConfig.input_fields
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
easy_rec/python/protos/export.proto Top
-
-
-
-
- ExportConfig
- Message for configuring exporting models.
-
-
-
-
- Field Type Label Description
-
-
-
- batch_size
- int32
- optional
- batch size used for exported model, -1 indicates batch_size is None
-which is only supported by classification model right now, while
-other models support static batch_size Default: -1
+ feature_type
+ FeatureConfig.FeatureType
+ required
+ Default: IdFeature
- exporter_type
+ embedding_name
string
optional
- type of exporter [final | latest | best | none] when train_and_evaluation
-final: performs a single export in the end of training
-latest: regularly exports the serving graph and checkpoints
-latest: export the best model according to best_exporter_metric
-none: do not perform export Default: final
+
- dump_embedding_shape
- bool
+ embedding_dim
+ uint32
optional
- for large embedding models to serve on eas
-embedding lookup is done outside of tensorflow graph;
-so the tensorflow graph contains only the dnn graphs(the attention part included);
-the lookuped results are passed to dnn graphs via embedding placeholders;
-we dump embedding placeholder shapes, so that embedding
-placeholders could be built. Default: false
+ Default: 0
- best_exporter_metric
- string
+ hash_bucket_size
+ uint64
optional
- the metric used to determine the best checkpoint Default: auc
+ Default: 0
- metric_bigger
- bool
+ num_buckets
+ uint64
optional
- metric value the bigger the best Default: true
+ for categorical_column_with_identity Default: 0
- enable_early_stop
- bool
- optional
- enable early stop Default: false
+ boundaries
+ double
+ repeated
+ only for raw features
- early_stop_func
+ separator
string
optional
- custom early stop function, format:
- early_stop_func(eval_results, early_stop_params)
-return True if should stop
+ separator with in features Default: |
- early_stop_params
+ kv_separator
string
optional
- custom early stop parameters
+ delimeter to separator key from value
- max_check_steps
- int32
+ seq_multi_sep
+ string
optional
- early stop max check steps Default: 10000
+ delimeter to separate sequence multi-values
- multi_placeholder
- bool
+ max_seq_len
+ uint32
optional
- each feature has a placeholder Default: true
+ truncate sequence data to max_seq_len
- exports_to_keep
- int32
- optional
- export to keep, only for exporter_type in [best, latest] Default: 1
-
-
-
- multi_value_fields
- MultiValueFields
- optional
- multi value field list
-
-
-
- placeholder_named_by_input
- bool
- optional
- is placeholder named by input Default: false
-
-
-
-
-
-
-
-
-
- MultiValueFields
-
-
-
-
-
- Field Type Label Description
-
-
-
-
- input_name
- string
- repeated
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
easy_rec/python/protos/feature_config.proto Top
-
-
-
-
- AttentionCombiner
-
-
-
-
-
-
- FeatureConfig
-
-
-
-
-
- Field Type Label Description
-
-
-
-
- feature_name
- string
- optional
-
-
-
-
- input_names
- string
- repeated
- input field names: must be included in DatasetConfig.input_fields
-
-
-
- feature_type
- FeatureConfig.FeatureType
- required
- Default: IdFeature
-
-
-
- embedding_name
- string
- optional
-
-
-
-
- embedding_dim
- uint32
- optional
- Default: 0
-
-
-
- hash_bucket_size
- uint64
- optional
- Default: 0
-
-
-
- num_buckets
- uint64
- optional
- for categorical_column_with_identity Default: 0
-
-
-
- boundaries
- double
- repeated
- only for raw features
-
-
-
- separator
- string
- optional
- separator with in features Default: |
-
-
-
- kv_separator
- string
- optional
- delimeter to separator key from value
-
-
-
- seq_multi_sep
- string
- optional
- delimeter to separate sequence multi-values
-
-
-
- vocab_file
- string
+ vocab_file
+ string
optional
@@ -3953,7 +4160,7 @@ FeatureConfig
combiner
string
optional
- combiner Default: mean
+ combiner Default: sum
@@ -3986,6 +4193,14 @@ FeatureConfig
Default: 0
+
+ normalizer_fn
+ string
+ optional
+ normalization function for raw features:
+ such as: tf.math.log1p
+
+
raw_input_dim
uint32
@@ -4000,6 +4215,34 @@ FeatureConfig
sequence feature combiner
+
+ sub_feature_type
+ FeatureConfig.FeatureType
+ optional
+ sub feature type for sequence feature Default: IdFeature
+
+
+
+ sequence_length
+ uint32
+ optional
+ sequence length Default: 1
+
+
+
+ expression
+ string
+ optional
+ for expr feature
+
+
+
+ ev_params
+ EVParams
+ optional
+ embedding variable params
+
+
@@ -4065,10 +4308,17 @@ FeatureGroupConfig
sequence_features
SeqAttGroupConfig
- optional
+ repeated
+
+ negative_sampler
+ bool
+ optional
+ Default: false
+
+
@@ -4128,6 +4378,20 @@ SeqAttGroupConfig
Default: false
+
+ need_key_feature
+ bool
+ optional
+ Default: true
+
+
+
+ allow_key_transform
+ bool
+ optional
+ Default: false
+
+
@@ -4159,6 +4423,13 @@ SeqAttMap
+
+ aux_hist_seq
+ string
+ repeated
+
+
+
@@ -4281,10 +4552,16 @@ FeatureConfig.FeatureType
+
+ ExprFeature
+ 6
+
+
+
- WideOrDeep
+ FeatureConfig.FieldType
@@ -4293,33 +4570,80 @@ WideOrDeep
- DEEP
+ INT32
0
- WIDE
+ INT64
1
- WIDE_AND_DEEP
+ STRING
2
-
-
-
+
+ FLOAT
+ 4
+
+
+
+ DOUBLE
+ 5
+
+
+
+ BOOL
+ 6
+
+
+
+
+ WideOrDeep
+
+
+
+ Name Number Description
+
+
+
+ DEEP
+ 0
+
+
-
+
+ WIDE
+ 1
+
+
+
+
+ WIDE_AND_DEEP
+ 2
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/fm.proto Top
@@ -4357,6 +4681,72 @@ FM
+
+
easy_rec/python/protos/hive_config.proto Top
+
+
+
+
+ HiveConfig
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ host
+ string
+ required
+ hive master's ip
+
+
+
+ port
+ uint32
+ required
+ hive port Default: 10000
+
+
+
+ username
+ string
+ required
+ hive username Default: admin
+
+
+
+ database
+ string
+ required
+ hive database Default: default
+
+
+
+ table_name
+ string
+ required
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/hyperparams.proto Top
@@ -4632,7 +5022,7 @@ easy_rec/python/protos/layer.proto
- HighWayTower
+ CMBFTower
@@ -4643,62 +5033,136 @@ HighWayTower
- input
- string
+ multi_head_num
+ uint32
required
-
+ The number of heads of cross modal fusion layer Default: 1
- emb_size
+ image_multi_head_num
uint32
required
-
+ The number of heads of image feature learning layer Default: 1
-
-
-
-
-
-
+
+ text_multi_head_num
+ uint32
+ required
+ The number of heads of text feature learning layer Default: 1
+
+
+ text_head_size
+ uint32
+ required
+ The dimension of text heads
+
+
+ image_head_size
+ uint32
+ required
+ The dimension of image heads Default: 64
+
+
+ image_feature_patch_num
+ uint32
+ required
+ The number of patches of image feature, take effect when there is only one image feature Default: 1
+
+
+ image_feature_dim
+ uint32
+ required
+ Do dimension reduce to this size for image feature before single modal learning module Default: 0
+
+
+ image_self_attention_layer_num
+ uint32
+ required
+ The number of self attention layers for image features Default: 0
+
+
+ text_self_attention_layer_num
+ uint32
+ required
+ The number of self attention layers for text features Default: 1
+
+
+ cross_modal_layer_num
+ uint32
+ required
+ The number of cross modal layers Default: 1
+
+
+ image_cross_head_size
+ uint32
+ required
+ The dimension of image cross modal heads
+
-
-
easy_rec/python/protos/loss.proto Top
-
-
+
+ text_cross_head_size
+ uint32
+ required
+ The dimension of text cross modal heads
+
+
+ hidden_dropout_prob
+ float
+ required
+ Dropout probability for hidden layers Default: 0
+
- CircleLoss
-
+
+ attention_probs_dropout_prob
+ float
+ required
+ Dropout probability of the attention probabilities Default: 0
+
+
+ use_token_type
+ bool
+ required
+ Whether to add embeddings for different text sequence features Default: false
+
-
-
- Field Type Label Description
-
-
+
+ use_position_embeddings
+ bool
+ required
+ Whether to add position embeddings for the position of each token in the text sequence Default: true
+
- margin
- float
+ max_position_embeddings
+ uint32
required
- Default: 0.25
+ Maximum sequence length that might ever be used with this model Default: 0
- gamma
+ text_seq_emb_dropout_prob
float
required
- Default: 32
+ Dropout probability for text sequence embeddings Default: 0.1
+
+
+
+ other_feature_dnn
+ DNN
+ optional
+ dnn layers for other features
@@ -4708,7 +5172,7 @@ CircleLoss
- Loss
+ HighWayTower
@@ -4719,17 +5183,17 @@ Loss
- loss_type
- LossType
+ input
+ string
required
- weight
- float
+ emb_size
+ uint32
required
- Default: 1
+
@@ -4739,7 +5203,7 @@ Loss
- MultiSimilarityLoss
+ UniterTower
@@ -4750,76 +5214,82 @@ MultiSimilarityLoss
- alpha
- float
+ hidden_size
+ uint32
required
- Default: 2
+ Size of the encoder layers and the pooler layer
- beta
- float
+ num_hidden_layers
+ uint32
required
- Default: 50
+ Number of hidden layers in the Transformer encoder
- lamb
- float
+ num_attention_heads
+ uint32
required
- Default: 1
+ Number of attention heads for each attention layer in the Transformer encoder
- eps
- float
+ intermediate_size
+ uint32
required
- Default: 0.1
+ The size of the "intermediate" (i.e. feed-forward) layer in the Transformer encoder
-
-
-
-
-
-
+
+ hidden_act
+ string
+ required
+ The non-linear activation function (function or string) in the encoder and pooler.
- SoftmaxCrossEntropyWithNegativeMining
-
+"gelu", "relu", "tanh" and "swish" are supported. Default: gelu
+
+
+ hidden_dropout_prob
+ float
+ required
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler Default: 0.1
+
-
-
- Field Type Label Description
-
-
+
+ attention_probs_dropout_prob
+ float
+ required
+ The dropout ratio for the attention probabilities Default: 0.1
+
- num_negative_samples
+ max_position_embeddings
uint32
required
-
+ The maximum sequence length that this model might ever be used with Default: 512
- margin
- float
+ use_position_embeddings
+ bool
required
- Default: 0
+ Whether to add position embeddings for the position of each token in the text sequence Default: true
- gamma
+ initializer_range
float
required
- Default: 1
+ The stddev of the truncated_normal_initializer for initializing all weight matrices Default: 0.02
- coefficient_of_support_vector
- float
- required
- Default: 1
+ other_feature_dnn
+ DNN
+ optional
+ dnn layers for other features
@@ -4831,71 +5301,6 @@ SoftmaxCrossEntropyWithNeg
- LossType
-
-
-
- Name Number Description
-
-
-
-
- CLASSIFICATION
- 0
-
-
-
-
- L2_LOSS
- 1
-
-
-
-
- SIGMOID_L2_LOSS
- 2
-
-
-
-
- CROSS_ENTROPY_LOSS
- 3
- crossentropy loss/log loss
-
-
-
- SOFTMAX_CROSS_ENTROPY
- 4
-
-
-
-
- CIRCLE_LOSS
- 5
-
-
-
-
- MULTI_SIMILARITY_LOSS
- 6
-
-
-
-
- SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
- 7
-
-
-
-
- PAIR_WISE_LOSS
- 8
-
-
-
-
-
-
@@ -4903,12 +5308,12 @@ LossType
-
easy_rec/python/protos/mind.proto Top
+
easy_rec/python/protos/loss.proto Top
- Capsule
+ CircleLoss
@@ -4919,45 +5324,48 @@ Capsule
- max_k
- uint32
- optional
- max number of high capsules Default: 5
-
-
-
- max_seq_len
- uint32
+ margin
+ float
required
- max behaviour sequence length
+ Default: 0.25
- high_dim
- uint32
+ gamma
+ float
required
- high capsule embedding vector dimension
+ Default: 32
-
- num_iters
- uint32
- optional
- number EM iterations Default: 3
-
+
+
+
+
+
+
+
+ F1ReweighedLoss
+
+
+
+
+
+ Field Type Label Description
+
+
- routing_logits_scale
+ f1_beta_square
float
- optional
- routing logits scale Default: 20
+ required
+ Default: 1
- routing_logits_stddev
+ label_smoothing
float
- optional
- routing logits initial stddev Default: 1
+ required
+ Default: 0
@@ -4967,7 +5375,7 @@ Capsule
- MIND
+ Loss
@@ -4978,62 +5386,45 @@ MIND
- pre_capsule_dnn
- DNN
- optional
- preprocessing dnn before entering capsule layer
-
-
-
- user_dnn
- DNN
+ loss_type
+ LossType
required
- dnn layers applied on concated results of
-capsule output and user_context(none sequence features)
-
-
-
- user_seq_combine
- MIND.UserSeqCombineMethod
- optional
- method to combine several user sequences
-such as item_ids, category_ids Default: SUM
+
- item_dnn
- DNN
+ weight
+ float
required
- dnn layers applied on item features
+ Default: 1
- capsule_config
- Capsule
- required
+ f1_reweighted_loss
+ F1ReweighedLoss
+ optional
- simi_pow
- float
+ softmax_loss
+ SoftmaxCrossEntropyWithNegativeMining
optional
- similarity power, the paper says that the big
-the better Default: 10
+
- simi_func
- Similarity
+ circle_loss
+ CircleLoss
optional
- Default: COSINE
+
- l2_regularization
- float
- required
- Default: 0.0001
+ multi_simi_loss
+ MultiSimilarityLoss
+ optional
+
@@ -5043,44 +5434,7 @@ MIND
-
-
- MIND.UserSeqCombineMethod
-
-
-
- Name Number Description
-
-
-
-
- CONCAT
- 0
-
-
-
-
- SUM
- 1
-
-
-
-
-
-
-
-
-
-
-
-
-
-
easy_rec/python/protos/mmoe.proto Top
-
-
-
-
- ExpertTower
+ MultiSimilarityLoss
@@ -5091,17 +5445,31 @@ ExpertTower
- expert_name
- string
+ alpha
+ float
required
-
+ Default: 2
- dnn
- DNN
+ beta
+ float
required
-
+ Default: 50
+
+
+
+ lamb
+ float
+ required
+ Default: 1
+
+
+
+ eps
+ float
+ required
+ Default: 0.1
@@ -5111,7 +5479,7 @@ ExpertTower
- MMoE
+ SoftmaxCrossEntropyWithNegativeMining
@@ -5122,38 +5490,31 @@ MMoE
- experts
- ExpertTower
- repeated
- deprecated: original mmoe experts config
-
-
-
- expert_dnn
- DNN
- optional
- mmoe expert dnn layer definition
+ num_negative_samples
+ uint32
+ required
+
- num_expert
- uint32
- optional
- number of mmoe experts Default: 0
+ margin
+ float
+ required
+ Default: 0
- task_towers
- TaskTower
- repeated
- task tower
+ gamma
+ float
+ required
+ Default: 1
- l2_regularization
+ coefficient_of_support_vector
float
required
- l2 regularization Default: 0.0001
+ Default: 1
@@ -5165,19 +5526,90 @@ MMoE
+ LossType
+
+
+
+ Name Number Description
+
+
+
+
+ CLASSIFICATION
+ 0
+
+
+
+
+ L2_LOSS
+ 1
+
+
+
+
+ SIGMOID_L2_LOSS
+ 2
+
+
+
+
+ CROSS_ENTROPY_LOSS
+ 3
+ crossentropy loss/log loss
+
+
+
+ SOFTMAX_CROSS_ENTROPY
+ 4
+
+
+
+
+ CIRCLE_LOSS
+ 5
+
+
+
+
+ MULTI_SIMILARITY_LOSS
+ 6
+
+
+
+
+ SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
+ 7
+
+
+
+
+ PAIR_WISE_LOSS
+ 8
+
+
+
+
+ F1_REWEIGHTED_LOSS
+ 9
+
+
+
+
+
+
+
-
-
easy_rec/python/protos/multi_tower.proto Top
+
easy_rec/python/protos/mind.proto Top
- BSTTower
+ Capsule
@@ -5188,55 +5620,67 @@ BSTTower
- input
- string
- required
-
+ max_k
+ uint32
+ optional
+ max number of high capsules Default: 5
- seq_len
+ max_seq_len
uint32
required
- Default: 5
+ max behaviour sequence length
- multi_head_size
+ high_dim
uint32
required
- Default: 4
+ high capsule embedding vector dimension
-
-
-
-
-
-
+
+ num_iters
+ uint32
+ optional
+ number EM iterations Default: 3
+
- DINTower
-
+
+ routing_logits_scale
+ float
+ optional
+ routing logits scale Default: 20
+
+
+ routing_logits_stddev
+ float
+ optional
+ routing logits initial stddev Default: 1
+
-
-
- Field Type Label Description
-
-
+
+ squash_pow
+ float
+ optional
+ squash power Default: 1
+
- input
- string
- required
-
+ scale_ratio
+ float
+ optional
+ output ratio Default: 1
- dnn
- DNN
- required
-
+ const_caps_num
+ bool
+ optional
+ constant interest number
+in default, use log(seq_len) Default: false
@@ -5246,7 +5690,7 @@ DINTower
- MultiTower
+ MIND
@@ -5257,19 +5701,70 @@ MultiTower
- towers
- Tower
- repeated
-
+ pre_capsule_dnn
+ DNN
+ optional
+ preprocessing dnn before entering capsule layer
- final_dnn
+ user_dnn
+ DNN
+ required
+ dnn layers applied on user_context(none sequence features)
+
+
+
+ concat_dnn
+ DNN
+ required
+ concat user and capsule dnn
+
+
+
+ user_seq_combine
+ MIND.UserSeqCombineMethod
+ optional
+ method to combine several user sequences
+such as item_ids, category_ids Default: SUM
+
+
+
+ item_dnn
DNN
required
+ dnn layers applied on item features
+
+
+
+ capsule_config
+ Capsule
+ required
+
+ simi_pow
+ float
+ optional
+ similarity power, the paper says that the big
+the better Default: 10
+
+
+
+ simi_func
+ Similarity
+ optional
+ Default: COSINE
+
+
+
+ scale_simi
+ bool
+ optional
+ add a layer for scaling the similarity Default: true
+
+
l2_regularization
float
@@ -5278,19 +5773,35 @@ MultiTower
- din_towers
- DINTower
- repeated
+ time_id_fea
+ string
+ optional
- bst_towers
- BSTTower
- repeated
+ item_id
+ string
+ optional
+
+ ignore_in_batch_neg_sam
+ bool
+ optional
+ Default: false
+
+
+
+ max_interests_simi
+ float
+ optional
+ if small than 1.0, then a loss will be added to
+limit the maximal interest similarities, but
+in experiments, setup such a loss leads to low hitrate. Default: 1
+
+
@@ -5300,6 +5811,29 @@ MultiTower
+ MIND.UserSeqCombineMethod
+
+
+
+ Name Number Description
+
+
+
+
+ CONCAT
+ 0
+
+
+
+
+ SUM
+ 1
+
+
+
+
+
+
@@ -5307,13 +5841,13 @@ MultiTower
-
easy_rec/python/protos/optimizer.proto Top
+
easy_rec/python/protos/mmoe.proto Top
- AdagradOptimizer
- Configuration message for the AdagradOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
+ ExpertTower
+
@@ -5323,21 +5857,28 @@ AdagradOptimizer
- learning_rate
- LearningRate
- optional
+ expert_name
+ string
+ required
-
+
+ dnn
+ DNN
+ required
+
+
+
+
- AdamAsyncOptimizer
- Only available on pai-tf, which has better performance than AdamOptimizer
+ MMoE
+
@@ -5347,24 +5888,38 @@ AdamAsyncOptimizer
- learning_rate
- LearningRate
+ experts
+ ExpertTower
+ repeated
+ deprecated: original mmoe experts config
+
+
+
+ expert_dnn
+ DNN
optional
-
+ mmoe expert dnn layer definition
- beta1
- float
+ num_expert
+ uint32
optional
- Default: 0.9
+ number of mmoe experts Default: 0
- beta2
+ task_towers
+ TaskTower
+ repeated
+ task tower
+
+
+
+ l2_regularization
float
- optional
- Default: 0.999
+ required
+ l2 regularization Default: 0.0001
@@ -5374,7 +5929,21 @@ AdamAsyncOptimizer
- AdamAsyncWOptimizer
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/multi_tower.proto Top
+
+
+
+
+ BSTTower
@@ -5385,31 +5954,24 @@ AdamAsyncWOptimizer
- learning_rate
- LearningRate
- optional
+ input
+ string
+ required
- weight_decay
- float
- optional
- Default: 1e-06
-
-
-
- beta1
- float
- optional
- Default: 0.9
+ seq_len
+ uint32
+ required
+ Default: 5
- beta2
- float
- optional
- Default: 0.999
+ multi_head_size
+ uint32
+ required
+ Default: 4
@@ -5419,8 +5981,8 @@ AdamAsyncWOptimizer
- AdamOptimizer
- Configuration message for the AdamOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
+ DINTower
+
@@ -5430,24 +5992,17 @@ AdamOptimizer
- learning_rate
- LearningRate
- optional
+ input
+ string
+ required
- beta1
- float
- optional
- Default: 0.9
-
-
-
- beta2
- float
- optional
- Default: 0.999
+ dnn
+ DNN
+ required
+
@@ -5457,7 +6012,7 @@ AdamOptimizer
- AdamWOptimizer
+ MultiTower
@@ -5468,31 +6023,38 @@ AdamWOptimizer
- learning_rate
- LearningRate
- optional
+ towers
+ Tower
+ repeated
- weight_decay
- float
- optional
- Default: 1e-06
+ final_dnn
+ DNN
+ required
+
- beta1
+ l2_regularization
float
- optional
- Default: 0.9
+ required
+ Default: 0.0001
- beta2
- float
- optional
- Default: 0.999
+ din_towers
+ DINTower
+ repeated
+
+
+
+
+ bst_towers
+ BSTTower
+ repeated
+
@@ -5502,32 +6064,22 @@ AdamWOptimizer
- ConstantLearningRate
- Configuration message for a constant learning rate.
-
-
- Field Type Label Description
-
-
-
- learning_rate
- float
- optional
- Default: 0.002
-
-
-
+
+
easy_rec/python/protos/multi_tower_recall.proto Top
+
+
+
- CosineDecayLearningRate
- Configuration message for a cosine decaying learning rate as defined in utils/learning_schedules.py
+ MultiTowerRecall
+
@@ -5537,38 +6089,38 @@ CosineDecayLearningRate
- learning_rate_base
- float
- optional
- Default: 0.002
+ user_tower
+ RecallTower
+ required
+
- total_steps
- uint32
- optional
- Default: 4000000
+ item_tower
+ RecallTower
+ required
+
- warmup_learning_rate
+ l2_regularization
float
- optional
- Default: 0.0002
+ required
+ Default: 0.0001
- warmup_steps
- uint32
- optional
- Default: 10000
+ final_dnn
+ DNN
+ required
+
- hold_base_rate_steps
- uint32
- optional
- Default: 0
+ ignore_in_batch_neg_sam
+ bool
+ required
+ Default: false
@@ -5578,8 +6130,8 @@ CosineDecayLearningRate
- ExponentialDecayLearningRate
- Configuration message for an exponentially decaying learning rate. See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ decaying_the_learning_rate#exponential_decay
+ RecallTower
+
@@ -5589,63 +6141,35 @@ ExponentialDecayLearningRate
- initial_learning_rate
- float
- optional
- Default: 0.002
+ dnn
+ DNN
+ required
+
-
- decay_steps
- uint32
- optional
- Default: 4000000
-
+
+
-
- decay_factor
- float
- optional
- Default: 0.95
-
-
- staircase
- bool
- optional
- Default: true
-
-
- burnin_learning_rate
- float
- optional
- Default: 0
-
-
- burnin_steps
- uint32
- optional
- Default: 0
-
-
- min_learning_rate
- float
- optional
- Default: 0
-
-
-
- FtrlOptimizer
-
+
+
+
+
easy_rec/python/protos/optimizer.proto Top
+
+
+
+
+ AdagradOptimizer
+ Configuration message for the AdagradOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
@@ -5658,42 +6182,45 @@ FtrlOptimizer
learning_rate
LearningRate
optional
- optional float learning_rate = 1 [default=1e-4];
+
-
- learning_rate_power
- float
- optional
- Default: -0.5
-
+
+
+
-
- initial_accumulator_value
- float
- optional
- Default: 0.1
-
+
+
+
+ AdamAsyncOptimizer
+ Only available on pai-tf, which has better performance than AdamOptimizer
+
+
+
+
+ Field Type Label Description
+
+
- l1_reg
- float
+ learning_rate
+ LearningRate
optional
- Default: 0
+
- l2_reg
+ beta1
float
optional
- Default: 0
+ Default: 0.9
- l2_shrinkage_reg
+ beta2
float
optional
- Default: 0
+ Default: 0.999
@@ -5703,8 +6230,8 @@ FtrlOptimizer
- LearningRate
- Configuration message for optimizer learning rate.
+ AdamAsyncWOptimizer
+
@@ -5714,45 +6241,31 @@ LearningRate
- constant_learning_rate
- ConstantLearningRate
- optional
-
-
-
-
- exponential_decay_learning_rate
- ExponentialDecayLearningRate
- optional
-
-
-
-
- manual_step_learning_rate
- ManualStepLearningRate
+ learning_rate
+ LearningRate
optional
- cosine_decay_learning_rate
- CosineDecayLearningRate
+ weight_decay
+ float
optional
-
+ Default: 1e-06
- poly_decay_learning_rate
- PolyDecayLearningRate
+ beta1
+ float
optional
-
+ Default: 0.9
- transformer_learning_rate
- TransformerLearningRate
+ beta2
+ float
optional
-
+ Default: 0.999
@@ -5762,8 +6275,8 @@ LearningRate
- ManualStepLearningRate
- Configuration message for a manually defined learning rate schedule.
+ AdamOptimizer
+ Configuration message for the AdamOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
@@ -5773,25 +6286,24 @@ ManualStepLearningRate
- initial_learning_rate
- float
+ learning_rate
+ LearningRate
optional
- Default: 0.002
+
- schedule
- ManualStepLearningRate.LearningRateSchedule
- repeated
-
+ beta1
+ float
+ optional
+ Default: 0.9
- warmup
- bool
+ beta2
+ float
optional
- Whether to linearly interpolate learning rates for steps in
-[0, schedule[0].step]. Default: false
+ Default: 0.999
@@ -5801,7 +6313,7 @@ ManualStepLearningRate
- ManualStepLearningRate.LearningRateSchedule
+ AdamWOptimizer
@@ -5812,17 +6324,31 @@ ManualStepLearningRa
- step
- uint32
+ learning_rate
+ LearningRate
optional
- learning_rate
+ weight_decay
float
optional
- Default: 0.002
+ Default: 1e-06
+
+
+
+ beta1
+ float
+ optional
+ Default: 0.9
+
+
+
+ beta2
+ float
+ optional
+ Default: 0.999
@@ -5832,8 +6358,8 @@ ManualStepLearningRa
- MomentumOptimizer
- Configuration message for the MomentumOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
+ ConstantLearningRate
+ Configuration message for a constant learning rate.
@@ -5844,16 +6370,9 @@ MomentumOptimizer
learning_rate
- LearningRate
- optional
-
-
-
-
- momentum_optimizer_value
float
optional
- Default: 0.9
+ Default: 0.002
@@ -5863,8 +6382,8 @@ MomentumOptimizer
- MomentumWOptimizer
-
+ CosineDecayLearningRate
+ Configuration message for a cosine decaying learning rate as defined in utils/learning_schedules.py
@@ -5874,24 +6393,38 @@ MomentumWOptimizer
- learning_rate
- LearningRate
+ learning_rate_base
+ float
optional
-
+ Default: 0.002
- weight_decay
- float
+ total_steps
+ uint32
optional
- Default: 1e-06
+ Default: 4000000
- momentum_optimizer_value
+ warmup_learning_rate
float
optional
- Default: 0.9
+ Default: 0.0002
+
+
+
+ warmup_steps
+ uint32
+ optional
+ Default: 10000
+
+
+
+ hold_base_rate_steps
+ uint32
+ optional
+ Default: 0
@@ -5901,8 +6434,8 @@ MomentumWOptimizer
- Optimizer
- Top level optimizer message.
+ ExponentialDecayLearningRate
+ Configuration message for an exponentially decaying learning rate. See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ decaying_the_learning_rate#exponential_decay
-
- adam_asyncw_optimizer
- AdamAsyncWOptimizer
- optional
-
-
-
- use_moving_average
- bool
+
+
+
+ FtrlOptimizer
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ learning_rate
+ LearningRate
optional
- Default: false
+ optional float learning_rate = 1 [default=1e-4];
- moving_average_decay
+ learning_rate_power
float
optional
- Default: 0.9999
+ Default: -0.5
- embedding_learning_rate_multiplier
+ initial_accumulator_value
float
optional
-
+ Default: 0.1
+
+
+
+ l1_reg
+ float
+ optional
+ Default: 0
+
+
+
+ l2_reg
+ float
+ optional
+ Default: 0
+
+
+
+ l2_shrinkage_reg
+ float
+ optional
+ Default: 0
@@ -6002,8 +6559,8 @@ Optimizer
- PolyDecayLearningRate
- Configuration message for a poly decaying learning rate. See https://www.tensorflow.org/api_docs/python/tf/train/polynomial_decay.
+ LearningRate
+ Configuration message for optimizer learning rate.
@@ -6013,31 +6570,45 @@ PolyDecayLearningRate
- learning_rate_base
- float
- required
+ constant_learning_rate
+ ConstantLearningRate
+ optional
- total_steps
- int64
- required
+ exponential_decay_learning_rate
+ ExponentialDecayLearningRate
+ optional
- power
- float
- required
+ manual_step_learning_rate
+ ManualStepLearningRate
+ optional
- end_learning_rate
- float
+ cosine_decay_learning_rate
+ CosineDecayLearningRate
optional
- Default: 0
+
+
+
+
+ poly_decay_learning_rate
+ PolyDecayLearningRate
+ optional
+
+
+
+
+ transformer_learning_rate
+ TransformerLearningRate
+ optional
+
@@ -6047,8 +6618,8 @@ PolyDecayLearningRate
- RMSPropOptimizer
- Configuration message for the RMSPropOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+ ManualStepLearningRate
+ Configuration message for a manually defined learning rate schedule.
@@ -6058,31 +6629,25 @@ RMSPropOptimizer
- learning_rate
- LearningRate
- optional
-
-
-
-
- momentum_optimizer_value
+ initial_learning_rate
float
optional
- Default: 0.9
+ Default: 0.002
- decay
- float
- optional
- Default: 0.9
+ schedule
+ ManualStepLearningRate.LearningRateSchedule
+ repeated
+
- epsilon
- float
+ warmup
+ bool
optional
- Default: 1
+ Whether to linearly interpolate learning rates for steps in
+[0, schedule[0].step]. Default: false
@@ -6092,7 +6657,7 @@ RMSPropOptimizer
-
+ ManualStepLearningRate.LearningRateSchedule
@@ -6103,31 +6668,48 @@
- learning_rate_base
- float
- required
+ step
+ uint32
+ optional
- hidden_size
- int32
- required
-
+ learning_rate
+ float
+ optional
+ Default: 0.002
+
+
+
+
+
+
+
+ MomentumOptimizer
+ Configuration message for the MomentumOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
+
+
+
+
+ Field Type Label Description
+
+
+
- warmup_steps
- int32
- required
+ learning_rate
+ LearningRate
+ optional
- step_scaling_rate
+ momentum_optimizer_value
float
optional
- Default: 1
+ Default: 0.9
@@ -6137,22 +6719,46 @@
+ MomentumWOptimizer
+
+
+
+ Field Type Label Description
+
+
+
+ learning_rate
+ LearningRate
+ optional
+
+
+
+ weight_decay
+ float
+ optional
+ Default: 1e-06
+
+
+ momentum_optimizer_value
+ float
+ optional
+ Default: 0.9
+
+
+
-
-
easy_rec/python/protos/pipeline.proto Top
-
-
- EasyRecConfig
-
+
+ Optimizer
+ Top level optimizer message.
+
+
+
+
+
+ PolyDecayLearningRate
+ Configuration message for a poly decaying learning rate. See https://www.tensorflow.org/api_docs/python/tf/train/polynomial_decay.
+
+
+
+
+ Field Type Label Description
+
+
+
- model_config
- EasyRecModel
+ learning_rate_base
+ float
required
- recommendation model config
+
- export_config
- ExportConfig
- optional
+ total_steps
+ int64
+ required
- fg_json_path
- string
- optional
+ power
+ float
+ required
+
+ end_learning_rate
+ float
+ optional
+ Default: 0
+
+
@@ -6273,21 +6903,52 @@ EasyRecConfig
+ RMSPropOptimizer
+ Configuration message for the RMSPropOptimizer See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+
+
+
+
+ Field Type Label Description
+
+
+
+ learning_rate
+ LearningRate
+ optional
+
+
+
+ momentum_optimizer_value
+ float
+ optional
+ Default: 0.9
+
+
+ decay
+ float
+ optional
+ Default: 0.9
+
+
+ epsilon
+ float
+ optional
+ Default: 1
+
+
+
-
-
easy_rec/python/protos/ple.proto Top
-
-
-
+
@@ -6298,40 +6959,31 @@
- network_name
- string
+ learning_rate_base
+ float
required
- expert_num_per_task
- uint32
+ hidden_size
+ int32
required
- number of experts per task
-
-
-
- share_num
- uint32
- optional
- number of experts for share
-For the last extraction_network, no need to configure this
+
- task_expert_net
- DNN
+ warmup_steps
+ int32
required
- dnn network of experts per task
+
- share_expert_net
- DNN
+ step_scaling_rate
+ float
optional
- dnn network of experts for share
-For the last extraction_network, no need to configure this
+ Default: 1
@@ -6341,7 +6993,21 @@
- PLE
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/pipeline.proto Top
+
+
+
+
+ EasyRecConfig
@@ -6352,24 +7018,145 @@ PLE
- extraction_networks
- ExtractionNetwork
- repeated
- extraction network
+ train_input_path
+ string
+ optional
+
- task_towers
- TaskTower
+ kafka_train_input
+ KafkaServer
+ optional
+
+
+
+
+ datahub_train_input
+ DatahubServer
+ optional
+
+
+
+
+ hive_train_input
+ HiveConfig
+ optional
+
+
+
+
+ binary_train_input
+ BinaryDataInput
+ optional
+
+
+
+
+ eval_input_path
+ string
+ optional
+
+
+
+
+ kafka_eval_input
+ KafkaServer
+ optional
+
+
+
+
+ datahub_eval_input
+ DatahubServer
+ optional
+
+
+
+
+ hive_eval_input
+ HiveConfig
+ optional
+
+
+
+
+ binary_eval_input
+ BinaryDataInput
+ optional
+
+
+
+
+ model_dir
+ string
+ required
+
+
+
+
+ train_config
+ TrainConfig
+ optional
+ train config, including optimizer, weight decay, num_steps and so on
+
+
+
+ eval_config
+ EvalConfig
+ optional
+
+
+
+
+ data_config
+ DatasetConfig
+ optional
+
+
+
+
+ feature_configs
+ FeatureConfig
repeated
- task tower
+ for compatibility
- l2_regularization
- float
+ feature_config
+ FeatureConfigV2
optional
- l2 regularization Default: 0.0001
+
+
+
+
+ model_config
+ EasyRecModel
+ required
+ recommendation model config
+
+
+
+ export_config
+ ExportConfig
+ optional
+
+
+
+
+ fg_json_path
+ string
+ optional
+ Json file[RTP FG] to define input data and features:
+* In easy_rec.python.utils.fg_util.load_fg_json_to_config:
+ data_config and feature_config will be generated
+ based on fg_json.
+* After generation, a prefix '!' is added:
+ fg_json_path = '!' + fg_json_path
+ indicates config update is already done, and should not
+ be updated anymore. In this way, we make load_fg_json_to_config
+ function reentrant.
+This step is done before edit_config_json to take effect.
@@ -6388,12 +7175,678 @@ PLE
-
easy_rec/python/protos/rocket_launching.proto Top
+
easy_rec/python/protos/ple.proto Top
+
+
+
+
+
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ network_name
+ string
+ required
+
+
+
+
+ expert_num_per_task
+ uint32
+ required
+ number of experts per task
+
+
+
+ share_num
+ uint32
+ optional
+ number of experts for share
+For the last extraction_network, no need to configure this
+
+
+
+ task_expert_net
+ DNN
+ required
+ dnn network of experts per task
+
+
+
+ share_expert_net
+ DNN
+ optional
+ dnn network of experts for share
+For the last extraction_network, no need to configure this
+
+
+
+
+
+
+
+
+
+ PLE
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ extraction_networks
+ ExtractionNetwork
+ repeated
+ extraction network
+
+
+
+ task_towers
+ TaskTower
+ repeated
+ task tower
+
+
+
+ l2_regularization
+ float
+ optional
+ l2 regularization Default: 0.0001
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/rocket_launching.proto Top
+
+
+
+
+ RocketLaunching
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ share_dnn
+ DNN
+ required
+
+
+
+
+ booster_dnn
+ DNN
+ required
+
+
+
+
+ light_dnn
+ DNN
+ required
+
+
+
+
+ l2_regularization
+ float
+ optional
+ Default: 0.0001
+
+
+
+ feature_based_distillation
+ bool
+ optional
+ Default: false
+
+
+
+ feature_distillation_function
+ Similarity
+ optional
+ COSINE = 0; EUCLID = 1; Default: COSINE
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/simi.proto Top
+
+
+
+
+
+
+ Similarity
+
+
+
+ Name Number Description
+
+
+
+
+ COSINE
+ 0
+
+
+
+
+ INNER_PRODUCT
+ 1
+
+
+
+
+ EUCLID
+ 2
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/simple_multi_task.proto Top
+
+
+
+
+ SimpleMultiTask
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ task_towers
+ TaskTower
+ repeated
+
+
+
+
+ l2_regularization
+ float
+ required
+ Default: 0.0001
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/tf_predict.proto Top
+
+
+
+
+ ArrayProto
+ Protocol buffer representing an array
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ dtype
+ ArrayDataType
+
+ Data Type.
+
+
+
+ array_shape
+ ArrayShape
+
+ Shape of the array.
+
+
+
+ float_val
+ float
+ repeated
+ DT_FLOAT.
+
+
+
+ double_val
+ double
+ repeated
+ DT_DOUBLE.
+
+
+
+ int_val
+ int32
+ repeated
+ DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
+
+
+
+ string_val
+ bytes
+ repeated
+ DT_STRING.
+
+
+
+ int64_val
+ int64
+ repeated
+ DT_INT64.
+
+
+
+ bool_val
+ bool
+ repeated
+ DT_BOOL.
+
+
+
+
+
+
+
+
+
+ ArrayShape
+ Dimensions of an array
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ dim
+ int64
+ repeated
+
+
+
+
+
+
+
+
+
+
+ PredictRequest
+ PredictRequest specifies which TensorFlow model to run, as well as how inputs are mapped to tensors and how outputs are filtered before returning to user.
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ signature_name
+ string
+
+ A named signature to evaluate. If unspecified, the default signature
+will be used
+
+
+
+ inputs
+ PredictRequest.InputsEntry
+ repeated
+ Input tensors.
+Names of input tensor are alias names. The mapping from aliases to real
+input tensor names is expected to be stored as named generic signature
+under the key "inputs" in the model export.
+Each alias listed in a generic signature named "inputs" should be provided
+exactly once in order to run the prediction.
+
+
+
+ output_filter
+ string
+ repeated
+ Output filter.
+Names specified are alias names. The mapping from aliases to real output
+tensor names is expected to be stored as named generic signature under
+the key "outputs" in the model export.
+Only tensors specified here will be run/fetched and returned, with the
+exception that when none is specified, all tensors specified in the
+named signature will be run/fetched and returned.
+
+
+
+ debug_level
+ int32
+
+
+
+
+
+
+
+
+
+
+
+ PredictRequest.InputsEntry
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ key
+ string
+
+
+
+
+
+ value
+ ArrayProto
+
+
+
+
+
+
+
+
+
+
+
+ PredictResponse
+ Response for PredictRequest on successful run.
+
+
+
+
+
+
+
+
+ PredictResponse.OutputsEntry
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ key
+ string
+
+
+
+
+
+ value
+ ArrayProto
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ArrayDataType
+
+
+
+ Name Number Description
+
+
+
+
+ DT_INVALID
+ 0
+ Not a legal value for DataType. Used to indicate a DataType field
+has not been set.
+
+
+
+ DT_FLOAT
+ 1
+ Data types that all computation devices are expected to be
+capable to support.
+
+
+
+ DT_DOUBLE
+ 2
+
+
+
+
+ DT_INT32
+ 3
+
+
+
+
+ DT_UINT8
+ 4
+
+
+
+
+ DT_INT16
+ 5
+
+
+
+
+ DT_INT8
+ 6
+
+
+
+
+ DT_STRING
+ 7
+
+
+
+
+ DT_COMPLEX64
+ 8
+ Single-precision complex
+
+
+
+ DT_INT64
+ 9
+
+
+
+
+ DT_BOOL
+ 10
+
+
+
+
+ DT_QINT8
+ 11
+ Quantized int8
+
+
+
+ DT_QUINT8
+ 12
+ Quantized uint8
+
+
+
+ DT_QINT32
+ 13
+ Quantized int32
+
+
+
+ DT_BFLOAT16
+ 14
+ Float32 truncated to 16 bits. Only for cast ops.
+
+
+
+ DT_QINT16
+ 15
+ Quantized int16
+
+
+
+ DT_QUINT16
+ 16
+ Quantized uint16
+
+
+
+ DT_UINT16
+ 17
+
+
+
+
+ DT_COMPLEX128
+ 18
+ Double-precision complex
+
+
+
+ DT_HALF
+ 19
+
+
+
+
+ DT_RESOURCE
+ 20
+
+
+
+
+ DT_VARIANT
+ 21
+ Arbitrary C++ data types
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/tower.proto Top
- RocketLaunching
+ BayesTaskTower
@@ -6404,112 +7857,202 @@ RocketLaunching
- share_dnn
- DNN
+ tower_name
+ string
required
-
+ task name for the task tower
- booster_dnn
- DNN
- required
-
+ label_name
+ string
+ optional
+ label for the task, default is label_fields by order
- light_dnn
- DNN
- required
-
+ metrics_set
+ EvalMetrics
+ repeated
+ metrics for the task
- l2_regularization
- float
+ loss_type
+ LossType
optional
- Default: 0.0001
+ loss for the task Default: CLASSIFICATION
- feature_based_distillation
- bool
+ num_class
+ uint32
optional
- Default: false
+ num_class for multi-class classification loss Default: 1
- feature_distillation_function
- Similarity
+ dnn
+ DNN
optional
- COSINE = 0; EUCLID = 1; Default: COSINE
+ task specific dnn
-
-
+
+ relation_tower_names
+ string
+ repeated
+ related tower names
+
+
+ relation_dnn
+ DNN
+ optional
+ relation dnn
+
+
+ weight
+ float
+ optional
+ training loss weights Default: 1
+
+
+ task_space_indicator_label
+ string
+ optional
+ label name for indcating the sample space for the task tower
+
+
+ in_task_space_weight
+ float
+ optional
+ the loss weight for sample in the task space Default: 1
+
+
+ out_task_space_weight
+ float
+ optional
+ the loss weight for sample out the task space Default: 1
+
+
+ losses
+ Loss
+ repeated
+ level for prediction
+required uint32 prediction_level = 13;
+prediction weights
+optional float prediction_weight = 14 [default = 1.0];
+multiple losses
+
+
+
+ TaskTower
+
-
-
easy_rec/python/protos/simi.proto Top
-
-
+
+
+ Field Type Label Description
+
+
+
+ tower_name
+ string
+ required
+ task name for the task tower
+
+
+ label_name
+ string
+ optional
+ label for the task, default is label_fields by order
+
- Similarity
-
-
-
- Name Number Description
-
-
+
+ metrics_set
+ EvalMetrics
+ repeated
+ metrics for the task
+
-
- COSINE
- 0
-
-
+
+ loss_type
+ LossType
+ optional
+ loss for the task Default: CLASSIFICATION
+
-
- INNER_PRODUCT
- 1
-
-
+
+ num_class
+ uint32
+ optional
+ num_class for multi-class classification loss Default: 1
+
-
- EUCLID
- 2
-
-
+
+ dnn
+ DNN
+ optional
+ task specific dnn
+
-
-
+
+ weight
+ float
+ optional
+ training loss weights Default: 1
+
+
+
+ task_space_indicator_label
+ string
+ optional
+ label name for indcating the sample space for the task tower
+
+
+ in_task_space_weight
+ float
+ optional
+ the loss weight for sample in the task space Default: 1
+
+
+ out_task_space_weight
+ float
+ optional
+ the loss weight for sample out the task space Default: 1
+
+
+ losses
+ Loss
+ repeated
+ multiple losses
+
+
+
-
-
easy_rec/python/protos/simple_multi_task.proto Top
-
-
- SimpleMultiTask
+ Tower
@@ -6520,17 +8063,17 @@ SimpleMultiTask
- task_towers
- TaskTower
- repeated
+ input
+ string
+ required
- l2_regularization
- float
+ dnn
+ DNN
required
- Default: 0.0001
+
@@ -6549,108 +8092,75 @@ SimpleMultiTask
-
easy_rec/python/protos/tower.proto Top
+
easy_rec/python/protos/train.proto Top
- BayesTaskTower
+ IncrementSaveConfig
-
-
-
- Field Type Label Description
-
-
-
-
- tower_name
- string
- required
- task name for the task tower
-
-
-
- label_name
- string
- optional
- label for the task, default is label_fields by order
-
-
-
- metrics_set
- EvalMetrics
- repeated
- metrics for the task
-
-
-
- loss_type
- LossType
- optional
- loss for the task Default: CLASSIFICATION
-
+
+
+
+ Field Type Label Description
+
+
- num_class
- uint32
+ sparse_save_secs
+ int32
optional
- num_class for multi-class classification loss Default: 1
+ Default: 0
- dnn
- DNN
+ dense_save_secs
+ int32
optional
- task specific dnn
+ Default: 0
- relation_tower_names
- string
- repeated
- related tower names
+ sparse_save_steps
+ int32
+ optional
+ Default: 0
- relation_dnn
- DNN
+ dense_save_steps
+ int32
optional
- relation dnn
+ Default: 0
- weight
- float
+ debug_save_update
+ bool
optional
- training loss weights Default: 1
+ if open, will save increment updates to model_dir/incr_save/ Default: false
- task_space_indicator_label
- string
+ kafka
+ IncrementSaveConfig.Kafka
optional
- label name for indcating the sample space for the task tower
+
- in_task_space_weight
- float
+ datahub
+ IncrementSaveConfig.Datahub
optional
- the loss weight for sample in the task space Default: 1
+
- out_task_space_weight
- float
+ fs
+ IncrementSaveConfig.File
optional
- the loss weight for sample out the task space
-
-level for prediction
-required uint32 prediction_level = 13;
-prediction weights
-optional float prediction_weight = 14 [default = 1.0]; Default: 1
+
@@ -6660,7 +8170,7 @@ BayesTaskTower
- TaskTower
+ IncrementSaveConfig.Datahub
@@ -6671,73 +8181,115 @@ TaskTower
- tower_name
+ akId
string
required
- task name for the task tower
+
- label_name
+ akSecret
string
- optional
- label for the task, default is label_fields by order
+ required
+
- metrics_set
- EvalMetrics
- repeated
- metrics for the task
+ region
+ string
+ required
+
- loss_type
- LossType
- optional
- loss for the task Default: CLASSIFICATION
+ project
+ string
+ required
+
- num_class
- uint32
- optional
- num_class for multi-class classification loss Default: 1
+ topic
+ string
+ required
+
- dnn
- DNN
+ consumer
+ IncrementSaveConfig.Datahub.Consumer
+ required
+
+
+
+
+
+
+
+
+
+
+ IncrementSaveConfig.Datahub.Consumer
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ offset
+ int64
optional
- task specific dnn
+ Default: 0
- weight
- float
+ timeout
+ int32
optional
- training loss weights Default: 1
+ Default: 600
+
+
+
+
+
+
+
+ IncrementSaveConfig.File
+
+
+
+
+
+ Field Type Label Description
+
+
+
- task_space_indicator_label
+ incr_save_dir
string
optional
- label name for indcating the sample space for the task tower
+ Default: incr_save
- in_task_space_weight
- float
+ relative
+ bool
optional
- the loss weight for sample in the task space Default: 1
+ relative to model_dir Default: true
- out_task_space_weight
- float
+ mount_path
+ string
optional
- the loss weight for sample out the task space Default: 1
+ for online inference, please set the storage.mount_path to mount_path
+online service will fail Default: /home/admin/docker_ml/workspace/incr_save/
@@ -6747,7 +8299,7 @@ TaskTower
- Tower
+ IncrementSaveConfig.Kafka
@@ -6758,15 +8310,22 @@ Tower
- input
+ server
string
required
- dnn
- DNN
+ topic
+ string
+ required
+
+
+
+
+ consumer
+ IncrementSaveConfig.Kafka.Consumer
required
@@ -6778,18 +8337,49 @@ Tower
+ IncrementSaveConfig.Kafka.Consumer
+
+
+
+
+
+ Field Type Label Description
+
+
+
+ config_topic
+ string
+ optional
+
+
+
+ config_global
+ string
+ optional
+
+
+
+ offset
+ int64
+ optional
+ Default: 0
+
+
+ timeout
+ int32
+ optional
+ Default: 600
+
+
+
-
-
easy_rec/python/protos/train.proto Top
-
-
TrainConfig
@@ -6846,6 +8436,16 @@ TrainConfig
In case so, build a SyncReplicateOptimizer Default: true
+
+ sparse_accumulator_type
+ string
+ optional
+ only take effect on pai-tf when sync_replicas is set,
+options are:
+ raw, hash, multi_map, list, parallel
+in general, multi_map runs faster than other options. Default: multi_map
+
+
startup_delay_steps
float
@@ -6919,13 +8519,6 @@ TrainConfig
Number of gpus per machine Default: 1
-
- separate_save
- bool
- optional
- Default: false
-
-
summary_model_vars
bool
@@ -6977,6 +8570,29 @@ TrainConfig
match variable patterns to freeze
+
+ incr_save_config
+ IncrementSaveConfig
+ optional
+ increment save config
+
+
+
+ enable_oss_stop_signal
+ bool
+ optional
+ enable oss stop signal
+stop by create OSS_STOP_SIGNAL under model_dir Default: false
+
+
+
+ dead_line
+ string
+ optional
+ stop training after dead_line time, format:
+ 20220508 23:59:59
+
+
@@ -7042,6 +8658,51 @@ DistributionStrategy
+
+
easy_rec/python/protos/uniter.proto Top
+
+
+
+
+ Uniter
+
+
+
+
+
+ Field Type Label Description
+
+
+
+
+ config
+ UniterTower
+ required
+
+
+
+
+ final_dnn
+ DNN
+ required
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
easy_rec/python/protos/variational_dropout.proto Top
diff --git a/docs/source/pycharm_vscode_docker.md b/docs/source/pycharm_vscode_docker.md
new file mode 100644
index 000000000..ecc010baf
--- /dev/null
+++ b/docs/source/pycharm_vscode_docker.md
@@ -0,0 +1,101 @@
+# PyCharm / VSCode
+
+
+
+### 构建镜像
+
+```bash
+git clone https://github.com/Alibaba/EasyRec
+cd EasyRec
+sh scripts/build_docker.sh
+```
+
+### 运行容器
+
+- 查看docker镜像
+
+```bash
+docker images
+```
+
+
+
+- 启动docker镜像
+
+```bash
+docker run -it /bin/bash
+```
+
+### vscode配置
+
+#### 连接本地容器
+
+- vscode 安装 插件 remote - containers、remote - wsl 、remote development
+
+
+
+- 安装插件后,vscode 状态栏会出现远程连接的图标,点击图标。remote_explorer 选择containers ,CONTAINERS 显示出 运行的容器。点击 + ,连接容器。
+
+
+
+- 弹出新的window
+
+
+
+#### 连接远程容器
+
+- vscode 安装 插件docker ,remote-ssh
+
+
+
+
+- vscode 连接远程服务器
+
+
+
+- 弹出 window , 点击 docker 图标,展示出运行的容器
+
+
+
+- 选择容器,右键attach shell,打开终端
+
+
+
+### pycharm配置
+
+#### pycharm版本
+
+- 本示例使用的版本是: **professional 2022.2.2**
+
+#### 安装插件 docker
+
+
+
+- 安装插件后 pycharm 底部的services会显示docker connect
+
+#### 配置docker 连接
+
+- 菜单路径: View=>Tools=>Services
+
+
+
+- 本地docker: 选择 docker for mac
+- 远程服务器docker 选择 ssh, 填上 @
+
+
+
+- 确定后,显示出containers和images.
+
+
+
+- 点击右上端terminal,进入交互.
+
+
+
+- 选择容器,右键,点击show files ,显示容器内所有文件.
+- 选择文件,右键,查看和下载到本地.
+
+
diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst
index 556affe5d..64e4491fd 100644
--- a/docs/source/quick_start.rst
+++ b/docs/source/quick_start.rst
@@ -6,4 +6,7 @@
quick_start/local_tutorial.md
quick_start/mc_tutorial.md
+ quick_start/mc_tutorial_inner.md
+ quick_start/dlc_tutorial.md
quick_start/emr_tutorial.md
+ quick_start/designer_tutorial.md
diff --git a/docs/source/quick_start/designer_tutorial.md b/docs/source/quick_start/designer_tutorial.md
new file mode 100644
index 000000000..66a22d15c
--- /dev/null
+++ b/docs/source/quick_start/designer_tutorial.md
@@ -0,0 +1,100 @@
+# PAI-Designer Tutorial
+
+## PAI-Designer介绍
+
+PAI-Designer(Studio 2.0)是基于云原生架构Pipeline Service(PAIFlow)的可视化建模工具, 提供可视化的机器学习开发环境,同时提供丰富且成熟的机器学习算法,覆盖商品推荐、金融风控及广告预测等场景,支持基于MaxCompute、PAI-DLC、Flink等计算资源进行大规模分布式运算,可以满足您不同方向的业务需求,实现低门槛开发人工智能服务。
+[使用文档](https://help.aliyun.com/document_detail/114522.html)
+
+### 在Designer进行EasyRec训练的优势
+
+- 可视化编辑配置文件并自动保存至OSS
+- 简化rolearn、执行资源等配置
+- 历史任务记录及版本回滚
+- 一键部署DataWorks定时调度任务
+
+## 使用入口
+
+点击[阿里云PAI管控台](https://pai.console.aliyun.com/#/studio),选择进入一个工作空间
+开始使用Designer。
+
+新建一个工作流,可以在画布上拖拉拽左侧组件按照业务需求构建工作流,对MaxCompute数据表/OSS文件等数据源进行分析及模型构建。
+
+
+## EasyRec训练组件
+
+### 输入桩配置
+
+| 输入桩(从左到右) | 限制数据类型 | 对应PAI命令参数 | 是否必选 |
+| ---------- | ----------- | ------------------------------------------------------- | ---- |
+| 训练表 | MaxCompute表 | `train_tables` | 是 |
+| 评估表 | MaxCompute表 | `eval_tables` | 否 |
+| checkpoint | OSS存储的模型 | `edit_config_json`中的`train_config.fine_tune_checkpoint` | 否 |
+| 分箱表 | MaxCompute表 | `boundary_table` | 否 |
+
+### 右侧参数说明
+
+| 页签 | 参数 | 是否必选 | 描述 | 默认值 |
+| ---- | ---------------------- | ---- | ------------------------------------------------------------------------------------------------------ | ------------ |
+| 参数设置 | 模型路径 | 否 | 对应PAI命令参数`model_dir` | 工作流自动设置的工作路径 |
+| 参数设置 | EasyRec配置 | 是 | 在下方编辑框填写config配置,保存至指定的OSS路径下,对应PAI命令参数`config` | |
+| 参数设置 | 指定算法版本 | 否 | 点开高级选项后,可以自定义EasyRec的执行版本。请先参考文档[EasyRec版本更新](../release.md)上传对应版本的tar包到OSS,在这个参数中选中上传的文件。对应参数`script` | 空 |
+| 执行调优 | ps数量 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 2 |
+| 执行调优 | ps CPU数量 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 6 |
+| 执行调优 | ps Memory数量(MB) | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 30000 |
+| 执行调优 | Worker数量 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 6 |
+| 执行调优 | Worker CPU数量 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 6 |
+| 执行调优 | Worker Memory用量(单位为MB) | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 30000 |
+| 执行调优 | Worker GPU卡数 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 0 |
+
+### 输出桩配置
+
+| 输出桩(从左到右) | 数据类型 | 对应PAI命令参数 |
+| --------- | -------- | ------------ |
+| 输出模型 | OSS存储的模型 | `model_dir ` |
+
+### 对应PAI命令
+
+在页面提交该组件执行,底层实际等同于执行了名为`easy_rec_ext`的PAI命令进行模型训练
+`pai -name easy_rec_ext -project algo_public -Dcmd=train`
+
+- 具体命令及详细[参数说明](../train.md#on-pai)
+
+## EasyRec预测组件
+
+### 输入桩配置
+
+| 输入桩(从左到右) | 限制数据类型 | 对应PAI命令参数 | 是否必选 |
+| --------- | ----------- | ----------------- | ---- |
+| 输入模型 | OSS存储的模型 | `saved_model_dir` | 是 |
+| 输入表 | MaxCompute表 | `input_table` | 是 |
+
+### 右侧参数说明
+
+| 页签 | 参数 | 是否必选 | 描述 | 默认值 |
+| ---- | ---------------------- | ---- | -------------------------------------------------------------------------------------------------------------- | ----------------- |
+| 参数设置 | 输入选择列 | 否 | 从输入表选择特征列给到预测模型,不能与排除列同时使用 | - |
+| 参数设置 | 排除列 | 否 | 预测模型不需要使用的输入列,不能和输入选择列同时使用 | - |
+| 参数设置 | 输出保留列 | 否 | 在预测结构表中原样输出的列 | - |
+| 参数设置 | 预测详情输出列 | 否 | 选择预测模型的输出到MaxCompute表的映射,细节请参见[EasyRec离线预测文档](../predict/MaxCompute%20%E7%A6%BB%E7%BA%BF%E9%A2%84%E6%B5%8B.md) | 默认为"probs double" |
+| 参数设置 | miniBatch的大小 | 否 | 对应参数`batch_size` | 1024 |
+| 执行调优 | Worker数量 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 6 |
+| 执行调优 | Worker CPU数量 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 6 |
+| 执行调优 | Worker Memory用量(单位为MB) | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 30000 |
+| 执行调优 | Worker GPU卡数 | 否 | 完整的执行调优参数会拼装成`cluster`参数 | 0 |
+
+### 输出桩配置
+
+| 输出桩(从左到右) | 数据类型 | 对应PAI命令参数 |
+| --------- | ----------- | --------------- |
+| 输出表 | MaxCompute表 | `output_table ` |
+
+### 对应PAI命令
+
+在页面提交该组件执行,底层实际等同于执行了名为`easy_rec_ext`的PAI命令进行数据批量预测
+`pai -name easy_rec_ext -project algo_public -Dcmd=predict`
+
+- 具体命令及详细[参数说明](../train.md#on-pai)
+
+### 推荐算法定制的方案
+
+- 在Designer中做推荐算法特征工程、排序模型训练、向量召回等案例的阿里云官网[文档链接](https://help.aliyun.com/zh/pai/use-cases/overview-18)
diff --git a/docs/source/quick_start/dlc_tutorial.md b/docs/source/quick_start/dlc_tutorial.md
new file mode 100644
index 000000000..f766a5f93
--- /dev/null
+++ b/docs/source/quick_start/dlc_tutorial.md
@@ -0,0 +1,116 @@
+# DLC Tutorial
+
+[PAI-DLC](https://help.aliyun.com/document_detail/165124.html)(Deep Learning Containers)是基于阿里巴巴容器服务ACK(Alibaba Cloud Container Service for Kubernetes)的深度学习训练平台,为您提供灵活、稳定、易用和极致性能的深度学习训练环境。
+
+## 上传数据到OSS
+
+使用DLC运行EasyRec,首先需要将EasyRec的[训练数据](http://easyrec.oss-cn-beijing.aliyuncs.com/demo/dwd_avazu_ctr_deepmodel_10w.csv)和[配置文件](http://easyrec.oss-cn-beijing.aliyuncs.com/demo/wide_and_deep_on_avazau_ctr.config)上传到Aliyun OSS。
+
+
+## 创建数据集
+
+进入[PAI控制台](https://pai.console.aliyun.com/?regionId=cn-beijing),并选择需要使用的工作空间,点击AI资源管理/数据集,创建数据集。
+
+创建方式选择阿里云存储,属性选择文件夹,选择数据和配置文件所在的OSS路径,并设置数据集的挂载路径。任务运行时,会从挂载路径下读取训练数据和配置文件。
+
+## 创建DLC任务
+
+### 配置任务
+
+进入[PAI控制台](https://pai.console.aliyun.com),并选择需要使用的工作空间,点击模型开发和训练/容器训练(DLC),点击创建任务。
+选择运行镜像,以及数据集,并输入执行命令:
+
+执行命令如下:
+
+```bash
+python -m easy_rec.python.train_eval --pipeline_config_path /mnt/data/dlc_demo/wide_and_deep_on_avazau_ctr.config --continue_train --train_input_path /mnt/data/dlc_demo/dwd_avazu_ctr_deepmodel_10w.csv --eval_input_path /mnt/data/dlc_demo/dwd_avazu_ctr_deepmodel_10w.csv --model_dir /mnt/data/dlc_demo/wide_and_deep_v3/ --edit_config_json='{"train_config.num_steps":1200, "eval_config.num_examples":10240}'
+```
+
+- 可以通过edit_config_json修改配置,避免频繁修改配置文件,如train_config.num_steps等信息
+- 注意:这里仅仅是训练的demo, 所以使用train_config.num_steps, 实际实验时不建议设置train_config.num_steps,建议设置data_config.num_epochs, 实际实验时也不建议设置eval_config.num_examples, 不设置时默认评估整个测试集.
+
+#### 使用ODPS表作为输入:
+
+```bash
+cat << EOF >> odps_conf
+access_id=xxxx
+access_key=xxxx
+end_point=http://xxxx
+
+EOF
+
+export ODPS_CONFIG_FILE_PATH=odps_conf
+python -m easy_rec.python.train_eval --pipeline_config_path /mnt/data/dlc_demo/wide_and_deep_on_avazau_ctr.config --continue_train --train_input_path odps://project/tables/train_input_table --eval_input_path odps://project/tables/test_input_table --model_dir /mnt/data/dlc_demo/wide_and_deep_v3/ --edit_config_json='{"data_config.num_epochs":1, "data_config.input_type":"OdpsInputV3"}'
+```
+
+- data_config.input_type: 加载输入数据的类是OdpsInputV3, 目前只支持OdpsInputV3.
+
+#### 评估
+
+```bash
+python -m easy_rec.python.eval --pipeline_config_path /mnt/data/dlc_demo/wide_and_deep_v3/pipeline.config --eval_input_path /mnt/data/dlc_demo/dwd_avazu_ctr_deepmodel_10w.csv
+```
+
+- checkpoint_path: 指定要评估的checkpoint, 默认评估model_dir下面最新的checkpoint
+- eval_input_path: 评估的输入路径, 如果是Odps表的话, 需要配置odps_conf文件,并且设置环境变量: export ODPS_CONFIG_FILE_PATH=odps_conf
+- distribute_eval: 是否使用分布式预测,如果使用分布式预测, 在配置任务资源时需要设置ps.
+- eval_result_path: 评估指标的保存位置
+
+#### 预测
+
+```bash
+python -m easy_rec.python.tools.predict_and_chk --saved_model_dir /mnt/data/dlc_demo/wide_and_deep_v3/export/final --input_path /mnt/data/dlc_demo/dwd_avazu_ctr_deepmodel_10w.csv --label_id 0 --separator "," --save_path /mnt/data/dlc_demo/wide_and_deep_v3/predict.out
+```
+
+- saved_model_dir: saved_model目录
+- input_path: 需要预测的文件
+- label_id: label column id, 可以指定多个, 比如: --label_id 0 1 2
+- separator: 输入的分隔符
+- save_path: 预测结果的保存目录
+- 注意: 目前只支持单worker预测,多worker预测适配中; odps表预测适配中.
+
+### 配置任务资源
+
+任务的资源配置选择进阶模式,我们选择了1个Chief、1个Worker、一个PS、一个Evaluator的配置。
+
+
+### 查看任务详情
+
+然后点击 提交 即可,点击 任务 能看到 任务列表,可以查看任务详情:
+
+点击生成脚本, 可以查看通过通过[命令行](https://help.aliyun.com/document_detail/214317.html)提交任务, 方便在DataWorks里面做例行训练.
+
+```bash
+dlc submit tfjob \
+ --name=easy_rec_test \
+ --command='python -m easy_rec.python.train_eval --pipeline_config_path /mnt/data/dlc_demo/wide_and_deep_on_avazau_ctr.config --continue_train --train_input_path /mnt/data/dlc_demo/dwd_avazu_ctr_deepmodel_10w.csv --eval_input_path /mnt/data/dlc_demo/dwd_avazu_ctr_deepmodel_10w.csv --model_dir /mnt/data/dlc_demo/wide_and_deep_v3/ --edit_config_json='\''{"train_config.num_steps":1200, "eval_config.num_examples":10240}'\''' \
+ --data_sources=d-5sf0ox5pw1pgi4vl7e \
+ --workspace_id=67849 \
+ --priority=1 \
+ --workers=1 \
+ --worker_image=mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.7.4 \
+ --worker_spec=ecs.g6.2xlarge \
+ --ps=1 \
+ --ps_image=mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.7.4 \
+ --ps_spec=ecs.g6.2xlarge \
+ --chief=true \
+ --chief_image=mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.7.4 \
+ --chief_spec=ecs.g6.2xlarge \
+ --evaluators=1 \
+ --evaluator_image=mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.7.4 \
+ --evaluator_spec=ecs.g6.2xlarge
+```
+
+通过dlc命令提交的任务也可以在 任务列表 中查看.
+
+### 查看tensorboard
+
+创建tensorboard, 填写model_dir的相对路径
+
+查看tensorboard
+
+
+### 查看模型
+
+当任务运行成功后,找到对应的oss路径,可以看到任务生成的模型。
+
diff --git a/docs/source/quick_start/local_tutorial.md b/docs/source/quick_start/local_tutorial.md
index bbde40e9d..4a00140b4 100644
--- a/docs/source/quick_start/local_tutorial.md
+++ b/docs/source/quick_start/local_tutorial.md
@@ -1,20 +1,99 @@
# Local Tutorial
-### 下载安装EasyRec
+### 安装EasyRec
+
+我们提供了`本地Anaconda安装`和`Docker镜像启动`两种方式。
+
+有技术问题可加钉钉群:37930014162
+
+#### 本地Anaconda安装
+
+温馨提示:**在搭载Apple M系列芯片的MacBook上必须使用TensorFlow 2.5或更高版本**,安装方法请查看TF官方文档。
+
+Demo实验中使用的环境为 `python=3.6.8` + `tensorflow=1.12.0`
+
+```bash
+conda create -n py36_tf12 python=3.6.8
+conda activate py36_tf12
+pip install tensorflow==1.12.0
+pip install tensorflow_probability==0.5.0
+```
+
+注意:必须要安装`tensorflow_probability`包,需要根据tensorflow的版本安装对应版本的`tensorflow_probability`包。
+
+常见版本对应关系:
+
+| TensorFlow版本 | TensorFlowProbability版本 |
+| ------------ | ----------------------- |
+| 1.12 | 0.5.0 |
+| 1.15 | 0.8.0 |
+| 2.5.0 | 0.13.0 |
+| 2.6.0 | 0.14.0 |
+| 2.7.0 | 0.15.0 |
+| 2.8.0 | 0.16.0 |
+| 2.10 | 0.18.0 |
+| 2.11 | 0.19.0 |
+| 2.12 | 0.20.0 |
+
+其他版本对应关系请查看链接:[Releases · tensorflow/probability](https://github.com/tensorflow/probability/releases)。
```bash
git clone https://github.com/alibaba/EasyRec.git
cd EasyRec
-wget https://easyrec.oss-cn-beijing.aliyuncs.com/data/easyrec_data_20210818.tar.gz
-bash scripts/gen_proto.sh # 根据proto文件生成 配置解析.py文件
+bash scripts/init.sh
python setup.py install
```
+#### Docker镜像启动
+
+Docker的环境为`python=3.6.9` + `tensorflow=1.15.5`
+
+##### 方法一:拉取已上传的镜像(推荐)
+
+```bash
+git clone https://github.com/alibaba/EasyRec.git
+cd EasyRec
+docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.8.5
+docker run -td --network host -v /local_path/EasyRec:/docker_path/EasyRec mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.8.5
+docker exec -it bash
+```
+
+可选镜像:
+
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-paitf1.12-0.8.5 \[只能跑在DLC环境\]
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-paitf1.15-0.8.5
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15.5-0.8.5
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15.5-gpu-0.8.5
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py39-tf2.11-0.8.5
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-0.8.5
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-gpu-0.8.5
+
+##### 方法二:自行构建Docker镜像
+
+我们提供四个版本的tensorflow镜像构建示例,对应的脚本路径如下:
+
+- scripts/build_docker_tf112.sh
+- scripts/build_docker_tf115.sh
+- scripts/build_docker_tf210.sh
+- scripts/build_docker_tf212.sh
+
+默认使用`tensorflow 1.15`的版本,示例脚本如下,请根据需要替换脚本路径:
+
+```bash
+git clone https://github.com/alibaba/EasyRec.git
+cd EasyRec
+bash scripts/build_docker.sh
+sudo docker run -td --network host -v /local_path:/docker_path mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-
+sudo docker exec -it bash
+```
+
+注:\需匹配当前EasyRec版本。
+
### 输入数据:
输入一般是csv格式的文件。
-#### 示例数据
+#### 示例数据(点击下载)
- train: [dwd_avazu_ctr_deepmodel_train.csv](http://easyrec.oss-cn-beijing.aliyuncs.com/data/dwd_avazu_ctr_deepmodel_train.csv)
- test: [dwd_avazu_ctr_deepmodel_test.csv](http://easyrec.oss-cn-beijing.aliyuncs.com/data/dwd_avazu_ctr_deepmodel_test.csv)
diff --git a/docs/source/quick_start/mc_tutorial.md b/docs/source/quick_start/mc_tutorial.md
index 641ef1071..0f6065f0c 100644
--- a/docs/source/quick_start/mc_tutorial.md
+++ b/docs/source/quick_start/mc_tutorial.md
@@ -1,13 +1,21 @@
# MaxCompute Tutorial
+**此文档是针对公网用户(非阿里内部开发)**,在MaxCompute公有云版本上使用EasyRec的说明。
+
+针对阿里集团内部用户,请参考[mc_tutorial_inner](mc_tutorial_inner.md)。
+
+有技术问题可加钉钉群:37930014162
+
### 输入数据:
-输入一般是odps表:
+输入一般是MaxCompute表:
- train: pai_online_project.dwd_avazu_ctr_deepmodel_train
-- test: pai_online_project.dwd_avazu_ctr_deepmodel_test
+- test: pai_online_project.dwd_avazu_ctr_deepmodel_test
+
+说明:原则上这两张表是自己odps的表,为了方便,以上提供case的两张表可在国内用户的MaxCompute项目空间中访问。
-说明:原则上这两张表是自己odps的表,为了方便,以上提供case的两张表在任何地方都可以访问。两个表可以带分区,也可以不带分区。
+两个表可以带分区,也可以不带分区。带分区的方式:odps://xyz_project/table1/dt=20240101
### 训练:
@@ -18,28 +26,36 @@
pai -name easy_rec_ext -project algo_public
-Dcmd=train
-Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config
--Dtables=odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_train,odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "gpu":100, "memory":40000}}'
--Dwith_evaluator=1
+-Dtrain_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_train'
+-Deval_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "gpu":0, "memory":40000}}'
+-Deval_method=separate
-Dmodel_dir=oss://easyrec/ckpt/MultiTower
-Darn=acs:ram::xxx:role/xxx
-Dbuckets=oss://easyrec/
-DossHost=oss-cn-beijing-internal.aliyuncs.com;
```
-- -Dcmd: train 模型训练
+- -Dcmd: train 表示模型训练
- -Dconfig: 训练用的配置文件
-- -Dtables: 定义训练表和测试表,默认最后一个表示测试表。
+- -Dtrain_tables: 定义训练表
+- -Deval_tables: 定义评估表
+- -Dtables: 定义其他依赖表(可选),如负采样的表
- -Dcluster: 定义PS的数目和worker的数目。具体见:[PAI-TF任务参数介绍](https://help.aliyun.com/document_detail/154186.html?spm=a2c4g.11186623.4.3.e56f1adb7AJ9T5)
-- -Dwith_evaluator,训练时定义一个worker将被用于做评估
+- -Deval_method: 评估方法
+- separate: 用worker(task_id=1)做评估。找到MaxCompute训练任务的logview,打开logview之后在worker1机器的stderr日志中查看评估指标数据。
+- none: 不需要评估
+- master: 在master(task_id=0)上做评估
+- -Dfine_tune_checkpoint: 可选,从checkpoint restore参数,进行finetune
+- 可以指定directory,将使用directory里面的最新的checkpoint.
- -Dmodel_dir: 如果指定了model_dir将会覆盖config里面的model_dir,一般在周期性调度的时候使用。
-- -Darn: rolearn 注意这个的arn要替换成客户自己的。可以从dataworks的设置中查看arn。
+- -Darn: rolearn 注意这个的arn要替换成客户自己的。可以从dataworks的设置中查看arn;或者阿里云控制台人工智能平台PAI,左侧菜单"开通和授权",找到全部云产品依赖->Designer->OSS->查看授权信息。
- -Dbuckets: config所在的bucket和保存模型的bucket; 如果有多个bucket,逗号分割
- -DossHost: ossHost地址
### 注意:
-- dataworks和pai的project 一样,案例都是pai_online_project,用户需要根据自己的环境修改。如果需要使用gpu,PAI的project需要设置开通GPU。链接:[https://pai.data.aliyun.com/console?projectId=®ionId=cn-beijing#/visual](https://pai.data.aliyun.com/console?projectId=%C2%AEionId=cn-beijing#/visual) ,其中regionId可能不一致。
+- dataworks和PAI的project一样,案例都是pai_online_project,用户需要根据自己的环境修改。如果需要使用gpu,PAI的project需要设置开通GPU。链接:[https://pai.data.aliyun.com/console?projectId=®ionId=cn-beijing#/visual](https://pai.data.aliyun.com/console?projectId=%C2%AEionId=cn-beijing#/visual) ,其中regionId可能不一致。

@@ -55,8 +71,8 @@ pai -name easy_rec_ext -project algo_public
pai -name easy_rec_ext -project algo_public
-Dcmd=evaluate
-Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config
--Dtables=odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test
--Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}'
+-Deval_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test'
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":0, "memory":40000}}'
-Dmodel_dir=oss://easyrec/ckpt/MultiTower
-Darn=acs:ram::xxx:role/xxx
-Dbuckets=oss://easyrec/
@@ -65,10 +81,12 @@ pai -name easy_rec_ext -project algo_public
- -Dcmd: evaluate 模型评估
- -Dconfig: 同训练
-- -Dtables: 只需要指定测试 tables
+- -Deval_tables: 指定测试tables
+- -Dtables: 指定其他依赖表,如负采样的表
- -Dcluster: 评估不需要PS节点,指定一个worker节点即可
- -Dmodel_dir: 如果指定了model_dir将会覆盖config里面的model_dir,一般在周期性调度的时候使用
- -Dcheckpoint_path: 使用指定的checkpoint_path,如oss://easyrec/ckpt/MultiTower/model.ckpt-1000。不指定的话,默认model_dir中最新的ckpt文件。
+- arn,buckets,ossHost同训练.
### 导出:
@@ -94,6 +112,33 @@ pai -name easy_rec_ext -project algo_public
- -Dexport_dir: 导出的目录
- -Dcluster: 评估不需要PS节点,指定一个worker节点即可
- -Dcheckpoint_path: 同评估
+- arn,buckets,ossHost同训练.
+
+### 导出RTP serving checkpoint:
+
+```
+导出RTPserving支持的checkpoint, 更多参考[RTPServing的文档](../feature/rtp_native.md).
+```
+
+```sql
+pai -name easy_rec_ext -project algo_public
+-Dcmd=export_checkpoint
+-Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config
+-Dmodel_dir=oss://easyrec/ckpt/MultiTower
+-Dexport_dir=oss://easyrec/ckpt/MultiTower/export
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
+-Darn=acs:ram::xxx:role/xxx
+-Dbuckets=oss://easyrec/
+-DossHost=oss-cn-beijing-internal.aliyuncs.com
+```
+
+- -Dcmd: export_checkpoint, 导出RTP支持的checkpoint
+- -Dconfig: 同训练
+- -Dmodel_dir: 同训练
+- -Dexport_dir: 导出的目录
+- -Dcluster: 评估不需要PS节点,指定一个worker节点即可
+- -Dcheckpoint_path: 同评估
+- arn,buckets,ossHost同训练.
### 配置文件:
@@ -147,7 +192,7 @@ data_config {
#### 特征相关
-特征配置具体见:[特征](../feature/feature.md)
+特征配置具体见:[特征](../feature/feature.rst)
```protobuf
feature_config: {
diff --git a/docs/source/quick_start/mc_tutorial_inner.md b/docs/source/quick_start/mc_tutorial_inner.md
new file mode 100644
index 000000000..d04dc2e8d
--- /dev/null
+++ b/docs/source/quick_start/mc_tutorial_inner.md
@@ -0,0 +1,344 @@
+# MaxCompute(Inner) Tutorial
+
+**此文档是针对阿里集团内部用户**,在MaxCompute上使用EasyRec的说明。
+
+### 输入数据:
+
+输入一般是odps表:
+
+- train: your_own_project.dwd_avazu_ctr_deepmodel_train
+- test: your_own_project.dwd_avazu_ctr_deepmodel_test
+
+说明:这两张表是自己odps project的表,由于阿里集团内部无法访问公共云的pai_online_project,可以通过自行下载数据集到本地,并使用odpscmd [tunnel upload](https://help.aliyun.com/document_detail/27833.html?spm=a2c4g.27797.0.i1)命令上传至自己的project进行使用。
+
+- dwd_avazu_ctr_deepmodel_train.csv (http://easyrec.oss-cn-beijing.aliyuncs.com/data/dwd_avazu_ctr_deepmodel_train.csv)
+
+- dwd_avazu_ctr_deepmodel_test.csv (http://easyrec.oss-cn-beijing.aliyuncs.com/data/dwd_avazu_ctr_deepmodel_test.csv)
+
+参考命令:
+
+```sql
+create table if not exists dwd_avazu_ctr_deepmodel_train (click bigint, hour string, c1 string, banner_pos string, site_id string, site_domain string, site_category string, app_id string, app_domain string, app_category string, device_id string, device_ip string, device_model string, device_type string, device_conn_type string, c14 string, c15 string, c16 string, c17 string, c18 string, c19 string, c20 string, c21 string);
+
+tunnel upload dwd_avazu_ctr_deepmodel_train.csv your_own_project.dwd_avazu_ctr_deepmodel_train;
+```
+
+### 训练:
+
+- 配置文件: [dwd_avazu_ctr_deepmodel_ext.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config), 配置文件采用prototxt格式,内容解析见[配置文件](#Qgqxc)
+ - 修改配置文件里面的**model_dir**字段为: 自己的实验oss目录
+
+```sql
+pai -name easy_rec_ext -project algo_public
+-Dcmd=train
+-Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config
+-Dtrain_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_train'
+-Deval_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "gpu":0, "memory":40000}}'
+-Deval_method=separate
+-Dmodel_dir=oss://easyrec/ckpt/MultiTower
+-Dbuckets=oss://easyrec/?role_arn=acs:ram::xxx:role/xxx&host=oss-cn-beijing-internal.aliyuncs.com;
+```
+
+- -Dcmd: train 模型训练
+- -Dconfig: 训练用的配置文件
+- -Dtrain_tables: 定义训练表
+- -Deval_tables: 定义测试表
+- -Dtables: 定义其他依赖表(可选),如负采样的表
+- -Dcluster: 定义PS的数目和worker的数目。具体见:[PAI-TF任务参数介绍](https://help.aliyun.com/document_detail/154186.html?spm=a2c4g.11186623.4.3.e56f1adb7AJ9T5)
+- -Deval_method: 评估方法
+- separate: 用worker(task_id=1)做评估
+- none: 不需要评估
+- master: 在master(task_id=0)上做评估
+- -Dfine_tune_checkpoint: 可选,从checkpoint restore参数,进行finetune
+- 可以指定directory,将使用directory里面的最新的checkpoint.
+- -Dmodel_dir: 如果指定了model_dir将会覆盖config里面的model_dir,一般在周期性调度的时候使用。
+- -Dbuckets: config所在的bucket和保存模型的bucket; 如果有多个bucket,逗号分割
+
+### 注意:
+
+- dataworks和pai的project 一样,案例都是pai_online_project,用户需要根据自己的环境修改。如果需要使用gpu,PAI的project需要设置开通GPU。链接:[https://pai.data.aliyun.com/console?projectId=®ionId=cn-beijing#/visual](https://pai.data.aliyun.com/console?projectId=%C2%AEionId=cn-beijing#/visual) ,其中regionId可能不一致。
+
+ 
+
+- oss的bucket需要提前开通好,案例中bucket名称是easyrec。创建bucket请参考:[创建存储空间](https://help.aliyun.com/document_detail/31885.html)
+
+- arn需要在PAI-studio的project(当前案例中的project是pai_online_project)的OSS访问授权设置页面查看和创建,如下图:
+
+
+
+### 评估:
+
+```sql
+pai -name easy_rec_ext -project algo_public
+-Dcmd=evaluate
+-Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config
+-Deval_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test'
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dmodel_dir=oss://easyrec/ckpt/MultiTower
+-Dbuckets=oss://easyrec/?role_arn=acs:ram::xxx:role/xxx&host=oss-cn-beijing-internal.aliyuncs.com;
+```
+
+- -Dcmd: evaluate 模型评估
+- -Dconfig: 同训练
+- -Deval_tables: 指定测试tables
+- -Dtables: 指定其他依赖表,如负采样的表
+- -Dcluster: 评估不需要PS节点,指定一个worker节点即可
+- -Dmodel_dir: 如果指定了model_dir将会覆盖config里面的model_dir,一般在周期性调度的时候使用
+- -Dcheckpoint_path: 使用指定的checkpoint_path,如oss://easyrec/ckpt/MultiTower/model.ckpt-1000。不指定的话,默认model_dir中最新的ckpt文件。
+- -Dbuckets: oss bucket,同训练.
+
+### 导出:
+
+```
+由于导出模型,只需要读入checkpoint导出model,因此不需要ps 结点,也不需要GPU。
+```
+
+```sql
+pai -name easy_rec_ext -project algo_public
+-Dcmd=export
+-Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config
+-Dmodel_dir=oss://easyrec/ckpt/MultiTower
+-Dexport_dir=oss://easyrec/ckpt/MultiTower/export
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
+-Dbuckets=oss://easyrec/?role_arn=acs:ram::xxx:role/xxx&host=oss-cn-beijing-internal.aliyuncs.com;
+```
+
+- -Dcmd: export 模型导出
+- -Dconfig: 同训练
+- -Dmodel_dir: 同训练
+- -Dexport_dir: 导出的目录
+- -Dcluster: 评估不需要PS节点,指定一个worker节点即可
+- -Dcheckpoint_path: 同评估
+- -Dbuckets: oss bucket,同训练.
+
+### 导出RTP serving:
+
+```
+导出RTPserving支持的checkpoint, 更多参考[RTPServing的文档](../feature/rtp_native.md).
+```
+
+```sql
+pai -name easy_rec_ext -project algo_public
+-Dcmd=export_checkpoint
+-Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config
+-Dmodel_dir=oss://easyrec/ckpt/MultiTower
+-Dexport_dir=oss://easyrec/ckpt/MultiTower/export
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
+-Dbuckets=oss://easyrec/?role_arn=acs:ram::xxx:role/xxx&host=oss-cn-beijing-internal.aliyuncs.com;
+```
+
+- -Dcmd: export_checkpoint, 导出RTP支持的checkpoint
+- -Dconfig: 同训练
+- -Dmodel_dir: 同训练
+- -Dexport_dir: 导出的目录
+- -Dcluster: 评估不需要PS节点,指定一个worker节点即可
+- -Dcheckpoint_path: 同评估
+- -Dbuckets: oss bucket,同训练.
+
+### 配置文件:
+
+#### 输入输出
+
+```protobuf
+# 训练表和测试表,如果在PAI上,不需要设置,会被-Dtables参数覆盖
+train_input_path: ""
+eval_input_path: ""
+# 模型保存路径
+model_dir: "oss://easyrec/easy_rec_test/experiment/dwd_avazu_ctr"
+```
+
+#### 数据相关
+
+数据配置具体见:[数据](../feature/data.md)
+
+```protobuf
+# 数据相关的描述
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:""
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: STRING
+ default_val:""
+ }
+ ...
+ input_fields: {
+ input_name: "c20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: OdpsInputV2
+}
+```
+
+#### 特征相关
+
+特征配置具体见:[特征](../feature/feature.rst)
+
+```protobuf
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 50
+ }
+ features: {
+ input_names: "c1"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+
+ ...
+
+ features: {
+ input_names: "c20"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c21"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+}
+```
+
+#### 训练相关
+
+训练配置具体见:[训练](../train.md)
+
+```protobuf
+# 训练相关的参数
+train_config {
+ # 每200轮打印一行log
+ log_step_count_steps: 200
+ # 优化器相关的参数
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ # 使用SyncReplicasOptimizer进行分布式训练(同步模式)
+ sync_replicas: true
+ # num_steps = total_sample_num * num_epochs / batch_size / num_workers
+ num_steps: 1000
+}
+```
+
+#### 评估相关
+
+评估配置具体见:[评估](../eval.md)
+
+```protobuf
+eval_config {
+ # 仅仅评估1000个样本,这里是为了示例速度考虑,实际使用时需要删除
+ num_examples: 1000
+ metrics_set: {
+ # metric为auc
+ auc {}
+ }
+}
+```
+
+#### 模型相关
+
+```protobuf
+model_config:{
+ model_class: "MultiTower"
+ feature_groups: {
+ group_name: "item"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "user"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "user_item"
+ feature_names: "hour"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:DEEP
+ }
+
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [384, 320, 256, 192, 128]
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [384, 320, 256, 192, 128]
+ }
+ }
+ towers {
+ input: "user_item"
+ dnn {
+ hidden_units: [384, 320, 256, 192, 128]
+ }
+ }
+
+ final_dnn {
+ hidden_units: [256, 192, 128, 64]
+ }
+ l2_regularization: 0.0
+ }
+ embedding_regularization: 0.0
+}
+
+```
+
+配置文件下载:[dwd_avazu_ctr_deepmodel_ext.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config)
+
+#### 配置参考手册
+
+[EasyRecConfig参考手册](../reference.md)
diff --git a/docs/source/release.md b/docs/source/release.md
index 527eb1d84..45306b398 100644
--- a/docs/source/release.md
+++ b/docs/source/release.md
@@ -3,10 +3,11 @@
### PAI(Max Compute) EasyRec升级
```bash
-sh pai_jobs/deploy_ext.sh -V ${VERSION} -O
+sh pai_jobs/deploy_ext.sh -V ${VERSION} -G
+ls -lh pai_jobs/easy_rec_ext_${VERSION}_res.tar.gz
```
-将资源包上传至ODPS
+将资源包pai_jobs/easy_rec_ext\_${VERSION}\_res.tar.gz上传至ODPS

diff --git a/docs/source/train.md b/docs/source/train.md
index be8c146d0..15907d3ab 100644
--- a/docs/source/train.md
+++ b/docs/source/train.md
@@ -1,8 +1,8 @@
# 训练
-### train_config
+## train_config
-- log_step_count_steps: 200 # 每200轮打印一行log
+- log_step_count_steps: 200 # 每200步打印一行log
- optimizer_config # 优化器相关的参数
@@ -20,18 +20,64 @@
}
```
+ - 多优化器支持:
+ - 可以配置两个optimizer, 分别对应embedding权重和dense权重;
+ - 实现参考EasyRecModel.get_grouped_vars和multi_optimizer.MultiOptimizer;
+ - 示例(samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config):
+ ```protobuf
+ train_config {
+ ...
+ optimizer_config { # for embedding_weights
+ adagrad_optimizer {
+ learning_rate {
+ constant_learning_rate {
+ learning_rate: 0.05
+ }
+ }
+ initial_accumulator_value: 1.0
+ }
+ }
+
+ optimizer_config: { # for dense weights
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ }
+ ```
+ - Note: [WideAndDeep](./models/wide_and_deep.md)模型的optimizer设置:
+ - 设置两个optimizer时, 第一个optimizer仅用于wide参数;
+ - 如果要给deep embedding单独设置optimizer, 需要设置3个optimizer.
+
- sync_replicas: true # 是否同步训练,默认是false
- 使用SyncReplicasOptimizer进行分布式训练(同步模式)
- 仅在train_distribute为NoStrategy时可以设置成true,其它情况应该设置为false
- PS异步训练也设置为false
-
-- train_distribute: 默认不开启Strategy(NoStrategy), strategy确定分布式执行的方式
-
- - NoStrategy 不使用Strategy
- - PSStrategy 异步ParameterServer模式
- - MirroredStrategy 单机多卡模式,仅在PAI上可以使用,本地和EMR上不能使用
- - MultiWorkerMirroredStrategy 多机多卡模式,在TF版本>=1.15时可以使用
+ - 注意在设置为 true 时,总共的训练步数为:min(total_sample_num \* num_epochs / batch_size, num_steps) / num_workers
+
+- train_distribute: 默认不开启Strategy(NoStrategy), strategy确定分布式执行的方式, 可以分成两种模式: PS-Worker模式 和 All-Reduce模式
+
+ - PS-Worker模式:
+ - NoStrategy: 根据sync_replicas的取值决定采用同步或者异步训练
+ - sync_replicas=true,采用ps worker同步训练
+ - 注意: 该模式容易导致ps存在通信瓶颈, 建议用混合并行的模式进行同步训练
+ - sync_replicas=false, 采用ps worker异步训练
+ - All-Reduce模式:
+ - 数据并行:
+ - MirroredStrategy: 单机多卡模式,仅在PAI上可以使用,本地和EMR上不能使用
+ - MultiWorkerMirroredStrategy: 多机多卡模式,在TF版本>=1.15时可以使用
+ - HorovodStragtegy: horovod多机多卡并行, 需要安装horovod
+ - 混合并行: 数据并行 + Embedding分片, 需要安装horovod
+ - EmbeddingParallelStrategy: 在horovod多机多卡并行的基础上, 增加了Embedding分片的功能
+ - SokStrategy: 在horovod多机多卡并行的基础上, 增加了[SOK](https://github.com/NVIDIA-Merlin/HugeCTR/tree/main/sparse_operation_kit) Key-Value Embedding和Embedding分片的功能
+ - 注意: 该模式仅支持GPU模式, 需要安装SOK.
- num_gpus_per_worker: 仅在MirrorredStrategy, MultiWorkerMirroredStrategy, PSStrategy的时候有用
@@ -39,11 +85,10 @@
- 总共训练多少轮
- num_steps = total_sample_num * num_epochs / batch_size / num_workers
- - **分布式训练时一定要设置num_steps,否则评估任务会结束不了**
- fine_tune_checkpoint: 需要restore的checkpoint路径,也可以是包含checkpoint的目录,如果目录里面有多个checkpoint,将使用最新的checkpoint
-- fine_tune_ckpt_var_map: 需要restore的参数列表文件路径,文件的每一行是{variable_name in current model ckpt}\\t{variable name in old model ckpt}
+- fine_tune_ckpt_var_map: 需要restore的参数列表文件路径,文件的每一行是{variable_name in current model}\\t{variable name in old model ckpt}
- 需要设置fine_tune_ckpt_var_map的情形:
- current ckpt和old ckpt不完全匹配, 如embedding的名字不一样:
@@ -62,11 +107,11 @@
print(key)
```
-- save_checkpoints_steps: 每隔多少轮保存一次checkpoint, 默认是1000
+- save_checkpoints_steps: 每隔多少步保存一次checkpoint, 默认是1000。当训练数据量很大的时候,这个值要设置大一些
- save_checkpoints_secs: 每隔多少s保存一次checkpoint, 不可以和save_checkpoints_steps同时指定
-- keep_checkpoint_max: 最多保存多少个checkpoint, 默认是10
+- keep_checkpoint_max: 最多保存多少个checkpoint, 默认是10。当模型较大的时候可以设置为5,可节约存储
- log_step_count_steps: 每隔多少轮,打印一次训练信息,默认是10
@@ -74,9 +119,9 @@
- 更多参数请参考[easy_rec/python/protos/train.proto](./reference.md)
-### 训练命令
+## 训练命令
-#### Local
+### Local
```bash
python -m easy_rec.python.train_eval --pipeline_config_path dwd_avazu_ctr_deepmodel.config
@@ -86,11 +131,18 @@ python -m easy_rec.python.train_eval --pipeline_config_path dwd_avazu_ctr_deepmo
- --continue_train: restore之前的checkpoint,继续训练
- --model_dir: 如果指定了model_dir将会覆盖config里面的model_dir,一般在周期性调度的时候使用
- --edit_config_json: 使用json的方式对config的一些字段进行修改,如:
- ```sql
- --edit_config_json='{"train_config.fine_tune_checkpoint": "oss://easyrec/model.ckpt-50"}'
+ ```bash
+ --edit_config_json='{"train_config.fine_tune_checkpoint": "experiments/ctr/model.ckpt-50"}'
+ ```
+- Extend Args: 命令行参数修改config, 类似edit_config_json
+ - 支持train_config.*, eval_config.*, data_config.*, feature_config.*
+ - 示例:
+ ```bash
+ --train_config.fine_tune_checkpoint=experiments/ctr/model.ckpt-50
+ --data_config.negative_sampler.input_path=data/test/tb_data/taobao_ad_feature_gl
```
-#### On PAI
+### On PAI
```sql
pai -name easy_rec_ext -project algo_public
@@ -126,32 +178,51 @@ pai -name easy_rec_ext -project algo_public
- 如果是pai内部版,则不需要指定arn和ossHost, arn和ossHost放在-Dbuckets里面
- -Dbuckets=oss://easyrec/?role_arn=acs:ram::xxx:role/ev-ext-test-oss&host=oss-cn-beijing-internal.aliyuncs.com
-#### On EMR
+### On DLC
-单机单卡模式:
+- 基于Kubeflow的云原生的训练方式
+- [参考文档](./quick_start/dlc_tutorial.md)
-```bash
-el_submit -t standalone -a easy_rec_train -f dwd_avazu_ctr_deepmodel.config -m local -wn 1 -wc 6 -wm 20000 -wg 1 -c "python -m easy_rec.python.train_eval --pipeline_config_path dwd_avazu_ctr_deepmodel.config --continue_train"
-```
+### On EMR
-- 参数同Local模式
+- 基于开源大数据平台的训练方式
+- [参考文档](https://help.aliyun.com/zh/emr/emr-on-ecs/user-guide/use-easyrec-to-perform-model-training-evaluation-and-prediction-on-data-from-hive-tables)
-多worker模式:
+## 混合并行(EmbeddingParallel)
-- 需要在配置文件中设置train_config.train_distribute为MultiWorkerMirroredStrategy
+混合并行模式下Embedding参数会分片, 均匀分布到各个worker上, 通过all2all的通信方式来聚合不同worker上的Embedding。MLP参数在每个worker上都有完整的一份复制, 在参数更新时,会通过allreduce的方式同步不同worker的更新。
-```bash
-el_submit -t standalone -a easy_rec_train -f dwd_avazu_ctr_deepmodel.config -m local -wn 1 -wc 6 -wm 20000 -wg 2 -c "python -m easy_rec.python.train_eval --pipeline_config_path dwd_avazu_ctr_deepmodel.config --continue_train"
-```
+### 依赖
-- 参数同Local模式
+- 混合并行使用Horovod做底层的通信, 因此需要安装Horovod, 可以直接使用下面的镜像
+- mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:sok-tf212-gpus-v5
+ ```
+ sudo docker run --gpus=all --privileged -v /home/easyrec/:/home/easyrec/ -ti mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:sok-tf212-gpus-v5 bash
+ ```
-PS模式:
+### 配置
-- 需要在配置文件中设置train_config.sync_replicas为true
+- 修改train_config.train_distribute为EmbeddingParallelStrategy
+ ```
+ train_config {
+ ...
+ train_distribute: EmbeddingParallelStrategy
+ ...
+ }
+ ```
-```bash
-el_submit -t tensorflow-ps -a easy_rec_train -f dwd_avazu_ctr_deepmodel.config -m local -pn 1 -pc 4 -pm 20000 -wn 3 -wc 6 -wm 20000 -c "python -m easy_rec.python.train_eval --pipeline_config_path dwd_avazu_ctr_deepmodel.config --continue_train"
-```
+### 命令
-- 参数同Local模式
+- 训练
+ ```
+ CUDA_VISIBLE_DEVICES=0,1,2,4 horovodrun -np 4 python -m easy_rec.python.train_eval --pipeline_config_path samples/model_config/dlrm_on_criteo_parquet_ep_v2.config
+ ```
+- 评估
+ ```
+ CUDA_VISIBLE_DEVICES=0 horovodrun -np 1 python -m easy_rec.python.eval --pipeline_config_path samples/model_config/dlrm_on_criteo_parquet_ep_v2.config
+ ```
+ - 注意: 评估目前仅支持单个worker评估
+- 导出
+ ```
+ CUDA_VISIBLE_DEVICES=0 horovodrun -np 1 python -m easy_rec.python.export --pipeline_config_path samples/model_config/dlrm_on_criteo_parquet_ep_v2.config --export_dir dlrm_criteo_export/
+ ```
diff --git a/docs/source/vector_retrieve.md b/docs/source/vector_retrieve.md
index d88cc9704..8d3f7b909 100644
--- a/docs/source/vector_retrieve.md
+++ b/docs/source/vector_retrieve.md
@@ -3,7 +3,7 @@
## Pai 命令
```sql
-pai -name easy_rec_ext -project algo_public
+pai -name easy_rec_ext -project algo_public_dev
-Dcmd=vector_retrieve
-Dquery_table=odps://pai_online_project/tables/query_vector_table
-Ddoc_table=odps://pai_online_project/tables/doc_vector_table
@@ -37,13 +37,13 @@ pai -name easy_rec_ext -project algo_public
## 使用示例
-### 1. 创建查询表
+### 1. 创建索引表
```sql
create table doc_table(pk BIGINT,vector string) partitioned by (pt string);
-INSERT OVERWRITE TABLE query_table PARTITION(pt='20190410')
-VALUES
+INSERT OVERWRITE TABLE doc_table PARTITION(pt='20190410')
+VALUES
(1, '0.1,0.2,-0.4,0.5'),
(2, '-0.1,0.8,0.4,0.5'),
(3, '0.59,0.2,0.4,0.15'),
@@ -53,13 +53,13 @@ VALUES
;
```
-### 2. 创建索引表
+### 2. 创建查询表
```sql
create table query_table(pk BIGINT,vector string) partitioned by (pt string);
-INSERT OVERWRITE TABLE doc_table PARTITION(pt='20190410')
-VALUES
+INSERT OVERWRITE TABLE query_table PARTITION(pt='20190410')
+VALUES
(1, '0.1,0.2,0.4,0.5'),
(2, '-0.1,0.2,0.4,0.5'),
(3, '0.5,0.2,0.4,0.5'),
@@ -74,7 +74,6 @@ VALUES
```sql
pai -name easy_rec_ext -project algo_public_dev
-Dcmd='vector_retrieve'
--DentryFile='run.py'
-Dquery_table='odps://${project}/tables/query_table/pt=20190410'
-Ddoc_table='odps://${project}/tables/doc_table/pt=20190410'
-Doutput_table='odps://${project}/tables/knn_result_table/pt=20190410'
@@ -92,7 +91,25 @@ pai -name easy_rec_ext -project algo_public_dev
\"cpu\" : 600
}
}';
-;
+```
+
+FQA: 遇到以下错误怎么办?
+
+```
+File "run.py", line 517, in main
+ raise ValueError('cmd should be one of train/evaluate/export/predict')
+ValueError: cmd should be one of train/evaluate/export/predict
+```
+
+这个错误是因为包含`向量近邻检索`的最新的EasyRec版本暂时还没有正式发布。
+
+解决方案:从 [Github](https://github.com/alibaba/EasyRec)
+的master分支拉取最新代码,使用`bash pai_jobs/deploy_ext.sh -V ${version}`命令打一个最新的资源包`easy_rec_ext_${version}_res.tar.gz`,
+上传到MaxCompute作为Archive资源,最后,在上述命令中加两个如下的参数即可解决。
+
+```
+-Dversion='${version}'
+-Dres_project=${maxcompute_project}
```
### 4. 查看结果
@@ -113,4 +130,4 @@ SELECT * from knn_result_table where pt='20190410';
-- 20 2 0.3800000250339508
-- 30 3 0.5370000004768372
-- 30 30 0.4973999857902527
-```
\ No newline at end of file
+```
diff --git "a/docs/source/\345\274\200\345\217\221.md" "b/docs/source/\345\274\200\345\217\221.md"
deleted file mode 100644
index 7a477495c..000000000
--- "a/docs/source/\345\274\200\345\217\221.md"
+++ /dev/null
@@ -1,85 +0,0 @@
-# 开发
-
-### 代码风格
-
-我们采用 [PEP8](https://www.python.org/dev/peps/pep-0008/) 作为首选代码风格。
-
-我们使用以下工具进行 美化纠错 和格式化:
-
-- [flake8](http://flake8.pycqa.org/en/latest/): 美化纠错(linter)
-- [yapf](https://github.com/google/yapf):格式化程序
-- [isort](https://github.com/timothycrosley/isort):对 import 进行排序整合
-
-我们在每次提交时都会自动使用 [pre-commit hook](https://pre-commit.com/) , 来检查和格式化 `flake8`、`yapf`、`isort`、`trailing whitespaces`、修复 `end-of-files`问题,对 `requirments.txt` 进行排序。
-
-yapf 和 isort 的样式配置可以在[setup.cfg](setup.cfg) 中找到。
-
-pre-commit hook 的配置存储在 [.pre-commit-config](.pre-commit-config.yaml) 中。
-
-在克隆git仓库后,您需要安装初始化pre-commit hook:
-
-```bash
-pip install -U pre-commit
-```
-
-定位到存储库文件夹
-
-```bash
-pre-commit install
-```
-
-在此之后,每次提交检查代码 linters 和格式化程序将被强制执行。
-
-如果您只想格式化和整理代码,则可以运行
-
-```bash
-pre-commit run -a
-```
-
-### 测试
-
-#### 单元测试
-
-TEST_DEVICES=0,1 sh scripts/ci_test.sh
-
-```bash
-TEST_DEVICES=0,1 sh scripts/ci_test.sh
-```
-
-#### Odps 测试
-
-```bash
-TEMPDIR=/tmp python -m easy_rec.python.test.odps_run --oss_config ~/.ossutilconfig [--odps_config {ODPS_CONFIG} --algo_project {ALOG_PROJ} --arn acs:ram::xxx:role/yyy TestPipelineOnOdps.*]
-```
-
-#### 测试数据
-
-如果您要添加新数据,请在“git commit”之前执行以下操作,以将其提交到 git-lfs:
-
-```bash
-python git-lfs/git_lfs.py add data/test/new_data
-python git-lfs/git_lfs.py push
-```
-
-### 文档
-
-我们支持 [MarkDown](https://guides.github.com/features/mastering-markdown/) 格式和 [reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html) 格式的文档。
-
-如果文档包含公式或表格,我们建议您使用 reStructuredText 格式或使用
-[md-to-rst](https://cloudconvert.com/md-to-rst) 将现有的 Markdown 文件转换为 reStructuredText 。
-
-**构建文档** # 在python3环境下运行
-
-```bash
-sh scripts/build_docs.sh
-```
-
-### 构建安装包
-
-构建pip包
-
-```bash
-python setup.py sdist bdist_wheel
-```
-
-### [部署](./release.md)
diff --git a/easy_rec/__init__.py b/easy_rec/__init__.py
index 6b9c30155..cbbc20e9c 100644
--- a/easy_rec/__init__.py
+++ b/easy_rec/__init__.py
@@ -1,11 +1,11 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+
import logging
import os
+import platform
import sys
-import tensorflow as tf
-
from easy_rec.version import __version__
curr_dir, _ = os.path.split(__file__)
@@ -15,27 +15,54 @@
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
-from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
-from easy_rec.python.main import evaluate # isort:skip # noqa: E402
-from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
-from easy_rec.python.main import export # isort:skip # noqa: E402
-from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
+# Avoid import tensorflow which conflicts with the version used in EasyRecProcessor
+if 'PROCESSOR_TEST' not in os.environ:
+ from tensorflow.python.platform import tf_logging
+ # In DeepRec, logger.propagate of tf_logging is False, should be True
+ tf_logging._logger.propagate = True
-try:
- import tensorflow_io.oss
-except Exception:
- pass
+ def get_ops_dir():
+ import tensorflow as tf
+ if platform.system() == 'Linux':
+ ops_dir = os.path.join(curr_dir, 'python/ops')
+ if 'PAI' in tf.__version__:
+ ops_dir = os.path.join(ops_dir, '1.12_pai')
+ elif tf.__version__.startswith('1.12'):
+ ops_dir = os.path.join(ops_dir, '1.12')
+ elif tf.__version__.startswith('1.15'):
+ if 'IS_ON_PAI' in os.environ:
+ ops_dir = os.path.join(ops_dir, 'DeepRec')
+ else:
+ ops_dir = os.path.join(ops_dir, '1.15')
+ else:
+ tmp_version = tf.__version__.split('.')
+ tmp_version = '.'.join(tmp_version[:2])
+ ops_dir = os.path.join(ops_dir, tmp_version)
+ return ops_dir
+ else:
+ return None
-print('easy_rec version: %s' % __version__)
-print('Usage: easy_rec.help()')
+ ops_dir = get_ops_dir()
+ if ops_dir is not None and not os.path.exists(ops_dir):
+ logging.warning('ops_dir[%s] does not exist' % ops_dir)
+ ops_dir = None
-_global_config = {}
+ from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
+ from easy_rec.python.main import evaluate # isort:skip # noqa: E402
+ from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
+ from easy_rec.python.main import export # isort:skip # noqa: E402
+ from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
+ from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402
-ops_dir = os.path.join(curr_dir, 'python/ops')
-if tf.__version__.startswith('1.12'):
- ops_dir = os.path.join(ops_dir, '1.12')
-elif tf.__version__.startswith('1.15'):
- ops_dir = os.path.join(ops_dir, '1.15')
+ try:
+ import tensorflow_io.oss
+ except Exception:
+ pass
+
+ print('easy_rec version: %s' % __version__)
+ print('Usage: easy_rec.help()')
+
+_global_config = {}
def help():
diff --git a/easy_rec/python/builders/loss_builder.py b/easy_rec/python/builders/loss_builder.py
index df689a488..720dfdd9e 100644
--- a/easy_rec/python/builders/loss_builder.py
+++ b/easy_rec/python/builders/loss_builder.py
@@ -2,42 +2,217 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
+import numpy as np
import tensorflow as tf
+from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
+from easy_rec.python.loss.jrc_loss import jrc_loss
+from easy_rec.python.loss.listwise_loss import listwise_distill_loss
+from easy_rec.python.loss.listwise_loss import listwise_rank_loss
+from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss
+from easy_rec.python.loss.pairwise_loss import pairwise_hinge_loss
+from easy_rec.python.loss.pairwise_loss import pairwise_logistic_loss
from easy_rec.python.loss.pairwise_loss import pairwise_loss
from easy_rec.python.protos.loss_pb2 import LossType
+from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA
+
+from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
-def build(loss_type, label, pred, loss_weight=1.0, num_class=1, **kwargs):
+def build(loss_type,
+ label,
+ pred,
+ loss_weight=1.0,
+ num_class=1,
+ loss_param=None,
+ **kwargs):
+ loss_name = kwargs.pop('loss_name') if 'loss_name' in kwargs else 'unknown'
if loss_type == LossType.CLASSIFICATION:
if num_class == 1:
return tf.losses.sigmoid_cross_entropy(
label, logits=pred, weights=loss_weight, **kwargs)
else:
+ assert label.dtype in [tf.int32, tf.int64], \
+ 'label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy.'
return tf.losses.sparse_softmax_cross_entropy(
labels=label, logits=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.CROSS_ENTROPY_LOSS:
return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs)
+ elif loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
+ losses = tf.keras.backend.binary_crossentropy(label, pred, from_logits=True)
+ return tf.reduce_mean(losses)
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
logging.info('%s is used' % LossType.Name(loss_type))
return tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=loss_weight, **kwargs)
+ elif loss_type == LossType.ZILN_LOSS:
+ loss = zero_inflated_lognormal_loss(label, pred)
+ if np.isscalar(loss_weight) and loss_weight != 1.0:
+ return loss * loss_weight
+ return loss
+ elif loss_type == LossType.JRC_LOSS:
+ session = kwargs.get('session_ids', None)
+ if loss_param is None:
+ return jrc_loss(label, pred, session, name=loss_name)
+ return jrc_loss(
+ label,
+ pred,
+ session,
+ loss_param.alpha,
+ loss_weight_strategy=loss_param.loss_weight_strategy,
+ sample_weights=loss_weight,
+ same_label_loss=loss_param.same_label_loss,
+ name=loss_name)
elif loss_type == LossType.PAIR_WISE_LOSS:
- return pairwise_loss(pred, label)
+ session = kwargs.get('session_ids', None)
+ margin = 0 if loss_param is None else loss_param.margin
+ temp = 1.0 if loss_param is None else loss_param.temperature
+ return pairwise_loss(
+ label,
+ pred,
+ session_ids=session,
+ margin=margin,
+ temperature=temp,
+ weights=loss_weight,
+ name=loss_name)
+ elif loss_type == LossType.PAIRWISE_LOGISTIC_LOSS:
+ session = kwargs.get('session_ids', None)
+ temp = 1.0 if loss_param is None else loss_param.temperature
+ ohem_ratio = 1.0 if loss_param is None else loss_param.ohem_ratio
+ hinge_margin = None
+ if loss_param is not None and loss_param.HasField('hinge_margin'):
+ hinge_margin = loss_param.hinge_margin
+ lbl_margin = False if loss_param is None else loss_param.use_label_margin
+ return pairwise_logistic_loss(
+ label,
+ pred,
+ session_ids=session,
+ temperature=temp,
+ hinge_margin=hinge_margin,
+ ohem_ratio=ohem_ratio,
+ weights=loss_weight,
+ use_label_margin=lbl_margin,
+ name=loss_name)
+ elif loss_type == LossType.PAIRWISE_HINGE_LOSS:
+ session = kwargs.get('session_ids', None)
+ temp, ohem_ratio, margin = 1.0, 1.0, 1.0
+ label_is_logits, use_label_margin, use_exponent = True, True, False
+ if loss_param is not None:
+ temp = loss_param.temperature
+ ohem_ratio = loss_param.ohem_ratio
+ margin = loss_param.margin
+ label_is_logits = loss_param.label_is_logits
+ use_label_margin = loss_param.use_label_margin
+ use_exponent = loss_param.use_exponent
+ return pairwise_hinge_loss(
+ label,
+ pred,
+ session_ids=session,
+ temperature=temp,
+ margin=margin,
+ ohem_ratio=ohem_ratio,
+ weights=loss_weight,
+ label_is_logits=label_is_logits,
+ use_label_margin=use_label_margin,
+ use_exponent=use_exponent,
+ name=loss_name)
+ elif loss_type == LossType.PAIRWISE_FOCAL_LOSS:
+ session = kwargs.get('session_ids', None)
+ if loss_param is None:
+ return pairwise_focal_loss(
+ label, pred, session_ids=session, weights=loss_weight, name=loss_name)
+ hinge_margin = None
+ if loss_param.HasField('hinge_margin'):
+ hinge_margin = loss_param.hinge_margin
+ return pairwise_focal_loss(
+ label,
+ pred,
+ session_ids=session,
+ gamma=loss_param.gamma,
+ alpha=loss_param.alpha if loss_param.HasField('alpha') else None,
+ hinge_margin=hinge_margin,
+ ohem_ratio=loss_param.ohem_ratio,
+ temperature=loss_param.temperature,
+ weights=loss_weight,
+ name=loss_name)
+ elif loss_type == LossType.LISTWISE_RANK_LOSS:
+ session = kwargs.get('session_ids', None)
+ trans_fn, temp, label_is_logits, scale = None, 1.0, False, False
+ if loss_param is not None:
+ temp = loss_param.temperature
+ label_is_logits = loss_param.label_is_logits
+ scale = loss_param.scale_logits
+ if loss_param.HasField('transform_fn'):
+ trans_fn = loss_param.transform_fn
+ return listwise_rank_loss(
+ label,
+ pred,
+ session,
+ temperature=temp,
+ label_is_logits=label_is_logits,
+ transform_fn=trans_fn,
+ scale_logits=scale,
+ weights=loss_weight)
+ elif loss_type == LossType.LISTWISE_DISTILL_LOSS:
+ session = kwargs.get('session_ids', None)
+ trans_fn, temp, label_clip_max_value, scale = None, 1.0, 512.0, False
+ if loss_param is not None:
+ temp = loss_param.temperature
+ label_clip_max_value = loss_param.label_clip_max_value
+ scale = loss_param.scale_logits
+ if loss_param.HasField('transform_fn'):
+ trans_fn = loss_param.transform_fn
+ return listwise_distill_loss(
+ label,
+ pred,
+ session,
+ temperature=temp,
+ label_clip_max_value=label_clip_max_value,
+ transform_fn=trans_fn,
+ scale_logits=scale,
+ weights=loss_weight)
+ elif loss_type == LossType.F1_REWEIGHTED_LOSS:
+ f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square
+ label_smoothing = 0 if loss_param is None else loss_param.label_smoothing
+ return f1_reweight_sigmoid_cross_entropy(
+ label,
+ pred,
+ f1_beta_square,
+ weights=loss_weight,
+ label_smoothing=label_smoothing)
+ elif loss_type == LossType.BINARY_FOCAL_LOSS:
+ if loss_param is None:
+ return sigmoid_focal_loss_with_logits(
+ label, pred, sample_weights=loss_weight, name=loss_name)
+ gamma = loss_param.gamma
+ alpha = None
+ if loss_param.HasField('alpha'):
+ alpha = loss_param.alpha
+ return sigmoid_focal_loss_with_logits(
+ label,
+ pred,
+ gamma=gamma,
+ alpha=alpha,
+ ohem_ratio=loss_param.ohem_ratio,
+ sample_weights=loss_weight,
+ label_smoothing=loss_param.label_smoothing,
+ name=loss_name)
else:
raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))
-def build_kd_loss(kds, prediction_dict, label_dict):
+def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
"""Build knowledge distillation loss.
Args:
kds: list of knowledge distillation object of type KD.
prediction_dict: dict of predict_name to predict tensors.
label_dict: ordered dict of label_name to label tensors.
+ feature_dict: dict of feature name to feature value
Return:
knowledge distillation loss will be add to loss_dict with key: kd_loss.
@@ -53,35 +228,106 @@ def build_kd_loss(kds, prediction_dict, label_dict):
loss_name = 'kd_loss_' + kd.pred_name.replace('/', '_')
loss_name += '_' + kd.soft_label_name.replace('/', '_')
+ loss_weight = kd.loss_weight
+ if kd.HasField('task_space_indicator_name') and kd.HasField(
+ 'task_space_indicator_value'):
+ in_task_space = tf.to_float(
+ tf.equal(feature_dict[kd.task_space_indicator_name],
+ kd.task_space_indicator_value))
+ loss_weight = loss_weight * (
+ kd.in_task_space_weight * in_task_space + kd.out_task_space_weight *
+ (1 - in_task_space))
+
label = label_dict[kd.soft_label_name]
pred = prediction_dict[kd.pred_name]
+ epsilon = tf.keras.backend.epsilon()
+ num_class = 1 if len(pred.get_shape()) < 2 else pred.get_shape()[-1]
- if kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
+ if kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
+ if not kd.label_is_logits: # label is prob
+ label = tf.clip_by_value(label, epsilon, 1 - epsilon)
+ label = tf.log(label / (1 - label))
+ if not kd.pred_is_logits:
+ pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
+ pred = tf.log(pred / (1 - pred))
+ if kd.temperature > 0:
+ label = label / kd.temperature
+ pred = pred / kd.temperature
+ label = tf.nn.sigmoid(label) # convert to prob
+ elif kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
+ if not kd.label_is_logits: # label is prob
+ if num_class == 1: # for binary classification
+ label = tf.clip_by_value(label, epsilon, 1 - epsilon)
+ label = tf.log(label / (1 - label))
+ else:
+ label = tf.math.log(label + epsilon)
+ label -= tf.reduce_max(label)
+ if not kd.pred_is_logits:
+ if num_class == 1: # for binary classification
+ pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
+ pred = tf.log(pred / (1 - pred))
+ else:
+ pred = tf.math.log(pred + epsilon)
+ pred -= tf.reduce_max(pred)
+ if kd.temperature > 0:
+ label = label / kd.temperature
+ pred = pred / kd.temperature
+ if num_class > 1:
+ label = tf.nn.softmax(label)
+ pred = tf.nn.softmax(pred)
+ else:
+ label = tf.nn.sigmoid(label) # convert to prob
+ pred = tf.nn.sigmoid(pred) # convert to prob
+ elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
if not kd.label_is_logits:
- label = tf.math.log(label + 1e-7)
+ label = tf.math.log(label + epsilon)
if not kd.pred_is_logits:
- pred = tf.math.log(pred + 1e-7)
-
- if kd.temperature > 0 and kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
- label = label / kd.temperature
- pred = pred / kd.temperature
-
- if kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
- num_class = 1 if len(pred.get_shape()) < 2 else pred.get_shape()[-1]
+ pred = tf.math.log(pred + epsilon)
+ if kd.temperature > 0:
+ label = label / kd.temperature
+ pred = pred / kd.temperature
if num_class > 1:
label = tf.nn.softmax(label)
pred = tf.nn.softmax(pred)
elif num_class == 1:
label = tf.nn.sigmoid(label)
- pred = tf.nn.sigmoid(label)
+ pred = tf.nn.sigmoid(pred)
- if kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
+ if kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
+ if num_class == 1:
+ label = tf.expand_dims(label, 1) # [B, 1]
+ labels = tf.concat([1 - label, label], axis=1) # [B, 2]
+ pred = tf.expand_dims(pred, 1) # [B, 1]
+ preds = tf.concat([1 - pred, pred], axis=1) # [B, 2]
+ else:
+ labels = label
+ preds = pred
+ losses = tf.keras.losses.KLD(labels, preds)
+ loss_dict[loss_name] = tf.reduce_mean(
+ losses, name=loss_name) * loss_weight
+ elif kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
+ losses = tf.keras.backend.binary_crossentropy(
+ label, pred, from_logits=True)
+ loss_dict[loss_name] = tf.reduce_mean(
+ losses, name=loss_name) * loss_weight
+ elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
loss_dict[loss_name] = tf.losses.log_loss(
- label, pred, weights=kd.loss_weight)
+ label, pred, weights=loss_weight)
elif kd.loss_type == LossType.L2_LOSS:
loss_dict[loss_name] = tf.losses.mean_squared_error(
- labels=label, predictions=pred, weights=kd.loss_weight)
+ labels=label, predictions=pred, weights=loss_weight)
else:
- assert False, 'unsupported loss type for kd: %s' % LossType.Name(
- kd.loss_type)
+ loss_param = kd.WhichOneof('loss_param')
+ kwargs = {}
+ if loss_param is not None:
+ loss_param = getattr(kd, loss_param)
+ if hasattr(loss_param, 'session_name'):
+ kwargs['session_ids'] = feature_dict[loss_param.session_name]
+ loss_dict[loss_name] = build(
+ kd.loss_type,
+ label,
+ pred,
+ loss_weight=loss_weight,
+ loss_param=loss_param,
+ **kwargs)
return loss_dict
diff --git a/easy_rec/python/builders/optimizer_builder.py b/easy_rec/python/builders/optimizer_builder.py
index 7a2331b32..8b9a8251b 100644
--- a/easy_rec/python/builders/optimizer_builder.py
+++ b/easy_rec/python/builders/optimizer_builder.py
@@ -88,6 +88,14 @@ def build(optimizer_config):
beta1=config.beta1,
beta2=config.beta2)
+ if optimizer_type == 'lazy_adam_optimizer':
+ config = optimizer_config.lazy_adam_optimizer
+ learning_rate = _create_learning_rate(config.learning_rate)
+ summary_vars.append(learning_rate)
+ from easy_rec.python.compat.adam_s import AdamOptimizerS
+ optimizer = AdamOptimizerS(
+ learning_rate=learning_rate, beta1=config.beta1, beta2=config.beta2)
+
if optimizer_type == 'momentumw_optimizer':
config = optimizer_config.momentumw_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
@@ -103,7 +111,9 @@ def build(optimizer_config):
config = optimizer_config.adagrad_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
- optimizer = tf.train.AdagradOptimizer(learning_rate)
+ optimizer = tf.train.AdagradOptimizer(
+ learning_rate,
+ initial_accumulator_value=config.initial_accumulator_value)
if optimizer_type == 'adam_async_optimizer':
config = optimizer_config.adam_async_optimizer
diff --git a/easy_rec/python/compat/adam_s.py b/easy_rec/python/compat/adam_s.py
new file mode 100644
index 000000000..a0ef60b80
--- /dev/null
+++ b/easy_rec/python/compat/adam_s.py
@@ -0,0 +1,245 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Adam for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class AdamOptimizerS(optimizer.Optimizer):
+ """Optimizer that implements the Adam algorithm.
+
+ References:
+ Adam - A Method for Stochastic Optimization:
+ [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
+ ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8,
+ use_locking=False,
+ name='Adam'):
+ r"""Construct a new Adam optimizer.
+
+ Initialization:
+
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
+
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section 2 of the paper:
+
+ $$t := t + 1$$
+ $$\text{lr}_t := \mathrm{learning_rate} *
+ \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
+
+ $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
+ $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$
+ $$\text{variable} := \text{variable} -
+ \text{lr}_t * m_t / (\sqrt{v_t} + \epsilon)$$
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1. Note that since AdamOptimizerS uses the
+ formulation just before Section 2.1 of the Kingma and Ba paper rather than
+ the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
+ hat" in the paper.
+
+ The sparse implementation of this algorithm (used when the gradient is an
+ IndexedSlices object, typically because of `tf.gather` or an embedding
+ lookup in the forward pass) does apply momentum to variable slices even if
+ they were not used in the forward pass (meaning they have a gradient equal
+ to zero). Momentum decay (beta1) is also applied to the entire momentum
+ accumulator. This means that the sparse behavior is equivalent to the dense
+ behavior (in contrast to some momentum implementations which ignore momentum
+ unless a variable slice was actually used).
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta1: A float value or a constant float tensor. The exponential decay
+ rate for the 1st moment estimates.
+ beta2: A float value or a constant float tensor. The exponential decay
+ rate for the 2nd moment estimates.
+ epsilon: A small constant for numerical stability. This epsilon is
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
+ Section 2.1), not the epsilon in Algorithm 1 of the paper.
+ use_locking: If True use locks for update operations.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
+ `epsilon` can each be a callable that takes no arguments and returns the
+ actual value to use. This can be useful for changing these values across
+ different invocations of optimizer functions.
+ @end_compatibility
+ """
+ super(AdamOptimizerS, self).__init__(use_locking, name)
+ self._lr = learning_rate
+ self._beta1 = beta1
+ self._beta2 = beta2
+ self._epsilon = epsilon
+
+ # Tensor versions of the constructor arguments, created in _prepare().
+ self._lr_t = None
+ self._beta1_t = None
+ self._beta2_t = None
+ self._epsilon_t = None
+
+ def _get_beta_accumulators(self):
+ with ops.init_scope():
+ if context.executing_eagerly():
+ graph = None
+ else:
+ graph = ops.get_default_graph()
+ return (self._get_non_slot_variable('beta1_power', graph=graph),
+ self._get_non_slot_variable('beta2_power', graph=graph))
+
+ def _create_slots(self, var_list):
+ # Create the beta1 and beta2 accumulators on the same device as the first
+ # variable. Sort the var_list to make sure this device is consistent across
+ # workers (these need to go on the same PS, otherwise some updates are
+ # silently ignored).
+ first_var = min(var_list, key=lambda x: x.name)
+ self._create_non_slot_variable(
+ initial_value=self._beta1, name='beta1_power', colocate_with=first_var)
+ self._create_non_slot_variable(
+ initial_value=self._beta2, name='beta2_power', colocate_with=first_var)
+
+ # Create slots for the first and second moments.
+ for v in var_list:
+ self._zeros_slot(v, 'm', self._name)
+ self._zeros_slot(v, 'v', self._name)
+
+ def _prepare(self):
+ lr = self._call_if_callable(self._lr)
+ beta1 = self._call_if_callable(self._beta1)
+ beta2 = self._call_if_callable(self._beta2)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._lr_t = ops.convert_to_tensor(lr, name='learning_rate')
+ self._beta1_t = ops.convert_to_tensor(beta1, name='beta1')
+ self._beta2_t = ops.convert_to_tensor(beta2, name='beta2')
+ self._epsilon_t = ops.convert_to_tensor(epsilon, name='epsilon')
+
+ def _apply_dense(self, grad, var):
+ m = self.get_slot(var, 'm')
+ v = self.get_slot(var, 'v')
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ return training_ops.apply_adam(
+ var,
+ m,
+ v,
+ math_ops.cast(beta1_power, var.dtype.base_dtype),
+ math_ops.cast(beta2_power, var.dtype.base_dtype),
+ math_ops.cast(self._lr_t, var.dtype.base_dtype),
+ math_ops.cast(self._beta1_t, var.dtype.base_dtype),
+ math_ops.cast(self._beta2_t, var.dtype.base_dtype),
+ math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var):
+ m = self.get_slot(var, 'm')
+ v = self.get_slot(var, 'v')
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ return training_ops.resource_apply_adam(
+ var.handle,
+ m.handle,
+ v.handle,
+ math_ops.cast(beta1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta2_power, grad.dtype.base_dtype),
+ math_ops.cast(self._lr_t, grad.dtype.base_dtype),
+ math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
+ math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
+ math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add):
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+ lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+ beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+ beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+ epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+ # m_t = beta1 * m + (1 - beta1) * g_t
+ m = self.get_slot(var, 'm')
+ m_scaled_g_values = grad * (1 - beta1_t)
+ # m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
+ m_decay = array_ops.gather(m, indices) * beta1_t
+ m_part_n = m_scaled_g_values + m_decay
+ m_t = state_ops.scatter_update(m, indices, m_part_n)
+ # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+ v = self.get_slot(var, 'v')
+ v_scaled_g_values = (grad * grad) * (1 - beta2_t)
+ v_decay = array_ops.gather(v, indices) * beta2_t
+ v_part_n = v_scaled_g_values + v_decay
+ v_t = state_ops.scatter_update(v, indices, v_part_n)
+ # v_sqrt = math_ops.sqrt(v_t)
+ # var_update = state_ops.assign_sub(
+ # var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
+ v_part_sqrt = math_ops.sqrt(v_part_n)
+ var_update = scatter_add(var, indices,
+ -lr * m_part_n / (v_part_sqrt + epsilon_t))
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _apply_sparse(self, grad, var):
+ return self._apply_sparse_shared(
+ grad.values,
+ var,
+ grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x,
+ i,
+ v,
+ use_locking=self._use_locking))
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._apply_sparse_shared(grad, var, indices,
+ self._resource_scatter_add)
+
+ def _finish(self, update_ops, name_scope):
+ # Update the power accumulators.
+ with ops.control_dependencies(update_ops):
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ with ops.colocate_with(beta1_power):
+ update_beta1 = beta1_power.assign(
+ beta1_power * self._beta1_t, use_locking=self._use_locking)
+ update_beta2 = beta2_power.assign(
+ beta2_power * self._beta2_t, use_locking=self._use_locking)
+ return control_flow_ops.group(
+ *update_ops + [update_beta1, update_beta2], name=name_scope)
diff --git a/easy_rec/python/compat/array_ops.py b/easy_rec/python/compat/array_ops.py
new file mode 100644
index 000000000..d788bc8c1
--- /dev/null
+++ b/easy_rec/python/compat/array_ops.py
@@ -0,0 +1,229 @@
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_math_ops
+
+
+def convert_to_int_tensor(tensor, name, dtype=tf.int32):
+ """Converts the given value to an integer Tensor."""
+ tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
+ if tensor.dtype.is_integer:
+ tensor = gen_math_ops.cast(tensor, dtype)
+ else:
+ raise TypeError('%s must be an integer tensor; dtype=%s' %
+ (name, tensor.dtype))
+ return tensor
+
+
+def _with_nonzero_rank(data):
+ """If `data` is scalar, then add a dimension; otherwise return as-is."""
+ if data.shape.ndims is not None:
+ if data.shape.ndims == 0:
+ return tf.stack([data])
+ else:
+ return data
+ else:
+ data_shape = tf.shape(data)
+ data_ndims = tf.rank(data)
+ return tf.reshape(data, tf.concat([[1], data_shape], axis=0)[-data_ndims:])
+
+
+def get_positive_axis(axis, ndims):
+ """Validate an `axis` parameter, and normalize it to be positive.
+
+ If `ndims` is known (i.e., not `None`), then check that `axis` is in the
+ range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
+ `axis + ndims` (otherwise).
+ If `ndims` is not known, and `axis` is positive, then return it as-is.
+ If `ndims` is not known, and `axis` is negative, then report an error.
+
+ Args:
+ axis: An integer constant
+ ndims: An integer constant, or `None`
+
+ Returns:
+ The normalized `axis` value.
+
+ Raises:
+ ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
+ `ndims is None`.
+ """
+ if not isinstance(axis, int):
+ raise TypeError('axis must be an int; got %s' % type(axis).__name__)
+ if ndims is not None:
+ if 0 <= axis < ndims:
+ return axis
+ elif -ndims <= axis < 0:
+ return axis + ndims
+ else:
+ raise ValueError('axis=%s out of bounds: expected %s<=axis<%s' %
+ (axis, -ndims, ndims))
+ elif axis < 0:
+ raise ValueError('axis may only be negative if ndims is statically known.')
+ return axis
+
+
+def tile_one_dimension(data, axis, multiple):
+ """Tiles a single dimension of a tensor."""
+ # Assumes axis is a nonnegative int.
+ if data.shape.ndims is not None:
+ multiples = [1] * data.shape.ndims
+ multiples[axis] = multiple
+ else:
+ ones_value = tf.ones(tf.rank(data), tf.int32)
+ multiples = tf.concat(
+ [ones_value[:axis], [multiple], ones_value[axis + 1:]], axis=0)
+ return tf.tile(data, multiples)
+
+
+def _all_dimensions(x):
+ """Returns a 1D-tensor listing all dimensions in x."""
+ # Fast path: avoid creating Rank and Range ops if ndims is known.
+ if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
+ return constant_op.constant(np.arange(x.get_shape().ndims), dtype=tf.int32)
+ if (isinstance(x, sparse_tensor.SparseTensor) and
+ x.dense_shape.get_shape().is_fully_defined()):
+ r = x.dense_shape.get_shape().dims[0].value # sparse.dense_shape is 1-D.
+ return constant_op.constant(np.arange(r), dtype=tf.int32)
+
+ # Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
+ return gen_math_ops._range(0, tf.rank(x), 1)
+
+
+# This op is intended to exactly match the semantics of numpy.repeat, with
+# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
+# when axis is not specified. Rather than implement that special behavior, we
+# simply make `axis` be a required argument.
+#
+# External (OSS) `tf.repeat` feature request:
+# https://github.com/tensorflow/tensorflow/issues/8246
+def repeat_with_axis(data, repeats, axis, name=None):
+ """Repeats elements of `data`.
+
+ Args:
+ data: An `N`-dimensional tensor.
+ repeats: A 1-D integer tensor specifying how many times each element in
+ `axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`.
+ Supports broadcasting from a scalar value.
+ axis: `int`. The axis along which to repeat values. Must be less than
+ `max(N, 1)`.
+ name: A name for the operation.
+
+ Returns:
+ A tensor with `max(N, 1)` dimensions. Has the same shape as `data`,
+ except that dimension `axis` has size `sum(repeats)`.
+ #### Examples:
+ ```python
+ >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
+ ['a', 'a', 'a', 'c', 'c']
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
+ [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
+ [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
+ ```
+ """
+ if not isinstance(axis, int):
+ raise TypeError('axis must be an int; got %s' % type(axis).__name__)
+
+ with ops.name_scope(name, 'Repeat', [data, repeats]):
+ data = ops.convert_to_tensor(data, name='data')
+ repeats = convert_to_int_tensor(repeats, name='repeats')
+ repeats.shape.with_rank_at_most(1)
+
+ # If `data` is a scalar, then upgrade it to a vector.
+ data = _with_nonzero_rank(data)
+ data_shape = tf.shape(data)
+
+ # If `axis` is negative, then convert it to a positive value.
+ axis = get_positive_axis(axis, data.shape.ndims)
+
+ # Check data Tensor shapes.
+ if repeats.shape.ndims == 1:
+ data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
+
+ # If we know that `repeats` is a scalar, then we can just tile & reshape.
+ if repeats.shape.ndims == 0:
+ expanded = tf.expand_dims(data, axis + 1)
+ tiled = tile_one_dimension(expanded, axis + 1, repeats)
+ result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
+ axis=0)
+ return tf.reshape(tiled, result_shape)
+
+ # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
+ if repeats.shape.ndims != axis + 1:
+ repeats_shape = tf.shape(repeats)
+ repeats_ndims = tf.rank(repeats)
+ broadcast_shape = tf.concat(
+ [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
+ repeats = tf.broadcast_to(repeats, broadcast_shape)
+ repeats.set_shape([None] * (axis + 1))
+
+ # Create a "sequence mask" based on `repeats`, where slices across `axis`
+ # contain one `True` value for each repetition. E.g., if
+ # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
+ max_repeat = gen_math_ops.maximum(
+ 0, gen_math_ops._max(repeats, _all_dimensions(repeats)))
+ mask = tf.sequence_mask(repeats, max_repeat)
+
+ # Add a new dimension around each value that needs to be repeated, and
+ # then tile that new dimension to match the maximum number of repetitions.
+ expanded = tf.expand_dims(data, axis + 1)
+ tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
+
+ # Use `boolean_mask` to discard the extra repeated values. This also
+ # flattens all dimensions up through `axis`.
+ masked = tf.boolean_mask(tiled, mask)
+
+ # Reshape the output tensor to add the outer dimensions back.
+ if axis == 0:
+ result = masked
+ else:
+ result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
+ axis=0)
+ result = tf.reshape(masked, result_shape)
+
+ # Preserve shape information.
+ if data.shape.ndims is not None:
+ new_axis_size = 0 if repeats.shape[0] == 0 else None
+ result.set_shape(data.shape[:axis].concatenate(
+ [new_axis_size]).concatenate(data.shape[axis + 1:]))
+
+ return result
+
+
+def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
+ """Repeat elements of `input`.
+
+ Args:
+ input: An `N`-dimensional Tensor.
+ repeats: An 1-D `int` Tensor. The number of repetitions for each element.
+ repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
+ must equal `input.shape[axis]` if axis is not None.
+ axis: An int. The axis along which to repeat values. By default (axis=None),
+ use the flattened input array, and return a flat output array.
+ name: A name for the operation.
+
+ Returns:
+ A Tensor which has the same shape as `input`, except along the given axis.
+ If axis is None then the output array is flattened to match the flattened
+ input array.
+ #### Examples:
+ ```python
+ >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
+ ['a', 'a', 'a', 'c', 'c']
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
+ [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
+ [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
+ >>> repeat(3, repeats=4)
+ [3, 3, 3, 3]
+ >>> repeat([[1,2], [3,4]], repeats=2)
+ [1, 1, 2, 2, 3, 3, 4, 4]
+ ```
+ """
+ if axis is None:
+ input = tf.reshape(input, [-1])
+ axis = 0
+ return repeat_with_axis(input, repeats, axis, name)
diff --git a/easy_rec/python/compat/dynamic_variable.py b/easy_rec/python/compat/dynamic_variable.py
new file mode 100644
index 000000000..83414331c
--- /dev/null
+++ b/easy_rec/python/compat/dynamic_variable.py
@@ -0,0 +1,542 @@
+#
+# Copyright (c) 2022, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+
+import tensorflow as tf
+from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
+from sparse_operation_kit.experiment.communication import num_gpus
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+# from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops.resource_variable_ops import ResourceVariable
+from tensorflow.python.ops.resource_variable_ops import variable_accessed
+
+# from tensorflow.python.util import object_identity
+
+dynamic_variable_count = 0
+
+_resource_var_from_proto = ResourceVariable.from_proto
+
+
+class DynamicVariable(ResourceVariable):
+ """Abbreviated as ``sok.experiment.DynamicVariable``.
+
+ A variable that allocates memory dynamically.
+
+ Parameters
+ ----------
+ dimension: int
+ The last dimension of this variable(that is, the embedding vector
+ size of embedding table).
+
+ initializer: string
+ a string to specify how to initialize this variable.
+ Currently, only support "random" or string of a float
+ value(meaning const initializer). Default value is "random".
+
+ var_type: string
+ a string to specify to use DET or HKV as the backend.
+ If use HKV as the backend, only support tf.int64 as key_type
+ If use HKV as the backend, please set init_capacity and max_capacity value equal to 2 powers.
+
+ key_type: dtype
+ specify the data type of indices. Unlike the static variable of
+ tensorflow, this variable is dynamically allocated and contains
+ a hash table inside it. So the data type of indices must be
+ specified to construct the hash table. Default value is tf.int64.
+
+ dtype: dtype
+ specify the data type of values. Default value is tf.float32.
+
+ Example
+ -------
+ .. code-block:: python
+
+ import numpy as np
+ import tensorflow as tf
+ import horovod.tensorflow as hvd
+ from sparse_operation_kit import experiment as sok
+
+ v = sok.DynamicVariable(dimension=3, initializer="13")
+ print("v.shape:", v.shape)
+ print("v.size:", v.size)
+
+ indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)
+
+ embedding = tf.nn.embedding_lookup(v, indices)
+ print("embedding:", embedding)
+ print("v.shape:", v.shape)
+ print("v.size:", v.size)
+ """
+
+ def __init__(self,
+ dimension,
+ initializer=None,
+ var_type=None,
+ name=None,
+ constraint=None,
+ trainable=True,
+ key_type=None,
+ dtype=None,
+ mode=None,
+ variable_def=None,
+ import_scope=None,
+ **kwargs):
+ self._indices = None
+ if variable_def is not None:
+ super(DynamicVariable, self)._init_from_proto(
+ variable_def, import_scope=import_scope, validate_shape=False)
+ g = ops.get_default_graph()
+ handle = g.as_graph_element(
+ ops.prepend_name_scope(
+ variable_def.variable_name, import_scope=import_scope),
+ allow_operation=False)
+ self._dimension = handle.op.get_attr('shape').dim[-1].size
+ self._key_type = handle.op.get_attr('key_type')
+ self._handle_type = handle.op.get_attr('dtype')
+ self._mode = None
+ self._config = {}
+ self._name = variable_def.variable_name.split(':')[0]
+ self._trainable = variable_def.trainable
+ self._dummy_handle = handle
+ self._handle = handle
+
+ # init op
+ init_op = g.as_graph_element(variable_def.initializer_name)
+ self._initializer_op = init_op
+
+ init_tf = init_op.control_inputs[0]
+ # init_dummy = init_op.control_inputs[1]
+
+ self._tf_handle = init_tf.inputs[0]
+ return
+
+ self._key_type = key_type if key_type is not None else tf.int64
+ self._handle_dtype = dtype if dtype is not None else tf.float32
+ self._dimension = dimension
+ self._mode = mode
+ self._config = json.dumps(kwargs)
+ self._config_dict = kwargs
+ if var_type == 'hybrid' and self._key_type != tf.int64:
+ raise NotImplementedError(
+ 'only key_type tf.int64 is supported in HKV backend')
+ if name is None:
+ global dynamic_variable_count
+ name = 'sok_dynamic_Variable_' + str(dynamic_variable_count)
+ dynamic_variable_count += 1
+ var_type = 'hbm' if var_type is None else var_type
+ self._var_type = var_type
+ self._base = super(DynamicVariable, self)
+ self._base.__init__(
+ initial_value=[[0.0] * dimension],
+ trainable=trainable,
+ name=name + '/proxy',
+ dtype=self._handle_dtype,
+ constraint=constraint,
+ distribute_strategy=None,
+ synchronization=None,
+ aggregation=None,
+ shape=[None, dimension],
+ )
+
+ with ops.init_scope():
+ # name = "DynamicVariable" if name is None else name
+ with ops.name_scope(name) as name_scope:
+ self._dummy_name = ops.name_from_scope_name(name_scope)
+ if context.executing_eagerly():
+ self._dummy_name = '%s_%d' % (name, ops.uid())
+ with ops.NullContextmanager():
+ shape = [None, dimension]
+ initializer = '' if initializer is None else initializer
+ self._initializer = initializer
+ handle = dynamic_variable_ops.dummy_var_handle(
+ container='DummyVariableContainer',
+ shared_name=self._dummy_name,
+ key_type=self._key_type,
+ dtype=self._handle_dtype,
+ shape=shape,
+ )
+ if type(initializer) is str:
+ init_op = dynamic_variable_ops.dummy_var_initialize(
+ handle,
+ initializer=initializer,
+ var_type=var_type,
+ unique_name=self._dummy_name,
+ key_type=self._key_type,
+ dtype=self._handle_dtype,
+ config=self._config,
+ )
+ else:
+ with tf.control_dependencies([initializer._initializer_op]):
+ initial_val = initializer.read_value()
+ init_op = dynamic_variable_ops.dummy_var_initialize(
+ handle,
+ initializer=initial_val,
+ var_type=var_type,
+ unique_name=self._dummy_name,
+ key_type=self._key_type,
+ dtype=self._handle_dtype,
+ config=self._config,
+ )
+ # TODO: Add is_initialized_op
+ # is_initialized_op = ops.convert_to_tensor(True)
+
+ self._tf_handle = self._handle
+ self._dummy_handle = handle
+ # Note that the default handle will be sok's handle
+ self._handle = self._dummy_handle
+ self._initializer_op = tf.group([self._initializer_op, init_op])
+ # self._is_initialized_op = tf.group([self._is_initialized_op, is_initialized_op])
+
+ handle_data = (
+ resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
+ .HandleData())
+ handle_data.is_set = True
+ handle_data.shape_and_type.append(
+ resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
+ .HandleShapeAndType(
+ shape=self.shape.as_proto(), dtype=self.dtype.as_datatype_enum))
+ resource_variable_ops._set_handle_shapes_and_types(
+ self._handle,
+ handle_data,
+ graph_mode=False if context.executing_eagerly() else True)
+
+ def is_static(self):
+ return self._handle is self._tf_handle
+
+ def to_static(self, indices, lookup_only=False):
+ if not self.is_static() and self._indices is None:
+ buffer = self.sparse_read(indices, lookup_only)
+ self._indices = indices
+ self._handle = self._tf_handle
+ return self.assign(buffer)
+ else:
+ raise RuntimeError('to_static() must be called in dynamic mode.')
+
+ def to_dynamic(self):
+ if self.is_static():
+ buffer = self.read_value()
+ sparse_delta = ops.IndexedSlices(buffer, self._indices, self.shape)
+ self._indices = None
+ self._handle = self._dummy_handle
+ return self.scatter_update(sparse_delta)
+ else:
+ raise RuntimeError('to_dynamic() must be called in static mode.')
+
+ @property
+ def name(self):
+ return self._dummy_handle.name
+
+ def __repr__(self):
+ if self.is_static():
+ return self._base.__repr__()
+ return "" % (
+ self._dummy_name,
+ self.shape,
+ self.dtype.name,
+ )
+
+ @property
+ def size(self):
+ return dynamic_variable_ops.dummy_var_shape(
+ self._dummy_handle, key_type=self._key_type, dtype=self._handle_dtype)
+
+ @property
+ def indices(self):
+ return self._indices
+
+ @property
+ def dimension(self):
+ return self._dimension
+
+ def get_shape(self):
+ return [self._dimension]
+
+ @property
+ def key_type(self):
+ return self._key_type
+
+ @property
+ def handle_dtype(self):
+ return self._handle_dtype
+
+ @property
+ def backend_type(self):
+ return self._var_type
+
+ @property
+ def config_dict(self):
+ return self._config_dict
+
+ @property
+ def mode(self):
+ return self._mode
+
+ @property
+ def num_gpus(self):
+ return num_gpus()
+
+ @property
+ def initializer_str(self):
+ return self._initializer
+
+ def key_map(self, indices):
+ return indices
+
+ # -------------------------------------------------------------------------
+ # Methods supported both in static mode and dynamic mode
+ # -------------------------------------------------------------------------
+
+ def sparse_read(self, indices, name=None, lookup_only=False):
+ if self.is_static():
+ return self._base.sparse_read(indices, name)
+
+ variable_accessed(self)
+ if indices.dtype == tf.int32:
+ indices = tf.cast(indices, tf.int64)
+ return dynamic_variable_ops.dummy_var_sparse_read(
+ self._dummy_handle,
+ indices,
+ dtype=self._handle_dtype,
+ lookup_only=lookup_only)
+
+ def scatter_sub(self, sparse_delta, use_locking=False, name=None):
+ if self.is_static():
+ return self._base.scatter_sub(sparse_delta, use_locking, name)
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
+ return dynamic_variable_ops.dummy_var_scatter_add(
+ self._dummy_handle,
+ sparse_delta.indices,
+ ops.convert_to_tensor(-sparse_delta.values, self.dtype),
+ )
+
+ def scatter_add(self, sparse_delta, use_locking=False, name=None):
+ if self.is_static():
+ return self._base.scatter_add(sparse_delta, use_locking, name)
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
+ return dynamic_variable_ops.dummy_var_scatter_add(
+ self._dummy_handle,
+ sparse_delta.indices,
+ ops.convert_to_tensor(sparse_delta.values, self.dtype),
+ )
+
+ def scatter_update(self, sparse_delta, use_locking=False, name=None):
+ if self.is_static():
+ return self._base.scatter_update(sparse_delta, use_locking, name)
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
+ return dynamic_variable_ops.dummy_var_scatter_update(
+ self._dummy_handle,
+ sparse_delta.indices,
+ ops.convert_to_tensor(sparse_delta.values, self.dtype),
+ )
+
+ # -------------------------------------------------------------------------
+ # Methods not supported both in static mode and dynamic mode
+ # -------------------------------------------------------------------------
+
+ def __deepcopy__(self, *args, **kwargs):
+ raise NotImplementedError('__deepcopy__() is not supported.')
+
+ def __reduce__(self, *args, **kwargs):
+ raise NotImplementedError('__reduce__() is not supported.')
+
+ def to_proto(self, *args, **kwargs):
+ return super(DynamicVariable, self).to_proto(*args, **kwargs)
+ # raise NotImplementedError("to_proto() is not supported.")
+
+ @staticmethod
+ def from_proto(variable_def, import_scope=None):
+ if '/DummyVarHandle' in variable_def.variable_name:
+ return DynamicVariable(
+ dimension=0, variable_def=variable_def, import_scope=import_scope)
+ else:
+ return _resource_var_from_proto(variable_def, import_scope)
+ # raise NotImplementedError("from_proto() is not supported.")
+
+ def set_shape(self, *args, **kwargs):
+ raise NotImplementedError('set_shape() is not supported.')
+
+ # -------------------------------------------------------------------------
+ # Methods only supported in static mode
+ # -------------------------------------------------------------------------
+
+ def is_initialized(self, name):
+ return True
+ if self.is_static():
+ return self._base.is_initialized(name)
+ raise NotImplementedError(
+ 'is_initialized() is not supported in dynamic mode.')
+
+ def _read_variable_op(self):
+ if self.is_static():
+ return self._base._read_variable_op()
+ raise NotImplementedError(
+ '_read_variable_op() is not supported in dynamic mode.')
+
+ def value(self):
+ if self.is_static():
+ return self._base.value()
+ raise NotImplementedError('value() is not supported in dynamic mode.')
+
+ def _dense_var_to_tensor(self, *args, **kwargs):
+ if self.is_static():
+ return self._base._dense_var_to_tensor(*args, **kwargs)
+ raise NotImplementedError(
+ '_dense_var_to_tensor() is not supported in dynamic mode.')
+
+ def _gather_saveables_for_checkpoint(self):
+ if self.is_static():
+ return self._base._gather_saveables_for_checkpoint()
+ raise NotImplementedError(
+ '_gather_saveables_for_checkpoint() is not supported in dynamic mode.')
+
+ def gather_nd(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.gather_nd(*args, **kwargs)
+ raise NotImplementedError('gather_nd() is not supported in dynamic mode.')
+
+ def assign_add(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.assign_add(*args, **kwargs)
+ raise NotImplementedError('assign_add() is not supported in dynamic mode.')
+
+ def assign(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.assign(*args, **kwargs)
+ raise NotImplementedError('assign() is not supported in dynamic mode.')
+
+ def scatter_max(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.scatter_max(*args, **kwargs)
+ raise NotImplementedError('scatter_max() is not supported in dynamic mode.')
+
+ def scatter_min(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.scatter_min(*args, **kwargs)
+ raise NotImplementedError('scatter_min() is not supported in dynamic mode.')
+
+ def scatter_mul(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.scatter_mul(*args, **kwargs)
+ raise NotImplementedError('scatter_mul() is not supported in dynamic mode.')
+
+ def scatter_dim(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.scatter_dim(*args, **kwargs)
+ raise NotImplementedError('scatter_dim() is not supported in dynamic mode.')
+
+ def batch_scatter_update(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.batch_scatter_update(*args, **kwargs)
+ raise NotImplementedError(
+ 'batch_scatter_update() is not supported in dynamic mode.')
+
+ def scatter_nd_sub(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.scatter_nd_sub(*args, **kwargs)
+ raise NotImplementedError(
+ 'scatter_nd_sub() is not supported in dynamic mode.')
+
+ def scatter_nd_update(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.scatter_nd_update(*args, **kwargs)
+ raise NotImplementedError(
+ 'scatter_nd_update() is not supported in dynamic mode.')
+
+ def _strided_slice_assign(self, *args, **kwargs):
+ if self.is_static():
+ return self._base._strided_slice_assign(*args, **kwargs)
+ raise NotImplementedError(
+ '_strided_slice_assign() is not supported in dynamic mode.')
+
+ def __int__(self, *args, **kwargs):
+ if self.is_static():
+ return self._base.__int__(*args, **kwargs)
+ raise NotImplementedError('__int__() is not supported in dynamic mode.')
+
+
+ResourceVariable.from_proto = DynamicVariable.from_proto
+
+# @tf.RegisterGradient("DummyVarSparseRead")
+# def _SparseReadGrad(op, grad):
+# """Gradient for sparse_read."""
+# handle = op.inputs[0]
+# indices = op.inputs[1]
+# key_type = op.get_attr("key_type")
+# dtype = op.get_attr("dtype")
+# variable_shape = dynamic_variable_ops.dummy_var_shape(handle, key_type=key_type, dtype=dtype)
+# size = array_ops.expand_dims(array_ops.size(indices), 0)
+# values_shape = array_ops.concat([size, variable_shape[1:]], 0)
+# grad = array_ops.reshape(grad, values_shape)
+# indices = array_ops.reshape(indices, size)
+# return (ops.IndexedSlices(grad, indices, variable_shape), None)
+
+
+def export(var):
+ """Abbreviated as ``sok.experiment.export``.
+
+ Export the indices and value tensor from the given variable.
+
+ Parameters
+ ----------
+ var: sok.DynamicVariable
+ The variable to extract indices and values.
+
+ Returns
+ -------
+ indices: tf.Tensor
+ The indices of the given variable.
+
+ values: tf.Tensor
+ the values of the given variable.
+ """
+ if isinstance(var, DynamicVariable):
+ indices, values = dynamic_variable_ops.dummy_var_export(
+ var.handle, key_type=var.key_type, dtype=var.handle_dtype)
+ with tf.device('CPU'):
+ indices = tf.identity(indices)
+ values = tf.identity(values)
+ return indices, values
+
+
+def assign(var, indices, values):
+ """Abbreviated as ``sok.experiment.assign``.
+
+ Assign the indices and value tensor to the target variable.
+
+ Parameters
+ ----------
+ var: sok.DynamicVariable
+ The target variable of assign.
+
+ indices: tf.Tensor
+ indices to be assigned to the variable.
+
+ values: tf.Tensor
+ values to be assigned to the variable
+
+ Returns
+ -------
+ variable: sok.DynamicVariable
+ """
+ if isinstance(var, DynamicVariable):
+ tf.cast(indices, var._key_type)
+ return dynamic_variable_ops.dummy_var_assign(var.handle, indices, values)
diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py
index d68ee618a..fc850fb62 100644
--- a/easy_rec/python/compat/early_stopping.py
+++ b/easy_rec/python/compat/early_stopping.py
@@ -15,9 +15,15 @@
"""Utilities for early stopping."""
import collections
+import datetime
+import logging
import operator
import os
+import threading
+import time
+import tensorflow as tf
+from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
@@ -29,10 +35,15 @@
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import estimator_export
+from easy_rec.python.utils.config_util import parse_time
from easy_rec.python.utils.load_class import load_by_path
+if LooseVersion(tf.__version__) >= LooseVersion('2.12.0'):
+ from tensorflow_estimator.python.estimator.estimator_export import estimator_export
+else:
+ from tensorflow.python.util.tf_export import estimator_export
+
_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'
EARLY_STOP_SIG_SCOPE = 'signal_early_stopping'
@@ -299,7 +310,8 @@ def custom_early_stop_hook(estimator,
if eval_dir is None:
eval_dir = estimator.eval_dir()
- if isinstance(custom_stop_func, str) or isinstance(custom_stop_func, unicode):
+ if isinstance(custom_stop_func, str) or isinstance(custom_stop_func,
+ type(u'')):
custom_stop_func = load_by_path(custom_stop_func)
def _custom_stop_fn():
@@ -548,3 +560,94 @@ def after_run(self, run_context, run_values):
if should_early_stop:
tf_logging.info('Early stopping requested, suspending run.')
run_context.request_stop()
+
+
+class OssStopSignalHook(session_run_hook.SessionRunHook):
+
+ def __init__(self, model_dir, run_every_secs=10, run_every_steps=None):
+ self._stop_sig_file = os.path.join(model_dir, 'OSS_STOP_SIGNAL')
+ self._stop = False
+ self._check_run = True
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=run_every_secs, every_steps=run_every_steps)
+ sleep_time = run_every_secs if run_every_secs is not None else 1
+ self._curr_step = 0
+
+ def _check_stop():
+ while self._check_run:
+ if self._timer.should_trigger_for_step(self._curr_step):
+ self._timer.update_last_triggered_step(self._curr_step)
+ if gfile.Exists(self._stop_sig_file):
+ self._stop = True
+ logging.info('OssStopSignalHook: stop on signal %s' %
+ self._stop_sig_file)
+ break
+ else:
+ time.sleep(sleep_time)
+
+ self._th = threading.Thread(target=_check_stop)
+ self._th.start()
+
+ self._global_step_tensor = None
+ self._stop_var = _get_or_create_stop_var()
+ self._stop_op = None
+
+ def begin(self):
+ self._global_step_tensor = training_util.get_global_step()
+ self._stop_op = state_ops.assign(self._stop_var, True)
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ if self._stop:
+ run_context.request_stop()
+ run_context.session.run(self._stop_op)
+ self._curr_step = run_values.results
+
+ def end(self, session):
+ self._check_run = False
+ self._th.join()
+
+
+def oss_stop_hook(estimator, run_every_secs=10, run_every_steps=None):
+ """Creates oss stop hook.
+
+ Returns a `SessionRunHook` that stops training when model_dir/OSS_STOP_SIGNAL is created.
+ """
+ if estimator.config.is_chief:
+ return OssStopSignalHook(
+ estimator.model_dir,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+ else:
+ return _CheckForStoppingHook()
+
+
+class DeadlineStopHook(session_run_hook.SessionRunHook):
+
+ def __init__(self, deadline_ts):
+ self._deadline_ts = deadline_ts
+ self._stop_var = _get_or_create_stop_var()
+ self._stop_op = None
+
+ def begin(self):
+ self._stop_op = state_ops.assign(self._stop_var, True)
+
+ def after_run(self, run_context, run_values):
+ curr_ts = time.mktime(datetime.datetime.now().timetuple())
+ if curr_ts > self._deadline_ts:
+ run_context.request_stop()
+ run_context.session.run(self._stop_op)
+
+
+def deadline_stop_hook(estimator, dead_line):
+ """Creates oss stop hook.
+
+ Returns a `SessionRunHook` that stops training when timestamp > deadline_ts.
+ """
+ deadline_ts = parse_time(dead_line)
+ if estimator.config.is_chief:
+ return DeadlineStopHook(deadline_ts)
+ else:
+ return _CheckForStoppingHook()
diff --git a/easy_rec/python/compat/embedding_parallel_saver.py b/easy_rec/python/compat/embedding_parallel_saver.py
new file mode 100644
index 000000000..96c08c592
--- /dev/null
+++ b/easy_rec/python/compat/embedding_parallel_saver.py
@@ -0,0 +1,316 @@
+# -*- encoding:utf-8 -*-
+
+import logging
+import os
+
+import numpy as np
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+# from tensorflow.python.ops import math_ops
+# from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.training import saver
+
+from easy_rec.python.utils import constant
+
+try:
+ import horovod.tensorflow as hvd
+ from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
+ from easy_rec.python.compat import dynamic_variable
+except Exception:
+ dynamic_variable_ops = None
+ dynamic_variable = None
+
+try:
+ from tensorflow.python.framework.load_library import load_op_library
+ import easy_rec
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
+ load_embed_lib = load_op_library(load_embed_lib_path)
+except Exception as ex:
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
+ load_embed_lib = None
+
+
+def _get_embed_part_id(embed_file):
+ embed_file = embed_file.split('/')[-1]
+ embed_file = embed_file.split('.')[0]
+ embed_id = embed_file.split('-')[-1]
+ return int(embed_id)
+
+
+class EmbeddingParallelSaver(saver.Saver):
+
+ def __init__(self,
+ var_list=None,
+ reshape=False,
+ sharded=False,
+ max_to_keep=5,
+ keep_checkpoint_every_n_hours=10000.0,
+ name=None,
+ restore_sequentially=False,
+ saver_def=None,
+ builder=None,
+ defer_build=False,
+ allow_empty=False,
+ write_version=saver_pb2.SaverDef.V2,
+ pad_step_number=False,
+ save_relative_paths=False,
+ filename=None):
+ self._kv_vars = []
+ self._embed_vars = []
+ tf_vars = []
+ embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
+ for var in var_list:
+ if dynamic_variable is not None and isinstance(
+ var, dynamic_variable.DynamicVariable):
+ self._kv_vars.append(var)
+ elif var.name in embed_para_vars:
+ logging.info('save shard embedding %s part_id=%d part_shape=%s' %
+ (var.name, hvd.rank(), var.get_shape()))
+ self._embed_vars.append(var)
+ else:
+ tf_vars.append(var)
+ super(EmbeddingParallelSaver, self).__init__(
+ tf_vars,
+ reshape=reshape,
+ sharded=sharded,
+ max_to_keep=max_to_keep,
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+ name=name,
+ restore_sequentially=restore_sequentially,
+ saver_def=saver_def,
+ builder=builder,
+ defer_build=defer_build,
+ allow_empty=allow_empty,
+ write_version=write_version,
+ pad_step_number=pad_step_number,
+ save_relative_paths=save_relative_paths,
+ filename=filename)
+ self._is_build = False
+
+ def _has_embed_vars(self):
+ return (len(self._kv_vars) + len(self._embed_vars)) > 0
+
+ def _save_dense_embedding(self, embed_var):
+ logging.info('task[%d] save_dense_embed: %s' % (hvd.rank(), embed_var.name))
+
+ def _save_embed(embed, filename, var_name):
+ task_id = hvd.rank()
+ filename = filename.decode('utf-8')
+ var_name = var_name.decode('utf-8').replace('/', '__')
+ embed_dir = filename + '-embedding/'
+ logging.info('task[%d] save_dense_embed: %s to %s' %
+ (task_id, var_name, embed_dir))
+ if not gfile.Exists(embed_dir):
+ gfile.MakeDirs(embed_dir)
+ embed_file = filename + '-embedding/embed-' + var_name + '-part-%d.bin' % task_id
+ with gfile.GFile(embed_file, 'wb') as fout:
+ fout.write(embed.tobytes())
+
+ if task_id == 0:
+ # clear old embedding tables
+ embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
+ embed_files = gfile.Glob(embed_pattern)
+ for embed_file in embed_files:
+ embed_id = _get_embed_part_id(embed_file)
+ if embed_id >= hvd.size():
+ gfile.DeleteRecursively(embed_file)
+ return np.asarray([embed_file], order='C', dtype=np.object)
+
+ file_name = ops.get_default_graph().get_tensor_by_name(
+ self.saver_def.filename_tensor_name)
+ save_paths = script_ops.py_func(_save_embed,
+ [embed_var, file_name, embed_var.name],
+ dtypes.string)
+ return save_paths
+
+ def _load_dense_embedding(self, embed_var):
+ file_name = ops.get_default_graph().get_tensor_by_name(
+ self.saver_def.filename_tensor_name)
+ embed_dim = embed_var.get_shape()[-1]
+ embed_part_size = embed_var.get_shape()[0]
+
+ def _load_embed(embed, embed_dim, embed_part_size, part_id, part_num,
+ filename, var_name):
+ filename = filename.decode('utf-8')
+ var_name = var_name.decode('utf-8').replace('/', '__')
+ embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
+ embed_files = gfile.Glob(embed_pattern)
+
+ embed_files.sort(key=_get_embed_part_id)
+
+ logging.info('task[%d] embed_files=%s embed_dim=%d embed_part_size=%d' %
+ (part_id, ','.join(embed_files), embed_dim, embed_part_size))
+
+ part_embed_vals = np.zeros([embed_part_size, embed_dim], dtype=np.float32)
+ part_update_cnt = 0
+ for embed_file in embed_files:
+ part_id_o = _get_embed_part_id(embed_file)
+ with gfile.GFile(embed_file, 'rb') as fin:
+ embed_val = np.frombuffer(fin.read(), np.float32)
+ embed_val = embed_val.reshape([-1, embed_dim])
+ embed_ids_o = np.arange(len(embed_val))
+ embed_ids_o = part_id_o + embed_ids_o * len(embed_files)
+ sel_ids = np.where(
+ np.logical_and((embed_ids_o % part_num) == part_id,
+ embed_ids_o < embed_part_size * part_num))[0]
+ part_update_cnt += len(sel_ids)
+ embed_ids = embed_ids_o[sel_ids]
+ embed_ids_n = np.array(embed_ids / part_num, dtype=np.int64)
+ part_embed_vals[embed_ids_n] = embed_val[sel_ids]
+ logging.info('task[%d] load_part_cnt=%d' % (part_id, part_update_cnt))
+ return part_embed_vals
+
+ with ops.control_dependencies([embed_var._initializer_op]):
+ if load_embed_lib is not None:
+ embed_val = load_embed_lib.load_embed(
+ task_index=hvd.rank(),
+ task_num=hvd.size(),
+ embed_dim=embed_dim,
+ embed_part_size=embed_part_size,
+ var_name='embed-' + embed_var.name.replace('/', '__'),
+ ckpt_path=file_name)
+ else:
+ embed_val = script_ops.py_func(_load_embed, [
+ embed_var, embed_dim, embed_part_size,
+ hvd.rank(),
+ hvd.size(), file_name, embed_var.name
+ ], dtypes.float32)
+ embed_val.set_shape(embed_var.get_shape())
+ return state_ops.assign(embed_var, embed_val)
+
+ def _save_kv_embedding(self, sok_var):
+ indices, values = dynamic_variable_ops.dummy_var_export(
+ sok_var.handle, key_type=sok_var.key_type, dtype=sok_var.handle_dtype)
+ file_name = ops.get_default_graph().get_tensor_by_name(
+ self.saver_def.filename_tensor_name)
+
+ def _save_key_vals(indices, values, filename, var_name):
+ var_name = var_name.decode('utf-8').replace('/', '__')
+ filename = filename.decode('utf-8')
+ sok_dir = filename + '-embedding/'
+ if not gfile.Exists(sok_dir):
+ gfile.MakeDirs(sok_dir)
+ task_id = hvd.rank()
+ key_file = filename + '-embedding/embed-' + var_name + '-part-%d.key' % task_id
+ with gfile.GFile(key_file, 'wb') as fout:
+ fout.write(indices.tobytes())
+ val_file = filename + '-embedding/embed-' + var_name + '-part-%d.val' % task_id
+ with gfile.GFile(val_file, 'wb') as fout:
+ fout.write(values.tobytes())
+
+ if task_id == 0:
+ key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
+ key_files = gfile.Glob(key_file_pattern)
+ for key_file in key_files:
+ embed_id = _get_embed_part_id(key_file)
+ if embed_id >= hvd.size():
+ gfile.DeleteRecursively(key_file)
+ val_file = key_file[:-4] + '.val'
+ if gfile.Exists(val_file):
+ gfile.DeleteRecursively(val_file)
+
+ return np.asarray([key_file, val_file], order='C', dtype=np.object)
+
+ save_paths = script_ops.py_func(_save_key_vals,
+ [indices, values, file_name, sok_var.name],
+ dtypes.string)
+ return save_paths
+
+ def _load_kv_embedding(self, sok_var):
+
+ def _load_key_vals(filename, var_name):
+ var_name = var_name.decode('utf-8').replace('/', '__')
+ filename = filename.decode('utf-8')
+ key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
+ logging.info('key_file_pattern=%s filename=%s var_name=%s var=%s' %
+ (key_file_pattern, filename, var_name, str(sok_var)))
+ key_files = gfile.Glob(key_file_pattern)
+ logging.info('key_file_pattern=%s file_num=%d' %
+ (key_file_pattern, len(key_files)))
+ all_keys = []
+ all_vals = []
+ for key_file in key_files:
+ with gfile.GFile(key_file, 'rb') as fin:
+ tmp_keys = np.frombuffer(fin.read(), dtype=np.int64)
+ tmp_ids = tmp_keys % hvd.size()
+ tmp_ids = np.where(tmp_ids == hvd.rank())[0]
+ if len(tmp_ids) == 0:
+ break
+ all_keys.append(tmp_keys.take(tmp_ids, axis=0))
+ logging.info('part_keys.shape=%s %s %s' % (str(
+ tmp_keys.shape), str(tmp_ids.shape), str(all_keys[-1].shape)))
+
+ val_file = key_file[:-4] + 'vals'
+ with gfile.GFile(val_file, 'rb') as fin:
+ tmp_vals = np.frombuffer(
+ fin.read(), dtype=np.float32).reshape([-1, sok_var._dimension])
+ all_vals.append(tmp_vals.take(tmp_ids, axis=0))
+ logging.info('part_vals.shape=%s %s %s' % (str(
+ tmp_vals.shape), str(tmp_ids.shape), str(all_vals[-1].shape)))
+
+ all_keys = np.concatenate(all_keys, axis=0)
+ all_vals = np.concatenate(all_vals, axis=0)
+
+ shuffle_ids = np.array(range(len(all_keys)))
+ np.random.shuffle(shuffle_ids)
+ all_keys = all_keys.take(shuffle_ids, axis=0)
+ all_vals = all_vals.take(shuffle_ids, axis=0)
+ return all_keys, all_vals
+
+ file_name = ops.get_default_graph().get_tensor_by_name(
+ self.saver_def.filename_tensor_name)
+ if load_embed_lib is not None:
+ keys, vals = load_embed_lib.load_kv_embed(
+ task_index=hvd.rank(),
+ task_num=hvd.size(),
+ embed_dim=sok_var._dimension,
+ var_name='embed-' + sok_var.name.replace('/', '__'),
+ ckpt_path=file_name)
+ else:
+ logging.warning('libload_embed.so not loaded, will use python script_ops')
+ keys, vals = script_ops.py_func(_load_key_vals, [file_name, sok_var.name],
+ (dtypes.int64, dtypes.float32))
+ with ops.control_dependencies([sok_var._initializer_op]):
+ return dynamic_variable_ops.dummy_var_assign(sok_var.handle, keys, vals)
+
+ def build(self):
+ if self._is_built:
+ return
+ super(EmbeddingParallelSaver, self).build()
+ if self.saver_def.restore_op_name and self._has_embed_vars():
+ # load data from the model
+ restore_ops = []
+ for sok_var in self._kv_vars:
+ restore_ops.append(self._load_kv_embedding(sok_var))
+ for embed_var in self._embed_vars:
+ restore_ops.append(self._load_dense_embedding(embed_var))
+ old_restore_op = ops.get_default_graph().get_operation_by_name(
+ self.saver_def.restore_op_name)
+ restore_ops.append(old_restore_op)
+ restore_op_n = control_flow_ops.group(restore_ops)
+ self.saver_def.restore_op_name = restore_op_n.name
+
+ if self.saver_def.save_tensor_name and self._has_embed_vars():
+ file_name = ops.get_default_graph().get_tensor_by_name(
+ self.saver_def.filename_tensor_name)
+ save_part_ops = []
+ for sok_var in self._kv_vars:
+ save_part_op = self._save_kv_embedding(sok_var)
+ save_part_ops.append(save_part_op)
+ for embed_var in self._embed_vars:
+ save_part_op = self._save_dense_embedding(embed_var)
+ save_part_ops.append(save_part_op)
+ old_save_op = ops.get_default_graph().get_tensor_by_name(
+ self.saver_def.save_tensor_name)
+ # only the first worker needs to save non embedding variables
+ if hvd.rank() == 0:
+ save_part_ops.append(old_save_op)
+ with ops.control_dependencies(save_part_ops):
+ save_op_n = array_ops.identity(file_name)
+ self.saver_def.save_tensor_name = save_op_n.name
diff --git a/easy_rec/python/compat/estimator_train.py b/easy_rec/python/compat/estimator_train.py
index 076878a3e..1ec71e491 100644
--- a/easy_rec/python/compat/estimator_train.py
+++ b/easy_rec/python/compat/estimator_train.py
@@ -8,7 +8,9 @@
from tensorflow.python.estimator.training import _assert_eval_spec
from tensorflow.python.estimator.training import _ContinuousEvalListener
from tensorflow.python.estimator.training import _TrainingExecutor
+from tensorflow.python.util import compat
+from easy_rec.python.compat.exporter import FinalExporter
from easy_rec.python.utils import estimator_utils
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training # NOQA
@@ -80,8 +82,35 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
'(with task id 0). Given task id {}'.format(config.task_id))
result = executor.run()
+
+ # fix for the bug evaluator fails to export in case num_epoch is reached
+ # before num_steps is reached or num_steps is set to infinite
+ if estimator_utils.is_evaluator():
+ export_dir_base = os.path.join(
+ compat.as_str_any(estimator.model_dir), compat.as_str_any('export'))
+ for exporter in eval_spec.exporters:
+ if isinstance(exporter, FinalExporter):
+ export_path = os.path.join(
+ compat.as_str_any(export_dir_base),
+ compat.as_str_any(exporter.name))
+ # avoid duplicate export
+ if gfile.IsDirectory(export_path + '/'):
+ continue
+ exporter.export(
+ estimator=estimator,
+ export_path=export_path,
+ checkpoint_path=estimator_utils.latest_checkpoint(
+ estimator.model_dir),
+ eval_result=None,
+ is_the_final_export=True)
+
if estimator_utils.is_chief():
with gfile.GFile(train_done_listener.train_done_file, 'w') as fout:
fout.write('Train Done.')
return result
+
+
+def estimator_train_done(estimator):
+ train_done_file = os.path.join(estimator.model_dir, 'ESTIMATOR_TRAIN_DONE')
+ return gfile.Exists(train_done_file)
diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py
index fbadcf2d8..73a568d9c 100644
--- a/easy_rec/python/compat/feature_column/feature_column.py
+++ b/easy_rec/python/compat/feature_column/feature_column.py
@@ -135,6 +135,7 @@
import abc
import collections
import math
+import os
import numpy as np
import six
@@ -145,9 +146,11 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine import training
from tensorflow.python.layers import base
+# from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
@@ -160,13 +163,198 @@
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+# from tensorflow.python.ops.ragged import ragged_tensor
+# from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
-from easy_rec.python.compat import embedding_ops as ev_embedding_ops
from easy_rec.python.compat.feature_column import utils as fc_utils
+from easy_rec.python.utils import conditional
+from easy_rec.python.utils import constant
+from easy_rec.python.utils import embedding_utils
+
+try:
+ from easy_rec.python.compat import dynamic_variable
+except Exception:
+ dynamic_variable = None
+
+try:
+ import horovod.tensorflow as hvd
+except Exception:
+ hvd = None
+
+
+def embedding_lookup_ragged(embedding_weights,
+ ragged_ids,
+ ragged_weights,
+ combiner,
+ max_norm=None,
+ name=None):
+ segment_ids = ragged_ids.value_rowids()
+ if segment_ids.dtype != dtypes.int32:
+ segment_ids = math_ops.cast(segment_ids, dtypes.int32)
+ ids = ragged_ids.flat_values
+ ids, idx = array_ops.unique(ids)
+ embeddings = embedding_ops.embedding_lookup(
+ embedding_weights, ids, partition_strategy='mod', max_norm=max_norm)
+ if ragged_weights is not None:
+ weights = ragged_weights.flat_values
+ embeddings = array_ops.gather(embeddings, idx)
+ original_dtype = embeddings.dtype
+ if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
+ # Cast low-precision embeddings to float32 during the computation to
+ # avoid numerical issues.
+ embeddings = math_ops.cast(embeddings, dtypes.float32)
+ if weights.dtype != embeddings.dtype:
+ weights = math_ops.cast(weights, embeddings.dtype)
+ weights = array_ops.expand_dims(weights, len(embeddings.get_shape()))
+ embeddings = embeddings * weights
+ if combiner == 'sum':
+ return math_ops.segment_sum(embeddings, segment_ids, name=name)
+ elif combiner == 'mean':
+ embeddings = math_ops.segment_sum(embeddings, segment_ids)
+ weight_sum = math_ops.segment_sum(weights, segment_ids)
+ embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name)
+ elif combiner == 'sqrtn':
+ embeddings = math_ops.segment_sum(embeddings, segment_ids)
+ weights_squared = math_ops.pow(weights, 2)
+ weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
+ weight_sum_sqrt = math_ops.sqrt(weight_sum)
+ embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name)
+ else:
+ assert False, 'Unrecognized combiner'
+ if embeddings.dtype != original_dtype:
+ embeddings = math_ops.cast(embeddings, original_dtype)
+ return embeddings
+ else:
+ assert idx is not None
+ if combiner == 'sum':
+ embeddings = math_ops.sparse_segment_sum(
+ embeddings, idx, segment_ids, name=name)
+ elif combiner == 'mean':
+ embeddings = math_ops.sparse_segment_mean(
+ embeddings, idx, segment_ids, name=name)
+ elif combiner == 'sqrtn':
+ embeddings = math_ops.sparse_segment_sqrt_n(
+ embeddings, idx, segment_ids, name=name)
+ else:
+ assert False, 'Unrecognized combiner'
+ return embeddings
+
+
+# model parallel embedding lookup
+def embedding_parallel_lookup(embedding,
+ lookup_indices,
+ output_ids,
+ is_training,
+ output_tensors=None,
+ batch_size=None):
+ N = len(output_ids)
+ if batch_size is None:
+ num_segments = None
+ else:
+ num_segments = N * batch_size
+ # first concat all the ids and unique
+ if isinstance(lookup_indices, dict) and 'sparse_fea' in lookup_indices.keys():
+ # all_uniq_ids, uniq_idx, segment_lens = features['sparse_fea']
+ all_ids, segment_lens = lookup_indices['sparse_fea']
+ all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
+ cumsum_lens = math_ops.cumsum(segment_lens)
+ segment_ids = array_ops.searchsorted(
+ cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
+ elif isinstance(lookup_indices, dict) and 'ragged_ids' in lookup_indices.keys(
+ ) and 'ragged_lens' in lookup_indices.keys():
+ all_ids, segment_lens = lookup_indices['ragged_ids'], lookup_indices[
+ 'ragged_lens']
+ all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
+ cumsum_lens = math_ops.cumsum(segment_lens)
+ segment_ids = array_ops.searchsorted(
+ cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
+ elif isinstance(lookup_indices[0], sparse_tensor_lib.SparseTensor):
+ with ops.device('/cpu:0'):
+ all_ids = array_ops.concat([x.values for x in lookup_indices], axis=0)
+ segment_ids = array_ops.concat([x.indices[:, 0] for x in lookup_indices],
+ axis=0)
+ all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
+ elif 'RaggedTensor' in str(type(lookup_indices[0])):
+ with ops.device('/cpu:0'):
+ all_ids = array_ops.concat([x.values for x in lookup_indices], axis=0)
+ segment_lens = array_ops.concat([x.row_lengths() for x in lookup_indices],
+ axis=0)
+ all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
+ cumsum_lens = math_ops.cumsum(segment_lens)
+ segment_ids = array_ops.searchsorted(
+ cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
+ else:
+ assert False, 'invalid indices type: %s' % str(type(lookup_indices[0]))
+
+ num_parts = hvd.size()
+ if num_parts > 1:
+ # dynamic partition
+ p_assignments = math_ops.cast(all_uniq_ids % num_parts, dtypes.int32)
+ gather_ids = data_flow_ops.dynamic_partition(all_uniq_ids, p_assignments,
+ num_parts)
+ original_ids = math_ops.range(array_ops.size(all_uniq_ids))
+ original_part_ids = data_flow_ops.dynamic_partition(original_ids,
+ p_assignments,
+ num_parts)
+ # all2all
+ split_sizes = array_ops.concat([array_ops.shape(x) for x in gather_ids],
+ axis=0)
+ send_ids = array_ops.concat(gather_ids, axis=0)
+ recv_ids, recv_lens = hvd.alltoall(send_ids, split_sizes)
+
+ # read embedding from dynamic variable
+ if isinstance(embedding, dynamic_variable.DynamicVariable):
+ send_embed = embedding.sparse_read(
+ recv_ids, lookup_only=(not is_training))
+ else:
+ # find in subarray position
+ # 0 2 4 6 8 10 ...
+ # 1 3 5 7 9 11 ...
+ recv_ids = math_ops.cast(recv_ids / num_parts, dtypes.int64)
+ send_embed = array_ops.gather(embedding, recv_ids)
+
+ # all2all
+ recv_embeddings, _ = hvd.alltoall(send_embed, recv_lens)
+ recv_embeddings = array_ops.split(
+ recv_embeddings, num_or_size_splits=split_sizes)
+ recv_embeddings = data_flow_ops.parallel_dynamic_stitch(
+ original_part_ids, recv_embeddings, name='parallel_dynamic_stitch')
+ embeddings = math_ops.sparse_segment_sum(
+ recv_embeddings,
+ uniq_idx,
+ segment_ids,
+ num_segments=num_segments,
+ name='sparse_segment_sum')
+ else:
+ if isinstance(embedding, dynamic_variable.DynamicVariable):
+ recv_embeddings = embedding.sparse_read(
+ all_uniq_ids, lookup_only=(not is_training))
+ else:
+ recv_embeddings = array_ops.gather(embedding, all_uniq_ids)
+ embeddings = math_ops.sparse_segment_sum(
+ recv_embeddings,
+ uniq_idx,
+ segment_ids,
+ num_segments=num_segments,
+ name='sparse_segment_sum')
+
+ embed_dim = embedding.get_shape()[-1]
+ output_tensor = array_ops.reshape(embeddings, [N, -1, embed_dim])
+
+ if output_tensors is not None:
+ outputs = array_ops.split(output_tensor, num_or_size_splits=N, axis=0)
+ for output, output_id in zip(outputs, output_ids):
+ output_tensors[output_id] = array_ops.squeeze(output, axis=0)
+
+ if batch_size is None:
+ batch_size = -1
+ return array_ops.reshape(
+ array_ops.transpose(output_tensor, perm=[1, 0, 2]),
+ [batch_size, N * embed_dim])
def _internal_input_layer(features,
@@ -177,7 +365,8 @@ def _internal_input_layer(features,
scope=None,
cols_to_output_tensors=None,
from_template=False,
- feature_name_to_output_tensors=None):
+ feature_name_to_output_tensors=None,
+ is_training=True):
"""See input_layer, `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
for column in feature_columns:
@@ -195,9 +384,12 @@ def _internal_input_layer(features,
def _get_logits(): # pylint: disable=missing-docstring
builder = _LazyBuilder(features)
output_tensors = []
- ordered_columns = []
- for column in sorted(feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
+
+ tmp_cols = feature_columns
+ if embedding_utils.sort_col_by_name():
+ logging.info('will sort columns[len=%d] by name' % len(tmp_cols))
+ tmp_cols = sorted(tmp_cols, key=lambda x: x.name)
+ for column in tmp_cols:
with variable_scope.variable_scope(
None, default_name=column._var_scope_name): # pylint: disable=protected-access
tensor = column._get_dense_tensor( # pylint: disable=protected-access
@@ -217,11 +409,221 @@ def _get_logits(): # pylint: disable=missing-docstring
scope=variable_scope.get_variable_scope().name)
if cols_to_output_tensors is not None:
cols_to_output_tensors[column] = output_tensor
- if feature_name_to_output_tensors is not None and column.raw_name in feature_name_to_output_tensors:
+ if feature_name_to_output_tensors is not None:
feature_name_to_output_tensors[column.raw_name] = output_tensor
- _verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
+ assert hvd is not None, 'horovod is not installed'
+ builder = _LazyBuilder(features)
+
+ if embedding_utils.embedding_on_cpu():
+ embedding_device = '/cpu:0'
+ else:
+ embedding_device = '/gpu:0'
+
+ def _get_var_type(column):
+ if column.ev_params.use_cache:
+ return 'hybrid'
+ else:
+ return None
+
+ output_tensors = []
+ ordered_columns = []
+
+ lookup_embeddings = []
+ lookup_indices = None
+ lookup_combiners = []
+ lookup_cols = []
+ lookup_output_ids = []
+ lookup_wgts = []
+
+ dense_cols = []
+ dense_output_ids = []
+
+ shared_weights = {}
+ dense_cnt = 0
+
+ batch_sizes = []
+ for column in feature_columns:
+ ordered_columns.append(column)
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name): # pylint: disable=protected-access
+ # for features which does not require embedding
+ if 'Embedding' not in str(type(column)):
+ dense_cols.append(column)
+ dense_output_ids.append(len(output_tensors))
+ output_tensors.append(None)
+ dense_cnt += 1
+ continue
+
+ # for features require embedding
+ num_buckets = column.categorical_column.num_buckets + hvd.size() - 1
+ per_worker_buckets = num_buckets // hvd.size()
+ embedding_shape = (per_worker_buckets, column.dimension)
+ if 'SharedEmbedding' in str(type(column)):
+ shared_name = column.shared_embedding_collection_name
+ if shared_name in shared_weights:
+ embedding_weights = shared_weights[shared_name]
+ else:
+ with ops.device(embedding_device):
+ if column.ev_params is not None:
+ assert dynamic_variable is not None, 'sok is not installed'
+ embedding_weights = dynamic_variable.DynamicVariable(
+ name='embedding_weights',
+ dimension=column.dimension,
+ initializer='random {"stddev":0.0025}', # column.initializer,
+ var_type=_get_var_type(column),
+ trainable=column.trainable and trainable,
+ dtype=dtypes.float32,
+ init_capacity=column.ev_params.init_capacity,
+ max_capacity=column.ev_params.max_capacity)
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=column.initializer,
+ trainable=column.trainable and trainable,
+ partitioner=None,
+ collections=weight_collections)
+ shared_weights[shared_name] = embedding_weights
+ else:
+ with ops.device(embedding_device):
+ if column.ev_params is not None:
+ assert dynamic_variable is not None, 'sok is not installed'
+ embedding_weights = dynamic_variable.DynamicVariable(
+ name='embedding_weights',
+ dimension=column.dimension,
+ initializer='random {"stddev":0.0025}', # column.initializer,
+ var_type=_get_var_type(column),
+ trainable=column.trainable and trainable,
+ dtype=dtypes.float32,
+ init_capacity=column.ev_params.init_capacity,
+ max_capacity=column.ev_params.max_capacity)
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=column.initializer,
+ trainable=column.trainable and trainable,
+ partitioner=None,
+ collections=weight_collections)
+ lookup_embeddings.append(embedding_weights)
+ output_id = len(output_tensors)
+ output_tensors.append(None)
+ lookup_output_ids.append(output_id)
+ lookup_cols.append(column)
+ lookup_combiners.append(column.combiner)
+
+ # SparseTensor RaggedTensor
+ # features are not gathered into one, may have
+ # performance issues
+ if 'sparse_fea' in features.keys():
+ if lookup_indices is None:
+ lookup_indices = {'sparse_fea': features['sparse_fea']}
+ elif 'ragged_ids' in features.keys():
+ if lookup_indices is None:
+ lookup_indices = {
+ 'ragged_ids': features['ragged_ids'],
+ 'ragged_lens': features['ragged_lens']
+ }
+ if 'ragged_wgts' in features:
+ lookup_indices['ragged_wgts'] = features['ragged_wgts']
+ else:
+ if lookup_indices is None:
+ lookup_indices = []
+ with ops.device('/cpu:0'):
+ sparse_tensors = column.categorical_column._get_sparse_tensors(
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ lookup_indices.append(sparse_tensors.id_tensor)
+ if sparse_tensors.weight_tensor is not None:
+ lookup_wgts.append(sparse_tensors.weight_tensor)
+ if cols_to_vars is not None:
+ cols_to_vars[column] = ops.get_collection(
+ ops.GraphKeys.GLOBAL_VARIABLES,
+ scope=variable_scope.get_variable_scope().name)
+
+ if dense_cnt > 0:
+ if 'dense_fea' in features:
+ fea_dim_s = 0
+ for dense_output_id, dense_col in zip(dense_output_ids, dense_cols):
+ fea_dim_e = fea_dim_s + dense_col.shape[0]
+ output_tensors[dense_output_id] = features[
+ 'dense_fea'][:, fea_dim_s:fea_dim_e]
+ fea_dim_s = fea_dim_e
+ batch_sizes.append(array_ops.shape(features['dense_fea'])[0])
+ else:
+ for dense_output_id, dense_col in zip(dense_output_ids, dense_cols):
+ output_tensors[dense_output_id] = features[dense_col.raw_name]
+ batch_sizes.append(array_ops.shape(output_tensors[dense_output_id])[0])
+
+ for tmp_embed_var in set(lookup_embeddings):
+ ops.add_to_collection(constant.EmbeddingParallel, tmp_embed_var.name)
+
+ if len(batch_sizes) == 0:
+ batch_size = None
+ else:
+ batch_size = batch_sizes[0]
+ # do embedding parallel lookup
+ if len(lookup_output_ids) > 0:
+ packed_input = ('sparse_fea' in features or 'ragged_ids' in features)
+ if packed_input:
+ uniq_embed_cnt = len(set(lookup_embeddings))
+ assert uniq_embed_cnt == 1, 'only one uniq embed is support for packed inputs'
+ outputs = embedding_parallel_lookup(lookup_embeddings[0],
+ lookup_indices, lookup_output_ids,
+ is_training, output_tensors,
+ batch_size)
+ else:
+ if batch_size is None:
+ all_indices = []
+ for lookup_indice in lookup_indices:
+ all_indices.append(lookup_indice.indices[-1:, 0])
+ all_indices = array_ops.concat(all_indices, axis=0)
+ batch_size = math_ops.reduce_max(all_indices) + 1
+ # group lookup_embeddings
+ grouped_inputs = {}
+ for embedding, lookup_indice, output_id in zip(lookup_embeddings,
+ lookup_indices,
+ lookup_output_ids):
+ if embedding not in grouped_inputs:
+ grouped_inputs[embedding] = {
+ 'lookup_indice': [lookup_indice],
+ 'output_id': [output_id]
+ }
+ else:
+ grouped_inputs[embedding]['lookup_indice'].append(lookup_indice)
+ grouped_inputs[embedding]['output_id'].append(output_id)
+
+ for embedding in grouped_inputs:
+ lookup_indices = grouped_inputs[embedding]['lookup_indice']
+ output_ids = grouped_inputs[embedding]['output_id']
+ outputs = embedding_parallel_lookup(embedding, lookup_indices,
+ output_ids, is_training,
+ output_tensors, batch_size)
+
+ for output_tensor, col in zip(output_tensors, feature_columns):
+ if feature_name_to_output_tensors is not None:
+ feature_name_to_output_tensors[col.raw_name] = output_tensor
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[col] = output_tensor
+
+ if packed_input and dense_cnt == 0:
+ return outputs
+ else:
+ return array_ops.concat(output_tensors, axis=1)
+ else:
+ for output_tensor, col in zip(output_tensors, feature_columns):
+ if feature_name_to_output_tensors is not None:
+ feature_name_to_output_tensors[col.raw_name] = output_tensor
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[col] = output_tensor
+ return array_ops.concat(output_tensors, axis=1)
+
# If we're constructing from the `make_template`, that by default adds a
# variable scope with the name of the layer. In that case, we dont want to
# add another `variable_scope` as that would break checkpoints.
@@ -230,7 +632,12 @@ def _get_logits(): # pylint: disable=missing-docstring
else:
with variable_scope.variable_scope(
scope, default_name='input_layer', values=features.values()):
- return _get_logits()
+ if embedding_utils.is_embedding_parallel():
+ return _get_logits_embedding_parallel()
+ else:
+ with conditional(embedding_utils.embedding_on_cpu(),
+ ops.device('/cpu:0')):
+ return _get_logits()
def input_layer(features,
@@ -239,7 +646,8 @@ def input_layer(features,
trainable=True,
cols_to_vars=None,
cols_to_output_tensors=None,
- feature_name_to_output_tensors=None):
+ feature_name_to_output_tensors=None,
+ is_training=True):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
@@ -303,7 +711,8 @@ def input_layer(features,
trainable=trainable,
cols_to_vars=cols_to_vars,
cols_to_output_tensors=cols_to_output_tensors,
- feature_name_to_output_tensors=feature_name_to_output_tensors)
+ feature_name_to_output_tensors=feature_name_to_output_tensors,
+ is_training=is_training)
# TODO(akshayka): InputLayer should be a subclass of Layer, and it
@@ -2139,6 +2548,9 @@ def _get_raw_feature_as_tensor(self, key):
ValueError: if the raw feature has rank 0.
"""
raw_feature = self._features[key]
+ if 'RaggedTensor' in str(type(raw_feature)):
+ return raw_feature
+
feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
raw_feature)
@@ -2250,8 +2662,8 @@ def _normalize_feature_columns(feature_columns):
if isinstance(feature_columns, _FeatureColumn):
feature_columns = [feature_columns]
- if isinstance(feature_columns, collections.Iterator):
- feature_columns = list(feature_columns)
+ # if isinstance(feature_columns, collections.Iterator):
+ # feature_columns = list(feature_columns)
if isinstance(feature_columns, dict):
raise ValueError('Expected feature_columns to be iterable, found dict.')
@@ -2519,7 +2931,7 @@ class _SharedEmbeddingColumn(
('categorical_column', 'dimension', 'combiner', 'initializer',
'shared_embedding_collection_name', 'ckpt_to_load_from',
'tensor_name_in_ckpt', 'max_norm', 'trainable', 'partitioner',
- 'use_embedding_variable'))):
+ 'ev_params'))):
"""See `embedding_column`."""
@property
@@ -2528,6 +2940,10 @@ def name(self):
self._name = '{}_shared_embedding'.format(self.categorical_column.name)
return self._name
+ @property
+ def raw_name(self):
+ return self.categorical_column.name
+
@property
def _var_scope_name(self):
return self.shared_embedding_collection_name
@@ -2575,7 +2991,7 @@ def _get_dense_tensor_internal(self,
'hood.'.format(shared_embedding_collection))
embedding_weights = shared_embedding_collection[0]
if embedding_weights.get_shape(
- ) != embedding_shape and not self.use_embedding_variable:
+ ) != embedding_shape and not self.ev_params is not None: # noqa : E714
raise ValueError(
'Shared embedding collection {} contains variable {} of '
'unexpected shape {}. Expected shape is {}. '
@@ -2586,7 +3002,7 @@ def _get_dense_tensor_internal(self,
embedding_weights.name,
embedding_weights.get_shape(), embedding_shape))
else:
- if not self.use_embedding_variable:
+ if self.ev_params is None:
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
shape=embedding_shape,
@@ -2599,19 +3015,29 @@ def _get_dense_tensor_internal(self,
# at eval or inference time, it is necessary to set
# the initializers to zeros, so that new key will
# get zero embedding
- import os
if os.environ.get('tf.estimator.mode', '') != \
os.environ.get('tf.estimator.ModeKeys.TRAIN', 'train'):
initializer = init_ops.zeros_initializer()
else:
initializer = self.initializer
+ extra_args = {}
+ if 'EmbeddingVariableConfig' in dir(variables):
+ ev_option = variables.EmbeddingVariableOption()
+ ev_option.filter_strategy = variables.CounterFilter(
+ filter_freq=self.ev_params.filter_freq)
+ extra_args['ev_option'] = ev_option
+ else:
+ extra_args['filter_options'] = variables.CounterFilterOptions(
+ self.ev_params.filter_freq)
embedding_weights = variable_scope.get_embedding_variable(
name='embedding_weights',
embedding_dim=self.dimension,
initializer=initializer,
trainable=self.trainable and trainable,
partitioner=self.partitioner,
- collections=weight_collections)
+ collections=weight_collections,
+ steps_to_live=self.ev_params.steps_to_live,
+ **extra_args)
ops.add_to_collection(self.shared_embedding_collection_name,
embedding_weights)
@@ -2622,23 +3048,24 @@ def _get_dense_tensor_internal(self,
checkpoint_utils.init_from_checkpoint(
self.ckpt_to_load_from, {self.tensor_name_in_ckpt: to_restore})
- # Return embedding lookup result.
- if self.use_embedding_variable:
- return ev_embedding_ops.safe_embedding_lookup_sparse(
- embedding_weights=embedding_weights,
- sparse_ids=sparse_ids,
- sparse_weights=sparse_weights,
- combiner=self.combiner,
- name='%s_weights' % self.name,
- max_norm=self.max_norm)
- else:
- return embedding_ops.safe_embedding_lookup_sparse(
+ if 'RaggedTensor' in str(type(sparse_ids)):
+ assert sparse_weights is None
+ return embedding_lookup_ragged(
embedding_weights=embedding_weights,
- sparse_ids=sparse_ids,
- sparse_weights=sparse_weights,
+ ragged_ids=sparse_ids,
+ ragged_weights=sparse_weights,
combiner=self.combiner,
- name='%s_weights' % self.name,
- max_norm=self.max_norm)
+ max_norm=self.max_norm,
+ name='%s_weights' % self.name)
+
+ # Return embedding lookup result.
+ return embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights=embedding_weights,
+ sparse_ids=sparse_ids,
+ sparse_weights=sparse_weights,
+ combiner=self.combiner,
+ name='%s_weights' % self.name,
+ max_norm=self.max_norm)
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if isinstance(self.categorical_column, _SequenceCategoricalColumn):
diff --git a/easy_rec/python/compat/feature_column/feature_column_v2.py b/easy_rec/python/compat/feature_column/feature_column_v2.py
index 27b9eabdb..01b9fc93f 100644
--- a/easy_rec/python/compat/feature_column/feature_column_v2.py
+++ b/easy_rec/python/compat/feature_column/feature_column_v2.py
@@ -131,9 +131,12 @@
import abc
import collections
import math
+import os
+import sys
import numpy as np
import six
+import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -164,15 +167,23 @@
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
-from easy_rec.python.compat import embedding_ops as ev_embedding_ops
+from easy_rec.python.compat import ops as compat_ops
from easy_rec.python.compat.feature_column import feature_column as fc_old
from easy_rec.python.compat.feature_column import utils as fc_utils
+from easy_rec.python.layers import utils as layer_utils
+
+from easy_rec.python.compat.feature_column.feature_column import embedding_lookup_ragged # NOQA
_FEATURE_COLUMN_DEPRECATION_DATE = None
_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
'deprecated. Please use the new FeatureColumn '
'APIs instead.')
+if os.getenv('SAFE_EMBEDDING', 'TRUE') == 'TRUE':
+ embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
+else:
+ embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse
+
class StateManager(object):
"""Manages the state associated with FeatureColumns.
@@ -814,7 +825,7 @@ def embedding_column(categorical_column,
max_norm=None,
trainable=True,
partitioner=None,
- use_embedding_variable=False):
+ ev_params=None):
"""`DenseColumn` that converts from sparse, categorical input.
Use this when your inputs are sparse, but you want to convert them to a dense
@@ -910,7 +921,7 @@ def model_fn(features, ...):
max_norm=max_norm,
trainable=trainable,
partitioner=partitioner,
- use_embedding_variable=use_embedding_variable)
+ ev_params=ev_params)
def shared_embedding_columns(categorical_columns,
@@ -923,7 +934,7 @@ def shared_embedding_columns(categorical_columns,
max_norm=None,
trainable=True,
partitioner=None,
- use_embedding_variable=False):
+ ev_params=None):
"""List of dense columns that convert from sparse, categorical input.
This is similar to `embedding_column`, except that it produces a list of
@@ -1053,12 +1064,6 @@ def model_fn(features, ...):
if isinstance(
c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn)): # pylint: disable=protected-access
c = c.categorical_column
- if not isinstance(c, type(c0)):
- raise ValueError(
- 'To use shared_embedding_column, all categorical_columns must have '
- 'the same type, or be weighted_categorical_column of the same type. '
- 'Given column: {} of type: {} does not match given column: {} of '
- 'type: {}'.format(c0, type(c0), c, type(c)))
if num_buckets != c._num_buckets: # pylint: disable=protected-access
raise ValueError(
'To use shared_embedding_column, all categorical_columns must have '
@@ -1084,7 +1089,7 @@ def model_fn(features, ...):
max_norm=max_norm,
trainable=trainable,
partitioner=partitioner,
- use_embedding_variable=use_embedding_variable))
+ ev_params=ev_params))
return result
@@ -1258,7 +1263,8 @@ def numeric_column(key,
shape=(1,),
default_value=None,
dtype=dtypes.float32,
- normalizer_fn=None):
+ normalizer_fn=None,
+ feature_name=None):
"""Represents real valued or numerical features.
Example:
@@ -1322,7 +1328,8 @@ def numeric_column(key,
fc_utils.assert_key_is_string(key)
return NumericColumn(
- key,
+ feature_name=feature_name,
+ key=key,
shape=shape,
default_value=default_value,
dtype=dtype,
@@ -1414,7 +1421,8 @@ def bucketized_column(source_column, boundaries):
def categorical_column_with_hash_bucket(key,
hash_bucket_size,
- dtype=dtypes.string):
+ dtype=dtypes.string,
+ feature_name=None):
"""Represents sparse feature where ids are set by hashing.
Use this when your sparse features are in string or integer format, and you
@@ -1467,97 +1475,7 @@ def categorical_column_with_hash_bucket(key,
fc_utils.assert_key_is_string(key)
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
- return HashedCategoricalColumn(key, hash_bucket_size, dtype)
-
-
-def categorical_column_with_vocabulary_file(key,
- vocabulary_file,
- vocabulary_size=None,
- num_oov_buckets=0,
- default_value=None,
- dtype=dtypes.string):
- """A `CategoricalColumn` with a vocabulary file.
-
- Use this when your inputs are in string or integer format, and you have a
- vocabulary file that maps each value to an integer ID. By default,
- out-of-vocabulary values are ignored. Use either (but not both) of
- `num_oov_buckets` and `default_value` to specify how to include
- out-of-vocabulary values.
-
- For input dictionary `features`, `features[key]` is either `Tensor` or
- `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string, which will be dropped by this feature column.
-
- Example with `num_oov_buckets`:
- File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
- abbreviation. All inputs with values in that file are assigned an ID 0-49,
- corresponding to its line number. All other values are hashed and assigned an
- ID 50-54.
-
- ```python
- states = categorical_column_with_vocabulary_file(
- key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
- num_oov_buckets=5)
- columns = [states, ...]
- features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
- linear_prediction = linear_model(features, columns)
- ```
-
- Example with `default_value`:
- File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
- other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
- in input, and other values missing from the file, will be assigned ID 0. All
- others are assigned the corresponding line number 1-50.
-
- ```python
- states = categorical_column_with_vocabulary_file(
- key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
- default_value=0)
- columns = [states, ...]
- features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
- linear_prediction, _, _ = linear_model(features, columns)
- ```
-
- And to make an embedding with either:
-
- ```python
- columns = [embedding_column(states, 3),...]
- features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
- dense_tensor = input_layer(features, columns)
- ```
-
- Args:
- key: A unique string identifying the input feature. It is used as the
- column name and the dictionary key for feature parsing configs, feature
- `Tensor` objects, and feature columns.
- vocabulary_file: The vocabulary file name.
- vocabulary_size: Number of the elements in the vocabulary. This must be no
- greater than length of `vocabulary_file`, if less than length, later
- values are ignored. If None, it is set to the length of `vocabulary_file`.
- num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
- buckets. All out-of-vocabulary inputs will be assigned IDs in the range
- `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
- the input value. A positive `num_oov_buckets` can not be specified with
- `default_value`.
- default_value: The integer ID value to return for out-of-vocabulary feature
- values, defaults to `-1`. This can not be specified with a positive
- `num_oov_buckets`.
- dtype: The type of features. Only string and integer types are supported.
-
- Returns:
- A `CategoricalColumn` with a vocabulary file.
-
- Raises:
- ValueError: `vocabulary_file` is missing or cannot be opened.
- ValueError: `vocabulary_size` is missing or < 1.
- ValueError: `num_oov_buckets` is a negative integer.
- ValueError: `num_oov_buckets` and `default_value` are both specified.
- ValueError: `dtype` is neither string nor integer.
- """
- return categorical_column_with_vocabulary_file_v2(key, vocabulary_file,
- vocabulary_size, dtype,
- default_value,
- num_oov_buckets)
+ return HashedCategoricalColumn(feature_name, key, hash_bucket_size, dtype)
def categorical_column_with_vocabulary_file_v2(key,
@@ -1565,7 +1483,8 @@ def categorical_column_with_vocabulary_file_v2(key,
vocabulary_size=None,
dtype=dtypes.string,
default_value=None,
- num_oov_buckets=0):
+ num_oov_buckets=0,
+ feature_name=None):
"""A `CategoricalColumn` with a vocabulary file.
Use this when your inputs are in string or integer format, and you have a
@@ -1671,6 +1590,7 @@ def categorical_column_with_vocabulary_file_v2(key,
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
fc_utils.assert_key_is_string(key)
return VocabularyFileCategoricalColumn(
+ feature_name=feature_name,
key=key,
vocabulary_file=vocabulary_file,
vocabulary_size=vocabulary_size,
@@ -1683,7 +1603,8 @@ def categorical_column_with_vocabulary_list(key,
vocabulary_list,
dtype=None,
default_value=-1,
- num_oov_buckets=0):
+ num_oov_buckets=0,
+ feature_name=None):
"""A `CategoricalColumn` with in-memory vocabulary.
Use this when your inputs are in string or integer format, and you have an
@@ -1788,6 +1709,7 @@ def categorical_column_with_vocabulary_list(key,
fc_utils.assert_key_is_string(key)
return VocabularyListCategoricalColumn(
+ feature_name=feature_name,
key=key,
vocabulary_list=tuple(vocabulary_list),
dtype=dtype,
@@ -1795,7 +1717,10 @@ def categorical_column_with_vocabulary_list(key,
num_oov_buckets=num_oov_buckets)
-def categorical_column_with_identity(key, num_buckets, default_value=None):
+def categorical_column_with_identity(key,
+ num_buckets,
+ default_value=None,
+ feature_name=None):
"""A `CategoricalColumn` that returns identity values.
Use this when your inputs are integers in the range `[0, num_buckets)`, and
@@ -1859,7 +1784,10 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
default_value, num_buckets, key))
fc_utils.assert_key_is_string(key)
return IdentityCategoricalColumn(
- key=key, number_buckets=num_buckets, default_value=default_value)
+ feature_name=feature_name,
+ key=key,
+ number_buckets=num_buckets,
+ default_value=default_value)
def indicator_column(categorical_column):
@@ -1971,7 +1899,7 @@ def weighted_categorical_column(categorical_column,
dtype=dtype)
-def crossed_column(keys, hash_bucket_size, hash_key=None):
+def crossed_column(keys, hash_bucket_size, hash_key=None, feature_name=None):
"""Returns a column for performing crosses of categorical features.
Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
@@ -2095,7 +2023,10 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
'Hashing before crossing will increase probability of collision. '
'Instead, use the feature name as a string. Given: {}'.format(key))
return CrossedColumn(
- keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
+ feature_name=feature_name,
+ keys=tuple(keys),
+ hash_bucket_size=hash_bucket_size,
+ hash_key=hash_key)
@six.add_metaclass(abc.ABCMeta)
@@ -2440,7 +2371,7 @@ def _create_categorical_column_weighted_sum(column, transformation_cache,
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- return embedding_ops.safe_embedding_lookup_sparse(
+ return embedding_lookup_sparse(
weight_var,
id_tensor,
sparse_weights=weight_tensor,
@@ -2621,6 +2552,9 @@ def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
Raises:
ValueError: when `input_tensor`'s rank is `None`.
"""
+ if 'RaggedTensor' in str(type(input_tensor)):
+ return input_tensor
+
input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
input_tensor)
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
@@ -2699,9 +2633,9 @@ def _normalize_feature_columns(feature_columns):
class NumericColumn(
DenseColumn,
fc_old._DenseColumn, # pylint: disable=protected-access
- collections.namedtuple(
- 'NumericColumn',
- ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
+ collections.namedtuple('NumericColumn',
+ ('feature_name', 'key', 'shape', 'default_value',
+ 'dtype', 'normalizer_fn'))):
"""see `numeric_column`."""
@property
@@ -2711,7 +2645,7 @@ def _is_v2_column(self):
@property
def name(self):
"""See `FeatureColumn` base class."""
- return self.key
+ return self.feature_name if self.feature_name else self.key
@property
def raw_name(self):
@@ -2982,57 +2916,66 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
return cls(**kwargs)
-class EmbeddingColumn(
+class SequenceBucketizedColumn(
DenseColumn,
- SequenceDenseColumn,
+ CategoricalColumn,
fc_old._DenseColumn, # pylint: disable=protected-access
- fc_old._SequenceDenseColumn, # pylint: disable=protected-access
- collections.namedtuple(
- 'EmbeddingColumn',
- ('categorical_column', 'dimension', 'combiner', 'initializer',
- 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
- 'partitioner', 'use_embedding_variable'))):
- """See `embedding_column`."""
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('SequenceBucketizedColumn',
+ ('source_column', 'boundaries'))):
+ """See `bucketized_column`."""
@property
def _is_v2_column(self):
- return (isinstance(self.categorical_column, FeatureColumn) and
- self.categorical_column._is_v2_column) # pylint: disable=protected-access
+ return (isinstance(self.source_column, FeatureColumn) and
+ self.source_column._is_v2_column) # pylint: disable=protected-access
@property
def name(self):
"""See `FeatureColumn` base class."""
- return '{}_embedding'.format(self.categorical_column.name)
+ return '{}_bucketized'.format(self.source_column.name)
@property
def raw_name(self):
"""See `FeatureColumn` base class."""
- return self.categorical_column.raw_name
+ return self.source_column.raw_name
@property
def parse_example_spec(self):
"""See `FeatureColumn` base class."""
- return self.categorical_column.parse_example_spec
+ return self.source_column.parse_example_spec
@property
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
_FEATURE_COLUMN_DEPRECATION)
def _parse_example_spec(self):
- return self.categorical_column._parse_example_spec # pylint: disable=protected-access
-
- def transform_feature(self, transformation_cache, state_manager):
- """Transforms underlying `categorical_column`."""
- return transformation_cache.get(self.categorical_column, state_manager)
+ return self.source_column._parse_example_spec # pylint: disable=protected-access
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
_FEATURE_COLUMN_DEPRECATION)
def _transform_feature(self, inputs):
- return inputs.get(self.categorical_column)
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = inputs.get(self.source_column)
+ bucketize_values = math_ops._bucketize(
+ source_tensor.values, boundaries=self.boundaries)
+ bucketize_tensor = sparse_tensor_lib.SparseTensor(
+ indices=source_tensor.indices,
+ values=bucketize_values,
+ dense_shape=source_tensor.dense_shape)
+ return bucketize_tensor
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = transformation_cache.get(self.source_column, state_manager)
+ return math_ops._bucketize( # pylint: disable=protected-access
+ source_tensor,
+ boundaries=self.boundaries)
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
- return tensor_shape.TensorShape([self.dimension])
+ return tensor_shape.TensorShape(
+ tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
@property
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
@@ -3040,102 +2983,576 @@ def variable_shape(self):
def _variable_shape(self):
return self.variable_shape
- def create_state(self, state_manager):
- """Creates the embedding lookup variable."""
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- state_manager.create_variable(
- self,
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- trainable=self.trainable,
- use_resource=True,
- initializer=self.initializer)
+ def _get_dense_tensor_for_input_tensor(self, input_tensor):
+ return array_ops.one_hot(
+ indices=math_ops.cast(input_tensor, dtypes.int64),
+ depth=len(self.boundaries) + 1,
+ on_value=1.,
+ off_value=0.)
- def _get_dense_tensor_internal_helper(self, sparse_tensors,
- embedding_weights):
- sparse_ids = sparse_tensors.id_tensor
- sparse_weights = sparse_tensors.weight_tensor
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns one hot encoded dense `Tensor`."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
- if self.ckpt_to_load_from is not None:
- to_restore = embedding_weights
- if isinstance(to_restore, variables.PartitionedVariable):
- to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
- checkpoint_utils.init_from_checkpoint(
- self.ckpt_to_load_from, {self.tensor_name_in_ckpt: to_restore})
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
- # Return embedding lookup result.
- if not self.use_embedding_variable:
- return embedding_ops.safe_embedding_lookup_sparse(
- embedding_weights=embedding_weights,
- sparse_ids=sparse_ids,
- sparse_weights=sparse_weights,
- combiner=self.combiner,
- name='%s_weights' % self.name,
- max_norm=self.max_norm)
- else:
- return ev_embedding_ops.safe_embedding_lookup_sparse(
- embedding_weights,
- sparse_ids,
- sparse_weights,
- combiner=self.combiner,
- name='%s_weights' % self.name,
- max_norm=self.max_norm)
+ @property
+ def num_buckets(self):
+ """See `CategoricalColumn` base class."""
+ # By construction, source_column is always one-dimensional.
+ return (len(self.boundaries) + 1) * self.source_column.shape[0]
- def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
- """Private method that follows the signature of get_dense_tensor."""
- embedding_weights = state_manager.get_variable(
- self, name='embedding_weights')
- return self._get_dense_tensor_internal_helper(sparse_tensors,
- embedding_weights)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
- def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
- trainable):
- """Private method that follows the signature of _get_dense_tensor."""
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- if (weight_collections and
- ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
- weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if not self.use_embedding_variable:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable and trainable,
- partitioner=self.partitioner,
- collections=weight_collections)
- else:
- # at eval or inference time, it is necessary to set
- # the initializers to zeros, so that new key will
- # get zero embedding
- import os
- if os.environ.get('tf.estimator.mode', '') != \
- os.environ.get('tf.estimator.ModeKeys.TRAIN', 'train'):
- initializer = init_ops.zeros_initializer()
- else:
- initializer = self.initializer
- embedding_weights = variable_scope.get_embedding_variable(
- name='embedding_weights',
- embedding_dim=self.dimension,
- initializer=initializer,
- trainable=self.trainable and trainable,
- partitioner=self.partitioner,
- collections=weight_collections)
- return self._get_dense_tensor_internal_helper(sparse_tensors,
- embedding_weights)
+ def _get_sparse_tensors_for_input_tensor(self, input_sparse_tensor):
+ input_tensor = input_sparse_tensor.values
+ input_indices = input_sparse_tensor.indices
+ batch_size = array_ops.shape(input_tensor)[0]
+ # By construction, source_column is always one-dimensional.
+ source_dimension = self.source_column.shape[0]
- def get_dense_tensor(self, transformation_cache, state_manager):
- """Returns tensor after doing the embedding lookup.
+ i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
+ # Flatten the bucket indices and unique them across dimensions
+ # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
+ bucket_indices = (
+ array_ops.reshape(input_tensor,
+ (-1,)) + (len(self.boundaries) + 1) * i2)
- Args:
- transformation_cache: A `FeatureTransformationCache` object to access
- features.
- state_manager: A `StateManager` to create / access resources such as
- lookup tables.
+ sparse_tensor = sparse_tensor_lib.SparseTensor(
+ indices=input_indices,
+ values=bucket_indices,
+ dense_shape=input_sparse_tensor.dense_shape)
+ # Compute the third dimension explicitly instead of setting it to -1, as
+ # that doesn't work for dynamically shaped tensors with 0-length at runtime.
+ # This happens for empty sequences.
+ shape = array_ops.shape(sparse_tensor)
+ target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
+ ret_seq_tensor = sparse_ops.sparse_reshape(sparse_tensor, target_shape)
+ return CategoricalColumn.IdWeightPair(ret_seq_tensor, None)
- Returns:
- Embedding lookup tensor.
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
+ @property
+ def parents(self):
+ """See 'FeatureColumn` base class."""
+ return [self.source_column]
+
+ def _get_config(self):
+ """See 'FeatureColumn` base class."""
+ config = dict(zip(self._fields, self))
+ config['source_column'] = serialize_feature_column(self.source_column)
+ return config
+
+ @classmethod
+ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
+ """See 'FeatureColumn` base class."""
+ _check_config_keys(config, cls._fields)
+ kwargs = config.copy()
+ kwargs['source_column'] = deserialize_feature_column(
+ config['source_column'], custom_objects, columns_by_name)
+ return cls(**kwargs)
+
+
+class SequenceNumericColumn(
+ DenseColumn,
+ CategoricalColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('SequenceNumericColumn',
+ ('source_column', 'sequence_length'))):
+ """See `SequenceNumericColumn`."""
+
+ @property
+ def _is_v2_column(self):
+ return (isinstance(self.source_column, FeatureColumn) and
+ self.source_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_bucketized'.format(self.source_column.name)
+
+ @property
+ def raw_name(self):
+ """See `FeatureColumn` base class."""
+ return self.source_column.raw_name
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.source_column.parse_example_spec
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.source_column._parse_example_spec # pylint: disable=protected-access
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = inputs.get(self.source_column)
+ return source_tensor
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = transformation_cache.get(self.source_column, state_manager)
+ return source_tensor
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.TensorShape(
+ tuple(self.source_column.shape) + (self.sequence_length,))
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
+ def _get_dense_tensor_for_input_tensor(self, input_tensor):
+ return array_ops.one_hot(
+ indices=math_ops.cast(input_tensor, dtypes.int64),
+ depth=self.sequence_length,
+ on_value=1.,
+ off_value=0.)
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns one hot encoded dense `Tensor`."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
+ def _get_sequence_dense_tensor(self, inputs):
+ input_tensor = inputs.get(self)
+ sparse_tensors = self._get_sparse_tensors_for_input_tensor(
+ input_tensor).id_tensor
+ sequence_length = fc_utils.sequence_length_from_sparse_tensor(
+ sparse_tensors)
+ sequence_length = tf.cast(sequence_length, tf.int32)
+ shape = array_ops.shape(sparse_tensors)
+ target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
+ ret_tensor = tf.sparse_to_dense(sparse_tensors.indices, target_shape,
+ sparse_tensors.values)
+ return CategoricalColumn.IdWeightPair(ret_tensor, sequence_length)
+
+ @property
+ def num_buckets(self):
+ """See `CategoricalColumn` base class."""
+ # By construction, source_column is always one-dimensional.
+ return self.sequence_length * self.source_column.shape[0]
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
+ def _get_sparse_tensors_for_input_tensor(self, sparse_tensor):
+ # Compute the third dimension explicitly instead of setting it to -1, as
+ # that doesn't work for dynamically shaped tensors with 0-length at runtime.
+ # This happens for empty sequences.
+ shape = array_ops.shape(sparse_tensor)
+ target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
+ ret_seq_tensor = sparse_ops.sparse_reshape(sparse_tensor, target_shape)
+ return CategoricalColumn.IdWeightPair(ret_seq_tensor, None)
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
+ @property
+ def parents(self):
+ """See 'FeatureColumn` base class."""
+ return [self.source_column]
+
+ def _get_config(self):
+ """See 'FeatureColumn` base class."""
+ config = dict(zip(self._fields, self))
+ config['source_column'] = serialize_feature_column(self.source_column)
+ return config
+
+ @classmethod
+ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
+ """See 'FeatureColumn` base class."""
+ _check_config_keys(config, cls._fields)
+ kwargs = config.copy()
+ kwargs['source_column'] = deserialize_feature_column(
+ config['source_column'], custom_objects, columns_by_name)
+ return cls(**kwargs)
+
+
+class SequenceWeightedCategoricalColumn(
+ CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple(
+ 'SequenceWeightedCategoricalColumn',
+ ('categorical_column', 'weight_feature_key', 'dtype'))):
+ """See `weighted_categorical_column`."""
+
+ @property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_weighted_by_{}'.format(self.categorical_column.name,
+ self.weight_feature_key)
+
+ @property
+ def raw_name(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.raw_name
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ config = self.categorical_column.parse_example_spec
+ if self.weight_feature_key in config:
+ raise ValueError('Parse config {} already exists for {}.'.format(
+ config[self.weight_feature_key], self.weight_feature_key))
+ config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
+ return config
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ config = self.categorical_column._parse_example_spec # pylint: disable=protected-access
+ if self.weight_feature_key in config:
+ raise ValueError('Parse config {} already exists for {}.'.format(
+ config[self.weight_feature_key], self.weight_feature_key))
+ config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
+ return config
+
+ @property
+ def num_buckets(self):
+ """See `DenseColumn` base class."""
+ return self.categorical_column.num_buckets
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _transform_weight_tensor(self, weight_tensor):
+ if weight_tensor is None:
+ raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
+ weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ weight_tensor)
+ if self.dtype != weight_tensor.dtype.base_dtype:
+ raise ValueError('Bad dtype, expected {}, but got {}.'.format(
+ self.dtype, weight_tensor.dtype))
+ if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
+ # The weight tensor can be a regular Tensor. In this case, sparsify it.
+ weight_tensor = _to_sparse_input_and_drop_ignore_values(
+ weight_tensor, ignore_value=0.0)
+ if not weight_tensor.dtype.is_floating:
+ weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
+ shape = tf.shape(weight_tensor)
+ target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
+ weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
+ return weight_tensor
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = transformation_cache.get(self.weight_feature_key,
+ state_manager)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
+ return (transformation_cache.get(self.categorical_column,
+ state_manager), weight_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = inputs.get(self.weight_feature_key)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
+ return (inputs.get(self.categorical_column), weight_tensor)
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ tensors = transformation_cache.get(self, state_manager)
+ return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ tensors = inputs.get(self)
+ return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+
+ @property
+ def parents(self):
+ """See 'FeatureColumn` base class."""
+ return [self.categorical_column, self.weight_feature_key]
+
+ def _get_config(self):
+ """See 'FeatureColumn` base class."""
+ config = dict(zip(self._fields, self))
+ config['categorical_column'] = serialize_feature_column(
+ self.categorical_column)
+ config['dtype'] = self.dtype.name
+ return config
+
+ @classmethod
+ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
+ """See 'FeatureColumn` base class."""
+ _check_config_keys(config, cls._fields)
+ kwargs = config.copy()
+ kwargs['categorical_column'] = deserialize_feature_column(
+ config['categorical_column'], custom_objects, columns_by_name)
+ kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
+ return cls(**kwargs)
+
+
+class EmbeddingColumn(
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
+ collections.namedtuple(
+ 'EmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'initializer',
+ 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
+ 'partitioner', 'ev_params'))):
+ """See `embedding_column`."""
+
+ @property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_embedding'.format(self.categorical_column.name)
+
+ @property
+ def raw_name(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.raw_name
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Transforms underlying `categorical_column`."""
+ return transformation_cache.get(self.categorical_column, state_manager)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ return inputs.get(self.categorical_column)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.TensorShape([self.dimension])
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
+ def create_state(self, state_manager):
+ """Creates the embedding lookup variable."""
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ state_manager.create_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ trainable=self.trainable,
+ use_resource=True,
+ initializer=self.initializer)
+
+ def _get_dense_tensor_internal_helper(self, sparse_tensors,
+ embedding_weights):
+ sparse_ids = sparse_tensors.id_tensor
+ sparse_weights = sparse_tensors.weight_tensor
+
+ if self.ckpt_to_load_from is not None:
+ to_restore = embedding_weights
+ if isinstance(to_restore, variables.PartitionedVariable):
+ to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
+ checkpoint_utils.init_from_checkpoint(
+ self.ckpt_to_load_from, {self.tensor_name_in_ckpt: to_restore})
+
+ if 'RaggedTensor' in str(type(sparse_ids)):
+ return embedding_lookup_ragged(
+ embedding_weights,
+ sparse_ids,
+ sparse_weights,
+ combiner=self.combiner,
+ max_norm=self.max_norm,
+ name='%s_weights' % self.name)
+
+ # Return embedding lookup result.
+ return embedding_lookup_sparse(
+ embedding_weights,
+ sparse_ids,
+ sparse_weights,
+ combiner=self.combiner,
+ name='%s_weights' % self.name,
+ max_norm=self.max_norm)
+
+ def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
+ """Private method that follows the signature of get_dense_tensor."""
+ embedding_weights = state_manager.get_variable(
+ self, name='embedding_weights')
+ return self._get_dense_tensor_internal_helper(sparse_tensors,
+ embedding_weights)
+
+ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
+ trainable):
+ """Private method that follows the signature of _get_dense_tensor."""
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if (weight_collections and
+ ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
+ weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ if self.ev_params is None:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32, # bfloat16,
+ initializer=self.initializer,
+ trainable=self.trainable and trainable,
+ partitioner=self.partitioner,
+ collections=weight_collections)
+ else:
+ # at eval or inference time, it is necessary to set
+ # the initializers to zeros, so that new key will
+ # get zero embedding
+ if os.environ.get('tf.estimator.mode', '') != \
+ os.environ.get('tf.estimator.ModeKeys.TRAIN', 'train'):
+ initializer = init_ops.zeros_initializer()
+ else:
+ initializer = self.initializer
+ extra_args = {}
+ if 'EmbeddingVariableConfig' in dir(variables):
+ ev_option = variables.EmbeddingVariableOption()
+ ev_option.filter_strategy = variables.CounterFilter(
+ filter_freq=self.ev_params.filter_freq)
+ extra_args['ev_option'] = ev_option
+ else:
+ extra_args['filter_options'] = variables.CounterFilterOptions(
+ self.ev_params.filter_freq)
+ embedding_weights = variable_scope.get_embedding_variable(
+ name='embedding_weights',
+ embedding_dim=self.dimension,
+ initializer=initializer,
+ trainable=self.trainable and trainable,
+ partitioner=self.partitioner,
+ collections=weight_collections,
+ steps_to_live=self.ev_params.steps_to_live,
+ **extra_args)
+
+ # Write the embedding configuration to RTP-specified collections. This will inform RTP to
+ # optimize this embedding operation.
+ embedding_attrs = layer_utils.gen_embedding_attrs(
+ column=self,
+ variable=embedding_weights,
+ bucket_size=self.categorical_column._num_buckets,
+ combiner=self.combiner,
+ is_embedding_var=(self.ev_params is not None))
+ embedding_attrs['name'] = layer_utils.unique_name_in_collection(
+ compat_ops.GraphKeys.RANK_SERVICE_EMBEDDING, embedding_attrs['name'])
+ layer_utils.update_attr_to_collection(
+ compat_ops.GraphKeys.RANK_SERVICE_EMBEDDING, embedding_attrs)
+
+ # operate embedding
+ predictions = self._get_dense_tensor_internal_helper(
+ sparse_tensors, embedding_weights)
+
+ # Update the information about the output and input nodes of embedding operation to the
+ # previous written RTP-specific collection entry. RTP uses these informations to extract
+ # the embedding subgraph.
+ if isinstance(sparse_tensors.id_tensor, sparse_tensor_lib.SparseTensor):
+ layer_utils.append_tensor_to_collection(
+ compat_ops.GraphKeys.RANK_SERVICE_EMBEDDING, embedding_attrs['name'],
+ 'tensor', predictions)
+ layer_utils.append_tensor_to_collection(
+ compat_ops.GraphKeys.RANK_SERVICE_EMBEDDING, embedding_attrs['name'],
+ 'input', sparse_tensors.id_tensor)
+
+ return predictions
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns tensor after doing the embedding lookup.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Embedding lookup tensor.
Raises:
ValueError: `categorical_column` is SequenceCategoricalColumn.
@@ -3202,7 +3619,9 @@ def _get_sequence_dense_tensor(self,
trainable=None):
if not isinstance(
self.categorical_column,
- (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn,
+ SequenceBucketizedColumn, SequenceNumericColumn,
+ SequenceWeightedCategoricalColumn)): # pylint: disable=protected-access
raise ValueError(
'In embedding_column: {}. '
'categorical_column must be of type SequenceCategoricalColumn '
@@ -3367,7 +3786,7 @@ def _get_dense_tensor_internal(self, transformation_cache, state_manager):
embedding_weights = self.shared_embedding_column_creator.embedding_weights
# Return embedding lookup result.
- return embedding_ops.safe_embedding_lookup_sparse(
+ return embedding_lookup_sparse(
embedding_weights=embedding_weights,
sparse_ids=sparse_ids,
sparse_weights=sparse_weights,
@@ -3452,7 +3871,8 @@ class HashedCategoricalColumn(
CategoricalColumn,
fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('HashedCategoricalColumn',
- ('key', 'hash_bucket_size', 'dtype'))):
+ ('feature_name', 'key', 'hash_bucket_size', 'dtype'))
+):
"""see `categorical_column_with_hash_bucket`."""
@property
@@ -3462,7 +3882,7 @@ def _is_v2_column(self):
@property
def name(self):
"""See `FeatureColumn` base class."""
- return self.key
+ return self.feature_name if self.feature_name else self.key
@property
def raw_name(self):
@@ -3482,9 +3902,6 @@ def _parse_example_spec(self):
def _transform_input_tensor(self, input_tensor):
"""Hashes the values in the feature_column."""
- if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
- raise ValueError('SparseColumn input must be a SparseTensor.')
-
fc_utils.assert_string_or_int(
input_tensor.dtype,
prefix='column_name: {} input_tensor'.format(self.key))
@@ -3495,13 +3912,19 @@ def _transform_input_tensor(self, input_tensor):
'key: {}, column dtype: {}, tensor dtype: {}'.format(
self.key, self.dtype, input_tensor.dtype))
- if self.dtype == dtypes.string:
+ if input_tensor.dtype == dtypes.string:
sparse_values = input_tensor.values
else:
sparse_values = string_ops.as_string(input_tensor.values)
sparse_id_values = string_ops.string_to_hash_bucket_fast(
sparse_values, self.hash_bucket_size, name='lookup')
+
+ if 'RaggedTensor' in str(type(input_tensor)):
+ from tensorflow.python.ops.ragged import ragged_tensor
+ return ragged_tensor.RaggedTensor.from_row_splits(
+ values=sparse_id_values, row_splits=input_tensor.row_splits)
+
return sparse_tensor_lib.SparseTensor(input_tensor.indices,
sparse_id_values,
input_tensor.dense_shape)
@@ -3567,9 +3990,10 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
class VocabularyFileCategoricalColumn(
CategoricalColumn,
fc_old._CategoricalColumn, # pylint: disable=protected-access
- collections.namedtuple('VocabularyFileCategoricalColumn',
- ('key', 'vocabulary_file', 'vocabulary_size',
- 'num_oov_buckets', 'dtype', 'default_value'))):
+ collections.namedtuple(
+ 'VocabularyFileCategoricalColumn',
+ ('feature_name', 'key', 'vocabulary_file', 'vocabulary_size',
+ 'num_oov_buckets', 'dtype', 'default_value'))):
"""See `categorical_column_with_vocabulary_file`."""
@property
@@ -3579,7 +4003,7 @@ def _is_v2_column(self):
@property
def name(self):
"""See `FeatureColumn` base class."""
- return self.key
+ return self.feature_name if self.feature_name else self.key
@property
def raw_name(self):
@@ -3685,10 +4109,9 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
class VocabularyListCategoricalColumn(
CategoricalColumn,
fc_old._CategoricalColumn, # pylint: disable=protected-access
- collections.namedtuple(
- 'VocabularyListCategoricalColumn',
- ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
-):
+ collections.namedtuple('VocabularyListCategoricalColumn',
+ ('feature_name', 'key', 'vocabulary_list', 'dtype',
+ 'default_value', 'num_oov_buckets'))):
"""See `categorical_column_with_vocabulary_list`."""
@property
@@ -3698,7 +4121,7 @@ def _is_v2_column(self):
@property
def name(self):
"""See `FeatureColumn` base class."""
- return self.key
+ return self.feature_name if self.feature_name else self.key
@property
def raw_name(self):
@@ -3803,8 +4226,9 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
class IdentityCategoricalColumn(
CategoricalColumn,
fc_old._CategoricalColumn, # pylint: disable=protected-access
- collections.namedtuple('IdentityCategoricalColumn',
- ('key', 'number_buckets', 'default_value'))):
+ collections.namedtuple(
+ 'IdentityCategoricalColumn',
+ ('feature_name', 'key', 'number_buckets', 'default_value'))):
"""See `categorical_column_with_identity`."""
@property
@@ -3814,7 +4238,7 @@ def _is_v2_column(self):
@property
def name(self):
"""See `FeatureColumn` base class."""
- return self.key
+ return self.feature_name if self.feature_name else self.key
@property
def raw_name(self):
@@ -3838,30 +4262,34 @@ def _transform_input_tensor(self, input_tensor):
raise ValueError('Invalid input, not integer. key: {} dtype: {}'.format(
self.key, input_tensor.dtype))
+ if 'RaggedTensor' in str(type(input_tensor)):
+ return input_tensor
+
values = math_ops.cast(input_tensor.values, dtypes.int64, name='values')
- num_buckets = math_ops.cast(
- self.num_buckets, dtypes.int64, name='num_buckets')
- zero = math_ops.cast(0, dtypes.int64, name='zero')
- if self.default_value is None:
- # Fail if values are out-of-range.
- assert_less = check_ops.assert_less(
- values,
- num_buckets,
- data=(values, num_buckets),
- name='assert_less_than_num_buckets')
- assert_greater = check_ops.assert_greater_equal(
- values, zero, data=(values,), name='assert_greater_or_equal_0')
- with ops.control_dependencies((assert_less, assert_greater)):
- values = array_ops.identity(values)
- else:
- # Assign default for out-of-range values.
- values = array_ops.where(
- math_ops.logical_or(
- values < zero, values >= num_buckets, name='out_of_range'),
- array_ops.fill(
- dims=array_ops.shape(values),
- value=math_ops.cast(self.default_value, dtypes.int64),
- name='default_values'), values)
+ if self.num_buckets < sys.maxsize:
+ num_buckets = math_ops.cast(
+ self.num_buckets, dtypes.int64, name='num_buckets')
+ zero = math_ops.cast(0, dtypes.int64, name='zero')
+ if self.default_value is None:
+ # Fail if values are out-of-range.
+ assert_less = check_ops.assert_less(
+ values,
+ num_buckets,
+ data=(values, num_buckets),
+ name='assert_less_than_num_buckets')
+ assert_greater = check_ops.assert_greater_equal(
+ values, zero, data=(values,), name='assert_greater_or_equal_0')
+ with ops.control_dependencies((assert_less, assert_greater)):
+ values = array_ops.identity(values)
+ else:
+ # Assign default for out-of-range values.
+ values = array_ops.where(
+ math_ops.logical_or(
+ values < zero, values >= num_buckets, name='out_of_range'),
+ array_ops.fill(
+ dims=array_ops.shape(values),
+ value=math_ops.cast(self.default_value, dtypes.int64),
+ name='default_values'), values)
return sparse_tensor_lib.SparseTensor(
indices=input_tensor.indices,
@@ -4053,8 +4481,9 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
class CrossedColumn(
CategoricalColumn,
fc_old._CategoricalColumn, # pylint: disable=protected-access
- collections.namedtuple('CrossedColumn',
- ('keys', 'hash_bucket_size', 'hash_key'))):
+ collections.namedtuple(
+ 'CrossedColumn',
+ ('feature_name', 'keys', 'hash_bucket_size', 'hash_key'))):
"""See `crossed_column`."""
@property
@@ -4071,6 +4500,8 @@ def _is_v2_column(self):
@property
def name(self):
"""See `FeatureColumn` base class."""
+ if self.feature_name:
+ return self.feature_name
feature_names = []
for key in _collect_leaf_level_keys(self):
if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access
@@ -4527,13 +4958,21 @@ def _parse_example_spec(self):
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class."""
- return self.categorical_column.transform_feature(transformation_cache,
- state_manager)
+ ret_tensor = self.categorical_column.transform_feature(
+ transformation_cache, state_manager)
+ shape = array_ops.shape(ret_tensor)
+ target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
+ ret_tensor = sparse_ops.sparse_reshape(ret_tensor, target_shape)
+ return ret_tensor
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
_FEATURE_COLUMN_DEPRECATION)
def _transform_feature(self, inputs):
- return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access
+ ret_tensor = self.categorical_column._transform_feature(inputs)
+ shape = array_ops.shape(ret_tensor)
+ target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
+ ret_tensor = sparse_ops.sparse_reshape(ret_tensor, target_shape)
+ return ret_tensor
@property
def num_buckets(self):
@@ -4779,3 +5218,13 @@ def deserialize_feature_columns(configs, custom_objects=None):
deserialize_feature_column(c, custom_objects, columns_by_name)
for c in configs
]
+
+
+def is_embedding_column(fc):
+ if isinstance(fc, EmbeddingColumn):
+ return True
+ if isinstance(fc, fc_old._SharedEmbeddingColumn):
+ return True
+ if isinstance(fc, SharedEmbeddingColumn):
+ return True
+ return False
diff --git a/easy_rec/python/compat/feature_column/sequence_feature_column.py b/easy_rec/python/compat/feature_column/sequence_feature_column.py
index 10382bbb6..b0fcdc9f7 100644
--- a/easy_rec/python/compat/feature_column/sequence_feature_column.py
+++ b/easy_rec/python/compat/feature_column/sequence_feature_column.py
@@ -29,9 +29,11 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import sparse_ops
+from easy_rec.python.compat.feature_column import feature_column as fc_v1
from easy_rec.python.compat.feature_column import feature_column_v2 as fc
from easy_rec.python.compat.feature_column import utils as fc_utils
@@ -191,7 +193,8 @@ def concatenate_context_input(context_input, sequence_input):
def sequence_categorical_column_with_identity(key,
num_buckets,
- default_value=None):
+ default_value=None,
+ feature_name=None):
"""Returns a feature column that represents sequences of integers.
Pass this to `embedding_column` or `indicator_column` to convert sequence
@@ -233,12 +236,57 @@ def sequence_categorical_column_with_identity(key,
"""
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_identity(
- key=key, num_buckets=num_buckets, default_value=default_value))
+ feature_name=feature_name,
+ key=key,
+ num_buckets=num_buckets,
+ default_value=default_value))
+
+
+def sequence_numeric_column_with_bucketized_column(source_column, boundaries):
+ if not isinstance(source_column, (SequenceNumericColumn,)): # pylint: disable=protected-access
+ raise ValueError(
+ 'source_column must be a column generated with sequence_numeric_column(). '
+ 'Given: {}'.format(source_column))
+ if len(source_column.shape) > 1:
+ raise ValueError('source_column must be one-dimensional column. '
+ 'Given: {}'.format(source_column))
+ if not boundaries:
+ raise ValueError('boundaries must not be empty.')
+ if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
+ raise ValueError('boundaries must be a sorted list.')
+ for i in range(len(boundaries) - 1):
+ if boundaries[i] >= boundaries[i + 1]:
+ raise ValueError('boundaries must be a sorted list.')
+ return fc.SequenceBucketizedColumn(source_column, tuple(boundaries))
+
+
+def sequence_numeric_column_with_raw_column(source_column, sequence_length):
+ if not isinstance(source_column, (SequenceNumericColumn,)): # pylint: disable=protected-access
+ raise ValueError(
+ 'source_column must be a column generated with sequence_numeric_column(). '
+ 'Given: {}'.format(source_column))
+ if len(source_column.shape) > 1:
+ raise ValueError('source_column must be one-dimensional column. '
+ 'Given: {}'.format(source_column))
+
+ return fc.SequenceNumericColumn(source_column, sequence_length)
+
+
+def sequence_weighted_categorical_column(categorical_column,
+ weight_feature_key,
+ dtype=dtypes.float32):
+ if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
+ raise ValueError('dtype {} is not convertible to float.'.format(dtype))
+ return fc.SequenceWeightedCategoricalColumn(
+ categorical_column=categorical_column,
+ weight_feature_key=weight_feature_key,
+ dtype=dtype)
def sequence_categorical_column_with_hash_bucket(key,
hash_bucket_size,
- dtype=dtypes.string):
+ dtype=dtypes.string,
+ feature_name=None):
"""A sequence of categorical terms where ids are set by hashing.
Pass this to `embedding_column` or `indicator_column` to convert sequence
@@ -277,7 +325,10 @@ def sequence_categorical_column_with_hash_bucket(key,
"""
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_hash_bucket(
- key=key, hash_bucket_size=hash_bucket_size, dtype=dtype))
+ feature_name=feature_name,
+ key=key,
+ hash_bucket_size=hash_bucket_size,
+ dtype=dtype))
def sequence_categorical_column_with_vocabulary_file(key,
@@ -285,7 +336,8 @@ def sequence_categorical_column_with_vocabulary_file(key,
vocabulary_size=None,
num_oov_buckets=0,
default_value=None,
- dtype=dtypes.string):
+ dtype=dtypes.string,
+ feature_name=None):
"""A sequence of categorical terms where ids use a vocabulary file.
Pass this to `embedding_column` or `indicator_column` to convert sequence
@@ -339,6 +391,7 @@ def sequence_categorical_column_with_vocabulary_file(key,
"""
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_file(
+ feature_name=feature_name,
key=key,
vocabulary_file=vocabulary_file,
vocabulary_size=vocabulary_size,
@@ -351,7 +404,8 @@ def sequence_categorical_column_with_vocabulary_list(key,
vocabulary_list,
dtype=None,
default_value=-1,
- num_oov_buckets=0):
+ num_oov_buckets=0,
+ feature_name=None):
"""A sequence of categorical terms where ids use an in-memory list.
Pass this to `embedding_column` or `indicator_column` to convert sequence
@@ -404,6 +458,7 @@ def sequence_categorical_column_with_vocabulary_list(key,
"""
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_list(
+ feature_name=feature_name,
key=key,
vocabulary_list=vocabulary_list,
dtype=dtype,
@@ -415,7 +470,8 @@ def sequence_numeric_column(key,
shape=(1,),
default_value=0.,
dtype=dtypes.float32,
- normalizer_fn=None):
+ normalizer_fn=None,
+ feature_name=None):
"""Returns a feature column that represents sequences of numeric data.
Example:
@@ -465,7 +521,8 @@ def sequence_numeric_column(key,
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
return SequenceNumericColumn(
- key,
+ feature_name=feature_name,
+ key=key,
shape=shape,
default_value=default_value,
dtype=dtype,
@@ -485,10 +542,10 @@ def _assert_all_equal_and_return(tensors, name=None):
class SequenceNumericColumn(
- fc.SequenceDenseColumn,
- collections.namedtuple(
- 'SequenceNumericColumn',
- ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
+ fc.SequenceDenseColumn, fc_v1._FeatureColumn,
+ collections.namedtuple('SequenceNumericColumn',
+ ('feature_name', 'key', 'shape', 'default_value',
+ 'dtype', 'normalizer_fn'))):
"""Represents sequences of numeric data."""
@property
@@ -497,6 +554,11 @@ def _is_v2_column(self):
@property
def name(self):
+ """See `FeatureColumn` base class."""
+ return self.feature_name if self.feature_name else self.key
+
+ @property
+ def raw_name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -505,6 +567,13 @@ def parse_example_spec(self):
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+ def _transform_feature(self, inputs):
+ input_tensor = inputs.get(self.key)
+ return self._transform_input_tensor(input_tensor)
+
+ def _transform_input_tensor(self, input_tensor):
+ return math_ops.cast(input_tensor, dtypes.float32)
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class.
@@ -522,7 +591,7 @@ def transform_feature(self, transformation_cache, state_manager):
input_tensor = transformation_cache.get(self.key, state_manager)
if self.normalizer_fn is not None:
input_tensor = self.normalizer_fn(input_tensor)
- return input_tensor
+ return self._transform_input_tensor(input_tensor)
@property
def variable_shape(self):
diff --git a/easy_rec/python/compat/layers.py b/easy_rec/python/compat/layers.py
new file mode 100644
index 000000000..651eefac8
--- /dev/null
+++ b/easy_rec/python/compat/layers.py
@@ -0,0 +1,329 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Higher level ops for building layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+
+
+def layer_norm(inputs,
+ center=True,
+ scale=True,
+ activation_fn=None,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ begin_norm_axis=1,
+ begin_params_axis=-1,
+ scope=None):
+ """Adds a Layer Normalization layer.
+
+ Based on the paper:
+
+ "Layer Normalization"
+ Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
+ https://arxiv.org/abs/1607.06450.
+
+ Can be used as a normalizer function for conv2d and fully_connected.
+
+ Given a tensor `inputs` of rank `R`, moments are calculated and normalization
+ is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering,
+ if requested, is performed over axes `begin_params_axis .. R - 1`.
+
+ By default, `begin_norm_axis = 1` and `begin_params_axis = -1`,
+ meaning that normalization is performed over all but the first axis
+ (the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable
+ parameters are calculated for the rightmost axis (the `C` if `inputs` is
+ `NHWC`). Scaling and recentering is performed via broadcast of the
+ `beta` and `gamma` parameters with the normalized tensor.
+
+ The shapes of `beta` and `gamma` are `inputs.shape[begin_params_axis:]`,
+ and this part of the inputs' shape must be fully defined.
+
+ Args:
+ inputs: A tensor having rank `R`. The normalization is performed over
+ axes `begin_norm_axis ... R - 1` and centering and scaling parameters
+ are calculated over `begin_params_axis ... R - 1`.
+ center: If True, add offset of `beta` to normalized tensor. If False, `beta`
+ is ignored.
+ scale: If True, multiply by `gamma`. If False, `gamma` is
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can be
+ disabled since the scaling can be done by the next layer.
+ activation_fn: Activation function, default set to None to skip it and
+ maintain a linear activation.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional collections for the variables.
+ outputs_collections: Collections to add the outputs.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ begin_norm_axis: The first normalization dimension: normalization will be
+ performed along dimensions `begin_norm_axis : rank(inputs)`
+ begin_params_axis: The first parameter (beta, gamma) dimension: scale
+ and centering parameters will have dimensions
+ `begin_params_axis : rank(inputs)` and will be broadcast with the
+ normalized inputs accordingly.
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ A `Tensor` representing the output of the operation, having the same
+ shape and dtype as `inputs`.
+
+ Raises:
+ ValueError: If the rank of `inputs` is not known at graph build time,
+ or if `inputs.shape[begin_params_axis:]` is not fully defined at
+ graph build time.
+ """
+ with variable_scope.variable_scope(
+ scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
+ inputs = ops.convert_to_tensor(inputs)
+ inputs_shape = inputs.shape
+ inputs_rank = inputs_shape.ndims
+ if inputs_rank is None:
+ raise ValueError('Inputs %s has undefined rank.' % inputs.name)
+ dtype = inputs.dtype.base_dtype
+ if begin_norm_axis < 0:
+ begin_norm_axis = inputs_rank + begin_norm_axis
+ if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
+ raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
+ 'must be < rank(inputs) (%d)' %
+ (begin_params_axis, begin_norm_axis, inputs_rank))
+ params_shape = inputs_shape[begin_params_axis:]
+ if not params_shape.is_fully_defined():
+ raise ValueError(
+ 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
+ (inputs.name, begin_params_axis, inputs_shape))
+ # Allocate parameters for the beta and gamma of the normalization.
+ beta, gamma = None, None
+ if center:
+ beta_collections = get_variable_collections(variables_collections, 'beta')
+ beta = model_variable(
+ 'beta',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=init_ops.zeros_initializer(),
+ collections=beta_collections,
+ trainable=trainable)
+ if scale:
+ gamma_collections = get_variable_collections(variables_collections,
+ 'gamma')
+ gamma = model_variable(
+ 'gamma',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=init_ops.ones_initializer(),
+ collections=gamma_collections,
+ trainable=trainable)
+ # Calculate the moments on the last axis (layer activations).
+ norm_axes = list(range(begin_norm_axis, inputs_rank))
+ mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
+ # Compute layer normalization using the batch_normalization function.
+ variance_epsilon = 1e-12
+ outputs = nn.batch_normalization(
+ inputs,
+ mean,
+ variance,
+ offset=beta,
+ scale=gamma,
+ variance_epsilon=variance_epsilon)
+ outputs.set_shape(inputs_shape)
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+ return collect_named_outputs(outputs_collections, sc.name, outputs)
+
+
+def get_variable_collections(variables_collections, name):
+ if isinstance(variables_collections, dict):
+ variable_collections = variables_collections.get(name, None)
+ else:
+ variable_collections = variables_collections
+ return variable_collections
+
+
+def collect_named_outputs(collections, alias, outputs):
+ """Add `Tensor` outputs tagged with alias to collections.
+
+ It is useful to collect end-points or tags for summaries. Example of usage:
+ logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
+ assert 'inception_v3/logits' in logits.aliases
+
+ Args:
+ collections: A collection or list of collections. If None skip collection.
+ alias: String to append to the list of aliases of outputs, for example,
+ 'inception_v3/conv1'.
+ outputs: Tensor, an output tensor to collect
+
+ Returns:
+ The outputs Tensor to allow inline call.
+ """
+ if collections:
+ append_tensor_alias(outputs, alias)
+ ops.add_to_collections(collections, outputs)
+ return outputs
+
+
+def append_tensor_alias(tensor, alias):
+ """Append an alias to the list of aliases of the tensor.
+
+ Args:
+ tensor: A `Tensor`.
+ alias: String, to add to the list of aliases of the tensor.
+
+ Returns:
+ The tensor with a new alias appended to its list of aliases.
+ """
+ # Remove ending '/' if present.
+ if alias[-1] == '/':
+ alias = alias[:-1]
+ if hasattr(tensor, 'aliases'):
+ tensor.aliases.append(alias)
+ else:
+ tensor.aliases = [alias]
+ return tensor
+
+
+def variable(name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ device=None,
+ partitioner=None,
+ custom_getter=None,
+ use_resource=None):
+ """Gets an existing variable with these parameters or creates a new one.
+
+ Args:
+ name: the name of the new or existing variable.
+ shape: shape of the new or existing variable.
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
+ initializer: initializer for the variable if one is created.
+ regularizer: a (Tensor -> Tensor or None) function; the result of
+ applying it on a newly created variable will be added to the collection
+ GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ collections: A list of collection names to which the Variable will be added.
+ If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`.
+ caching_device: Optional device string or function describing where the
+ Variable should be cached for reading. Defaults to the Variable's
+ device.
+ device: Optional device to place the variable. It can be an string or a
+ function that is called to get the device for the variable.
+ partitioner: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+ custom_getter: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
+ use_resource: If `True` use a ResourceVariable instead of a Variable.
+
+ Returns:
+ The created or existing variable.
+ """
+ collections = list(collections if collections is not None else
+ [ops.GraphKeys.GLOBAL_VARIABLES])
+
+ # Remove duplicates
+ collections = list(set(collections))
+ getter = variable_scope.get_variable
+ if custom_getter is not None:
+ getter = functools.partial(
+ custom_getter, reuse=variable_scope.get_variable_scope().reuse)
+ with ops.device(device or ''):
+ return getter(
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ use_resource=use_resource)
+
+
+def model_variable(name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ device=None,
+ partitioner=None,
+ custom_getter=None,
+ use_resource=None):
+ """Gets an existing model variable with these parameters or creates a new one.
+
+ Args:
+ name: the name of the new or existing variable.
+ shape: shape of the new or existing variable.
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
+ initializer: initializer for the variable if one is created.
+ regularizer: a (Tensor -> Tensor or None) function; the result of
+ applying it on a newly created variable will be added to the collection
+ GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ collections: A list of collection names to which the Variable will be added.
+ Note that the variable is always also added to the
+ `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections.
+ caching_device: Optional device string or function describing where the
+ Variable should be cached for reading. Defaults to the Variable's
+ device.
+ device: Optional device to place the variable. It can be an string or a
+ function that is called to get the device for the variable.
+ partitioner: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+ custom_getter: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
+ use_resource: If `True` use a ResourceVariable instead of a Variable.
+
+ Returns:
+ The created or existing variable.
+ """
+ collections = list(collections or [])
+ collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
+ var = variable(
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ device=device,
+ partitioner=partitioner,
+ custom_getter=custom_getter,
+ use_resource=use_resource)
+ return var
diff --git a/easy_rec/python/compat/ops.py b/easy_rec/python/compat/ops.py
new file mode 100644
index 000000000..548a0c27f
--- /dev/null
+++ b/easy_rec/python/compat/ops.py
@@ -0,0 +1,14 @@
+from tensorflow.python.framework import ops
+
+
+class GraphKeys(ops.GraphKeys):
+ # For rank service
+ RANK_SERVICE_FG_CONF = '__rank_service_fg_conf'
+ RANK_SERVICE_INPUT = '__rank_service_input'
+ RANK_SERVICE_OUTPUT = '__rank_service_output'
+ RANK_SERVICE_EMBEDDING = '__rank_service_embedding'
+ RANK_SERVICE_INPUT_SRC = '__rank_service_input_src'
+ RANK_SERVICE_REPLACE_OP = '__rank_service_replace'
+ RANK_SERVICE_SHAPE_OPT_FLAG = '__rank_service_shape_opt_flag'
+ # For compatition between RTP and EasyRec
+ RANK_SERVICE_FEATURE_NODE = '__rank_service_feature_node'
diff --git a/easy_rec/python/compat/optimizers.py b/easy_rec/python/compat/optimizers.py
index 21fede4b8..d31a4cd41 100644
--- a/easy_rec/python/compat/optimizers.py
+++ b/easy_rec/python/compat/optimizers.py
@@ -19,13 +19,18 @@
from __future__ import division
from __future__ import print_function
+import logging
+
import six
+import tensorflow as tf
# from tensorflow.contrib import framework as contrib_framework
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+# from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -36,6 +41,26 @@
from tensorflow.python.training import optimizer as optimizer_
from tensorflow.python.training import training as train
+from easy_rec.python.ops.incr_record import set_sparse_indices
+from easy_rec.python.utils import constant
+from easy_rec.python.utils import estimator_utils
+
+try:
+ from tensorflow.python.framework import indexed_slices
+except Exception:
+ indexed_slices = ops
+
+try:
+ import horovod.tensorflow as hvd
+except Exception:
+ hvd = None
+
+try:
+ from sparse_operation_kit import experiment as sok
+ from easy_rec.python.compat import sok_optimizer
+except Exception:
+ sok = None
+
OPTIMIZER_CLS_NAMES = {
'Adagrad':
train.AdagradOptimizer,
@@ -75,7 +100,9 @@ def optimize_loss(loss,
summaries=None,
colocate_gradients_with_ops=False,
not_apply_grad_after_first_step=False,
- increment_global_step=True):
+ increment_global_step=True,
+ incr_save=False,
+ embedding_parallel=False):
"""Given loss and parameters for optimizer, returns a training op.
Various ways of passing optimizers include:
@@ -146,6 +173,9 @@ def optimize_loss(loss,
calls `optimize_loss` multiple times per training step (e.g. to optimize
different parts of the model), use this arg to avoid incrementing
`global_step` more times than necessary.
+ incr_save: increment dump checkpoints.
+ embedding_parallel: whether to shard embedding and place embedding parts on
+ different works.
Returns:
Training op.
@@ -232,11 +262,15 @@ def optimize_loss(loss,
if not isinstance(opt, optimizer_.Optimizer):
raise ValueError('Unrecognized optimizer: function should return '
'subclass of Optimizer. Got %s.' % str(opt))
+ elif isinstance(optimizer, sok_optimizer.OptimizerWrapperV1) or \
+ isinstance(optimizer, sok_optimizer.OptimizerWrapperV2):
+ opt = optimizer
else:
raise ValueError('Unrecognized optimizer: should be string, '
'subclass of Optimizer, instance of '
'subclass of Optimizer or function with one argument. '
- 'Got %s.' % str(optimizer))
+ 'Got %s[type=%s].' %
+ (str(optimizer), str(type(optimizer))))
# All trainable variables, if specific variables are not specified.
if variables is None:
@@ -248,6 +282,68 @@ def optimize_loss(loss,
variables,
colocate_gradients_with_ops=colocate_gradients_with_ops)
+ if estimator_utils.has_hvd() and hvd.size() > 1:
+ if not embedding_parallel:
+ # embedding parameters not partitioned
+ reduced_grads = []
+ for g, v in gradients:
+ reduced_grads.append((hvd.allreduce(
+ g, op=hvd.Average,
+ compression=hvd.compression.NoneCompressor), v))
+ gradients = reduced_grads
+ else:
+ # embedding parameters partitioned:
+ # the gradients for embeddings from different workers are
+ # already summed together in the backward pass through
+ # hvd.alltoall, as the loss are not divided, the gradients
+ # need to be normalized, divide by worker number
+ embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
+ part_grads = []
+ part_vars = []
+ part_sparse_grads = []
+ part_sparse_vars = []
+ reduced_grads = []
+ for g, v in gradients:
+ if v.name not in embed_para_vars:
+ if isinstance(g, indexed_slices.IndexedSlices):
+ part_sparse_grads.append(g)
+ part_sparse_vars.append(v)
+ else:
+ part_grads.append(g)
+ part_vars.append(v)
+ else:
+ reduced_grads.append((indexed_slices.IndexedSlices(
+ indices=g.indices, values=g.values / hvd.size()), v))
+
+ group_allreduce = False
+ if len(part_grads) > 0:
+ if group_allreduce:
+ reduced_part_grads = hvd.grouped_allreduce(
+ part_grads,
+ op=hvd.Average,
+ compression=hvd.compression.NoneCompressor)
+ for g, v in zip(reduced_part_grads, part_vars):
+ reduced_grads.append((g, v))
+ else:
+ for g, v in zip(part_grads, part_vars):
+ g = hvd.allreduce(
+ g, op=hvd.Average, compression=hvd.compression.NoneCompressor)
+ reduced_grads.append((g, v))
+ if len(part_sparse_grads) > 0:
+ if group_allreduce:
+ reduced_part_grads = hvd.grouped_allreduce(
+ part_sparse_grads,
+ op=hvd.Average,
+ compression=hvd.compression.NoneCompressor)
+ for g, v in zip(reduced_part_grads, part_sparse_vars):
+ reduced_grads.append((g, v))
+ else:
+ for g, v in zip(part_sparse_grads, part_sparse_vars):
+ g = hvd.allreduce(
+ g, op=hvd.Average, compression=hvd.compression.NoneCompressor)
+ reduced_grads.append((g, v))
+ gradients = reduced_grads
+
# Optionally add gradient noise.
if gradient_noise_scale is not None:
gradients = _add_scaled_noise_to_gradients(gradients,
@@ -261,13 +357,23 @@ def optimize_loss(loss,
'Empty list of (gradient, var) pairs encountered. This is most '
'likely to be caused by an improper value of gradient_multipliers.')
- if 'global_gradient_norm' in summaries or 'gradient_norm' in summaries:
- summary.scalar('global_norm/gradient_norm',
- clip_ops.global_norm(list(zip(*gradients))[0]))
+ # if 'global_gradient_norm' in summaries or 'gradient_norm' in summaries:
+ # summary.scalar('global_norm/gradient_norm',
+ # clip_ops.global_norm(list(zip(*gradients))[0]))
# Optionally clip gradients by global norm.
if isinstance(clip_gradients, float):
- gradients = _clip_gradients_by_norm(gradients, clip_gradients)
+ # gradients = _clip_gradients_by_norm(gradients, clip_gradients)
+ sparse_norm, dense_norm, grad_norm = _get_grad_norm(
+ gradients, embedding_parallel)
+ summary.scalar('global_norm/sparse_grad', sparse_norm)
+ summary.scalar('global_norm/dense_grad', dense_norm)
+ summary.scalar('global_norm/gradient_norm', grad_norm)
+ grads = [x[0] for x in gradients]
+ vars = [x[1] for x in gradients]
+ clipped_grads, _ = clip_ops.clip_by_global_norm(
+ grads, clip_gradients, use_norm=grad_norm)
+ gradients = list(zip(clipped_grads, vars))
elif callable(clip_gradients):
gradients = clip_gradients(gradients)
elif clip_gradients is not None:
@@ -279,24 +385,28 @@ def optimize_loss(loss,
summary.scalar('loss', loss)
# Add histograms for variables, gradients and gradient norms.
- for gradient, variable in gradients:
- if isinstance(gradient, ops.IndexedSlices):
- grad_values = gradient.values
- else:
- grad_values = gradient
-
- if grad_values is not None:
- var_name = variable.name.replace(':', '_')
- if 'gradients' in summaries:
- summary.histogram('gradients/%s' % var_name, grad_values)
- if 'gradient_norm' in summaries:
- summary.scalar('gradient_norm/%s' % var_name,
- clip_ops.global_norm([grad_values]))
+ if not embedding_parallel:
+ for gradient, variable in gradients:
+ if isinstance(gradient, indexed_slices.IndexedSlices):
+ grad_values = gradient.values
+ else:
+ grad_values = gradient
+
+ if grad_values is not None:
+ var_name = variable.name.replace(':', '_')
+ if 'gradients' in summaries:
+ summary.histogram('gradients/%s' % var_name, grad_values)
+ if 'gradient_norm' in summaries:
+ summary.scalar('gradient_norm/%s' % var_name,
+ clip_ops.global_norm([grad_values]))
if clip_gradients is not None and ('global_gradient_norm' in summaries or
'gradient_norm' in summaries):
- summary.scalar('global_norm/clipped_gradient_norm',
- clip_ops.global_norm(list(zip(*gradients))[0]))
+ sparse_norm, dense_norm, grad_norm = _get_grad_norm(
+ gradients, embedding_parallel)
+ summary.scalar('global_norm/clipped_sparse_grad', sparse_norm)
+ summary.scalar('global_norm/clipped_dense_grad', dense_norm)
+ summary.scalar('global_norm/clipped_gradient_norm', grad_norm)
# Create gradient updates.
def _apply_grad():
@@ -304,21 +414,73 @@ def _apply_grad():
gradients,
global_step=global_step if increment_global_step else None,
name='train')
- return control_flow_ops.with_dependencies([grad_updates], loss)
+
+ embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
+ slot_names = opt.get_slot_names()
+ for var in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
+ if var.name in embed_para_vars:
+ for slot_name in slot_names:
+ tmp_var = opt.get_slot(var, slot_name)
+ logging.info('add shard embedding optimizer var: %s' % tmp_var.name)
+ ops.add_to_collection(constant.EmbeddingParallel, tmp_var.name)
+
+ incr_save_ops = []
+ if incr_save:
+ for grad, var in gradients:
+ if isinstance(grad, indexed_slices.IndexedSlices):
+ indices = grad.indices
+ with ops.colocate_with(var), ops.control_dependencies(
+ [grad_updates]):
+ incr_save_op = set_sparse_indices(indices, var_name=var.op.name)
+ incr_save_ops.append(incr_save_op)
+ ops.add_to_collection('SPARSE_UPDATE_VARIABLES',
+ (var, grad.indices.dtype))
+ else:
+ ops.add_to_collection('DENSE_UPDATE_VARIABLES', var)
+ return tf.group(incr_save_ops)
+ else:
+ return grad_updates
if not_apply_grad_after_first_step:
- train_tensor = control_flow_ops.cond(global_step > 0, lambda: loss,
- _apply_grad)
+ _apply_grad()
+ train_tensor = loss
else:
- # Ensure the train_tensor computes grad_updates.
train_tensor = _apply_grad()
return train_tensor
+def _get_grad_norm(grads_and_vars, embedding_parallel=False):
+ part_norms = []
+ sparse_norms = []
+ dense_norms = []
+ emb_para_names = ops.get_collection(constant.EmbeddingParallel)
+ for grad, var in grads_and_vars:
+ if embedding_parallel and hvd is not None and hvd.size() > 1:
+ if var.name in emb_para_names:
+ part_norms.append(gen_nn_ops.l2_loss(grad.values))
+ continue
+ if isinstance(grad, indexed_slices.IndexedSlices):
+ sparse_norms.append(gen_nn_ops.l2_loss(grad.values))
+ else:
+ dense_norms.append(gen_nn_ops.l2_loss(grad))
+ reduced_norms = hvd.grouped_allreduce(
+ part_norms, op=hvd.Sum, compression=hvd.compression.NoneCompressor)
+ sparse_norms = sparse_norms + reduced_norms
+ all_norms = reduced_norms + dense_norms
+ sparse_norm = math_ops.sqrt(
+ math_ops.reduce_sum(array_ops.stack(sparse_norms) * 2.0))
+ dense_norm = math_ops.sqrt(
+ math_ops.reduce_sum(array_ops.stack(dense_norms) * 2.0))
+ grad_norm = math_ops.sqrt(
+ math_ops.reduce_sum(array_ops.stack(all_norms)) * 2.0)
+ return sparse_norm, dense_norm, grad_norm
+
+
def _clip_gradients_by_norm(grads_and_vars, clip_gradients):
"""Clips gradients by global norm."""
gradients, variables = zip(*grads_and_vars)
+
clipped_gradients, _ = clip_ops.clip_by_global_norm(gradients, clip_gradients)
return list(zip(clipped_gradients, variables))
@@ -410,10 +572,10 @@ def gradient_clipping(grads_and_vars):
for grad in grads:
if grad is None:
clipped_grads.append(None)
- elif isinstance(grad, ops.IndexedSlices):
+ elif isinstance(grad, indexed_slices.IndexedSlices):
clipped_grads.append(
- ops.IndexedSlices(grad.values * factor, grad.indices,
- grad.dense_shape))
+ indexed_slices.IndexedSlices(grad.values * factor, grad.indices,
+ grad.dense_shape))
else:
clipped_grads.append(grad * factor)
@@ -430,7 +592,7 @@ def _add_scaled_noise_to_gradients(grads_and_vars, gradient_noise_scale):
if gradient is None:
noisy_gradients.append(None)
continue
- if isinstance(gradient, ops.IndexedSlices):
+ if isinstance(gradient, indexed_slices.IndexedSlices):
gradient_shape = gradient.dense_shape
else:
gradient_shape = gradient.get_shape()
@@ -447,9 +609,10 @@ def _multiply_gradients(grads_and_vars, gradient_multipliers):
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
multiplier = gradient_multipliers[key]
- if isinstance(grad, ops.IndexedSlices):
+ if isinstance(grad, indexed_slices.IndexedSlices):
grad_values = grad.values * multiplier
- grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
+ grad = indexed_slices.IndexedSlices(grad_values, grad.indices,
+ grad.dense_shape)
else:
grad *= math_ops.cast(multiplier, grad.dtype)
multiplied_grads_and_vars.append((grad, var))
diff --git a/easy_rec/python/compat/queues.py b/easy_rec/python/compat/queues.py
new file mode 100644
index 000000000..c7063d966
--- /dev/null
+++ b/easy_rec/python/compat/queues.py
@@ -0,0 +1,311 @@
+#
+# Module implementing queues
+#
+# multiprocessing/queues.py
+#
+# Copyright (c) 2006-2008, R Oudkerk
+# Licensed to PSF under a Contributor Agreement.
+#
+
+import collections
+import errno
+import logging
+import os
+import sys
+import threading
+import time
+import weakref
+from multiprocessing import connection
+from multiprocessing.util import Finalize
+from multiprocessing.util import is_exiting
+from multiprocessing.util import register_after_fork
+from queue import Empty
+from queue import Full
+
+import six
+
+try:
+ from multiprocessing import context
+except ImportError:
+ context = None
+ pass
+
+if context is not None:
+ _ForkingPickler = context.reduction.ForkingPickler
+else:
+ _ForkingPickler = None
+
+#
+# Queue type using a pipe, buffer and thread
+#
+
+
+class Queue(object):
+
+ _sentinel = object()
+
+ def __init__(self, ctx, maxsize=0, name=''):
+ assert not six.PY2, 'python2 is not supported'
+ if maxsize <= 0:
+ # Can raise ImportError (see issues #3770 and #23400)
+ from multiprocessing.synchronize import SEM_VALUE_MAX as maxsize
+ self._maxsize = maxsize
+ self._reader, self._writer = connection.Pipe(duplex=False)
+ self._rlock = ctx.Lock()
+ self._opid = os.getpid()
+ if sys.platform == 'win32':
+ self._wlock = None
+ else:
+ self._wlock = ctx.Lock()
+ self._sem = ctx.BoundedSemaphore(maxsize)
+ # For use by concurrent.futures
+ self._ignore_epipe = False
+ self._reset()
+ self._name = name
+ self._run = True
+
+ if sys.platform != 'win32':
+ register_after_fork(self, Queue._after_fork)
+
+ def __getstate__(self):
+ context.assert_spawning(self)
+ return (self._ignore_epipe, self._maxsize, self._reader, self._writer,
+ self._rlock, self._wlock, self._sem, self._opid, self._name,
+ self._run)
+
+ def __setstate__(self, state):
+ (self._ignore_epipe, self._maxsize, self._reader, self._writer, self._rlock,
+ self._wlock, self._sem, self._opid, self._name, self._run) = state
+ self._reset()
+
+ def _after_fork(self):
+ logging.debug('Queue._after_fork()')
+ self._reset(after_fork=True)
+
+ def _reset(self, after_fork=False):
+ if after_fork:
+ self._notempty._at_fork_reinit()
+ else:
+ self._notempty = threading.Condition(threading.Lock())
+ self._buffer = collections.deque()
+ self._thread = None
+ self._jointhread = None
+ self._joincancelled = False
+ self._closed = False
+ self._close = None
+ self._send_bytes = self._writer.send_bytes
+ self._recv_bytes = self._reader.recv_bytes
+ self._poll = self._reader.poll
+
+ def put(self, obj, block=True, timeout=None):
+ if self._closed:
+ raise ValueError('Queue %s is closed' % self._name)
+ if not self._sem.acquire(block, timeout):
+ raise Full
+
+ with self._notempty:
+ if self._thread is None:
+ self._start_thread()
+ self._buffer.append(obj)
+ self._notempty.notify()
+
+ def get(self, block=True, timeout=None):
+ if self._closed:
+ raise ValueError('Queue %s is closed' % self._name)
+ if block and timeout is None:
+ with self._rlock:
+ res = self._recv_bytes()
+ self._sem.release()
+ else:
+ if block:
+ deadline = time.monotonic() + timeout
+ if not self._rlock.acquire(block, timeout):
+ raise Empty
+ try:
+ if block:
+ timeout = deadline - time.monotonic()
+ if not self._poll(timeout):
+ raise Empty
+ elif not self._poll():
+ raise Empty
+ res = self._recv_bytes()
+ self._sem.release()
+ finally:
+ self._rlock.release()
+ # unserialize the data after having released the lock
+ return _ForkingPickler.loads(res)
+
+ def qsize(self):
+ # Raises NotImplementedError on Mac OSX because of broken sem_getvalue()
+ return self._maxsize - self._sem._semlock._get_value()
+
+ def empty(self):
+ return not self._poll()
+
+ def full(self):
+ return self._sem._semlock._is_zero()
+
+ def get_nowait(self):
+ return self.get(False)
+
+ def put_nowait(self, obj):
+ return self.put(obj, False)
+
+ def close(self, wait_send_finish=True):
+ self._closed = True
+ close = self._close
+ if not wait_send_finish and self._thread is not None and self._thread.is_alive(
+ ):
+ try:
+ if self._reader is not None:
+ self._reader.close()
+ except Exception:
+ pass
+ self._run = False
+ # clear queue
+ # with self._rlock:
+ # while self._thread.is_alive() and self._poll(1):
+ # res = self._recv_bytes()
+ # logging.info('Queue[name=' + self._name + '] clear one elem')
+ # logging.info('Queue[name=' + self._name + '] clear queue done')
+ if close:
+ self._close = None
+ close()
+
+ def join_thread(self):
+ logging.debug('Queue.join_thread()')
+ assert self._closed, 'Queue {0!r} not closed'.format(self)
+ if self._jointhread:
+ self._jointhread()
+
+ def cancel_join_thread(self):
+ logging.debug('Queue.cancel_join_thread()')
+ self._joincancelled = True
+ try:
+ self._jointhread.cancel()
+ except AttributeError:
+ pass
+
+ def _start_thread(self):
+ logging.debug('Queue._start_thread()')
+
+ # Start thread which transfers data from buffer to pipe
+ self._buffer.clear()
+ self._thread = threading.Thread(
+ target=self._feed,
+ args=(self._buffer, self._notempty, self._send_bytes, self._wlock,
+ self._reader.close, self._writer.close, self._ignore_epipe,
+ self._on_queue_feeder_error, self._sem),
+ name='QueueFeederThread')
+ self._thread.daemon = True
+
+ logging.debug('doing self._thread.start()')
+ self._thread.start()
+ logging.debug('... done self._thread.start()')
+
+ if not self._joincancelled:
+ self._jointhread = Finalize(
+ self._thread,
+ Queue._finalize_join, [weakref.ref(self._thread)],
+ exitpriority=-5)
+
+ # Send sentinel to the thread queue object when garbage collected
+ self._close = Finalize(
+ self,
+ Queue._finalize_close, [self._buffer, self._notempty],
+ exitpriority=10)
+
+ @staticmethod
+ def _finalize_join(twr):
+ logging.debug('joining queue thread')
+ thread = twr()
+ if thread is not None:
+ thread.join()
+ logging.debug('... queue thread joined')
+ else:
+ logging.debug('... queue thread already dead')
+
+ @staticmethod
+ def _finalize_close(buffer, notempty):
+ logging.debug('telling queue thread to quit')
+ with notempty:
+ buffer.append(Queue._sentinel)
+ notempty.notify()
+
+ def _feed(self, buffer, notempty, send_bytes, writelock, reader_close,
+ writer_close, ignore_epipe, onerror, queue_sem):
+ logging.debug('starting thread to feed data to pipe')
+ nacquire = notempty.acquire
+ nrelease = notempty.release
+ nwait = notempty.wait
+ bpopleft = buffer.popleft
+ sentinel = Queue._sentinel
+ if sys.platform != 'win32':
+ wacquire = writelock.acquire
+ wrelease = writelock.release
+ else:
+ wacquire = None
+
+ pid = os.getpid()
+ name = self._name
+ while self._run:
+ try:
+ nacquire()
+ try:
+ if not buffer:
+ nwait()
+ finally:
+ nrelease()
+ try:
+ while self._run:
+ obj = bpopleft()
+ if obj is sentinel:
+ # logging.info('Queue[' + self._name + '] feeder thread got sentinel -- exiting: ' + str(self._run))
+ reader_close()
+ writer_close()
+ return
+
+ # serialize the data before acquiring the lock
+ obj = _ForkingPickler.dumps(obj)
+ if wacquire is None:
+ send_bytes(obj)
+ else:
+ wacquire()
+ try:
+ send_bytes(obj)
+ finally:
+ wrelease()
+ except IndexError:
+ pass
+ except Exception as e:
+ if ignore_epipe and getattr(e, 'errno', 0) == errno.EPIPE:
+ logging.warning('Queue[' + name + '] exception: pid=' + str(pid) +
+ ' run=' + str(self._run) + ' e=' + str(e))
+ return
+ # Since this runs in a daemon thread the resources it uses
+ # may be become unusable while the process is cleaning up.
+ # We ignore errors which happen after the process has
+ # started to cleanup.
+ if is_exiting():
+ logging.warning('Queue[' + name + '] thread error in exiting: pid=' +
+ str(pid) + ' run=' + str(self._run) + ' e=' + str(e))
+ return
+ else:
+ # Since the object has not been sent in the queue, we need
+ # to decrease the size of the queue. The error acts as
+ # if the object had been silently removed from the queue
+ # and this step is necessary to have a properly working
+ # queue.
+ queue_sem.release()
+ onerror(e, obj)
+ # logging.info('Queue[' + name + '] send thread finish: pid=' + str(pid)
+ # + ' run=' + str(self._run))
+
+ @staticmethod
+ def _on_queue_feeder_error(e, obj):
+ """Private API hook called when feeding data in the background thread raises an exception.
+
+ For overriding by concurrent.futures.
+ """
+ import traceback
+ traceback.print_exc()
diff --git a/easy_rec/python/compat/sok_optimizer.py b/easy_rec/python/compat/sok_optimizer.py
new file mode 100644
index 000000000..7f368a9a1
--- /dev/null
+++ b/easy_rec/python/compat/sok_optimizer.py
@@ -0,0 +1,440 @@
+#
+# Copyright (c) 2022, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tensorflow as tf
+from tensorflow.python.eager import context
+# from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+# from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+
+from easy_rec.python.compat.dynamic_variable import DynamicVariable
+
+
+def OptimizerWrapper(optimizer):
+ """Abbreviated as ``sok.experiment.OptimizerWrapper``.
+
+ This is a wrapper for tensorflow optimizer so that it can update
+ dynamic_variable.DynamicVariable.
+
+ Parameters
+ ----------
+ optimizer: tensorflow optimizer
+ The original tensorflow optimizer.
+
+ Example
+ -------
+ .. code-block:: python
+
+ import numpy as np
+ import tensorflow as tf
+ import horovod.tensorflow as hvd
+ from sparse_operation_kit import experiment as sok
+
+ v = dynamic_variable.DynamicVariable(dimension=3, initializer="13")
+
+ indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)
+
+ with tf.GradientTape() as tape:
+ embedding = tf.nn.embedding_lookup(v, indices)
+ print("embedding:", embedding)
+ loss = tf.reduce_sum(embedding)
+
+ grads = tape.gradient(loss, [v])
+
+ optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)
+ optimizer = sok.OptimizerWrapper(optimizer)
+ optimizer.apply_gradients(zip(grads, [v]))
+
+ embedding = tf.nn.embedding_lookup(v, indices)
+ print("embedding:", embedding)
+ """
+ # a specific code path for dl framework tf2.11.0
+ try:
+ if isinstance(optimizer, tf.keras.optimizers.legacy.Optimizer):
+ return OptimizerWrapperV2(optimizer)
+ except Exception:
+ pass
+
+ if isinstance(optimizer, tf.keras.optimizers.Optimizer):
+ return OptimizerWrapperV2(optimizer)
+ else:
+ return OptimizerWrapperV1(optimizer)
+
+
+class OptimizerWrapperV1(object):
+
+ def __init__(self, optimizer):
+ self._optimizer = optimizer
+ # slots
+ unused = tf.Variable([0.0],
+ dtype=tf.float32,
+ name='unused',
+ trainable=False)
+ self._optimizer._create_slots([unused])
+ names, slots = [], []
+ for name in self._optimizer.get_slot_names():
+ names.append(name)
+ slots.append(self._optimizer.get_slot(unused, name))
+ unused_key = self._var_key(unused)
+ for name in names:
+ assert unused_key in self._optimizer._slots[name]
+ self._optimizer._slots[name].pop(unused_key)
+ self._initial_vals = {}
+ for i, name in enumerate(names):
+ self._initial_vals[name] = slots[i]
+ # self._optimizer._prepare()
+
+ def compute_gradients(self,
+ loss,
+ var_list=None,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None):
+ self._loss = loss
+ tmp_grads = gradients.gradients(loss, var_list)
+ return list(zip(tmp_grads, var_list))
+ # TODO: the following routine does not work with DynamicVariable
+ # return self._optimizer.compute_gradients(loss=loss, var_list=var_list,
+ # # gate_gradients=gate_gradients,
+ # aggregation_method=aggregation_method,
+ # colocate_gradients_with_ops=colocate_gradients_with_ops,
+ # grad_loss=grad_loss)
+
+ def _var_key(self, var):
+ if isinstance(var, DynamicVariable):
+ return (var._tf_handle.op.graph, var._tf_handle.op.name)
+ else:
+ return (var.op.graph, var.op.name)
+
+ def _create_slots(self, vars):
+ for var in vars:
+ if isinstance(var, DynamicVariable):
+ self._create_slots_dynamic(var)
+ else:
+ self._optimizer._create_slots(var)
+
+ def _create_slots_dynamic(self, var):
+ key = self._var_key(var)
+ for slot_name in self._initial_vals:
+ if key not in self._optimizer._slots[slot_name]:
+ if var.backend_type == 'hbm':
+ with ops.colocate_with(var):
+ slot = DynamicVariable(
+ dimension=var.dimension,
+ initializer=self._initial_vals[slot_name],
+ name='DynamicSlot',
+ trainable=False)
+ else:
+ tmp_config = var.config_dict
+ # tmp_initializer = var.initializer_str
+ with ops.colocate_with(var):
+ slot = DynamicVariable(
+ dimension=var.dimension,
+ initializer=self._initial_vals[slot_name],
+ var_type=var.backend_type,
+ name='DynamicSlot',
+ trainable=False,
+ **tmp_config)
+
+ self._optimizer._slots[slot_name][key] = slot
+
+ def get_slot_names(self):
+ return self._optimizer.get_slot_names()
+
+ def get_slot(self, var, slot_name):
+ key = self._var_key(var)
+ return self._optimizer._slots[slot_name][key]
+
+ @property
+ def _slots(self):
+ return self._optimizer._slots
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ gradients = grads_and_vars
+ sparse_vars = [x for x in gradients if 'DynamicVariable' in str(type(x[1]))]
+ dense_vars = [
+ x for x in gradients if 'DynamicVariable' not in str(type(x[1]))
+ ]
+
+ def _dummy_finish(update_ops, name_scope):
+ return update_ops
+
+ finish_func = self._optimizer._finish
+ self._optimizer._finish = _dummy_finish
+ with ops.control_dependencies([array_ops.identity(self._loss)]):
+ sparse_grad_updates = self.apply_sparse_gradients(sparse_vars, name=name)
+
+ dense_grad_updates = self._optimizer.apply_gradients(
+ dense_vars, global_step=None, name=name)
+ if sparse_grad_updates is not None and dense_grad_updates is not None:
+ grad_updates = sparse_grad_updates + dense_grad_updates
+ elif sparse_grad_updates is not None:
+ grad_updates = sparse_grad_updates
+ elif dense_grad_updates is not None:
+ grad_updates = dense_grad_updates
+
+ assert global_step is not None
+ with ops.control_dependencies([finish_func(grad_updates, 'update')]):
+ with ops.colocate_with(global_step):
+ if isinstance(global_step, resource_variable_ops.BaseResourceVariable):
+ # TODO(apassos): the implicit read in assign_add is slow; consider
+ # making it less so.
+ apply_updates = resource_variable_ops.assign_add_variable_op(
+ global_step.handle,
+ ops.convert_to_tensor(1, dtype=global_step.dtype),
+ name=name)
+ else:
+ apply_updates = state_ops.assign_add(global_step, 1, name=name)
+
+ if not context.executing_eagerly():
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
+
+ return apply_updates
+
+ def apply_sparse_gradients(self, grads_and_vars, global_step=None, name=None):
+ # 1. Create slots and do sparse_read
+ to_static_ops = []
+ grad_list, var_list = [], []
+ for g, v in grads_and_vars:
+ if g is not None:
+ unique, indices = tf.unique(g.indices)
+ grad_list.append(ops.IndexedSlices(g.values, indices, g.dense_shape))
+ # TODO: Check multi-thread safety of DET
+ with tf.control_dependencies([g.values]):
+ to_static_ops.append(v.to_static(unique, False))
+ var_list.append(v)
+ key = self._var_key(v)
+ for slot_name in self._initial_vals:
+ if key not in self._optimizer._slots[slot_name]:
+ tmp_slot_var_name = v._dummy_handle.op.name + '/' + self._optimizer._name
+ if v.backend_type == 'hbm':
+ with ops.colocate_with(v):
+ slot = DynamicVariable(
+ dimension=v.dimension,
+ initializer=self._initial_vals[slot_name],
+ name=tmp_slot_var_name,
+ trainable=False,
+ )
+ else:
+ tmp_config = v.config_dict
+ # tmp_initializer = v.initializer_str
+ with ops.colocate_with(v):
+ slot = DynamicVariable(
+ dimension=v.dimension,
+ initializer=self._initial_vals[slot_name],
+ var_type=v.backend_type,
+ name=tmp_slot_var_name,
+ trainable=False,
+ **tmp_config)
+
+ self._optimizer._slots[slot_name][key] = slot
+ else:
+ slot = self._optimizer._slots[slot_name][key]
+ to_static_ops.append(slot.to_static(unique))
+
+ if len(grad_list) == 0:
+ return
+
+ # 3. Call tf-optimizer
+ with ops.control_dependencies(to_static_ops):
+ train_op = self._optimizer.apply_gradients(
+ zip(grad_list, var_list), global_step=global_step, name=name)
+
+ # 5. Write buffer back to dynamic variables
+ to_dynamic_ops = []
+ if not isinstance(train_op, list):
+ train_op = [train_op]
+ with ops.control_dependencies(train_op):
+ for v in var_list:
+ key = self._var_key(v)
+ to_dynamic_ops.append(v.to_dynamic())
+ for name in self._initial_vals:
+ slot = self._optimizer._slots[name][key]
+ to_dynamic_ops.append(slot.to_dynamic())
+
+ return to_dynamic_ops
+
+
+class OptimizerWrapperV2(object):
+
+ def __init__(self, optimizer):
+ self._optimizer = optimizer
+ # slots
+ if tf.__version__[0] == '1':
+ unused = tf.Variable([0.0],
+ name='unused',
+ trainable=False,
+ use_resource=True)
+ else:
+ unused = tf.Variable([0.0], name='unused', trainable=False)
+ self._optimizer._create_slots([unused])
+ names, slots = [], []
+ for name in self._optimizer.get_slot_names():
+ names.append(name)
+ slots.append(self._optimizer.get_slot(unused, name))
+ unused_key = self._var_key(unused)
+ if unused_key in self._optimizer._slots:
+ self._optimizer._slots.pop(unused_key)
+ self._initial_vals = {}
+ for i, name in enumerate(names):
+ self._initial_vals[name] = slots[i]
+ self._iterations = tf.Variable(0)
+
+ @property
+ def lr(self):
+ return self._optimizer.lr
+
+ def _create_slots(self, vars):
+ for tmp_var in vars:
+ if isinstance(tmp_var, DynamicVariable):
+ self._create_slots_dynamic(tmp_var)
+ else:
+ self._optimizer._create_slots(tmp_var)
+
+ def _create_slots_dynamic(self, var):
+ key = self._var_key(var)
+ if key not in self._optimizer._slots:
+ self._optimizer._slots[key] = {}
+ for slot_name in self._initial_vals:
+ if slot_name not in self._optimizer._slots[key]:
+ if var.backend_type == 'hbm':
+ slot = DynamicVariable(
+ dimension=var.dimension,
+ initializer=self._initial_vals[slot_name],
+ name='DynamicSlot',
+ trainable=False,
+ )
+ else:
+ tmp_config = var.config_dict
+ # tmp_initializer = var.initializer_str
+ slot = DynamicVariable(
+ dimension=var.dimension,
+ initializer=self._initial_vals[slot_name],
+ var_type=var.backend_type,
+ name='DynamicSlot',
+ trainable=False,
+ **tmp_config)
+ self._optimizer._slots[key][slot_name] = slot
+
+ def _var_key(self, var):
+ if hasattr(var, '_distributed_container'):
+ var = var._distributed_container()
+ if var._in_graph_mode:
+ return var._shared_name
+ return var._unique_id
+
+ def get_slot_names(self):
+ return self._optimizer.get_slot_names()
+
+ def get_slot(self, var, name):
+ return self._optimizer.get_slot(var, name)
+
+ @property
+ def _slots(self):
+ return self._optimizer._slots
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ # 1. Create slots and do sparse_read
+ to_static_ops = []
+ grad_list, var_list = [], []
+ for g, v in grads_and_vars:
+ if g is not None:
+ unique, indices = tf.unique(g.indices)
+ grad_list.append(ops.IndexedSlices(g.values, indices, g.dense_shape))
+ # TODO: Check multi-thread safety of DET
+ # with tf.control_dependencies([g.values]):
+ to_static_ops.append(v.to_static(unique))
+ var_list.append(v)
+ key = self._var_key(v)
+ if key not in self._optimizer._slots:
+ self._optimizer._slots[key] = {}
+ for slot_name in self._initial_vals:
+ if slot_name not in self._optimizer._slots[key]:
+ if v.backend_type == 'hbm':
+ slot = DynamicVariable(
+ dimension=v.dimension,
+ initializer=self._initial_vals[slot_name],
+ name='DynamicSlot',
+ trainable=False,
+ )
+ else:
+ tmp_config = v.config_dict
+ # tmp_initializer = v.initializer_str
+ slot = DynamicVariable(
+ dimension=v.dimension,
+ initializer=self._initial_vals[slot_name],
+ var_type=v.backend_type,
+ name='DynamicSlot',
+ trainable=False,
+ **tmp_config)
+
+ self._optimizer._slots[key][slot_name] = slot
+ else:
+ slot = self._optimizer._slots[key][slot_name]
+ to_static_ops.append(slot.to_static(unique))
+
+ if len(grad_list) == 0:
+ return
+
+ # 2. Switch iterations
+ iterations = self._optimizer._iterations
+ self._optimizer._iterations = self._iterations
+
+ # 3. Call tf-optimizer
+ with tf.control_dependencies(to_static_ops):
+ train_op = self._optimizer.apply_gradients(
+ zip(grad_list, var_list), name=name)
+
+ # 4. Switch iterations
+ self._optimizer._iterations = iterations
+
+ # 5. Write buffer back to dynamic variables
+ to_dynamic_ops = []
+ with tf.control_dependencies([train_op]):
+ for v in var_list:
+ key = self._var_key(v)
+ to_dynamic_ops.append(v.to_dynamic())
+ for name in self._initial_vals:
+ slot = self._optimizer._slots[key][name]
+ to_dynamic_ops.append(slot.to_dynamic())
+ return tf.group(to_dynamic_ops)
+
+
+class SGD(object):
+
+ def __init__(self, lr):
+ self._lr = tf.Variable(lr)
+
+ @property
+ def lr(self):
+ return self._lr
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ train_ops = []
+ for g, v in grads_and_vars:
+ if g is not None:
+ scaled_g = ops.IndexedSlices(g.values * self._lr, g.indices,
+ g.dense_shape)
+ train_ops.append(v.scatter_sub(scaled_g))
+ return tf.group(train_ops)
diff --git a/easy_rec/python/compat/sync_replicas_optimizer.py b/easy_rec/python/compat/sync_replicas_optimizer.py
new file mode 100644
index 000000000..24c2921ba
--- /dev/null
+++ b/easy_rec/python/compat/sync_replicas_optimizer.py
@@ -0,0 +1,528 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Synchronize replicas for training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import queue_runner
+from tensorflow.python.training import session_manager
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.util.tf_export import tf_export
+
+
+# Please note that the gradients from replicas are averaged instead of summed
+# (as in the old sync_replicas_optimizer) so you need to increase the learning
+# rate according to the number of replicas. This change is introduced to be
+# consistent with how gradients are aggregated (averaged) within a batch in a
+# replica.
+@tf_export('train.SyncReplicasOptimizer')
+class SyncReplicasOptimizer(optimizer.Optimizer):
+ """Class to synchronize, aggregate gradients and pass them to the optimizer.
+
+ In a typical asynchronous training environment, it's common to have some
+ stale gradients. For example, with a N-replica asynchronous training,
+ gradients will be applied to the variables N times independently. Depending
+ on each replica's training speed, some gradients might be calculated from
+ copies of the variable from several steps back (N-1 steps on average). This
+ optimizer avoids stale gradients by collecting gradients from all replicas,
+ averaging them, then applying them to the variables in one shot, after
+ which replicas can fetch the new variables and continue.
+
+ The following accumulators/queue are created:
+
+ * N `gradient accumulators`, one per variable to train. Gradients are pushed
+ to them and the chief worker will wait until enough gradients are collected
+ and then average them before applying to variables. The accumulator will
+ drop all stale gradients (more details in the accumulator op).
+ * 1 `token` queue where the optimizer pushes the new global_step value after
+ all variables are updated.
+
+ The following local variable is created:
+ * `sync_rep_local_step`, one per replica. Compared against the global_step in
+ each accumulator to check for staleness of the gradients.
+
+ The optimizer adds nodes to the graph to collect gradients and pause the
+ trainers until variables are updated.
+ For the Parameter Server job:
+
+ 1. An accumulator is created for each variable, and each replica pushes the
+ gradients into the accumulators instead of directly applying them to the
+ variables.
+ 2. Each accumulator averages once enough gradients (replicas_to_aggregate)
+ have been accumulated.
+ 3. Apply the averaged gradients to the variables.
+ 4. Only after all variables have been updated, increment the global step.
+ 5. Only after step 4, pushes `global_step` in the `token_queue`, once for
+ each worker replica. The workers can now fetch the global step, use it to
+ update its local_step variable and start the next batch.
+
+ For the replicas:
+
+ 1. Start a step: fetch variables and compute gradients.
+ 2. Once the gradients have been computed, push them into gradient
+ accumulators. Each accumulator will check the staleness and drop the stale.
+ 3. After pushing all the gradients, dequeue an updated value of global_step
+ from the token queue and record that step to its local_step variable. Note
+ that this is effectively a barrier.
+ 4. Start the next batch.
+
+ ### Usage
+
+ ```python
+ # Create any optimizer to update the variables, say a simple SGD:
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Wrap the optimizer with sync_replicas_optimizer with 50 replicas: at each
+ # step the optimizer collects 50 gradients before applying to variables.
+ # Note that if you want to have 2 backup replicas, you can change
+ # total_num_replicas=52 and make sure this number matches how many physical
+ # replicas you started in your job.
+ opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
+ total_num_replicas=50)
+
+ # Some models have startup_delays to help stabilize the model but when using
+ # sync_replicas training, set it to 0.
+
+ # Now you can call `minimize()` or `compute_gradients()` and
+ # `apply_gradients()` normally
+ training_op = opt.minimize(total_loss, global_step=self.global_step)
+
+
+ # You can create the hook which handles initialization and queues.
+ sync_replicas_hook = opt.make_session_run_hook(is_chief)
+ ```
+
+ In the training program, every worker will run the train_op as if not
+ synchronized.
+
+ ```python
+ with training.MonitoredTrainingSession(
+ master=workers[worker_id].target, is_chief=is_chief,
+ hooks=[sync_replicas_hook]) as mon_sess:
+ while not mon_sess.should_stop():
+ mon_sess.run(training_op)
+ ```
+
+ To use SyncReplicasOptimizer with an `Estimator`, you need to send
+ sync_replicas_hook while calling the fit.
+ ```python
+ my_estimator = DNNClassifier(..., optimizer=opt)
+ my_estimator.fit(..., hooks=[sync_replicas_hook])
+ ```
+ """
+
+ sync_que_id = -1
+
+ def __init__(self,
+ opt,
+ replicas_to_aggregate,
+ total_num_replicas=None,
+ variable_averages=None,
+ variables_to_average=None,
+ use_locking=False,
+ name='sync_replicas',
+ **extra_args):
+ """Construct a sync_replicas optimizer.
+
+ Args:
+ opt: The actual optimizer that will be used to compute and apply the
+ gradients. Must be one of the Optimizer classes.
+ replicas_to_aggregate: number of replicas to aggregate for each variable
+ update.
+ total_num_replicas: Total number of tasks/workers/replicas, could be
+ different from replicas_to_aggregate.
+ If total_num_replicas > replicas_to_aggregate: it is backup_replicas +
+ replicas_to_aggregate.
+ If total_num_replicas < replicas_to_aggregate: Replicas compute
+ multiple batches per update to variables.
+ variable_averages: Optional `ExponentialMovingAverage` object, used to
+ maintain moving averages for the variables passed in
+ `variables_to_average`.
+ variables_to_average: a list of variables that need to be averaged. Only
+ needed if variable_averages is passed in.
+ use_locking: If True use locks for update operation.
+ name: string. Optional name of the returned operation.
+ """
+ if total_num_replicas is None:
+ total_num_replicas = replicas_to_aggregate
+
+ super(SyncReplicasOptimizer, self).__init__(use_locking, name)
+ logging.info(
+ 'SyncReplicasV2: replicas_to_aggregate=%s; total_num_replicas=%s',
+ replicas_to_aggregate, total_num_replicas)
+ self._opt = opt
+ self._replicas_to_aggregate = replicas_to_aggregate
+ self._gradients_applied = False
+ self._variable_averages = variable_averages
+ self._variables_to_average = variables_to_average
+ self._total_num_replicas = total_num_replicas
+ self._tokens_per_step = max(total_num_replicas, replicas_to_aggregate)
+ self._global_step = None
+ self._sync_token_queue = None
+ self._is_sync_que_closed = None
+ self._close_sync_que = None
+
+ # The synchronization op will be executed in a queue runner which should
+ # only be executed by one of the replicas (usually the chief).
+ self._chief_queue_runner = None
+
+ # Remember which accumulator is on which device to set the initial step in
+ # the accumulator to be global step. This list contains list of the
+ # following format: (accumulator, device).
+ self._accumulator_list = []
+
+ def compute_gradients(self, *args, **kwargs):
+ """Compute gradients of "loss" for the variables in "var_list".
+
+ This simply wraps the compute_gradients() from the real optimizer. The
+ gradients will be aggregated in the apply_gradients() so that user can
+ modify the gradients like clipping with per replica global norm if needed.
+ The global norm with aggregated gradients can be bad as one replica's huge
+ gradients can hurt the gradients from other replicas.
+
+ Args:
+ *args: Arguments for compute_gradients().
+ **kwargs: Keyword arguments for compute_gradients().
+
+ Returns:
+ A list of (gradient, variable) pairs.
+ """
+ return self._opt.compute_gradients(*args, **kwargs)
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This contains most of the synchronization implementation and also wraps the
+ apply_gradients() from the real optimizer.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ compute_gradients().
+ global_step: Optional Variable to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the Optimizer constructor.
+
+ Returns:
+ train_op: The op to dequeue a token so the replicas can exit this batch
+ and start the next one. This is executed by each replica.
+
+ Raises:
+ ValueError: If the grads_and_vars is empty.
+ ValueError: If global step is not provided, the staleness cannot be
+ checked.
+ """
+ if not grads_and_vars:
+ raise ValueError('Must supply at least one variable')
+
+ if global_step is None:
+ raise ValueError('Global step is required to check staleness')
+
+ self._global_step = global_step
+ train_ops = []
+ aggregated_grad = []
+ var_list = []
+
+ # local_anchor op will be placed on this worker task by default.
+ local_anchor = control_flow_ops.no_op()
+ # Colocating local_step variable prevents it being placed on the PS.
+ with ops.colocate_with(local_anchor):
+ self._local_step = variable_scope.variable(
+ initial_value=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=global_step.dtype.base_dtype,
+ name='sync_rep_local_step')
+
+ self.local_step_init_op = state_ops.assign(self._local_step, global_step)
+ chief_init_ops = [self.local_step_init_op]
+ self.ready_for_local_init_op = variables.report_uninitialized_variables(
+ variables.global_variables())
+
+ with ops.name_scope(None, self._name):
+ for grad, var in grads_and_vars:
+ var_list.append(var)
+ with ops.device(var.device):
+ # Dense gradients.
+ if grad is None:
+ aggregated_grad.append(None) # pass-through.
+ continue
+ elif isinstance(grad, ops.Tensor):
+ grad_accum = data_flow_ops.ConditionalAccumulator(
+ grad.dtype,
+ shape=var.get_shape(),
+ shared_name=var.name + '/grad_accum')
+ train_ops.append(
+ grad_accum.apply_grad(grad, local_step=self._local_step))
+ aggregated_grad.append(
+ grad_accum.take_grad(self._replicas_to_aggregate))
+ else:
+ if not isinstance(grad, ops.IndexedSlices):
+ raise ValueError('Unknown grad type!')
+ grad_accum = data_flow_ops.SparseConditionalAccumulator(
+ grad.dtype, shape=(), shared_name=var.name + '/grad_accum')
+ train_ops.append(
+ grad_accum.apply_indexed_slices_grad(
+ grad, local_step=self._local_step))
+ aggregated_grad.append(
+ grad_accum.take_indexed_slices_grad(
+ self._replicas_to_aggregate))
+
+ self._accumulator_list.append((grad_accum, var.device))
+
+ aggregated_grads_and_vars = zip(aggregated_grad, var_list)
+
+ # sync_op will be assigned to the same device as the global step.
+ with ops.device(global_step.device), ops.name_scope(''):
+ update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
+ global_step)
+
+ def _get_token_qname():
+ SyncReplicasOptimizer.sync_que_id += 1
+ if SyncReplicasOptimizer.sync_que_id == 0:
+ return 'sync_token_q'
+ else:
+ return 'sync_token_q_' + str(SyncReplicasOptimizer.sync_que_id)
+
+ # Create token queue.
+ token_qname = _get_token_qname()
+ logging.info('create sync_token_queue[%s]' % token_qname)
+ with ops.device(global_step.device), ops.name_scope(''):
+ sync_token_queue = (
+ data_flow_ops.FIFOQueue(
+ -1,
+ global_step.dtype.base_dtype,
+ shapes=(),
+ name=token_qname,
+ shared_name=token_qname))
+ self._sync_token_queue = sync_token_queue
+ self._is_sync_que_closed = sync_token_queue.is_closed()
+ self._close_sync_que = sync_token_queue.close(
+ cancel_pending_enqueues=True, name='close_sync_token_queue')
+
+ # dummy_queue is passed to the queue runner. Don't use the real queues
+ # because the queue runner doesn't automatically reopen it once it
+ # closed queues in PS devices.
+ dummy_queue = (
+ data_flow_ops.FIFOQueue(
+ 1,
+ types_pb2.DT_INT32,
+ shapes=(),
+ name='dummy_queue',
+ shared_name='dummy_queue'))
+
+ with ops.device(global_step.device), ops.name_scope(''):
+ # Replicas have to wait until they can get a token from the token queue.
+ with ops.control_dependencies(train_ops):
+ token = sync_token_queue.dequeue()
+ train_op = state_ops.assign(self._local_step, token)
+
+ with ops.control_dependencies([update_op]):
+ # Sync_op needs to insert tokens to the token queue at the end of the
+ # step so the replicas can fetch them to start the next step.
+ tokens = array_ops.fill([self._tokens_per_step], global_step)
+ sync_op = sync_token_queue.enqueue_many((tokens,))
+
+ if self._variable_averages is not None:
+ with ops.control_dependencies([sync_op]), ops.name_scope(''):
+ sync_op = self._variable_averages.apply(self._variables_to_average)
+
+ self._chief_queue_runner = queue_runner.QueueRunner(
+ dummy_queue, [sync_op])
+ ops.add_to_collection(ops.GraphKeys.QUEUE_RUNNERS,
+ self._chief_queue_runner)
+ for accum, dev in self._accumulator_list:
+ with ops.device(dev):
+ chief_init_ops.append(
+ accum.set_global_step(global_step, name='SetGlobalStep'))
+ self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
+ self._gradients_applied = True
+ return train_op
+
+ def get_chief_queue_runner(self):
+ """Returns the QueueRunner for the chief to execute.
+
+ This includes the operations to synchronize replicas: aggregate gradients,
+ apply to variables, increment global step, insert tokens to token queue.
+
+ Note that this can only be called after calling apply_gradients() which
+ actually generates this queuerunner.
+
+ Returns:
+ A `QueueRunner` for chief to execute.
+
+ Raises:
+ ValueError: If this is called before apply_gradients().
+ """
+ if self._gradients_applied is False:
+ raise ValueError('Should be called after apply_gradients().')
+
+ return self._chief_queue_runner
+
+ def get_slot(self, *args, **kwargs):
+ """Return a slot named "name" created for "var" by the Optimizer.
+
+ This simply wraps the get_slot() from the actual optimizer.
+
+ Args:
+ *args: Arguments for get_slot().
+ **kwargs: Keyword arguments for get_slot().
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ return self._opt.get_slot(*args, **kwargs)
+
+ def variables(self):
+ """Fetches a list of optimizer variables in the default graph.
+
+ This wraps `variables()` from the actual optimizer. It does not include
+ the `SyncReplicasOptimizer`'s local step.
+
+ Returns:
+ A list of variables.
+ """
+ return self._opt.variables()
+
+ def get_slot_names(self, *args, **kwargs):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ This simply wraps the get_slot_names() from the actual optimizer.
+
+ Args:
+ *args: Arguments for get_slot().
+ **kwargs: Keyword arguments for get_slot().
+
+ Returns:
+ A list of strings.
+ """
+ return self._opt.get_slot_names(*args, **kwargs)
+
+ def get_init_tokens_op(self, num_tokens=-1):
+ """Returns the op to fill the sync_token_queue with the tokens.
+
+ This is supposed to be executed in the beginning of the chief/sync thread
+ so that even if the total_num_replicas is less than replicas_to_aggregate,
+ the model can still proceed as the replicas can compute multiple steps per
+ variable update. Make sure:
+ `num_tokens >= replicas_to_aggregate - total_num_replicas`.
+
+ Args:
+ num_tokens: Number of tokens to add to the queue.
+
+ Returns:
+ An op for the chief/sync replica to fill the token queue.
+
+ Raises:
+ ValueError: If this is called before apply_gradients().
+ ValueError: If num_tokens are smaller than replicas_to_aggregate -
+ total_num_replicas.
+ """
+ if self._gradients_applied is False:
+ raise ValueError(
+ 'get_init_tokens_op() should be called after apply_gradients().')
+
+ tokens_needed = self._replicas_to_aggregate - self._total_num_replicas
+ if num_tokens == -1:
+ num_tokens = self._replicas_to_aggregate
+ elif num_tokens < tokens_needed:
+ raise ValueError(
+ 'Too few tokens to finish the first step: %d (given) vs %d (needed)' %
+ (num_tokens, tokens_needed))
+
+ if num_tokens > 0:
+ with ops.device(self._global_step.device), ops.name_scope(''):
+ tokens = array_ops.fill([num_tokens], self._global_step)
+ init_tokens = self._sync_token_queue.enqueue_many((tokens,))
+ else:
+ init_tokens = control_flow_ops.no_op(name='no_init_tokens')
+
+ return init_tokens
+
+ def make_session_run_hook(self, is_chief, num_tokens=-1):
+ """Creates a hook to handle SyncReplicasHook ops such as initialization."""
+ return _SyncReplicasOptimizerHook(self, is_chief, num_tokens)
+
+
+class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook):
+ """A SessionRunHook handles ops related to SyncReplicasOptimizer."""
+
+ def __init__(self, sync_optimizer, is_chief, num_tokens):
+ """Creates hook to handle SyncReplicasOptimizer initialization ops.
+
+ Args:
+ sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize.
+ is_chief: `Bool`, whether is this a chief replica or not.
+ num_tokens: Number of tokens to add to the queue.
+ """
+ self._sync_optimizer = sync_optimizer
+ self._is_chief = is_chief
+ self._num_tokens = num_tokens
+
+ def begin(self):
+ if self._sync_optimizer._gradients_applied is False: # pylint: disable=protected-access
+ raise ValueError(
+ 'SyncReplicasOptimizer.apply_gradient should be called before using '
+ 'the hook.')
+ if self._is_chief:
+ self._local_init_op = self._sync_optimizer.chief_init_op
+ self._ready_for_local_init_op = (
+ self._sync_optimizer.ready_for_local_init_op)
+ self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
+ self._num_tokens)
+ else:
+ self._local_init_op = self._sync_optimizer.local_step_init_op
+ self._ready_for_local_init_op = (
+ self._sync_optimizer.ready_for_local_init_op)
+ self._init_tokens_op = None
+
+ def after_create_session(self, session, coord):
+ """Runs SyncReplicasOptimizer initialization ops."""
+ local_init_success, msg = session_manager._ready( # pylint: disable=protected-access
+ self._ready_for_local_init_op, session,
+ 'Model is not ready for SyncReplicasOptimizer local init.')
+ if not local_init_success:
+ raise RuntimeError(
+ 'Init operations did not make model ready for SyncReplicasOptimizer '
+ 'local_init. Init op: %s, error: %s' %
+ (self._local_init_op.name, msg))
+ session.run(self._local_init_op)
+ is_closed = session.run(self._sync_optimizer._is_sync_que_closed)
+ assert not is_closed, 'sync_que is closed'
+ if self._init_tokens_op is not None:
+ session.run(self._init_tokens_op)
+
+ def end(self, session):
+ try:
+ is_closed = session.run(self._sync_optimizer._is_sync_que_closed)
+ if not is_closed:
+ logging.info('will close sync token que')
+ session.run(self._sync_optimizer._close_sync_que)
+ else:
+ logging.info('sync token que is closed')
+ except errors_impl.CancelledError:
+ logging.info('sync token que is closed')
diff --git a/easy_rec/python/compat/weight_decay_optimizers.py b/easy_rec/python/compat/weight_decay_optimizers.py
index d29dce5bb..47a755e0f 100755
--- a/easy_rec/python/compat/weight_decay_optimizers.py
+++ b/easy_rec/python/compat/weight_decay_optimizers.py
@@ -411,7 +411,7 @@ def __init__(self,
try:
- from tensorflow.python.training import AdamAsyncOptimizer
+ from tensorflow.train import AdamAsyncOptimizer
@tf_export('contrib.opt.AdamAsyncWOptimizer')
class AdamAsyncWOptimizer(DecoupledWeightDecayExtension, AdamAsyncOptimizer):
diff --git a/easy_rec/python/core/distribute_metrics.py b/easy_rec/python/core/distribute_metrics.py
deleted file mode 100644
index a5110c33f..000000000
--- a/easy_rec/python/core/distribute_metrics.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# -*- encoding:utf-8 -*-
-# Copyright (c) Alibaba, Inc. and its affiliates.
-from collections import defaultdict
-
-import numpy as np
-import tensorflow as tf
-from sklearn import metrics as sklearn_metrics
-
-if tf.__version__ >= '2.0':
- tf = tf.compat.v1
-
-
-def max_f1(label, predictions):
- """Calculate the largest F1 metric under different thresholds.
-
- Args:
- label: Ground truth (correct) target values.
- predictions: Estimated targets as returned by a model.
- """
- num_thresholds = 200
- kepsilon = 1e-7
- thresholds = [
- (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
- ]
- thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
-
- f1_scores = []
- precision_update_ops = []
- recall_update_ops = []
- for threshold in thresholds:
- pred = (predictions > threshold)
- precision, precision_update_op = tf.metrics.precision(
- labels=label, predictions=pred, name='precision_%s' % threshold)
- recall, recall_update_op = tf.metrics.recall(
- labels=label, predictions=pred, name='recall_%s' % threshold)
- f1_score = (2 * precision * recall) / (precision + recall + 1e-12)
- precision_update_ops.append(precision_update_op)
- recall_update_ops.append(recall_update_op)
- f1_scores.append(f1_score)
-
- f1 = tf.math.reduce_max(tf.stack(f1_scores))
- f1_update_op = tf.group(precision_update_ops + recall_update_ops)
- return f1, f1_update_op
-
-
-def _separated_auc_impl(labels, predictions, keys, reduction='mean'):
- """Computes the AUC group by the key separately.
-
- Args:
- labels: A `Tensor` whose shape matches `predictions`. Will be cast to
- `bool`.
- predictions: A floating point `Tensor` of arbitrary shape and whose values
- are in the range `[0, 1]`.
- keys: keys to be group by, A int or string `Tensor` whose shape matches `predictions`.
- reduction: reduction metric for auc of different keys
- * "mean": simple mean of different keys
- * "mean_by_sample_num": weighted mean with sample num of different keys
- * "mean_by_positive_num": weighted mean with positive sample num of different keys
- """
- assert reduction in ['mean', 'mean_by_sample_num', 'mean_by_positive_num'], \
- 'reduction method must in mean | mean_by_sample_num | mean_by_positive_num'
- separated_label = defaultdict(list)
- separated_prediction = defaultdict(list)
- separated_weights = defaultdict(int)
-
- def update_pyfunc(labels, predictions, keys):
- for label, prediction, key in zip(labels, predictions, keys):
- separated_label[key].append(label)
- separated_prediction[key].append(prediction)
- if reduction == 'mean':
- separated_weights[key] = 1
- elif reduction == 'mean_by_sample_num':
- separated_weights[key] += 1
- elif reduction == 'mean_by_positive_num':
- separated_weights[key] += label
-
- def value_pyfunc():
- metrics = []
- weights = []
- for key in separated_label.keys():
- per_label = np.asarray(separated_label[key])
- per_prediction = np.asarray(separated_prediction[key])
- if np.all(per_label == 1) or np.all(per_label == 0):
- continue
- metric = sklearn_metrics.roc_auc_score(per_label, per_prediction)
- metrics.append(metric)
- weights.append(separated_weights[key])
- if len(metrics) > 0:
- return np.average(metrics, weights=weights).astype(np.float32)
- else:
- return np.float32(0.0)
-
- update_op = tf.py_func(update_pyfunc, [labels, predictions, keys], [])
- value_op = tf.py_func(value_pyfunc, [], tf.float32)
- return value_op, update_op
-
-
-def gauc(labels, predictions, uids, reduction='mean'):
- """Computes the AUC group by user separately.
-
- Args:
- labels: A `Tensor` whose shape matches `predictions`. Will be cast to
- `bool`.
- predictions: A floating point `Tensor` of arbitrary shape and whose values
- are in the range `[0, 1]`.
- uids: user ids, A int or string `Tensor` whose shape matches `predictions`.
- reduction: reduction method for auc of different users
- * "mean": simple mean of different users
- * "mean_by_sample_num": weighted mean with sample num of different users
- * "mean_by_positive_num": weighted mean with positive sample num of different users
- """
- return _separated_auc_impl(labels, predictions, uids, reduction)
-
-
-def session_auc(labels, predictions, session_ids, reduction='mean'):
- """Computes the AUC group by session separately.
-
- Args:
- labels: A `Tensor` whose shape matches `predictions`. Will be cast to
- `bool`.
- predictions: A floating point `Tensor` of arbitrary shape and whose values
- are in the range `[0, 1]`.
- session_ids: session ids, A int or string `Tensor` whose shape matches `predictions`.
- reduction: reduction method for auc of different sessions
- * "mean": simple mean of different sessions
- * "mean_by_sample_num": weighted mean with sample num of different sessions
- * "mean_by_positive_num": weighted mean with positive sample num of different sessions
- """
- return _separated_auc_impl(labels, predictions, session_ids, reduction)
diff --git a/easy_rec/python/core/easyrec_metrics/__init__.py b/easy_rec/python/core/easyrec_metrics/__init__.py
new file mode 100644
index 000000000..cba3ebc08
--- /dev/null
+++ b/easy_rec/python/core/easyrec_metrics/__init__.py
@@ -0,0 +1,24 @@
+import logging
+import os
+
+import tensorflow as tf
+
+from easy_rec.python.utils import pai_util
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+distribute_eval = os.environ.get('distribute_eval')
+logging.info('distribute_eval = {}'.format(distribute_eval))
+if distribute_eval == 'True':
+ if pai_util.is_on_pai() or tf.__version__ <= '1.13':
+ logging.info('Will use distribute pai_tf metrics impl')
+ from easy_rec.python.core.easyrec_metrics import distribute_metrics_impl_pai as metrics_tf
+ else:
+ logging.info('Will use distribute tf metrics impl')
+ from easy_rec.python.core.easyrec_metrics import distribute_metrics_impl_tf as metrics_tf
+else:
+ if tf.__version__ >= '2.0':
+ from tensorflow.compat.v1 import metrics as metrics_tf
+ else:
+ from tensorflow import metrics as metrics_tf
diff --git a/easy_rec/python/core/metrics_impl_pai.py b/easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py
similarity index 99%
rename from easy_rec/python/core/metrics_impl_pai.py
rename to easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py
index addf8e961..ef6e10f86 100644
--- a/easy_rec/python/core/metrics_impl_pai.py
+++ b/easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py
@@ -18,7 +18,6 @@
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -598,10 +597,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'tp' in includes:
true_p = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='true_positives')
- is_true_positive = math_ops.to_int64(
- math_ops.logical_and(label_is_pos, pred_is_pos))
+ is_true_positive = math_ops.cast(
+ math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
if weights_tiled is not None:
is_true_positive *= weights_tiled
update_ops['tp'] = state_ops.assign_add(
@@ -610,10 +609,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'fn' in includes:
false_n = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='false_negatives')
- is_false_negative = math_ops.to_int64(
- math_ops.logical_and(label_is_pos, pred_is_neg))
+ is_false_negative = math_ops.cast(
+ math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
if weights_tiled is not None:
is_false_negative *= weights_tiled
update_ops['fn'] = state_ops.assign_add(
@@ -622,10 +621,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'tn' in includes:
true_n = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='true_negatives')
- is_true_negative = math_ops.to_int64(
- math_ops.logical_and(label_is_neg, pred_is_neg))
+ is_true_negative = math_ops.cast(
+ math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
if weights_tiled is not None:
is_true_negative *= weights_tiled
update_ops['tn'] = state_ops.assign_add(
@@ -634,10 +633,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'fp' in includes:
false_p = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='false_positives')
- is_false_positive = math_ops.to_int64(
- math_ops.logical_and(label_is_neg, pred_is_pos))
+ is_false_positive = math_ops.cast(
+ math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
if weights_tiled is not None:
is_false_positive *= weights_tiled
update_ops['fp'] = state_ops.assign_add(
@@ -731,7 +730,8 @@ def auc(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- print('use_pai_auc')
+ print('use_distribute_pai_auc')
+ logging.info('use_distribute_pai_auc')
if context.executing_eagerly():
raise RuntimeError('tf.metrics.auc is not supported when eager execution '
'is enabled.')
@@ -800,10 +800,6 @@ def interpolate_pr_auc(tp, fp, fn):
def compute_auc(tp, fn, tn, fp, name):
"""Computes the roc-auc or pr-auc based on confusion counts."""
- tp = tf.cast(tp, dtype=tf.float64)
- fn = tf.cast(fn, dtype=tf.float64)
- tn = tf.cast(tn, dtype=tf.float64)
- fp = tf.cast(fp, dtype=tf.float64)
if curve == 'PR':
if summation_method == 'trapezoidal':
logging.warning(
diff --git a/easy_rec/python/core/metrics_impl_tf.py b/easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py
similarity index 99%
rename from easy_rec/python/core/metrics_impl_tf.py
rename to easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py
index 920c85f34..1756826be 100644
--- a/easy_rec/python/core/metrics_impl_tf.py
+++ b/easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py
@@ -17,7 +17,6 @@
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
@@ -580,10 +579,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'tp' in includes:
true_p = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='true_positives')
is_true_positive = math_ops.cast(
- math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.int64)
+ math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
if weights_tiled is not None:
is_true_positive *= weights_tiled
update_ops['tp'] = state_ops.assign_add(
@@ -592,10 +591,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'fn' in includes:
false_n = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='false_negatives')
is_false_negative = math_ops.cast(
- math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.int64)
+ math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
if weights_tiled is not None:
is_false_negative *= weights_tiled
update_ops['fn'] = state_ops.assign_add(
@@ -604,10 +603,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'tn' in includes:
true_n = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='true_negatives')
is_true_negative = math_ops.cast(
- math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.int64)
+ math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
if weights_tiled is not None:
is_true_negative *= weights_tiled
update_ops['tn'] = state_ops.assign_add(
@@ -616,10 +615,10 @@ def _confusion_matrix_at_thresholds(labels,
if 'fp' in includes:
false_p = metric_variable([num_thresholds],
- dtypes.int64,
+ dtypes.float32,
name='false_positives')
is_false_positive = math_ops.cast(
- math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.int64)
+ math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
if weights_tiled is not None:
is_false_positive *= weights_tiled
update_ops['fp'] = state_ops.assign_add(
@@ -807,10 +806,6 @@ def interpolate_pr_auc(tp, fp, fn):
def compute_auc(tp, fn, tn, fp, name):
"""Computes the roc-auc or pr-auc based on confusion counts."""
- tp = tf.cast(tp, dtype=tf.float64)
- fp = tf.cast(fn, dtype=tf.float64)
- fn = tf.cast(fn, dtype=tf.float64)
- tn = tf.cast(tn, dtype=tf.float64)
if curve == 'PR':
if summation_method == 'trapezoidal':
logging.warning(
diff --git a/easy_rec/python/core/metrics.py b/easy_rec/python/core/metrics.py
index 1df6cc844..bd7cb0976 100644
--- a/easy_rec/python/core/metrics.py
+++ b/easy_rec/python/core/metrics.py
@@ -1,11 +1,21 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
+import logging
+import os
from collections import defaultdict
import numpy as np
import tensorflow as tf
from sklearn import metrics as sklearn_metrics
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from easy_rec.python.utils.estimator_utils import get_task_index_and_num
+from easy_rec.python.utils.io_util import read_data_from_json_path
+from easy_rec.python.utils.io_util import save_data_to_json_path
from easy_rec.python.utils.shape_utils import get_shape_list
if tf.__version__ >= '2.0':
@@ -19,6 +29,7 @@ def max_f1(label, predictions):
label: Ground truth (correct) target values.
predictions: Estimated targets as returned by a model.
"""
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
num_thresholds = 200
kepsilon = 1e-7
thresholds = [
@@ -31,9 +42,9 @@ def max_f1(label, predictions):
recall_update_ops = []
for threshold in thresholds:
pred = (predictions > threshold)
- precision, precision_update_op = tf.metrics.precision(
+ precision, precision_update_op = metrics_tf.precision(
labels=label, predictions=pred, name='precision_%s' % threshold)
- recall, recall_update_op = tf.metrics.recall(
+ recall, recall_update_op = metrics_tf.recall(
labels=label, predictions=pred, name='recall_%s' % threshold)
f1_score = (2 * precision * recall) / (precision + recall + 1e-12)
precision_update_ops.append(precision_update_op)
@@ -97,6 +108,155 @@ def value_pyfunc():
return value_op, update_op
+def fast_auc(labels, predictions, name, num_thresholds=1e5):
+ num_thresholds = int(num_thresholds)
+
+ def value_pyfunc(pos_neg_arr, total_pos_neg):
+ partial_sum_pos = 0
+ auc = 0
+ total_neg = total_pos_neg[0]
+ total_pos = total_pos_neg[1]
+ for i in range(num_thresholds + 1):
+ partial_sum_pos += pos_neg_arr[1][i]
+ auc += (total_pos - partial_sum_pos) * pos_neg_arr[0][i] * 2
+ auc += pos_neg_arr[0][i] * pos_neg_arr[1][i]
+ auc = np.double(auc) / np.double(total_pos * total_neg * 2)
+ logging.info('fast_auc[%s]: total_pos=%d total_neg=%d total=%d' %
+ (name, total_pos, total_neg, total_pos + total_neg))
+ return np.float32(auc)
+
+ with variable_scope.variable_scope(name_or_scope=name), tf.name_scope(name):
+ neg_pos_var = variable_scope.get_variable(
+ name='neg_pos_cnt',
+ shape=[2, num_thresholds + 1],
+ trainable=False,
+ collections=[tf.GraphKeys.METRIC_VARIABLES],
+ initializer=tf.zeros_initializer(),
+ dtype=tf.int64)
+ total_var = variable_scope.get_variable(
+ name='total_cnt',
+ shape=[2],
+ trainable=False,
+ collections=[tf.GraphKeys.METRIC_VARIABLES],
+ initializer=tf.zeros_initializer(),
+ dtype=tf.int64)
+ pred_bins = math_ops.cast(predictions * num_thresholds, dtype=tf.int32)
+ labels = math_ops.cast(labels, dtype=tf.int32)
+ labels = array_ops.reshape(labels, [-1, 1])
+ pred_bins = array_ops.reshape(pred_bins, [-1, 1])
+ update_op0 = state_ops.scatter_nd_add(
+ neg_pos_var, tf.concat([labels, pred_bins], axis=1),
+ array_ops.ones(tf.shape(labels)[0], dtype=tf.int64))
+ total_pos = math_ops.reduce_sum(labels)
+ total_neg = array_ops.shape(labels)[0] - total_pos
+ total_add = math_ops.cast(tf.stack([total_neg, total_pos]), dtype=tf.int64)
+ update_op1 = state_ops.assign_add(total_var, total_add)
+ return tf.py_func(value_pyfunc, [neg_pos_var, total_var],
+ tf.float32), tf.group([update_op0, update_op1])
+
+
+def _distribute_separated_auc_impl(labels,
+ predictions,
+ keys,
+ reduction='mean',
+ metric_name='sepatated_auc'):
+ """Computes the AUC group by the key separately.
+
+ Args:
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
+ `bool`.
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
+ are in the range `[0, 1]`.
+ keys: keys to be group by, A int or string `Tensor` whose shape matches `predictions`.
+ reduction: reduction metric for auc of different keys
+ metric_name: the name of compute metric
+ * "mean": simple mean of different keys
+ * "mean_by_sample_num": weighted mean with sample num of different keys
+ * "mean_by_positive_num": weighted mean with positive sample num of different keys
+ """
+ assert reduction in ['mean', 'mean_by_sample_num', 'mean_by_positive_num'], \
+ 'reduction method must in mean | mean_by_sample_num | mean_by_positive_num'
+ separated_label = defaultdict(list)
+ separated_prediction = defaultdict(list)
+ separated_weights = defaultdict(int)
+ tf_config = json.loads(os.environ['TF_CONFIG'])
+ cur_job_name = tf_config['task']['type']
+ cur_task_index, task_num = get_task_index_and_num()
+ cur_work_device = 'job_' + cur_job_name + '__' + 'task_' + str(cur_task_index)
+ eval_tmp_results_dir = os.environ['eval_tmp_results_dir']
+ assert tf.gfile.IsDirectory(
+ eval_tmp_results_dir), 'eval_tmp_results_dir not exists'
+
+ def update_pyfunc(labels, predictions, keys):
+ for label, prediction, key in zip(labels, predictions, keys):
+ key = str(key)
+ separated_label[key].append(label.item())
+ separated_prediction[key].append(prediction.item())
+ if reduction == 'mean':
+ separated_weights[key] = 1
+ elif reduction == 'mean_by_sample_num':
+ separated_weights[key] += 1
+ elif reduction == 'mean_by_positive_num':
+ separated_weights[key] += label.item()
+ for name, data in zip(
+ ['separated_label', 'separated_prediction', 'separated_weights'],
+ [separated_label, separated_prediction, separated_weights]):
+ cur_json_name = metric_name + '__' + cur_work_device + '__' + name + '.json'
+ cur_json_path = os.path.join(eval_tmp_results_dir, cur_json_name)
+ save_data_to_json_path(cur_json_path, data)
+
+ def value_pyfunc():
+ for task_i in range(1, task_num):
+ work_device_i = 'job_worker__task_' + str(task_i)
+ for name in [
+ 'separated_label', 'separated_prediction', 'separated_weights'
+ ]:
+ json_name_i = metric_name + '__' + work_device_i + '__' + name + '.json'
+ json_path_i = os.path.join(eval_tmp_results_dir, json_name_i)
+ data_i = read_data_from_json_path(json_path_i)
+ if (name == 'separated_label'):
+ separated_label.update({
+ key: separated_label.get(key, []) + data_i.get(key, [])
+ for key in set(
+ list(separated_label.keys()) + list(data_i.keys()))
+ })
+ elif (name == 'separated_prediction'):
+ separated_prediction.update({
+ key: separated_prediction.get(key, []) + data_i.get(key, [])
+ for key in set(
+ list(separated_prediction.keys()) + list(data_i.keys()))
+ })
+ elif (name == 'separated_weights'):
+ if reduction == 'mean':
+ separated_weights.update(data_i)
+ else:
+ separated_weights.update({
+ key: separated_weights.get(key, 0) + data_i.get(key, 0)
+ for key in set(
+ list(separated_weights.keys()) + list(data_i.keys()))
+ })
+ else:
+ assert False, 'Not supported name {}'.format(name)
+ metrics = []
+ weights = []
+ for key in separated_label.keys():
+ per_label = np.asarray(separated_label[key]).reshape([-1])
+ per_prediction = np.asarray(separated_prediction[key]).reshape([-1])
+ if np.all(per_label == 1) or np.all(per_label == 0):
+ continue
+ metric = sklearn_metrics.roc_auc_score(per_label, per_prediction)
+ metrics.append(metric)
+ weights.append(separated_weights[key])
+ if len(metrics) > 0:
+ return np.average(metrics, weights=weights).astype(np.float32)
+ else:
+ return np.float32(0.0)
+
+ update_op = tf.py_func(update_pyfunc, [labels, predictions, keys], [])
+ value_op = tf.py_func(value_pyfunc, [], tf.float32)
+ return value_op, update_op
+
+
def gauc(labels, predictions, uids, reduction='mean'):
"""Computes the AUC group by user separately.
@@ -111,6 +271,9 @@ def gauc(labels, predictions, uids, reduction='mean'):
* "mean_by_sample_num": weighted mean with sample num of different users
* "mean_by_positive_num": weighted mean with positive sample num of different users
"""
+ if os.environ.get('distribute_eval') == 'True':
+ return _distribute_separated_auc_impl(
+ labels, predictions, uids, reduction, metric_name='gauc')
return _separated_auc_impl(labels, predictions, uids, reduction)
@@ -128,6 +291,9 @@ def session_auc(labels, predictions, session_ids, reduction='mean'):
* "mean_by_sample_num": weighted mean with sample num of different sessions
* "mean_by_positive_num": weighted mean with positive sample num of different sessions
"""
+ if os.environ.get('distribute_eval') == 'True':
+ return _distribute_separated_auc_impl(
+ labels, predictions, session_ids, reduction, metric_name='session_auc')
return _separated_auc_impl(labels, predictions, session_ids, reduction)
@@ -145,6 +311,7 @@ def metric_learning_recall_at_k(k,
session_ids: session ids, a `Tensor` with shape [batch_size]
embed_normed: indicator of whether the input embeddings are l2_normalized
"""
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
# make sure embedding should be l2-normalized
if not embed_normed:
embeddings = tf.nn.l2_normalize(embeddings, axis=1)
@@ -160,11 +327,12 @@ def metric_learning_recall_at_k(k,
tf.expand_dims(session_ids, 0), tf.expand_dims(session_ids, 1))
labels_equal = tf.logical_and(sessions_equal, labels_equal)
mask = tf.logical_and(indices_not_equal, labels_equal)
- mask_pos = tf.where(mask, sim_mat,
- -tf.ones_like(sim_mat)) # shape: (batch_size, batch_size)
+ mask_pos = tf.where(
+ mask, sim_mat,
+ -array_ops.ones_like(sim_mat)) # shape: (batch_size, batch_size)
if isinstance(k, int):
_, pos_top_k_idx = tf.nn.top_k(mask_pos, k) # shape: (batch_size, k)
- return tf.metrics.recall_at_k(
+ return metrics_tf.recall_at_k(
labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=k)
if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
metrics = {}
@@ -172,7 +340,7 @@ def metric_learning_recall_at_k(k,
if kk < 1:
continue
_, pos_top_k_idx = tf.nn.top_k(mask_pos, kk)
- metrics['recall@' + str(kk)] = tf.metrics.recall_at_k(
+ metrics['recall@' + str(kk)] = metrics_tf.recall_at_k(
labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=kk)
return metrics
else:
@@ -184,6 +352,7 @@ def metric_learning_average_precision_at_k(k,
labels,
session_ids=None,
embed_normed=False):
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
# make sure embedding should be l2-normalized
if not embed_normed:
embeddings = tf.nn.l2_normalize(embeddings, axis=1)
@@ -198,13 +367,13 @@ def metric_learning_average_precision_at_k(k,
mask = tf.logical_and(sessions_equal, mask)
label_indices = _get_matrix_mask_indices(mask)
if isinstance(k, int):
- return tf.metrics.average_precision_at_k(label_indices, sim_mat, k)
+ return metrics_tf.average_precision_at_k(label_indices, sim_mat, k)
if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
metrics = {}
for kk in k:
if kk < 1:
continue
- metrics['MAP@' + str(kk)] = tf.metrics.average_precision_at_k(
+ metrics['MAP@' + str(kk)] = metrics_tf.average_precision_at_k(
label_indices, sim_mat, kk)
return metrics
else:
@@ -226,7 +395,7 @@ def _get_matrix_mask_indices(matrix, num_rows=None):
result = tf.gather(indices[:, 1], idx)
# replace invalid elements with -1
result = tf.where(
- tf.expand_dims(elem_per_row, 1) > r, result, -tf.ones_like(result))
+ tf.expand_dims(elem_per_row, 1) > r, result, -array_ops.ones_like(result))
max_index_per_row = tf.reduce_max(result, axis=1, keepdims=True)
max_index_per_row = tf.tile(max_index_per_row, [1, max_elem_per_row])
result = tf.where(result >= 0, result, max_index_per_row)
diff --git a/easy_rec/python/core/sampler.py b/easy_rec/python/core/sampler.py
index ca3a8f15d..dcdcd44f0 100644
--- a/easy_rec/python/core/sampler.py
+++ b/easy_rec/python/core/sampler.py
@@ -7,15 +7,43 @@
import logging
import math
import os
+import sys
import threading
import numpy as np
+import six
import tensorflow as tf
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import ds_util
+from easy_rec.python.utils.config_util import process_multi_file_input_path
+from easy_rec.python.utils.tf_utils import get_tf_type
+
+if tf.__version__.startswith('1.'):
+ from tensorflow.python.platform import gfile
+else:
+ import tensorflow.io.gfile as gfile
+
+
+# patch graph-learn string_attrs for utf-8
+@property
+def string_attrs(self): # NOQA
+ self._init()
+ return self._string_attrs
+
+
+# pyre-ignore [56]
+@string_attrs.setter
+# pyre-ignore [2, 3]
+def string_attrs(self, string_attrs): # NOQA
+ self._string_attrs = self._reshape(string_attrs, expand_shape=True)
+ self._inited = True
+
try:
import graphlearn as gl
+ from graphlearn.python.data.values import Values
+ Values.string_attrs = string_attrs
except Exception:
logging.info(
'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl"' # noqa: E501
@@ -42,8 +70,8 @@ def _get_np_type(field_type):
type_map = {
DatasetConfig.INT32: np.int32,
DatasetConfig.INT64: np.int64,
- DatasetConfig.STRING: np.str,
- DatasetConfig.BOOL: np.bool,
+ DatasetConfig.STRING: str,
+ DatasetConfig.BOOL: bool,
DatasetConfig.FLOAT: np.float32,
DatasetConfig.DOUBLE: np.double
}
@@ -51,29 +79,22 @@ def _get_np_type(field_type):
return type_map[field_type]
-def _get_tf_type(field_type):
- type_map = {
- DatasetConfig.INT32: tf.int32,
- DatasetConfig.INT64: tf.int64,
- DatasetConfig.STRING: tf.string,
- DatasetConfig.BOOL: tf.bool,
- DatasetConfig.FLOAT: tf.float32,
- DatasetConfig.DOUBLE: tf.double
- }
- assert field_type in type_map, 'invalid type: %s' % field_type
- return type_map[field_type]
-
-
class BaseSampler(object):
_instance_lock = threading.Lock()
def __init__(self, fields, num_sample, num_eval_sample=None):
self._g = None
self._sampler = None
- # TODO(hongsheng.jhs): check eval mode or not?
self._num_sample = num_sample
- self._num_eval_sample = num_eval_sample if num_eval_sample else num_sample
+ self._num_eval_sample = num_eval_sample if num_eval_sample is not None else num_sample
self._build_field_types(fields)
+ self._log_first_n = 5
+ self._is_on_ds = ds_util.is_on_ds()
+
+ def set_eval_num_sample(self):
+ print('set_eval_num_sample: %d %d' %
+ (self._num_sample, self._num_eval_sample))
+ self._num_sample = self._num_eval_sample
def _init_graph(self):
if 'TF_CONFIG' in os.environ:
@@ -81,9 +102,23 @@ def _init_graph(self):
if 'ps' in tf_config['cluster']:
# ps mode
tf_config = json.loads(os.environ['TF_CONFIG'])
- ps_count = len(tf_config['cluster']['ps'])
- task_count = len(tf_config['cluster']['worker']) + 2
- cluster = {'server_count': ps_count, 'client_count': task_count}
+ if 'worker' in tf_config['cluster']:
+ task_count = len(tf_config['cluster']['worker']) + 2
+ else:
+ task_count = 2
+ if self._is_on_ds:
+ gl.set_tracker_mode(0)
+ server_hosts = [
+ host.split(':')[0] + ':888' + str(i)
+ for i, host in enumerate(tf_config['cluster']['ps'])
+ ]
+ cluster = {
+ 'server': ','.join(server_hosts),
+ 'client_count': task_count
+ }
+ else:
+ ps_count = len(tf_config['cluster']['ps'])
+ cluster = {'server_count': ps_count, 'client_count': task_count}
if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(cluster=cluster, job_name='client', task_index=0)
elif tf_config['task']['type'] == 'worker':
@@ -107,11 +142,37 @@ def _init_graph(self):
else:
# worker mode
task_count = len(tf_config['cluster']['worker']) + 1
- if tf_config['task']['type'] in ['chief', 'master']:
- self._g.init(task_index=0, task_count=task_count)
- elif tf_config['task']['type'] == 'worker':
- self._g.init(
- task_index=tf_config['task']['index'] + 1, task_count=task_count)
+ if not self._is_on_ds:
+ if tf_config['task']['type'] in ['chief', 'master']:
+ self._g.init(task_index=0, task_count=task_count)
+ elif tf_config['task']['type'] == 'worker':
+ self._g.init(
+ task_index=tf_config['task']['index'] + 1,
+ task_count=task_count)
+ else:
+ gl.set_tracker_mode(0)
+ if tf_config['cluster'].get('chief', ''):
+ chief_host = tf_config['cluster']['chief'][0].split(
+ ':')[0] + ':8880'
+ else:
+ chief_host = tf_config['cluster']['master'][0].split(
+ ':')[0] + ':8880'
+ worker_hosts = chief_host + [
+ host.split(':')[0] + ':888' + str(i)
+ for i, host in enumerate(tf_config['cluster']['worker'])
+ ]
+
+ if tf_config['task']['type'] in ['chief', 'master']:
+ self._g.init(
+ task_index=0,
+ task_count=task_count,
+ hosts=','.join(worker_hosts))
+ elif tf_config['task']['type'] == 'worker':
+ self._g.init(
+ task_index=tf_config['task']['index'] + 1,
+ task_count=task_count,
+ hosts=worker_hosts)
+
# TODO(hongsheng.jhs): check cluster has evaluator or not?
else:
# local mode
@@ -128,7 +189,7 @@ def _build_field_types(self, fields):
self._attr_types.append(field.input_type)
self._attr_gl_types.append(_get_gl_type(field.input_type))
self._attr_np_types.append(_get_np_type(field.input_type))
- self._attr_tf_types.append(_get_tf_type(field.input_type))
+ self._attr_tf_types.append(get_tf_type(field.input_type))
@classmethod
def instance(cls, *args, **kwargs):
@@ -138,9 +199,14 @@ def instance(cls, *args, **kwargs):
return cls._instance
def __del__(self):
- self._g.close()
+ if self._g is not None:
+ self._g.close()
def _parse_nodes(self, nodes):
+ if self._log_first_n > 0:
+ logging.info('num_example=%d num_eval_example=%d node_num=%d' %
+ (self._num_sample, self._num_eval_sample, len(nodes.ids)))
+ self._log_first_n -= 1
features = []
int_idx = 0
float_idx = 0
@@ -155,11 +221,15 @@ def _parse_nodes(self, nodes):
float_idx += 1
elif attr_gl_type == 'string':
feature = nodes.string_attrs[:, :, string_idx]
+ if int(sys.version_info[0]) == 3:
+ feature = np.char.decode(feature.astype(np.string_), 'utf-8')
string_idx += 1
else:
raise ValueError('Unknown attr type %s' % attr_gl_type)
feature = np.reshape(feature,
[-1])[:self._num_sample].astype(attr_np_type)
+ if attr_gl_type == 'string':
+ feature = feature.tolist()
features.append(feature)
return features
@@ -182,6 +252,8 @@ def _parse_sparse_nodes(self, nodes):
else:
raise ValueError('Unknown attr type %s' % attr_gl_type)
feature = feature.astype(attr_np_type)
+ if attr_gl_type == 'string':
+ feature = feature.tolist()
features.append(feature)
return features, nodes.indices
@@ -223,9 +295,8 @@ def __init__(self,
'item', expand_factor, strategy='node_weight')
def _get_impl(self, ids):
- # assert len(ids) == self._batch_size
- # tf.logging.info("ids: %s", len(ids))
ids = np.array(ids, dtype=np.int64)
+ ids = np.pad(ids, (0, self._batch_size - len(ids)), 'edge')
nodes = self._sampler.get(ids)
features = self._parse_nodes(nodes)
return features
@@ -242,14 +313,165 @@ def get(self, ids):
sampled_values = tf.py_func(self._get_impl, [ids], self._attr_tf_types)
result_dict = {}
for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
- if t == tf.string:
- # string convert from np array to tensor will be padded with \000, we need remove it
- v = tf.regex_replace(v, '\000', '')
v.set_shape([self._num_sample])
result_dict[k] = v
return result_dict
+class NegativeSamplerInMemory(BaseSampler):
+ """Negative Sampler.
+
+ Weighted random sampling items not in batch.
+
+ Args:
+ data_path: item feature data path. id:int64 | weight:float | attrs:string.
+ fields: item input fields.
+ num_sample: number of negative samples.
+ batch_size: mini-batch size.
+ attr_delimiter: delimiter of feature string.
+ num_eval_sample: number of negative samples for evaluator.
+ """
+
+ def __init__(self,
+ data_path,
+ fields,
+ num_sample,
+ batch_size,
+ attr_delimiter=':',
+ num_eval_sample=None):
+ super(NegativeSamplerInMemory, self).__init__(fields, num_sample,
+ num_eval_sample)
+ self._batch_size = batch_size
+
+ self._item_ids = []
+ self._cols = [[] for x in fields]
+
+ if six.PY2 and isinstance(attr_delimiter, type(u'')):
+ attr_delimiter = attr_delimiter.encode('utf-8')
+ if data_path.startswith('odps://'):
+ self._load_table(data_path, attr_delimiter)
+ else:
+ self._load_data(data_path, attr_delimiter)
+
+ print('NegativeSamplerInMemory: total_row_num = %d' % len(self._cols[0]))
+ for col_id in range(len(self._attr_np_types)):
+ np_type = self._attr_np_types[col_id]
+ print('\tcol_id[%d], dtype=%s' % (col_id, self._attr_gl_types[col_id]))
+ if np_type != str:
+ self._cols[col_id] = np.array(self._cols[col_id], dtype=np_type)
+ else:
+ self._cols[col_id] = np.asarray(
+ self._cols[col_id], order='C', dtype=object)
+
+ def _load_table(self, data_path, attr_delimiter):
+ import common_io
+ reader = common_io.table.TableReader(data_path)
+ schema = reader.get_schema()
+ item_id_col = 0
+ fea_id_col = 2
+ for tid in range(len(schema)):
+ if schema[tid][0].startswith('feature'):
+ fea_id_col = tid
+ break
+ for tid in range(len(schema)):
+ if schema[tid][0].startswith('id'):
+ item_id_col = tid
+ break
+ print('NegativeSamplerInMemory: feature_id_col = %d, item_id_col = %d' %
+ (fea_id_col, item_id_col))
+ while True:
+ try:
+ row_arr = reader.read(num_records=1024, allow_smaller_final_batch=True)
+ for row in row_arr:
+ # item_id, weight, feature
+ self._item_ids.append(int(row[item_id_col]))
+ col_vals = row[fea_id_col].split(attr_delimiter)
+ assert len(col_vals) == len(
+ self._cols), 'invalid row[%d %d]: %s %s' % (len(
+ col_vals), len(self._cols), row[item_id_col], row[fea_id_col])
+ for col_id in range(len(col_vals)):
+ self._cols[col_id].append(col_vals[col_id])
+ except common_io.exception.OutOfRangeException:
+ reader.close()
+ break
+
+ def _load_data(self, data_path, attr_delimiter):
+ item_id_col = 0
+ fea_id_col = 2
+ print('NegativeSamplerInMemory: load sample feature from %s' % data_path)
+ with gfile.GFile(data_path, 'r') as fin:
+ for line_id, line_str in enumerate(fin):
+ line_str = line_str.strip()
+ cols = line_str.split('\t')
+ if line_id == 0:
+ schema = [x.split(':') for x in cols]
+ for tid in range(len(schema)):
+ if schema[tid][0].startswith('id'):
+ item_id_col = tid
+ if schema[tid][0].startswith('feature'):
+ fea_id_col = tid
+ print('feature_id_col = %d, item_id_col = %d' %
+ (fea_id_col, item_id_col))
+ else:
+ self._item_ids.append(int(cols[item_id_col]))
+ fea_vals = cols[fea_id_col].split(attr_delimiter)
+ assert len(fea_vals) == len(
+ self._cols), 'invalid row[%d][%d %d]:%s %s' % (
+ line_id, len(fea_vals), len(
+ self._cols), cols[item_id_col], cols[fea_id_col])
+ for col_id in range(len(fea_vals)):
+ self._cols[col_id].append(fea_vals[col_id])
+
+ def _get_impl(self, ids):
+ features = []
+ if type(ids[0]) != int:
+ ids = [int(x) for x in ids]
+ assert self._num_sample > 0, 'invalid num_sample: %d' % self._num_sample
+
+ indices = np.random.choice(
+ len(self._item_ids),
+ size=self._num_sample + self._batch_size,
+ replace=False)
+
+ sel_ids = []
+ for tid in indices:
+ rid = self._item_ids[tid]
+ if rid not in ids:
+ sel_ids.append(tid)
+ if len(sel_ids) >= self._num_sample and self._num_sample > 0:
+ break
+
+ features = []
+ for col_id in range(len(self._cols)):
+ tmp_col = self._cols[col_id]
+ np_type = self._attr_np_types[col_id]
+ if np_type != str:
+ sel_feas = tmp_col[sel_ids]
+ features.append(sel_feas)
+ else:
+ features.append(
+ np.asarray([tmp_col[x] for x in sel_ids], order='C', dtype=object))
+ return features
+
+ def get(self, ids):
+ """Sampling method.
+
+ Args:
+ ids: item id tensor.
+
+ Returns:
+ Negative sampled feature dict.
+ """
+ all_attr_types = list(self._attr_tf_types)
+ if self._num_sample <= 0:
+ all_attr_types.append(tf.float32)
+ sampled_values = tf.py_func(self._get_impl, [ids], all_attr_types)
+ result_dict = {}
+ for k, v in zip(self._attr_names, sampled_values):
+ result_dict[k] = v
+ return result_dict
+
+
class NegativeSamplerV2(BaseSampler):
"""Negative Sampler V2.
@@ -298,7 +520,9 @@ def __init__(self,
def _get_impl(self, src_ids, dst_ids):
src_ids = np.array(src_ids, dtype=np.int64)
+ src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), 'edge')
dst_ids = np.array(dst_ids, dtype=np.int64)
+ dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
nodes = self._sampler.get(src_ids, dst_ids)
features = self._parse_nodes(nodes)
return features
@@ -317,9 +541,6 @@ def get(self, src_ids, dst_ids):
self._attr_tf_types)
result_dict = {}
for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
- if t == tf.string:
- # string convert from np array to tensor will be padded with \000, we need remove it
- v = tf.regex_replace(v, '\000', '')
v.set_shape([self._num_sample])
result_dict[k] = v
return result_dict
@@ -381,6 +602,7 @@ def __init__(self,
def _get_impl(self, src_ids, dst_ids):
src_ids = np.array(src_ids, dtype=np.int64)
dst_ids = np.array(dst_ids, dtype=np.int64)
+ dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
nodes = self._neg_sampler.get(dst_ids)
neg_features = self._parse_nodes(nodes)
sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
@@ -388,7 +610,10 @@ def _get_impl(self, src_ids, dst_ids):
results = []
for i, v in enumerate(hard_neg_features):
- results.append(np.concatenate([neg_features[i], v], axis=-1))
+ if type(v) == list:
+ results.append(np.asarray(neg_features[i] + v, order='C', dtype=object))
+ else:
+ results.append(np.concatenate([neg_features[i], v], axis=0))
results.append(hard_neg_indices)
return results
@@ -407,9 +632,6 @@ def get(self, src_ids, dst_ids):
result_dict = {}
for k, t, v in zip(self._attr_names, self._attr_tf_types,
output_values[:-1]):
- if t == tf.string:
- # string convert from np array to tensor will be padded with \000, we need remove it
- v = tf.regex_replace(v, '\000', '')
v.set_shape([None])
result_dict[k] = v
@@ -464,7 +686,7 @@ def __init__(self,
attr_delimiter=attr_delimiter)) \
.edge(tf.compat.as_str(edge_data_path),
edge_type=('user', 'item', 'edge'),
- decoder=gl.Decoder(weighted=True)) \
+ decoder=gl.Decoder(weighted=True)) \
.edge(tf.compat.as_str(hard_neg_edge_data_path),
edge_type=('user', 'item', 'hard_neg_edge'),
decoder=gl.Decoder(weighted=True))
@@ -479,15 +701,21 @@ def __init__(self,
def _get_impl(self, src_ids, dst_ids):
src_ids = np.array(src_ids, dtype=np.int64)
+ src_ids_padded = np.pad(src_ids, (0, self._batch_size - len(src_ids)),
+ 'edge')
dst_ids = np.array(dst_ids, dtype=np.int64)
- nodes = self._neg_sampler.get(src_ids, dst_ids)
+ dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
+ nodes = self._neg_sampler.get(src_ids_padded, dst_ids)
neg_features = self._parse_nodes(nodes)
sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)
results = []
for i, v in enumerate(hard_neg_features):
- results.append(np.concatenate([neg_features[i], v], axis=-1))
+ if type(v) == list:
+ results.append(np.asarray(neg_features[i] + v, order='C', dtype=object))
+ else:
+ results.append(np.concatenate([neg_features[i], v], axis=0))
results.append(hard_neg_indices)
return results
@@ -506,9 +734,6 @@ def get(self, src_ids, dst_ids):
result_dict = {}
for k, t, v in zip(self._attr_names, self._attr_tf_types,
output_values[:-1]):
- if t == tf.string:
- # string convert from np array to tensor will be padded with \000, we need remove it
- v = tf.regex_replace(v, '\000', '')
v.set_shape([None])
result_dict[k] = v
@@ -519,15 +744,35 @@ def get(self, src_ids, dst_ids):
def build(data_config):
+
if not data_config.HasField('sampler'):
return None
sampler_type = data_config.WhichOneof('sampler')
+ print('sampler_type = %s' % sampler_type)
sampler_config = getattr(data_config, sampler_type)
+
+ if ds_util.is_on_ds():
+ gl.set_field_delimiter(sampler_config.field_delimiter)
+
if sampler_type == 'negative_sampler':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
+
+ input_path = process_multi_file_input_path(sampler_config.input_path)
return NegativeSampler.instance(
- data_path=sampler_config.input_path,
+ data_path=input_path,
+ fields=attr_fields,
+ num_sample=sampler_config.num_sample,
+ batch_size=data_config.batch_size,
+ attr_delimiter=sampler_config.attr_delimiter,
+ num_eval_sample=sampler_config.num_eval_sample)
+ elif sampler_type == 'negative_sampler_in_memory':
+ input_fields = {f.input_name: f for f in data_config.input_fields}
+ attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
+
+ input_path = process_multi_file_input_path(sampler_config.input_path)
+ return NegativeSamplerInMemory.instance(
+ data_path=input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
batch_size=data_config.batch_size,
@@ -536,10 +781,17 @@ def build(data_config):
elif sampler_type == 'negative_sampler_v2':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
+
+ user_input_path = process_multi_file_input_path(
+ sampler_config.user_input_path)
+ item_input_path = process_multi_file_input_path(
+ sampler_config.item_input_path)
+ pos_edge_input_path = process_multi_file_input_path(
+ sampler_config.pos_edge_input_path)
return NegativeSamplerV2.instance(
- user_data_path=sampler_config.user_input_path,
- item_data_path=sampler_config.item_input_path,
- edge_data_path=sampler_config.pos_edge_input_path,
+ user_data_path=user_input_path,
+ item_data_path=item_input_path,
+ edge_data_path=pos_edge_input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
batch_size=data_config.batch_size,
@@ -548,10 +800,17 @@ def build(data_config):
elif sampler_type == 'hard_negative_sampler':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
+
+ user_input_path = process_multi_file_input_path(
+ sampler_config.user_input_path)
+ item_input_path = process_multi_file_input_path(
+ sampler_config.item_input_path)
+ hard_neg_edge_input_path = process_multi_file_input_path(
+ sampler_config.hard_neg_edge_input_path)
return HardNegativeSampler.instance(
- user_data_path=sampler_config.user_input_path,
- item_data_path=sampler_config.item_input_path,
- hard_neg_edge_data_path=sampler_config.hard_neg_edge_input_path,
+ user_data_path=user_input_path,
+ item_data_path=item_input_path,
+ hard_neg_edge_data_path=hard_neg_edge_input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
num_hard_sample=sampler_config.num_hard_sample,
@@ -561,11 +820,20 @@ def build(data_config):
elif sampler_type == 'hard_negative_sampler_v2':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
+
+ user_input_path = process_multi_file_input_path(
+ sampler_config.user_input_path)
+ item_input_path = process_multi_file_input_path(
+ sampler_config.item_input_path)
+ pos_edge_input_path = process_multi_file_input_path(
+ sampler_config.pos_edge_input_path)
+ hard_neg_edge_input_path = process_multi_file_input_path(
+ sampler_config.hard_neg_edge_input_path)
return HardNegativeSamplerV2.instance(
- user_data_path=sampler_config.user_input_path,
- item_data_path=sampler_config.item_input_path,
- edge_data_path=sampler_config.pos_edge_input_path,
- hard_neg_edge_data_path=sampler_config.hard_neg_edge_input_path,
+ user_data_path=user_input_path,
+ item_data_path=item_input_path,
+ edge_data_path=pos_edge_input_path,
+ hard_neg_edge_data_path=hard_neg_edge_input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
num_hard_sample=sampler_config.num_hard_sample,
diff --git a/easy_rec/python/eval.py b/easy_rec/python/eval.py
index 920b3b001..d41c3d7ae 100644
--- a/easy_rec/python/eval.py
+++ b/easy_rec/python/eval.py
@@ -8,8 +8,14 @@
import tensorflow as tf
from tensorflow.python.lib.io import file_io
+from easy_rec.python.main import distribute_evaluate
from easy_rec.python.main import evaluate
+from easy_rec.python.protos.train_pb2 import DistributionStrategy
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import ds_util
+from easy_rec.python.utils import estimator_utils
+from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_distribute_eval_worker_num_on_ds # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -29,6 +35,11 @@
'override pipeline_config.eval_input_path')
tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir')
tf.app.flags.DEFINE_string('odps_config', None, help='odps config path')
+tf.app.flags.DEFINE_string('eval_result_path', 'eval_result.txt',
+ 'eval result metric file')
+tf.app.flags.DEFINE_bool('distribute_eval', False,
+ 'use distribute parameter server for train and eval.')
+tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
FLAGS = tf.app.flags.FLAGS
@@ -36,6 +47,11 @@ def main(argv):
if FLAGS.odps_config:
os.environ['ODPS_CONFIG_FILE_PATH'] = FLAGS.odps_config
+ if FLAGS.is_on_ds:
+ ds_util.set_on_ds()
+ if FLAGS.distribute_eval:
+ set_tf_config_and_get_distribute_eval_worker_num_on_ds()
+
assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
if FLAGS.model_dir:
pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config')
@@ -46,13 +62,40 @@ def main(argv):
else:
pipeline_config_path = FLAGS.pipeline_config_path
- eval_result = evaluate(pipeline_config_path, FLAGS.checkpoint_path,
- FLAGS.eval_input_path)
- for key in sorted(eval_result):
- # skip logging binary data
- if isinstance(eval_result[key], six.binary_type):
- continue
- logging.info('%s: %s' % (key, str(eval_result[key])))
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path)
+ if FLAGS.model_dir:
+ pipeline_config.model_dir = FLAGS.model_dir
+
+ if pipeline_config.train_config.train_distribute in [
+ DistributionStrategy.HorovodStrategy,
+ ]:
+ estimator_utils.init_hvd()
+ elif pipeline_config.train_config.train_distribute in [
+ DistributionStrategy.EmbeddingParallelStrategy,
+ DistributionStrategy.SokStrategy
+ ]:
+ estimator_utils.init_hvd()
+ estimator_utils.init_sok()
+
+ if FLAGS.distribute_eval:
+ os.environ['distribute_eval'] = 'True'
+ eval_result = distribute_evaluate(pipeline_config, FLAGS.checkpoint_path,
+ FLAGS.eval_input_path,
+ FLAGS.eval_result_path)
+ else:
+ os.environ['distribute_eval'] = 'False'
+ eval_result = evaluate(pipeline_config, FLAGS.checkpoint_path,
+ FLAGS.eval_input_path, FLAGS.eval_result_path)
+ if eval_result is not None:
+ # when distribute evaluate, only master has eval_result.
+ for key in sorted(eval_result):
+ # skip logging binary data
+ if isinstance(eval_result[key], six.binary_type):
+ continue
+ logging.info('%s: %s' % (key, str(eval_result[key])))
+ else:
+ logging.info('Eval result in master worker.')
if __name__ == '__main__':
diff --git a/easy_rec/python/export.py b/easy_rec/python/export.py
index 0c01f2eed..c1a8ce670 100644
--- a/easy_rec/python/export.py
+++ b/easy_rec/python/export.py
@@ -7,6 +7,14 @@
from tensorflow.python.lib.io import file_io
from easy_rec.python.main import export
+from easy_rec.python.protos.train_pb2 import DistributionStrategy
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import estimator_utils
+
+if tf.__version__.startswith('1.'):
+ from tensorflow.python.platform import gfile
+else:
+ import tensorflow.io.gfile as gfile
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -54,6 +62,10 @@
tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir')
tf.app.flags.mark_flag_as_required('export_dir')
+
+tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists')
+tf.app.flags.DEFINE_string('export_done_file', '',
+ 'a flag file to signal that export model is done')
FLAGS = tf.app.flags.FLAGS
@@ -105,8 +117,33 @@ def main(argv):
if FLAGS.oss_embedding_version:
extra_params['oss_embedding_version'] = FLAGS.oss_embedding_version
- export(FLAGS.export_dir, pipeline_config_path, FLAGS.checkpoint_path,
- FLAGS.asset_files, FLAGS.verbose, **extra_params)
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path)
+ if pipeline_config.train_config.train_distribute in [
+ DistributionStrategy.HorovodStrategy,
+ ]:
+ estimator_utils.init_hvd()
+ elif pipeline_config.train_config.train_distribute in [
+ DistributionStrategy.EmbeddingParallelStrategy,
+ DistributionStrategy.SokStrategy
+ ]:
+ estimator_utils.init_hvd()
+ estimator_utils.init_sok()
+
+ if FLAGS.clear_export:
+ logging.info('will clear export_dir=%s' % FLAGS.export_dir)
+ if gfile.IsDirectory(FLAGS.export_dir):
+ gfile.DeleteRecursively(FLAGS.export_dir)
+
+ export_out_dir = export(FLAGS.export_dir, pipeline_config_path,
+ FLAGS.checkpoint_path, FLAGS.asset_files,
+ FLAGS.verbose, **extra_params)
+
+ if FLAGS.export_done_file:
+ flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
+ logging.info('create export done file: %s' % flag_file)
+ with gfile.GFile(flag_file, 'w') as fout:
+ fout.write('ExportDone')
if __name__ == '__main__':
diff --git a/easy_rec/python/feature_column/feature_column.py b/easy_rec/python/feature_column/feature_column.py
index 5a208591c..8701b55fc 100644
--- a/easy_rec/python/feature_column/feature_column.py
+++ b/easy_rec/python/feature_column/feature_column.py
@@ -1,21 +1,22 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import collections
import logging
+import sys
import tensorflow as tf
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.platform import gfile
from easy_rec.python.builders import hyperparams_builder
from easy_rec.python.compat.feature_column import sequence_feature_column
from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
+from easy_rec.python.utils.proto_util import copy_obj
from easy_rec.python.compat.feature_column import feature_column_v2 as feature_column # NOQA
-if tf.__version__ >= '2.0':
- min_max_variable_partitioner = tf.compat.v1.min_max_variable_partitioner
- tf = tf.compat.v1
-else:
- min_max_variable_partitioner = tf.min_max_variable_partitioner
+MAX_HASH_BUCKET_SIZE = 9223372036854775807
class FeatureKeyError(KeyError):
@@ -24,6 +25,19 @@ def __init__(self, feature_name):
super(FeatureKeyError, self).__init__(feature_name)
+class SharedEmbedding(object):
+
+ def __init__(self, embedding_name, index, sequence_combiner=None):
+ self.embedding_name = embedding_name
+ self.index = index
+ self.sequence_combiner = sequence_combiner
+
+
+EVParams = collections.namedtuple('EVParams', [
+ 'filter_freq', 'steps_to_live', 'use_cache', 'init_capacity', 'max_capacity'
+])
+
+
class FeatureColumnParser(object):
"""Parse and generate feature columns."""
@@ -31,7 +45,7 @@ def __init__(self,
feature_configs,
wide_deep_dict={},
wide_output_dim=-1,
- use_embedding_variable=False):
+ ev_params=None):
"""Initializes a `FeatureColumnParser`.
Args:
@@ -42,7 +56,7 @@ def __init__(self,
easy_rec.python.layers.input_layer.InputLayer, it is defined in
easy_rec.python.protos.easy_rec_model_pb2.EasyRecModel.feature_groups
wide_output_dim: output dimension for wide columns
- use_embedding_variable: use EmbeddingVariable, which is provided by pai-tf
+ ev_params: params used by EmbeddingVariable, which is provided by pai-tf
"""
self._feature_configs = feature_configs
self._wide_output_dim = wide_output_dim
@@ -54,31 +68,32 @@ def __init__(self,
self._share_embed_names = {}
self._share_embed_infos = {}
- self._use_embedding_variable = use_embedding_variable
self._vocab_size = {}
+ self._global_ev_params = None
+ if ev_params is not None:
+ self._global_ev_params = self._build_ev_params(ev_params)
+
+ def _cmp_embed_config(a, b):
+ return a.embedding_dim == b.embedding_dim and a.combiner == b.combiner and\
+ a.initializer == b.initializer and a.max_partitions == b.max_partitions and\
+ a.embedding_name == b.embedding_name
+
for config in self._feature_configs:
if not config.HasField('embedding_name'):
continue
embed_name = config.embedding_name
- embed_info = {
- 'embedding_dim':
- config.embedding_dim,
- 'combiner':
- config.combiner,
- 'initializer':
- config.initializer if config.HasField('initializer') else None,
- 'max_partitions':
- config.max_partitions
- }
+
if embed_name in self._share_embed_names:
- assert embed_info == self._share_embed_infos[embed_name], \
+ assert _cmp_embed_config(config, self._share_embed_infos[embed_name]),\
'shared embed info of [%s] is not matched [%s] vs [%s]' % (
- embed_name, embed_info, self._share_embed_infos[embed_name])
+ embed_name, config, self._share_embed_infos[embed_name])
self._share_embed_names[embed_name] += 1
+ if config.feature_type == FeatureConfig.FeatureType.SequenceFeature:
+ self._share_embed_infos[embed_name] = copy_obj(config)
else:
self._share_embed_names[embed_name] = 1
- self._share_embed_infos[embed_name] = embed_info
+ self._share_embed_infos[embed_name] = copy_obj(config)
# remove not shared embedding names
not_shared = [
@@ -100,6 +115,7 @@ def __init__(self,
embed_name: [] for embed_name in self._share_embed_names
}
+ self._feature_vocab_size = {}
for config in self._feature_configs:
assert isinstance(config, FeatureConfig)
try:
@@ -115,55 +131,75 @@ def __init__(self,
self.parse_lookup_feature(config)
elif config.feature_type == config.SequenceFeature:
self.parse_sequence_feature(config)
- else:
+ elif config.feature_type == config.ExprFeature:
+ self.parse_expr_feature(config)
+ elif config.feature_type != config.PassThroughFeature:
assert False, 'invalid feature type: %s' % config.feature_type
except FeatureKeyError:
pass
for embed_name in self._share_embed_names:
initializer = None
- if self._share_embed_infos[embed_name]['initializer']:
+ if self._share_embed_infos[embed_name].HasField('initializer'):
initializer = hyperparams_builder.build_initializer(
- self._share_embed_infos[embed_name]['initializer'])
- partitioner = self._build_partitioner(
- self._share_embed_infos[embed_name]['max_partitions'])
+ self._share_embed_infos[embed_name].initializer)
+
+ partitioner = self._build_partitioner(self._share_embed_infos[embed_name])
+
+ if self._share_embed_infos[embed_name].HasField('ev_params'):
+ ev_params = self._build_ev_params(
+ self._share_embed_infos[embed_name].ev_params)
+ else:
+ ev_params = self._global_ev_params
+
# for handling share embedding columns
- share_embed_fcs = feature_column.shared_embedding_columns(
- self._deep_share_embed_columns[embed_name],
- self._share_embed_infos[embed_name]['embedding_dim'],
- initializer=initializer,
- shared_embedding_collection_name=embed_name,
- combiner=self._share_embed_infos[embed_name]['combiner'],
- partitioner=partitioner,
- use_embedding_variable=self._use_embedding_variable)
- self._deep_share_embed_columns[embed_name] = share_embed_fcs
+ if len(self._deep_share_embed_columns[embed_name]) > 0:
+ share_embed_fcs = feature_column.shared_embedding_columns(
+ self._deep_share_embed_columns[embed_name],
+ self._share_embed_infos[embed_name].embedding_dim,
+ initializer=initializer,
+ shared_embedding_collection_name=embed_name,
+ combiner=self._share_embed_infos[embed_name].combiner,
+ partitioner=partitioner,
+ ev_params=ev_params)
+ config = self._share_embed_infos[embed_name]
+ max_seq_len = config.max_seq_len if config.HasField(
+ 'max_seq_len') else -1
+ for fc in share_embed_fcs:
+ fc.max_seq_length = max_seq_len
+ self._deep_share_embed_columns[embed_name] = share_embed_fcs
+
# for handling wide share embedding columns
- if len(self._wide_share_embed_columns[embed_name]) == 0:
- continue
- share_embed_fcs = feature_column.shared_embedding_columns(
- self._wide_share_embed_columns[embed_name],
- self._wide_output_dim,
- initializer=initializer,
- shared_embedding_collection_name=embed_name + '_wide',
- combiner='sum',
- partitioner=partitioner,
- use_embedding_variable=self._use_embedding_variable)
- self._wide_share_embed_columns[embed_name] = share_embed_fcs
+ if len(self._wide_share_embed_columns[embed_name]) > 0:
+ share_embed_fcs = feature_column.shared_embedding_columns(
+ self._wide_share_embed_columns[embed_name],
+ self._wide_output_dim,
+ initializer=initializer,
+ shared_embedding_collection_name=embed_name + '_wide',
+ combiner='sum',
+ partitioner=partitioner,
+ ev_params=ev_params)
+ config = self._share_embed_infos[embed_name]
+ max_seq_len = config.max_seq_len if config.HasField(
+ 'max_seq_len') else -1
+ for fc in share_embed_fcs:
+ fc.max_seq_length = max_seq_len
+ self._wide_share_embed_columns[embed_name] = share_embed_fcs
for fc_name in self._deep_columns:
fc = self._deep_columns[fc_name]
- if type(fc) == tuple:
+ if isinstance(fc, SharedEmbedding):
self._deep_columns[fc_name] = self._get_shared_embedding_column(fc)
for fc_name in self._wide_columns:
fc = self._wide_columns[fc_name]
- if type(fc) == tuple:
+ if isinstance(fc, SharedEmbedding):
self._wide_columns[fc_name] = self._get_shared_embedding_column(
fc, deep=False)
for fc_name in self._sequence_columns:
fc = self._sequence_columns[fc_name]
- if type(fc) == tuple:
+ if isinstance(fc, SharedEmbedding):
self._sequence_columns[fc_name] = self._get_shared_embedding_column(fc)
@property
@@ -201,14 +237,25 @@ def is_deep(self, config):
WideOrDeep.DEEP, WideOrDeep.WIDE_AND_DEEP
]
+ def get_feature_vocab_size(self, feature):
+ return self._feature_vocab_size.get(feature, 1)
+
def _get_vocab_size(self, vocab_path):
if vocab_path in self._vocab_size:
return self._vocab_size[vocab_path]
- with tf.gfile.GFile(vocab_path, 'r') as fin:
+ with gfile.GFile(vocab_path, 'r') as fin:
vocabulary_size = sum(1 for _ in fin)
self._vocab_size[vocab_path] = vocabulary_size
return vocabulary_size
+ def _get_hash_bucket_size(self, config):
+ if not config.HasField('hash_bucket_size'):
+ return -1
+ if self._global_ev_params is not None or config.HasField('ev_params'):
+ return MAX_HASH_BUCKET_SIZE
+ else:
+ return config.hash_bucket_size
+
def parse_id_feature(self, config):
"""Generate id feature columns.
@@ -219,24 +266,32 @@ def parse_id_feature(self, config):
Args:
config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
"""
- hash_bucket_size = config.hash_bucket_size
+ feature_name = config.feature_name if config.HasField('feature_name') \
+ else config.input_names[0]
+ hash_bucket_size = self._get_hash_bucket_size(config)
if hash_bucket_size > 0:
fc = feature_column.categorical_column_with_hash_bucket(
- config.input_names[0], hash_bucket_size=hash_bucket_size)
+ feature_name,
+ hash_bucket_size=hash_bucket_size,
+ feature_name=feature_name)
elif config.vocab_list:
fc = feature_column.categorical_column_with_vocabulary_list(
- config.input_names[0],
+ feature_name,
default_value=0,
- vocabulary_list=config.vocab_list)
+ vocabulary_list=config.vocab_list,
+ feature_name=feature_name)
elif config.vocab_file:
fc = feature_column.categorical_column_with_vocabulary_file(
- config.input_names[0],
+ feature_name,
default_value=0,
vocabulary_file=config.vocab_file,
- vocabulary_size=self._get_vocab_size(config.vocab_file))
+ vocabulary_size=self._get_vocab_size(config.vocab_file),
+ feature_name=feature_name)
else:
+ use_ev = self._global_ev_params or config.HasField('ev_params')
+ num_buckets = sys.maxsize if use_ev else config.num_buckets
fc = feature_column.categorical_column_with_identity(
- config.input_names[0], config.num_buckets, default_value=0)
+ feature_name, num_buckets, default_value=0, feature_name=feature_name)
if self.is_wide(config):
self._add_wide_embedding_column(fc, config)
@@ -253,32 +308,40 @@ def parse_tag_feature(self, config):
Args:
config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
"""
- hash_bucket_size = config.hash_bucket_size
- if config.HasField('hash_bucket_size'):
+ feature_name = config.feature_name if config.HasField('feature_name') \
+ else config.input_names[0]
+ hash_bucket_size = self._get_hash_bucket_size(config)
+ if hash_bucket_size > 0:
tag_fc = feature_column.categorical_column_with_hash_bucket(
- config.input_names[0], hash_bucket_size, dtype=tf.string)
+ feature_name,
+ hash_bucket_size,
+ dtype=tf.string,
+ feature_name=feature_name)
elif config.vocab_list:
tag_fc = feature_column.categorical_column_with_vocabulary_list(
- config.input_names[0],
+ feature_name,
default_value=0,
- vocabulary_list=config.vocab_list)
+ vocabulary_list=config.vocab_list,
+ feature_name=feature_name)
elif config.vocab_file:
tag_fc = feature_column.categorical_column_with_vocabulary_file(
- config.input_names[0],
+ feature_name,
default_value=0,
vocabulary_file=config.vocab_file,
- vocabulary_size=self._get_vocab_size(config.vocab_file))
+ vocabulary_size=self._get_vocab_size(config.vocab_file),
+ feature_name=feature_name)
else:
+ use_ev = self._global_ev_params or config.HasField('ev_params')
+ num_buckets = sys.maxsize if use_ev else config.num_buckets
tag_fc = feature_column.categorical_column_with_identity(
- config.input_names[0], config.num_buckets, default_value=0)
+ feature_name, num_buckets, default_value=0, feature_name=feature_name)
if len(config.input_names) > 1:
tag_fc = feature_column.weighted_categorical_column(
- tag_fc, weight_feature_key=config.input_names[1], dtype=tf.float32)
+ tag_fc, weight_feature_key=feature_name + '_w', dtype=tf.float32)
elif config.HasField('kv_separator'):
- wgt_name = config.input_names[0] + '_WEIGHT'
tag_fc = feature_column.weighted_categorical_column(
- tag_fc, weight_feature_key=wgt_name, dtype=tf.float32)
+ tag_fc, weight_feature_key=feature_name + '_w', dtype=tf.float32)
if self.is_wide(config):
self._add_wide_embedding_column(tag_fc, config)
@@ -296,7 +359,9 @@ def parse_raw_feature(self, config):
feature_name = config.feature_name if config.HasField('feature_name') \
else config.input_names[0]
fc = feature_column.numeric_column(
- config.input_names[0], shape=(config.raw_input_dim,))
+ key=feature_name,
+ shape=(config.raw_input_dim,),
+ feature_name=feature_name)
bounds = None
if config.boundaries:
@@ -314,8 +379,8 @@ def parse_raw_feature(self, config):
try:
fc = feature_column.bucketized_column(fc, bounds)
except Exception as e:
- tf.logging.error('bucketized_column [%s] with bounds %s error' %
- (fc.name, str(bounds)))
+ logging.error('bucketized_column [%s] with bounds %s error' %
+ (fc.name, str(bounds)))
raise e
if self.is_wide(config):
self._add_wide_embedding_column(fc, config)
@@ -323,12 +388,13 @@ def parse_raw_feature(self, config):
self._add_deep_embedding_column(fc, config)
else:
tmp_id_col = feature_column.categorical_column_with_identity(
- config.input_names[0] + '_raw_proj_id',
+ feature_name + '_raw_proj_id',
config.raw_input_dim,
- default_value=0)
+ default_value=0,
+ feature_name=feature_name)
wgt_fc = feature_column.weighted_categorical_column(
tmp_id_col,
- weight_feature_key=config.input_names[0] + '_raw_proj_val',
+ weight_feature_key=feature_name + '_raw_proj_val',
dtype=tf.float32)
if self.is_wide(config):
self._add_wide_embedding_column(wgt_fc, config)
@@ -338,15 +404,50 @@ def parse_raw_feature(self, config):
else:
self._deep_columns[feature_name] = fc
+ def parse_expr_feature(self, config):
+ """Generate raw features columns.
+
+ if boundaries is set, will be converted to category_column first.
+
+ Args:
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
+ """
+ feature_name = config.feature_name if config.HasField('feature_name') \
+ else config.input_names[0]
+ fc = feature_column.numeric_column(
+ feature_name, shape=(1,), feature_name=feature_name)
+ if self.is_wide(config):
+ self._add_wide_embedding_column(fc, config)
+ if self.is_deep(config):
+ self._deep_columns[feature_name] = fc
+
def parse_combo_feature(self, config):
"""Generate combo feature columns.
Args:
config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
"""
+ feature_name = config.feature_name if config.HasField('feature_name') \
+ else None
assert len(config.input_names) >= 2
- fc = feature_column.crossed_column(
- config.input_names, config.hash_bucket_size, hash_key=None)
+
+ if len(config.combo_join_sep) == 0:
+ input_names = []
+ for input_id in range(len(config.input_names)):
+ if input_id == 0:
+ input_names.append(feature_name)
+ else:
+ input_names.append(feature_name + '_' + str(input_id))
+ fc = feature_column.crossed_column(
+ input_names,
+ self._get_hash_bucket_size(config),
+ hash_key=None,
+ feature_name=feature_name)
+ else:
+ fc = feature_column.categorical_column_with_hash_bucket(
+ feature_name,
+ hash_bucket_size=self._get_hash_bucket_size(config),
+ feature_name=feature_name)
if self.is_wide(config):
self._add_wide_embedding_column(fc, config)
@@ -362,9 +463,12 @@ def parse_lookup_feature(self, config):
feature_name = config.feature_name if config.HasField('feature_name') \
else config.input_names[0]
assert config.HasField('hash_bucket_size')
- hash_bucket_size = config.hash_bucket_size
+ hash_bucket_size = self._get_hash_bucket_size(config)
fc = feature_column.categorical_column_with_hash_bucket(
- feature_name, hash_bucket_size, dtype=tf.string)
+ feature_name,
+ hash_bucket_size,
+ dtype=tf.string,
+ feature_name=feature_name)
if self.is_wide(config):
self._add_wide_embedding_column(fc, config)
@@ -379,53 +483,115 @@ def parse_sequence_feature(self, config):
"""
feature_name = config.feature_name if config.HasField('feature_name') \
else config.input_names[0]
- if config.HasField('hash_bucket_size'):
- hash_bucket_size = config.hash_bucket_size
- fc = sequence_feature_column.sequence_categorical_column_with_hash_bucket(
- config.input_names[0], hash_bucket_size, dtype=tf.string)
- elif config.vocab_list:
- fc = sequence_feature_column.sequence_categorical_column_with_vocabulary_list(
- config.input_names[0],
- default_value=0,
- vocabulary_list=config.vocab_list)
- elif config.vocab_file:
- fc = sequence_feature_column.sequence_categorical_column_with_vocabulary_file(
- config.input_names[0],
- default_value=0,
- vocabulary_file=config.vocab_file,
- vocabulary_size=self._get_vocab_size(config.vocab_file))
- else:
- fc = sequence_feature_column.sequence_categorical_column_with_identity(
- config.input_names[0], config.num_buckets, default_value=0)
-
- assert config.embedding_dim > 0
+ sub_feature_type = config.sub_feature_type
+ assert sub_feature_type in [config.IdFeature, config.RawFeature], \
+ 'Current sub_feature_type only support IdFeature and RawFeature.'
+ if sub_feature_type == config.IdFeature:
+ if config.HasField('hash_bucket_size'):
+ hash_bucket_size = self._get_hash_bucket_size(config)
+ fc = sequence_feature_column.sequence_categorical_column_with_hash_bucket(
+ feature_name,
+ hash_bucket_size,
+ dtype=tf.string,
+ feature_name=feature_name)
+ elif config.vocab_list:
+ fc = sequence_feature_column.sequence_categorical_column_with_vocabulary_list(
+ feature_name,
+ default_value=0,
+ vocabulary_list=config.vocab_list,
+ feature_name=feature_name)
+ elif config.vocab_file:
+ fc = sequence_feature_column.sequence_categorical_column_with_vocabulary_file(
+ feature_name,
+ default_value=0,
+ vocabulary_file=config.vocab_file,
+ vocabulary_size=self._get_vocab_size(config.vocab_file),
+ feature_name=feature_name)
+ else:
+ use_ev = self._global_ev_params or config.HasField('ev_params')
+ num_buckets = sys.maxsize if use_ev else config.num_buckets
+ fc = sequence_feature_column.sequence_categorical_column_with_identity(
+ feature_name,
+ num_buckets,
+ default_value=0,
+ feature_name=feature_name)
+ else: # raw feature
+ bounds = None
+ fc = sequence_feature_column.sequence_numeric_column(
+ feature_name, shape=(1,), feature_name=feature_name)
+ if config.hash_bucket_size > 0:
+ hash_bucket_size = self._get_hash_bucket_size(config)
+ assert sub_feature_type == config.IdFeature, \
+ 'You should set sub_feature_type to IdFeature to use hash_bucket_size.'
+ elif config.boundaries:
+ bounds = list(config.boundaries)
+ bounds.sort()
+ elif config.num_buckets > 1 and config.max_val > config.min_val:
+ # the feature values are already normalized into [0, 1]
+ bounds = [
+ x / float(config.num_buckets) for x in range(0, config.num_buckets)
+ ]
+ logging.info('sequence feature discrete %s into %d buckets' %
+ (feature_name, config.num_buckets))
+ if bounds:
+ try:
+ fc = sequence_feature_column.sequence_numeric_column_with_bucketized_column(
+ fc, bounds)
+ except Exception as e:
+ logging.error(
+ 'sequence features bucketized_column [%s] with bounds %s error' %
+ (feature_name, str(bounds)))
+ raise e
+ elif config.hash_bucket_size <= 0:
+ if config.embedding_dim > 0:
+ tmp_id_col = sequence_feature_column.sequence_categorical_column_with_identity(
+ feature_name + '_raw_proj_id',
+ config.raw_input_dim,
+ default_value=0,
+ feature_name=feature_name)
+ wgt_fc = sequence_feature_column.sequence_weighted_categorical_column(
+ tmp_id_col,
+ weight_feature_key=feature_name + '_raw_proj_val',
+ dtype=tf.float32)
+ fc = wgt_fc
+ else:
+ fc = sequence_feature_column.sequence_numeric_column_with_raw_column(
+ fc, config.sequence_length)
- self._add_deep_embedding_column(fc, config)
+ if config.embedding_dim > 0:
+ self._add_deep_embedding_column(fc, config)
+ else:
+ self._sequence_columns[feature_name] = fc
- def _build_partitioner(self, max_partitions):
- if max_partitions > 1:
- if self._use_embedding_variable:
+ def _build_partitioner(self, config):
+ if config.max_partitions > 1:
+ if self._global_ev_params is not None or config.HasField('ev_params'):
# pai embedding_variable should use fixed_size_partitioner
- return tf.fixed_size_partitioner(num_shards=max_partitions)
+ return partitioned_variables.fixed_size_partitioner(
+ num_shards=config.max_partitions)
else:
- return min_max_variable_partitioner(max_partitions=max_partitions)
+ return partitioned_variables.min_max_variable_partitioner(
+ max_partitions=config.max_partitions)
else:
return None
def _add_shared_embedding_column(self, embedding_name, fc, deep=True):
- curr_id = len(self._deep_share_embed_columns[embedding_name])
if deep:
+ curr_id = len(self._deep_share_embed_columns[embedding_name])
self._deep_share_embed_columns[embedding_name].append(fc)
else:
+ curr_id = len(self._wide_share_embed_columns[embedding_name])
self._wide_share_embed_columns[embedding_name].append(fc)
- return (embedding_name, curr_id)
+ return SharedEmbedding(embedding_name, curr_id, None)
def _get_shared_embedding_column(self, fc_handle, deep=True):
- embed_name, embed_id = fc_handle
+ embed_name, embed_id = fc_handle.embedding_name, fc_handle.index
if deep:
- return self._deep_share_embed_columns[embed_name][embed_id]
+ tmp = self._deep_share_embed_columns[embed_name][embed_id]
else:
- return self._wide_share_embed_columns[embed_name][embed_id]
+ tmp = self._wide_share_embed_columns[embed_name][embed_id]
+ tmp.sequence_combiner = fc_handle.sequence_combiner
+ return tmp
def _add_wide_embedding_column(self, fc, config):
"""Generate wide feature columns.
@@ -443,13 +609,17 @@ def _add_wide_embedding_column(self, fc, config):
initializer = None
if config.HasField('initializer'):
initializer = hyperparams_builder.build_initializer(config.initializer)
+ if config.HasField('ev_params'):
+ ev_params = self._build_ev_params(config.ev_params)
+ else:
+ ev_params = self._global_ev_params
wide_fc = feature_column.embedding_column(
fc,
self._wide_output_dim,
combiner='sum',
initializer=initializer,
- partitioner=self._build_partitioner(config.max_partitions),
- use_embedding_variable=self._use_embedding_variable)
+ partitioner=self._build_partitioner(config),
+ ev_params=ev_params)
self._wide_columns[feature_name] = wide_fc
def _add_deep_embedding_column(self, fc, config):
@@ -457,22 +627,38 @@ def _add_deep_embedding_column(self, fc, config):
feature_name = config.feature_name if config.HasField('feature_name') \
else config.input_names[0]
assert config.embedding_dim > 0, 'embedding_dim is not set for %s' % feature_name
+ self._feature_vocab_size[feature_name] = fc.num_buckets
if config.embedding_name in self._deep_share_embed_columns:
fc = self._add_shared_embedding_column(config.embedding_name, fc)
else:
initializer = None
if config.HasField('initializer'):
initializer = hyperparams_builder.build_initializer(config.initializer)
+ if config.HasField('ev_params'):
+ ev_params = self._build_ev_params(config.ev_params)
+ else:
+ ev_params = self._global_ev_params
fc = feature_column.embedding_column(
fc,
config.embedding_dim,
combiner=config.combiner,
initializer=initializer,
- partitioner=self._build_partitioner(config.max_partitions),
- use_embedding_variable=self._use_embedding_variable)
+ partitioner=self._build_partitioner(config),
+ ev_params=ev_params)
+ fc.max_seq_length = config.max_seq_len if config.HasField(
+ 'max_seq_len') else -1
+
if config.feature_type != config.SequenceFeature:
self._deep_columns[feature_name] = fc
else:
if config.HasField('sequence_combiner'):
fc.sequence_combiner = config.sequence_combiner
self._sequence_columns[feature_name] = fc
+
+ def _build_ev_params(self, ev_params):
+ """Build embedding_variables params."""
+ ev_params = EVParams(
+ ev_params.filter_freq,
+ ev_params.steps_to_live if ev_params.steps_to_live > 0 else None,
+ ev_params.use_cache, ev_params.init_capacity, ev_params.max_capacity)
+ return ev_params
diff --git a/easy_rec/python/hpo/pai_hpo.py b/easy_rec/python/hpo/pai_hpo.py
index 12db919ad..8d6edaa19 100644
--- a/easy_rec/python/hpo/pai_hpo.py
+++ b/easy_rec/python/hpo/pai_hpo.py
@@ -44,7 +44,7 @@ def get_tuner(data, max_parallel, max_trial_num):
r = hpo.reader.create(**t['metric_reader'])
t.pop('metric_reader')
if r:
- subtask = hpo.task.create(**t, metric_reader=r)
+ subtask = hpo.task.create(metric_reader=r, **t)
else:
subtask = hpo.task.create(**t)
tasks.append(subtask)
diff --git a/easy_rec/python/inference/client/README.md b/easy_rec/python/inference/client/README.md
new file mode 100644
index 000000000..88871e057
--- /dev/null
+++ b/easy_rec/python/inference/client/README.md
@@ -0,0 +1,38 @@
+# EasyRecProcessor Client
+
+Demo
+
+```bash
+python -m easy_rec.python.client.client_demo \
+ --endpoint 1301055xxxxxxxxx.cn-hangzhou.pai-eas.aliyuncs.com \
+ --service_name ali_rec_rnk_sample_rt_v3 \
+ --token MmQ3Yxxxxxxxxxxx \
+ --table_schema data/test/client/user_table_schema \
+ --table_data data/test/client/user_table_data \
+ --item_lst data/test/client/item_lst
+
+# output:
+# results {
+# key: "item_0"
+# value {
+# scores: 0.0
+# scores: 0.0
+# }
+# }
+# results {
+# key: "item_1"
+# value {
+# scores: 0.0
+# scores: 0.0
+# }
+# }
+# results {
+# key: "item_2"
+# value {
+# scores: 0.0
+# scores: 0.0
+# }
+# }
+# outputs: "probs_is_click"
+# outputs: "probs_is_go"
+```
diff --git a/easy_rec/python/inference/client/client_demo.py b/easy_rec/python/inference/client/client_demo.py
new file mode 100644
index 000000000..9464b1073
--- /dev/null
+++ b/easy_rec/python/inference/client/client_demo.py
@@ -0,0 +1,134 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import logging
+import sys
+import traceback
+
+from easy_rec.python.inference.client.easyrec_request import EasyrecRequest
+from easy_rec.python.protos.predict_pb2 import PBFeature
+from easy_rec.python.protos.predict_pb2 import PBRequest
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+try:
+ from eas_prediction import PredictClient # TFRequest
+except Exception:
+ logging.error('eas_prediction is not installed: pip install eas-prediction')
+ sys.exit(1)
+
+
+def build_request(table_cols, table_data, item_ids=None):
+ request_pb = PBRequest()
+ assert isinstance(table_data, list)
+ try:
+ for col_id in range(len(table_cols)):
+ cname, dtype = table_cols[col_id]
+ value = table_data[col_id]
+ feat = PBFeature()
+ if value is None:
+ continue
+ if dtype == 'STRING':
+ feat.string_feature = value
+ elif dtype in ('FLOAT', 'DOUBLE'):
+ feat.float_feature = value
+ elif dtype == 'BIGINT':
+ feat.long_feature = value
+ elif dtype == 'INT':
+ feat.int_feature = value
+
+ request_pb.user_features[cname].CopyFrom(feat)
+ except Exception:
+ traceback.print_exc()
+ sys.exit()
+ request_pb.item_ids.extend(item_ids)
+ return request_pb
+
+
+def parse_table_schema(create_table_sql):
+ create_table_sql = create_table_sql.lower()
+ spos = create_table_sql.index('(')
+ epos = create_table_sql[spos + 1:].index(')')
+ cols = create_table_sql[(spos + 1):epos]
+ cols = [x.strip().lower() for x in cols.split(',')]
+ col_info_arr = []
+ for col in cols:
+ col = [k for k in col.split() if k != '']
+ assert len(col) == 2
+ col[1] = col[1].upper()
+ col_info_arr.append(col)
+ return col_info_arr
+
+
+def send_request(req_pb, client, debug_level=0):
+ req = EasyrecRequest()
+ req.add_feed(req_pb, debug_level)
+ tmp = client.predict(req)
+ return tmp
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--endpoint',
+ type=str,
+ default=None,
+ help='eas endpoint, such as 12345.cn-beijing.pai-eas.aliyuncs.com')
+ parser.add_argument(
+ '--service_name', type=str, default=None, help='eas service name')
+ parser.add_argument(
+ '--token', type=str, default=None, help='eas service token')
+ parser.add_argument(
+ '--table_schema',
+ type=str,
+ default=None,
+ help='user feature table schema path')
+ parser.add_argument(
+ '--table_data',
+ type=str,
+ default=None,
+ help='user feature table data path')
+ parser.add_argument('--item_lst', type=str, default=None, help='item list')
+
+ args, _ = parser.parse_known_args()
+
+ if args.endpoint is None:
+ logging.error('--endpoint is not set')
+ sys.exit(1)
+ if args.service_name is None:
+ logging.error('--service_name is not set')
+ sys.exit(1)
+ if args.token is None:
+ logging.error('--token is not set')
+ sys.exit(1)
+ if args.table_schema is None:
+ logging.error('--table_schema is not set')
+ sys.exit(1)
+ if args.table_data is None:
+ logging.error('--table_data is not set')
+ sys.exit(1)
+ if args.item_lst is None:
+ logging.error('--item_lst is not set')
+ sys.exit(1)
+
+ client = PredictClient(args.endpoint, args.service_name)
+ client.set_token(args.token)
+ client.init()
+
+ with open(args.table_schema, 'r') as fin:
+ create_table_sql = fin.read().strip()
+
+ with open(args.table_data, 'r') as fin:
+ table_data = fin.read().strip()
+
+ table_cols = parse_table_schema(create_table_sql)
+ table_data = table_data.split(';')
+
+ with open(args.item_lst, 'r') as fin:
+ items = fin.read().strip()
+ items = items.split(',')
+
+ req = build_request(table_cols, table_data, item_ids=items)
+ resp = send_request(req, client)
+ logging.info(resp)
diff --git a/easy_rec/python/inference/client/easyrec_request.py b/easy_rec/python/inference/client/easyrec_request.py
new file mode 100644
index 000000000..4980b5064
--- /dev/null
+++ b/easy_rec/python/inference/client/easyrec_request.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from eas_prediction.request import Request
+
+from easy_rec.python.protos.predict_pb2 import PBRequest
+from easy_rec.python.protos.predict_pb2 import PBResponse
+
+# from eas_prediction.request import Response
+
+
+class EasyrecRequest(Request):
+ """Request for tensorflow services whose input data is in format of protobuf.
+
+ This class privide methods to fill generate PBRequest and parse PBResponse.
+ """
+
+ def __init__(self, signature_name=None):
+ self.request_data = PBRequest()
+ self.signature_name = signature_name
+
+ def __str__(self):
+ return self.request_data
+
+ def set_signature_name(self, singature_name):
+ """Set the signature name of the model.
+
+ Args:
+ singature_name: signature name of the model
+ """
+ self.signature_name = singature_name
+
+ def add_feed(self, data, dbg_lvl=0):
+ if not isinstance(data, PBRequest):
+ self.request_data.ParseFromString(data)
+ else:
+ self.request_data = data
+ self.request_data.debug_level = dbg_lvl
+
+ def add_user_fea_flt(self, k, v):
+ self.request_data.user_features[k].float_feature = float(v)
+
+ def add_user_fea_s(self, k, v):
+ self.request_data.user_features[k].string_feature = str(v)
+
+ def set_faiss_neigh_num(self, neigh_num):
+ self.request_data.faiss_neigh_num = neigh_num
+
+ def keep_one_item_ids(self):
+ item_id = self.request_data.item_ids[0]
+ self.request_data.ClearField('item_ids')
+ self.request_data.item_ids.extend([item_id])
+
+ def to_string(self):
+ """Serialize the request to string for transmission.
+
+ Returns:
+ the request data in format of string
+ """
+ return self.request_data.SerializeToString()
+
+ def parse_response(self, response_data):
+ """Parse the given response data in string format to the related TFResponse object.
+
+ Args:
+ response_data: the service response data in string format
+
+ Returns:
+ the TFResponse object related the request
+ """
+ self.response = PBResponse()
+ self.response.ParseFromString(response_data)
+ return self.response
diff --git a/easy_rec/python/inference/csv_predictor.py b/easy_rec/python/inference/csv_predictor.py
new file mode 100644
index 000000000..c5e154d2e
--- /dev/null
+++ b/easy_rec/python/inference/csv_predictor.py
@@ -0,0 +1,189 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os
+
+import tensorflow as tf
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import SINGLE_PLACEHOLDER_FEATURE_KEY
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils.check_utils import check_split
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class CSVPredictor(Predictor):
+
+ def __init__(self,
+ model_path,
+ data_config,
+ with_header=False,
+ ds_vector_recall=False,
+ fg_json_path=None,
+ profiling_file=None,
+ selected_cols=None,
+ output_sep=chr(1)):
+ super(CSVPredictor, self).__init__(model_path, profiling_file, fg_json_path)
+ self._output_sep = output_sep
+ self._ds_vector_recall = ds_vector_recall
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
+ self._with_header = with_header
+
+ if 'rtp' in input_type:
+ self._is_rtp = True
+ self._input_sep = data_config.rtp_separator
+ else:
+ self._is_rtp = False
+ self._input_sep = data_config.separator
+
+ if selected_cols and not ds_vector_recall:
+ self._selected_cols = [int(x) for x in selected_cols.split(',')]
+ elif ds_vector_recall:
+ self._selected_cols = selected_cols.split(',')
+ else:
+ self._selected_cols = None
+
+ def _get_reserved_cols(self, reserved_cols):
+ if reserved_cols == 'ALL_COLUMNS':
+ if self._is_rtp:
+ if self._with_header:
+ reserved_cols = self._all_fields
+ else:
+ idx = 0
+ reserved_cols = []
+ for x in range(len(self._record_defaults) - 1):
+ if not self._selected_cols or x in self._selected_cols[:-1]:
+ reserved_cols.append(self._input_fields[idx])
+ idx += 1
+ else:
+ reserved_cols.append('no_used_%d' % x)
+ reserved_cols.append(SINGLE_PLACEHOLDER_FEATURE_KEY)
+ else:
+ reserved_cols = self._all_fields
+ else:
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
+ return reserved_cols
+
+ def _parse_line(self, line):
+ check_list = [
+ tf.py_func(
+ check_split, [line, self._input_sep,
+ len(self._record_defaults)],
+ Tout=tf.bool)
+ ]
+ with tf.control_dependencies(check_list):
+ fields = tf.decode_csv(
+ line,
+ field_delim=self._input_sep,
+ record_defaults=self._record_defaults,
+ name='decode_csv')
+ if self._is_rtp:
+ if self._with_header:
+ inputs = dict(zip(self._all_fields, fields))
+ else:
+ inputs = {}
+ idx = 0
+ for x in range(len(self._record_defaults) - 1):
+ if not self._selected_cols or x in self._selected_cols[:-1]:
+ inputs[self._input_fields[idx]] = fields[x]
+ idx += 1
+ else:
+ inputs['no_used_%d' % x] = fields[x]
+ inputs[SINGLE_PLACEHOLDER_FEATURE_KEY] = fields[-1]
+ else:
+ inputs = {self._all_fields[x]: fields[x] for x in range(len(fields))}
+ return inputs
+
+ def _get_num_cols(self, file_paths):
+ # try to figure out number of fields from one file
+ num_cols = -1
+ with gfile.GFile(file_paths[0], 'r') as fin:
+ num_lines = 0
+ for line_str in fin:
+ line_tok = line_str.strip().split(self._input_sep)
+ if num_cols != -1:
+ assert num_cols == len(line_tok), (
+ 'num selected cols is %d, not equal to %d, current line is: %s, please check input_sep and data.'
+ % (num_cols, len(line_tok), line_str))
+ num_cols = len(line_tok)
+ num_lines += 1
+ if num_lines > 10:
+ break
+ logging.info('num selected cols = %d' % num_cols)
+ return num_cols
+
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
+ slice_id):
+ file_paths = []
+ for path in input_path.split(','):
+ for x in gfile.Glob(path):
+ if not x.endswith('_SUCCESS'):
+ file_paths.append(x)
+ assert len(file_paths) > 0, 'match no files with %s' % input_path
+
+ if self._with_header:
+ with gfile.GFile(file_paths[0], 'r') as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ self._field_names = line_str.split(self._input_sep)
+ break
+ print('field_names: %s' % ','.join(self._field_names))
+ self._all_fields = self._field_names
+ elif self._ds_vector_recall:
+ self._all_fields = self._selected_cols
+ else:
+ self._all_fields = self._input_fields
+ if self._is_rtp:
+ num_cols = self._get_num_cols(file_paths)
+ self._record_defaults = ['' for _ in range(num_cols)]
+ if not self._selected_cols:
+ self._selected_cols = list(range(num_cols))
+ for col_idx in self._selected_cols[:-1]:
+ col_name = self._input_fields[col_idx]
+ default_val = self._get_defaults(col_name)
+ self._record_defaults[col_idx] = default_val
+ else:
+ self._record_defaults = [
+ self._get_defaults(col_name) for col_name in self._all_fields
+ ]
+
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+ parallel_num = min(num_parallel_calls, len(file_paths))
+ dataset = dataset.interleave(
+ lambda x: tf.data.TextLineDataset(x).skip(int(self._with_header)),
+ cycle_length=parallel_num,
+ num_parallel_calls=parallel_num)
+ dataset = dataset.shard(slice_num, slice_id)
+ dataset = dataset.batch(batch_size)
+ dataset = dataset.prefetch(buffer_size=64)
+ return dataset
+
+ def _get_writer(self, output_path, slice_id):
+ if not gfile.Exists(output_path):
+ gfile.MakeDirs(output_path)
+ res_path = os.path.join(output_path, 'part-%d.csv' % slice_id)
+ table_writer = gfile.GFile(res_path, 'w')
+ table_writer.write(
+ self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
+ return table_writer
+
+ def _write_lines(self, table_writer, outputs):
+ outputs = '\n'.join(
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
+ table_writer.write(outputs + '\n')
+
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
+ reserve_vals = [outputs[x] for x in output_cols] + \
+ [all_vals[k] for k in reserved_cols]
+ return reserve_vals
+
+ @property
+ def out_of_range_exception(self):
+ return (tf.errors.OutOfRangeError)
diff --git a/easy_rec/python/inference/hive_parquet_predictor.py b/easy_rec/python/inference/hive_parquet_predictor.py
new file mode 100644
index 000000000..bd5178fe7
--- /dev/null
+++ b/easy_rec/python/inference/hive_parquet_predictor.py
@@ -0,0 +1,200 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import tf_utils
+from easy_rec.python.utils.hive_utils import HiveUtils
+from easy_rec.python.utils.tf_utils import get_tf_type
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class HiveParquetPredictor(Predictor):
+
+ def __init__(self,
+ model_path,
+ data_config,
+ hive_config,
+ fg_json_path=None,
+ profiling_file=None,
+ output_sep=chr(1),
+ all_cols=None,
+ all_col_types=None):
+ super(HiveParquetPredictor, self).__init__(model_path, profiling_file,
+ fg_json_path)
+
+ self._data_config = data_config
+ self._hive_config = hive_config
+ self._output_sep = output_sep
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
+ if 'rtp' in input_type:
+ self._is_rtp = True
+ else:
+ self._is_rtp = False
+ self._all_cols = [x.strip() for x in all_cols if x != '']
+ self._all_col_types = [x.strip() for x in all_col_types if x != '']
+ self._record_defaults = [
+ self._get_defaults(col_name, col_type)
+ for col_name, col_type in zip(self._all_cols, self._all_col_types)
+ ]
+
+ def _get_reserved_cols(self, reserved_cols):
+ if reserved_cols == 'ALL_COLUMNS':
+ reserved_cols = self._all_cols
+ else:
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
+ return reserved_cols
+
+ def _parse_line(self, *fields):
+ fields = list(fields)
+ field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))}
+ return field_dict
+
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
+ slice_id):
+ self._hive_util = HiveUtils(
+ data_config=self._data_config, hive_config=self._hive_config)
+ hdfs_path = self._hive_util.get_table_location(input_path)
+ self._input_hdfs_path = gfile.Glob(os.path.join(hdfs_path, '*'))
+ assert len(self._input_hdfs_path) > 0, 'match no files with %s' % input_path
+
+ list_type = []
+ input_field_type_map = {
+ x.input_name: x.input_type for x in self._data_config.input_fields
+ }
+ type_2_tftype = {
+ 'string': tf.string,
+ 'double': tf.double,
+ 'float': tf.float32,
+ 'bigint': tf.int32,
+ 'boolean': tf.bool
+ }
+ for col_name, col_type in zip(self._all_cols, self._all_col_types):
+ if col_name in input_field_type_map:
+ list_type.append(get_tf_type(input_field_type_map[col_name]))
+ else:
+ list_type.append(type_2_tftype[col_type.lower()])
+ list_type = tuple(list_type)
+ list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
+ list_shapes = tuple(list_shapes)
+
+ def parquet_read():
+ for input_path in self._input_hdfs_path:
+ if input_path.endswith('SUCCESS'):
+ continue
+ df = pd.read_parquet(input_path, engine='pyarrow')
+
+ df.replace('', np.nan, inplace=True)
+ df.replace('NULL', np.nan, inplace=True)
+ total_records_num = len(df)
+
+ for k, v in zip(self._all_cols, self._record_defaults):
+ df[k].fillna(v, inplace=True)
+
+ for start_idx in range(0, total_records_num, batch_size):
+ end_idx = min(total_records_num, start_idx + batch_size)
+ batch_data = df[start_idx:end_idx]
+ inputs = []
+ for k in self._all_cols:
+ inputs.append(batch_data[k].to_numpy())
+ yield tuple(inputs)
+
+ dataset = tf.data.Dataset.from_generator(
+ parquet_read, output_types=list_type, output_shapes=list_shapes)
+ dataset = dataset.shard(slice_num, slice_id)
+ dataset = dataset.prefetch(buffer_size=64)
+ return dataset
+
+ def get_table_info(self, output_path):
+ partition_name, partition_val = None, None
+ if len(output_path.split('/')) == 2:
+ table_name, partition = output_path.split('/')
+ partition_name, partition_val = partition.split('=')
+ else:
+ table_name = output_path
+ return table_name, partition_name, partition_val
+
+ def _get_writer(self, output_path, slice_id):
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
+ is_exist = self._hive_util.is_table_or_partition_exist(
+ table_name, partition_name, partition_val)
+ assert not is_exist, '%s is already exists. Please drop it.' % output_path
+
+ output_path = output_path.replace('.', '/')
+ self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
+ self._hive_config.host, output_path)
+ if not gfile.Exists(self._hdfs_path):
+ gfile.MakeDirs(self._hdfs_path)
+ res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
+ table_writer = gfile.GFile(res_path, 'w')
+ return table_writer
+
+ def _write_lines(self, table_writer, outputs):
+ outputs = '\n'.join(
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
+ table_writer.write(outputs + '\n')
+
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
+ reserve_vals = [outputs[x] for x in output_cols] + \
+ [all_vals[k] for k in reserved_cols]
+ return reserve_vals
+
+ def load_to_table(self, output_path, slice_num, slice_id):
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
+ success_writer = gfile.GFile(res_path, 'w')
+ success_writer.write('')
+ success_writer.close()
+
+ if slice_id != 0:
+ return
+
+ for id in range(slice_num):
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
+ while not gfile.Exists(res_path):
+ time.sleep(10)
+
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
+ schema = ''
+ for output_col_name in self._output_cols:
+ tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
+ col_type = tf_utils.get_col_type(tf_type)
+ schema += output_col_name + ' ' + col_type + ','
+
+ for output_col_name in self._reserved_cols:
+ assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
+ idx = self._all_cols.index(output_col_name)
+ output_col_types = self._all_col_types[idx]
+ schema += output_col_name + ' ' + output_col_types + ','
+ schema = schema.rstrip(',')
+
+ if partition_name and partition_val:
+ sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
+ (table_name, schema, partition_name)
+ self._hive_util.run_sql(sql)
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
+ (self._hdfs_path, table_name, partition_name, partition_val)
+ self._hive_util.run_sql(sql)
+ else:
+ sql = 'create table if not exists %s (%s)' % \
+ (table_name, schema)
+ self._hive_util.run_sql(sql)
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
+ (self._hdfs_path, table_name)
+ self._hive_util.run_sql(sql)
+
+ @property
+ def out_of_range_exception(self):
+ return (tf.errors.OutOfRangeError)
diff --git a/easy_rec/python/inference/hive_predictor.py b/easy_rec/python/inference/hive_predictor.py
new file mode 100644
index 000000000..f2923f8a3
--- /dev/null
+++ b/easy_rec/python/inference/hive_predictor.py
@@ -0,0 +1,166 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+import tensorflow as tf
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import tf_utils
+from easy_rec.python.utils.hive_utils import HiveUtils
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class HivePredictor(Predictor):
+
+ def __init__(self,
+ model_path,
+ data_config,
+ hive_config,
+ fg_json_path=None,
+ profiling_file=None,
+ output_sep=chr(1),
+ all_cols=None,
+ all_col_types=None):
+ super(HivePredictor, self).__init__(model_path, profiling_file,
+ fg_json_path)
+
+ self._data_config = data_config
+ self._hive_config = hive_config
+ self._output_sep = output_sep
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
+ if 'rtp' in input_type:
+ self._is_rtp = True
+ else:
+ self._is_rtp = False
+ self._all_cols = [x.strip() for x in all_cols if x != '']
+ self._all_col_types = [x.strip() for x in all_col_types if x != '']
+ self._record_defaults = [
+ self._get_defaults(col_name, col_type)
+ for col_name, col_type in zip(self._all_cols, self._all_col_types)
+ ]
+
+ def _get_reserved_cols(self, reserved_cols):
+ if reserved_cols == 'ALL_COLUMNS':
+ reserved_cols = self._all_cols
+ else:
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
+ return reserved_cols
+
+ def _parse_line(self, line):
+ field_delim = self._data_config.rtp_separator if self._is_rtp else self._data_config.separator
+ fields = tf.decode_csv(
+ line,
+ field_delim=field_delim,
+ record_defaults=self._record_defaults,
+ name='decode_csv')
+ inputs = {self._all_cols[x]: fields[x] for x in range(len(fields))}
+ return inputs
+
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
+ slice_id):
+ self._hive_util = HiveUtils(
+ data_config=self._data_config, hive_config=self._hive_config)
+ self._input_hdfs_path = self._hive_util.get_table_location(input_path)
+ file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*'))
+ assert len(file_paths) > 0, 'match no files with %s' % input_path
+
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+ parallel_num = min(num_parallel_calls, len(file_paths))
+ dataset = dataset.interleave(
+ tf.data.TextLineDataset,
+ cycle_length=parallel_num,
+ num_parallel_calls=parallel_num)
+ dataset = dataset.shard(slice_num, slice_id)
+ dataset = dataset.batch(batch_size)
+ dataset = dataset.prefetch(buffer_size=64)
+ return dataset
+
+ def get_table_info(self, output_path):
+ partition_name, partition_val = None, None
+ if len(output_path.split('/')) == 2:
+ table_name, partition = output_path.split('/')
+ partition_name, partition_val = partition.split('=')
+ else:
+ table_name = output_path
+ return table_name, partition_name, partition_val
+
+ def _get_writer(self, output_path, slice_id):
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
+ is_exist = self._hive_util.is_table_or_partition_exist(
+ table_name, partition_name, partition_val)
+ assert not is_exist, '%s is already exists. Please drop it.' % output_path
+
+ output_path = output_path.replace('.', '/')
+ self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
+ self._hive_config.host, output_path)
+ if not gfile.Exists(self._hdfs_path):
+ gfile.MakeDirs(self._hdfs_path)
+ res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
+ table_writer = gfile.GFile(res_path, 'w')
+ return table_writer
+
+ def _write_lines(self, table_writer, outputs):
+ outputs = '\n'.join(
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
+ table_writer.write(outputs + '\n')
+
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
+ reserve_vals = [outputs[x] for x in output_cols] + \
+ [all_vals[k] for k in reserved_cols]
+ return reserve_vals
+
+ def load_to_table(self, output_path, slice_num, slice_id):
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
+ success_writer = gfile.GFile(res_path, 'w')
+ success_writer.write('')
+ success_writer.close()
+
+ if slice_id != 0:
+ return
+
+ for id in range(slice_num):
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
+ while not gfile.Exists(res_path):
+ time.sleep(10)
+
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
+ schema = ''
+ for output_col_name in self._output_cols:
+ tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
+ col_type = tf_utils.get_col_type(tf_type)
+ schema += output_col_name + ' ' + col_type + ','
+
+ for output_col_name in self._reserved_cols:
+ assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
+ idx = self._all_cols.index(output_col_name)
+ output_col_types = self._all_col_types[idx]
+ schema += output_col_name + ' ' + output_col_types + ','
+ schema = schema.rstrip(',')
+
+ if partition_name and partition_val:
+ sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
+ (table_name, schema, partition_name)
+ self._hive_util.run_sql(sql)
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
+ (self._hdfs_path, table_name, partition_name, partition_val)
+ self._hive_util.run_sql(sql)
+ else:
+ sql = 'create table if not exists %s (%s)' % \
+ (table_name, schema)
+ self._hive_util.run_sql(sql)
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
+ (self._hdfs_path, table_name)
+ self._hive_util.run_sql(sql)
+
+ @property
+ def out_of_range_exception(self):
+ return (tf.errors.OutOfRangeError)
diff --git a/easy_rec/python/inference/odps_predictor.py b/easy_rec/python/inference/odps_predictor.py
new file mode 100644
index 000000000..183fc4d13
--- /dev/null
+++ b/easy_rec/python/inference/odps_predictor.py
@@ -0,0 +1,70 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from easy_rec.python.inference.predictor import Predictor
+
+
+class ODPSPredictor(Predictor):
+
+ def __init__(self,
+ model_path,
+ fg_json_path=None,
+ profiling_file=None,
+ all_cols='',
+ all_col_types=''):
+ super(ODPSPredictor, self).__init__(model_path, profiling_file,
+ fg_json_path)
+ self._all_cols = [x.strip() for x in all_cols.split(',') if x != '']
+ self._all_col_types = [
+ x.strip() for x in all_col_types.split(',') if x != ''
+ ]
+ self._record_defaults = [
+ self._get_defaults(col_name, col_type)
+ for col_name, col_type in zip(self._all_cols, self._all_col_types)
+ ]
+
+ def _get_reserved_cols(self, reserved_cols):
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
+ return reserved_cols
+
+ def _parse_line(self, *fields):
+ fields = list(fields)
+ field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))}
+ return field_dict
+
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
+ slice_id):
+ input_list = input_path.split(',')
+ dataset = tf.data.TableRecordDataset(
+ input_list,
+ record_defaults=self._record_defaults,
+ slice_id=slice_id,
+ slice_count=slice_num,
+ selected_cols=','.join(self._all_cols))
+ dataset = dataset.batch(batch_size)
+ dataset = dataset.prefetch(buffer_size=64)
+ return dataset
+
+ def _get_writer(self, output_path, slice_id):
+ import common_io
+ table_writer = common_io.table.TableWriter(output_path, slice_id=slice_id)
+ return table_writer
+
+ def _write_lines(self, table_writer, outputs):
+ assert len(outputs) > 0
+ indices = list(range(0, len(outputs[0])))
+ table_writer.write(outputs, indices, allow_type_cast=False)
+
+ @property
+ def out_of_range_exception(self):
+ return (tf.python_io.OutOfRangeException, tf.errors.OutOfRangeError)
+
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
+ reserve_vals = [all_vals[k] for k in reserved_cols] + \
+ [outputs[x] for x in output_cols]
+ return reserve_vals
diff --git a/easy_rec/python/inference/parquet_predictor.py b/easy_rec/python/inference/parquet_predictor.py
new file mode 100644
index 000000000..7fead6388
--- /dev/null
+++ b/easy_rec/python/inference/parquet_predictor.py
@@ -0,0 +1,147 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.input.parquet_input import ParquetInput
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import input_utils
+
+try:
+ from tensorflow.python.framework.load_library import load_op_library
+ import easy_rec
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
+ load_embed_lib = load_op_library(load_embed_lib_path)
+except Exception as ex:
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
+
+
+class ParquetPredictor(Predictor):
+
+ def __init__(self,
+ model_path,
+ data_config,
+ ds_vector_recall=False,
+ fg_json_path=None,
+ profiling_file=None,
+ selected_cols=None,
+ output_sep=chr(1),
+ pipeline_config=None):
+ super(ParquetPredictor, self).__init__(model_path, profiling_file,
+ fg_json_path)
+ self._output_sep = output_sep
+ self._ds_vector_recall = ds_vector_recall
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
+ self.pipeline_config = pipeline_config
+
+ if 'rtp' in input_type:
+ self._is_rtp = True
+ self._input_sep = data_config.rtp_separator
+ else:
+ self._is_rtp = False
+ self._input_sep = data_config.separator
+
+ if selected_cols and not ds_vector_recall:
+ self._selected_cols = [int(x) for x in selected_cols.split(',')]
+ elif ds_vector_recall:
+ self._selected_cols = selected_cols.split(',')
+ else:
+ self._selected_cols = None
+
+ def _parse_line(self, line):
+ out_dict = {}
+ for key in line['feature']:
+ out_dict[key] = line['feature'][key]
+ if 'reserve' in line:
+ out_dict['reserve'] = line['reserve']
+ # for key in line['reserve']:
+ # out_dict[key] = line['reserve'][key]
+ return out_dict
+
+ def _get_reserved_cols(self, reserved_cols):
+ # already parsed in _get_dataset
+ return self._reserved_cols
+
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
+ slice_id):
+ feature_configs = config_util.get_compatible_feature_configs(
+ self.pipeline_config)
+
+ kwargs = {}
+ if self._reserved_args is not None and len(self._reserved_args) > 0:
+ if self._reserved_args == 'ALL_COLUMNS':
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
+ # gfile not supported, read_parquet requires random access
+ all_data = pd.read_parquet(parquet_file)
+ all_cols = list(all_data.columns)
+ kwargs['reserve_fields'] = all_cols
+ self._all_fields = all_cols
+ self._reserved_cols = all_cols
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
+ all_cols, parquet_file)
+ else:
+ self._reserved_cols = [
+ x.strip() for x in self._reserved_args.split(',') if x.strip() != ''
+ ]
+ kwargs['reserve_fields'] = self._reserved_cols
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
+ self._reserved_cols, parquet_file)
+ logging.info('reserve_fields=%s reserve_types=%s' %
+ (','.join(self._reserved_cols), ','.join(
+ [str(x) for x in kwargs['reserve_types']])))
+ else:
+ self._reserved_cols = []
+ self.pipeline_config.data_config.batch_size = batch_size
+
+ kwargs['is_predictor'] = True
+ parquet_input = ParquetInput(
+ self.pipeline_config.data_config,
+ feature_configs,
+ input_path,
+ task_index=slice_id,
+ task_num=slice_num,
+ pipeline_config=self.pipeline_config,
+ **kwargs)
+ return parquet_input._build(tf.estimator.ModeKeys.PREDICT, {})
+
+ def _get_writer(self, output_path, slice_id):
+ if not gfile.Exists(output_path):
+ gfile.MakeDirs(output_path)
+ res_path = os.path.join(output_path, 'part-%d.csv' % slice_id)
+ table_writer = gfile.GFile(res_path, 'w')
+ table_writer.write(
+ self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
+ return table_writer
+
+ def _write_lines(self, table_writer, outputs):
+ outputs = '\n'.join(
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
+ table_writer.write(outputs + '\n')
+
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
+ reserve_vals = []
+ for x in outputs:
+ tmp_val = outputs[x]
+ reserve_vals.append(tmp_val)
+ for k in reserved_cols:
+ tmp_val = all_vals['reserve'][k]
+ if tmp_val.dtype == np.object:
+ tmp_val = [x.decode('utf-8') for x in tmp_val]
+ reserve_vals.append(tmp_val)
+ return reserve_vals
+
+ @property
+ def out_of_range_exception(self):
+ return (tf.errors.OutOfRangeError)
diff --git a/easy_rec/python/inference/parquet_predictor_v2.py b/easy_rec/python/inference/parquet_predictor_v2.py
new file mode 100644
index 000000000..1ee08517b
--- /dev/null
+++ b/easy_rec/python/inference/parquet_predictor_v2.py
@@ -0,0 +1,147 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.input.parquet_input_v2 import ParquetInputV2
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import input_utils
+
+try:
+ from tensorflow.python.framework.load_library import load_op_library
+ import easy_rec
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
+ load_embed_lib = load_op_library(load_embed_lib_path)
+except Exception as ex:
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
+
+
+class ParquetPredictorV2(Predictor):
+
+ def __init__(self,
+ model_path,
+ data_config,
+ ds_vector_recall=False,
+ fg_json_path=None,
+ profiling_file=None,
+ selected_cols=None,
+ output_sep=chr(1),
+ pipeline_config=None):
+ super(ParquetPredictorV2, self).__init__(model_path, profiling_file,
+ fg_json_path)
+ self._output_sep = output_sep
+ self._ds_vector_recall = ds_vector_recall
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
+ self.pipeline_config = pipeline_config
+
+ if 'rtp' in input_type:
+ self._is_rtp = True
+ self._input_sep = data_config.rtp_separator
+ else:
+ self._is_rtp = False
+ self._input_sep = data_config.separator
+
+ if selected_cols and not ds_vector_recall:
+ self._selected_cols = [int(x) for x in selected_cols.split(',')]
+ elif ds_vector_recall:
+ self._selected_cols = selected_cols.split(',')
+ else:
+ self._selected_cols = None
+
+ def _parse_line(self, line):
+ out_dict = {}
+ for key in line['feature']:
+ out_dict[key] = line['feature'][key]
+ if 'reserve' in line:
+ out_dict['reserve'] = line['reserve']
+ # for key in line['reserve']:
+ # out_dict[key] = line['reserve'][key]
+ return out_dict
+
+ def _get_reserved_cols(self, reserved_cols):
+ # already parsed in _get_dataset
+ return self._reserved_cols
+
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
+ slice_id):
+ feature_configs = config_util.get_compatible_feature_configs(
+ self.pipeline_config)
+
+ kwargs = {}
+ if self._reserved_args is not None and len(self._reserved_args) > 0:
+ if self._reserved_args == 'ALL_COLUMNS':
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
+ # gfile not supported, read_parquet requires random access
+ all_data = pd.read_parquet(parquet_file)
+ all_cols = list(all_data.columns)
+ kwargs['reserve_fields'] = all_cols
+ self._all_fields = all_cols
+ self._reserved_cols = all_cols
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
+ all_cols, parquet_file)
+ else:
+ self._reserved_cols = [
+ x.strip() for x in self._reserved_args.split(',') if x.strip() != ''
+ ]
+ kwargs['reserve_fields'] = self._reserved_cols
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
+ self._reserved_cols, parquet_file)
+ logging.info('reserve_fields=%s reserve_types=%s' %
+ (','.join(self._reserved_cols), ','.join(
+ [str(x) for x in kwargs['reserve_types']])))
+ else:
+ self._reserved_cols = []
+ self.pipeline_config.data_config.batch_size = batch_size
+
+ kwargs['is_predictor'] = True
+ parquet_input = ParquetInputV2(
+ self.pipeline_config.data_config,
+ feature_configs,
+ input_path,
+ task_index=slice_id,
+ task_num=slice_num,
+ pipeline_config=self.pipeline_config,
+ **kwargs)
+ return parquet_input._build(tf.estimator.ModeKeys.PREDICT, {})
+
+ def _get_writer(self, output_path, slice_id):
+ if not gfile.Exists(output_path):
+ gfile.MakeDirs(output_path)
+ res_path = os.path.join(output_path, 'part-%d.csv' % slice_id)
+ table_writer = gfile.GFile(res_path, 'w')
+ table_writer.write(
+ self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
+ return table_writer
+
+ def _write_lines(self, table_writer, outputs):
+ outputs = '\n'.join(
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
+ table_writer.write(outputs + '\n')
+
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
+ reserve_vals = []
+ for x in outputs:
+ tmp_val = outputs[x]
+ reserve_vals.append(tmp_val)
+ for k in reserved_cols:
+ tmp_val = all_vals['reserve'][k]
+ if tmp_val.dtype == np.object:
+ tmp_val = [x.decode('utf-8') for x in tmp_val]
+ reserve_vals.append(tmp_val)
+ return reserve_vals
+
+ @property
+ def out_of_range_exception(self):
+ return (tf.errors.OutOfRangeError)
diff --git a/easy_rec/python/inference/predictor.py b/easy_rec/python/inference/predictor.py
index bc68891d7..2d02fac71 100644
--- a/easy_rec/python/inference/predictor.py
+++ b/easy_rec/python/inference/predictor.py
@@ -5,6 +5,7 @@
from __future__ import print_function
import abc
+import json
import logging
import math
import os
@@ -14,14 +15,23 @@
import six
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import signature_constants
-from easy_rec.python.utils import pai_util
+import easy_rec
+from easy_rec.python.utils import numpy_utils
from easy_rec.python.utils.config_util import get_configs_from_pipeline_file
+from easy_rec.python.utils.config_util import get_input_name_from_fg_json
+from easy_rec.python.utils.config_util import search_fg_json
from easy_rec.python.utils.input_utils import get_type_defaults
from easy_rec.python.utils.load_class import get_register_class_meta
+try:
+ tf.load_op_library(os.path.join(easy_rec.ops_dir, 'libcustom_ops.so'))
+except Exception as ex:
+ logging.warning('exception: %s' % str(ex))
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -90,7 +100,7 @@ def get_output_type(self):
class PredictorImpl(object):
- def __init__(self, model_path, profiling_file=None):
+ def __init__(self, model_path, profiling_file=None, use_latest=False):
"""Impl class for predictor.
Args:
@@ -98,6 +108,8 @@ def __init__(self, model_path, profiling_file=None):
profiling_file: profiling result file, default None.
if not None, predict function will use Timeline to profiling
prediction time, and the result json will be saved to profiling_file
+ use_latest: use latest saved_model.pb if multiple ones are found,
+ else raise an exception.
"""
self._inputs_map = {}
self._outputs_map = {}
@@ -106,6 +118,7 @@ def __init__(self, model_path, profiling_file=None):
self._model_path = model_path
self._input_names = []
self._is_multi_placeholder = True
+ self._use_latest = use_latest
self._build_model()
@@ -133,21 +146,32 @@ def search_pb(self, directory):
directory contain pb file
"""
dir_list = []
- for root, dirs, files in tf.gfile.Walk(directory):
+ for root, dirs, files in gfile.Walk(directory):
for f in files:
- _, ext = os.path.splitext(f)
- if ext == '.pb':
+ if f.endswith('saved_model.pb'):
dir_list.append(root)
if len(dir_list) == 0:
raise ValueError('savedmodel is not found in directory %s' % directory)
elif len(dir_list) > 1:
- raise ValueError('multiple saved model found in directory %s' % directory)
+ if self._use_latest:
+ logging.info('find %d models: %s' % (len(dir_list), ','.join(dir_list)))
+ dir_list = sorted(
+ dir_list,
+ key=lambda x: int(x.split('/')[(-2 if (x[-1] == '/') else -1)]))
+ return dir_list[-1]
+ else:
+ raise ValueError('multiple saved model found in directory %s' %
+ directory)
return dir_list[0]
def _get_input_fields_from_pipeline_config(self, model_path):
pipeline_path = os.path.join(model_path, 'assets/pipeline.config')
- assert tf.gfile.Exists(pipeline_path), '%s not exists.' % pipeline_path
+ if not gfile.Exists(pipeline_path):
+ logging.warning(
+ '%s not exists, default values maybe inconsistent with the values used in training.'
+ % pipeline_path)
+ return {}
pipeline_config = get_configs_from_pipeline_file(pipeline_path)
input_fields = pipeline_config.data_config.input_fields
input_fields_info = {
@@ -175,7 +199,7 @@ def _build_model(self):
# load model
_, ext = os.path.splitext(model_path)
tf.logging.info('loading model from %s' % model_path)
- if tf.gfile.IsDirectory(model_path):
+ if gfile.IsDirectory(model_path):
model_path = self.search_pb(model_path)
logging.info('model find in %s' % model_path)
self._input_fields_info, self._input_fields_list = self._get_input_fields_from_pipeline_config(
@@ -239,7 +263,7 @@ def _build_model(self):
type_name = asset_file.tensor_info.name.split(':')[0]
asset_path = os.path.join(model_path, constants.ASSETS_DIRECTORY,
asset_file.filename)
- assert tf.gfile.Exists(
+ assert gfile.Exists(
asset_path), '%s is missing in saved model' % asset_path
self._assets[type_name] = asset_path
logging.info(self._assets)
@@ -299,14 +323,18 @@ def predict(self, input_data_dict, output_names=None):
from tensorflow.python.client import timeline
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
- with tf.gfile.GFile(self._profiling_file, 'w') as f:
+ with gfile.GFile(self._profiling_file, 'w') as f:
f.write(ctf)
return results
class Predictor(PredictorInterface):
- def __init__(self, model_path, profiling_file=None):
+ def __init__(self,
+ model_path,
+ profiling_file=None,
+ fg_json_path=None,
+ use_latest=True):
"""Initialize a `Predictor`.
Args:
@@ -314,8 +342,10 @@ def __init__(self, model_path, profiling_file=None):
profiling_file: profiling result file, default None.
if not None, predict function will use Timeline to profiling
prediction time, and the result json will be saved to profiling_file
+ fg_json_path: fg.json file
+ use_latest: use latest saved_model.pb if multiple one exists.
"""
- self._predictor_impl = PredictorImpl(model_path, profiling_file)
+ self._predictor_impl = PredictorImpl(model_path, profiling_file, use_latest)
self._inputs_map = self._predictor_impl._inputs_map
self._outputs_map = self._predictor_impl._outputs_map
self._profiling_file = profiling_file
@@ -324,6 +354,9 @@ def __init__(self, model_path, profiling_file=None):
self._is_multi_placeholder = self._predictor_impl._is_multi_placeholder
self._input_fields = self._predictor_impl._input_fields_list
+ fg_json = self._get_fg_json(fg_json_path, model_path)
+ self._all_input_names = get_input_name_from_fg_json(fg_json)
+ logging.info('all_input_names: %s' % self._all_input_names)
@property
def input_names(self):
@@ -343,27 +376,76 @@ def output_names(self):
"""
return list(self._outputs_map.keys())
- def predict_impl(self,
- input_table,
- output_table,
- all_cols='',
- all_col_types='',
- selected_cols='',
- reserved_cols='',
- output_cols=None,
- batch_size=1024,
- slice_id=0,
- slice_num=1,
- input_sep=',',
- output_sep=chr(1)):
+ def _get_defaults(self, col_name, col_type='string'):
+ if col_name in self._input_fields_info:
+ col_type, default_val = self._input_fields_info[col_name]
+ default_val = get_type_defaults(col_type, default_val)
+ logging.info('col_name: %s, default_val: %s' % (col_name, default_val))
+ else:
+ defaults = {'string': '', 'double': 0.0, 'bigint': 0}
+ assert col_type in defaults, 'invalid col_type: %s, col_type: %s' % (
+ col_name, col_type)
+ default_val = defaults[col_type]
+ logging.info(
+ 'col_name: %s, default_val: %s.[not defined in saved_model_dir/assets/pipeline.config]'
+ % (col_name, default_val))
+ return default_val
+
+ def _parse_line(self, line):
+ pass
+
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
+ slice_id):
+ pass
+
+ def _get_writer(self, output_path, slice_id):
+ pass
+
+ def _get_reserved_cols(self, reserved_cols):
+ pass
+
+ @property
+ def out_of_range_exception(self):
+ return None
+
+ def _write_lines(self, table_writer, outputs):
+ pass
+
+ def load_to_table(self, output_path, slice_num, slice_id):
+ pass
+
+ def _get_fg_json(self, fg_json_path, model_path):
+ if fg_json_path and gfile.Exists(fg_json_path):
+ logging.info('load fg_json_path: ', fg_json_path)
+ with tf.gfile.GFile(fg_json_path, 'r') as fin:
+ fg_json = json.loads(fin.read())
+ else:
+ fg_json_path = search_fg_json(model_path)
+ if fg_json_path:
+ with tf.gfile.GFile(fg_json_path, 'r') as fin:
+ fg_json = json.loads(fin.read())
+ else:
+ fg_json = {}
+ return fg_json
+
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
+ pass
+
+ def predict_impl(
+ self,
+ input_path,
+ output_path,
+ reserved_cols='',
+ output_cols=None,
+ batch_size=1024,
+ slice_id=0,
+ slice_num=1,
+ ):
"""Predict table input with loaded model.
Args:
- input_table: table/file_path to read
- output_table: table/file_path to write
- all_cols: union of columns
- all_col_types: data types of the columns
- selected_cols: included column names, comma separated, such as "a,b,c"
+ input_path: table/file_path to read
+ output_path: table/file_path to write
reserved_cols: columns to be copy to output_table, comma separated, such as "a,b"
output_cols: output columns, comma separated, such as "y float, embedding string",
the output names[y, embedding] must be in saved_model output_names
@@ -371,224 +453,35 @@ def predict_impl(self,
slice_id: when multiple workers write the same table, each worker should
be assigned different slice_id, which is usually slice_id
slice_num: table slice number
- input_sep: separator of input file.
- output_sep: separator of predict result file.
"""
- if pai_util.is_on_pai():
- self.predict_table(
- input_table,
- output_table,
- all_cols=all_cols,
- all_col_types=all_col_types,
- selected_cols=selected_cols,
- reserved_cols=reserved_cols,
- output_cols=output_cols,
- batch_size=batch_size,
- slice_id=slice_id,
- slice_num=slice_num)
- else:
- self.predict_csv(
- input_table,
- output_table,
- reserved_cols=reserved_cols,
- output_cols=output_cols,
- batch_size=batch_size,
- slice_id=slice_id,
- slice_num=slice_num,
- input_sep=input_sep,
- output_sep=output_sep)
-
- def predict_csv(self, input_path, output_path, reserved_cols, output_cols,
- batch_size, slice_id, slice_num, input_sep, output_sep):
- record_defaults = [
- self._input_fields_info[col_name][1] for col_name in self._input_fields
- ]
- if reserved_cols == 'ALL_COLUMNS':
- reserved_cols = self._input_fields
- else:
- reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
if output_cols is None or output_cols == 'ALL_COLUMNS':
- output_cols = sorted(self._predictor_impl.output_names)
- logging.info('predict output cols: %s' % output_cols)
+ self._output_cols = sorted(self._predictor_impl.output_names)
+ logging.info('predict output cols: %s' % self._output_cols)
else:
# specified as score float,embedding string
tmp_cols = []
for x in output_cols.split(','):
if x.strip() == '':
continue
- tmp_keys = x.split(' ')
+ tmp_keys = x.strip().split(' ')
tmp_cols.append(tmp_keys[0].strip())
- output_cols = tmp_cols
+ self._output_cols = tmp_cols
with tf.Graph().as_default(), tf.Session() as sess:
num_parallel_calls = 8
- file_paths = []
- for x in input_path.split(','):
- file_paths.extend(tf.gfile.Glob(x))
- assert len(file_paths) > 0, 'match no files with %s' % input_path
-
- dataset = tf.data.Dataset.from_tensor_slices(file_paths)
- parallel_num = min(num_parallel_calls, len(file_paths))
- dataset = dataset.interleave(
- tf.data.TextLineDataset,
- cycle_length=parallel_num,
- num_parallel_calls=parallel_num)
- dataset = dataset.shard(slice_num, slice_id)
- logging.info('batch_size = %d' % batch_size)
- dataset = dataset.batch(batch_size)
- dataset = dataset.prefetch(buffer_size=64)
-
- def _parse_csv(line):
-
- def _check_data(line):
- sep = input_sep
- if type(sep) != type(str):
- sep = sep.encode('utf-8')
- field_num = len(line[0].split(sep))
- assert field_num == len(record_defaults), 'sep[%s] maybe invalid: field_num=%d, required_num=%d' \
- % (sep, field_num, len(record_defaults))
- return True
-
- check_op = tf.py_func(_check_data, [line], Tout=tf.bool)
- with tf.control_dependencies([check_op]):
- fields = tf.decode_csv(
- line,
- field_delim=',',
- record_defaults=record_defaults,
- name='decode_csv')
-
- inputs = {self._input_fields[x]: fields[x] for x in range(len(fields))}
- return inputs
-
- dataset = dataset.map(_parse_csv, num_parallel_calls=num_parallel_calls)
- iterator = dataset.make_one_shot_iterator()
- all_dict = iterator.get_next()
-
- if not tf.gfile.Exists(output_path):
- tf.gfile.MakeDirs(output_path)
- res_path = os.path.join(output_path, 'slice_%d.csv' % slice_id)
- table_writer = tf.gfile.FastGFile(res_path, 'w')
-
- input_names = self._predictor_impl.input_names
- progress = 0
- sum_t0, sum_t1, sum_t2 = 0, 0, 0
- pred_cnt = 0
- table_writer.write(output_sep.join(output_cols + reserved_cols) + '\n')
- while True:
- try:
- ts0 = time.time()
- all_vals = sess.run(all_dict)
-
- ts1 = time.time()
- input_vals = {k: all_vals[k] for k in input_names}
- outputs = self._predictor_impl.predict(input_vals, output_cols)
-
- for x in output_cols:
- if outputs[x].dtype == np.object:
- outputs[x] = [val.decode('utf-8') for val in outputs[x]]
- for k in reserved_cols:
- if all_vals[k].dtype == np.object:
- all_vals[k] = [val.decode('utf-8') for val in all_vals[k]]
-
- ts2 = time.time()
- reserve_vals = [outputs[x] for x in output_cols] + \
- [all_vals[k] for k in reserved_cols]
- outputs = [x for x in zip(*reserve_vals)]
- pred_cnt += len(outputs)
- outputs = '\n'.join(
- [output_sep.join([str(i) for i in output]) for output in outputs])
- table_writer.write(outputs + '\n')
-
- ts3 = time.time()
- progress += 1
- sum_t0 += (ts1 - ts0)
- sum_t1 += (ts2 - ts1)
- sum_t2 += (ts3 - ts2)
- except tf.errors.OutOfRangeError:
- break
- if progress % 100 == 0:
- logging.info('progress: batch_num=%d sample_num=%d' %
- (progress, progress * batch_size))
- logging.info('time_stats: read: %.2f predict: %.2f write: %.2f' %
- (sum_t0, sum_t1, sum_t2))
- logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' %
- (sum_t0, sum_t1, sum_t2))
- table_writer.close()
- logging.info('Predict %s done.' % input_path)
- logging.info('Predict size: %d.' % pred_cnt)
-
- def predict_table(self,
- input_table,
- output_table,
- all_cols,
- all_col_types,
- selected_cols,
- reserved_cols,
- output_cols=None,
- batch_size=1024,
- slice_id=0,
- slice_num=1):
-
- def _get_defaults(col_name, col_type):
- if col_name in self._input_fields_info:
- col_type, default_val = self._input_fields_info[col_name]
- default_val = get_type_defaults(col_type, default_val)
- logging.info('col_name: %s, default_val: %s' % (col_name, default_val))
+ self._reserved_args = reserved_cols
+ dataset = self._get_dataset(input_path, num_parallel_calls, batch_size,
+ slice_num, slice_id)
+ dataset = dataset.map(
+ self._parse_line, num_parallel_calls=num_parallel_calls)
+ if hasattr(tf.data, 'make_one_shot_iterator'):
+ iterator = tf.data.make_one_shot_iterator(dataset)
else:
- logging.info('col_name: %s is not used in predict.' % col_name)
- defaults = {'string': '', 'double': 0.0, 'bigint': 0}
- assert col_type in defaults, 'invalid col_type: %s, col_type: %s' % (
- col_name, col_type)
- default_val = defaults[col_type]
- return default_val
-
- all_cols = [x.strip() for x in all_cols.split(',') if x != '']
- all_col_types = [x.strip() for x in all_col_types.split(',') if x != '']
- reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
-
- if output_cols is None:
- output_cols = self._predictor_impl.output_names
- else:
- # specified as score float,embedding string
- tmp_cols = []
- for x in output_cols.split(','):
- if x.strip() == '':
- continue
- tmp_keys = x.split(' ')
- tmp_cols.append(tmp_keys[0].strip())
- output_cols = tmp_cols
-
- record_defaults = [
- _get_defaults(col_name, col_type)
- for col_name, col_type in zip(all_cols, all_col_types)
- ]
- with tf.Graph().as_default(), tf.Session() as sess:
- num_parallel_calls = 8
- input_table = input_table.split(',')
- dataset = tf.data.TableRecordDataset([input_table],
- record_defaults=record_defaults,
- slice_id=slice_id,
- slice_count=slice_num,
- selected_cols=','.join(all_cols))
-
- logging.info('batch_size = %d' % batch_size)
- dataset = dataset.batch(batch_size)
- dataset = dataset.prefetch(buffer_size=64)
-
- def _parse_table(*fields):
- fields = list(fields)
- field_dict = {all_cols[i]: fields[i] for i in range(len(fields))}
- return field_dict
-
- dataset = dataset.map(_parse_table, num_parallel_calls=num_parallel_calls)
- iterator = dataset.make_one_shot_iterator()
+ iterator = dataset.make_one_shot_iterator()
all_dict = iterator.get_next()
-
- import common_io
- table_writer = common_io.table.TableWriter(
- output_table, slice_id=slice_id)
-
+ self._reserved_cols = self._get_reserved_cols(reserved_cols)
input_names = self._predictor_impl.input_names
+ table_writer = self._get_writer(output_path, slice_id)
def _parse_value(all_vals):
if self._is_multi_placeholder:
@@ -596,11 +489,22 @@ def _parse_value(all_vals):
feature_vals = all_vals[SINGLE_PLACEHOLDER_FEATURE_KEY]
split_index = []
split_vals = {}
- for i, k in enumerate(input_names):
- split_index.append(k)
- split_vals[k] = []
+ fg_input_size = len(feature_vals[0].decode('utf-8').split('\002'))
+ if fg_input_size == len(input_names):
+ for i, k in enumerate(input_names):
+ split_index.append(k)
+ split_vals[k] = []
+ else:
+ assert self._all_input_names, 'must set fg_json_path when use fg input'
+ assert fg_input_size == len(self._all_input_names), (
+ 'The number of features defined in fg_json != the size of fg input. '
+ 'The number of features defined in fg_json is: %d; The size of fg input is: %d'
+ % (len(self._all_input_names), fg_input_size))
+ for i, k in enumerate(self._all_input_names):
+ split_index.append(k)
+ split_vals[k] = []
for record in feature_vals:
- split_records = record.split('\002')
+ split_records = record.decode('utf-8').split('\002')
for i, r in enumerate(split_records):
split_vals[split_index[i]].append(r)
return {k: np.array(split_vals[k]) for k in input_names}
@@ -616,25 +520,38 @@ def _parse_value(all_vals):
ts1 = time.time()
input_vals = _parse_value(all_vals)
- # logging.info('input names = %s' % input_names)
- # logging.info('input vals = %s' % input_vals)
- outputs = self._predictor_impl.predict(input_vals, output_cols)
+ outputs = self._predictor_impl.predict(input_vals, self._output_cols)
+ for x in self._output_cols:
+ if outputs[x].dtype == np.object:
+ outputs[x] = [val.decode('utf-8') for val in outputs[x]]
+ elif len(outputs[x].shape) == 2 and outputs[x].shape[1] == 1:
+ # automatic flatten only one element array
+ outputs[x] = [val[0] for val in outputs[x]]
+ elif len(outputs[x].shape) > 1:
+ outputs[x] = [
+ json.dumps(val, cls=numpy_utils.NumpyEncoder)
+ for val in outputs[x]
+ ]
+ for k in self._reserved_cols:
+ if k in all_vals and all_vals[k].dtype == np.object:
+ all_vals[k] = [
+ val.decode('utf-8', errors='ignore') for val in all_vals[k]
+ ]
ts2 = time.time()
- reserve_vals = [all_vals[k] for k in reserved_cols
- ] + [outputs[x] for x in output_cols]
- indices = list(range(0, len(reserve_vals)))
+ reserve_vals = self._get_reserve_vals(self._reserved_cols,
+ self._output_cols, all_vals,
+ outputs)
outputs = [x for x in zip(*reserve_vals)]
+ logging.info('predict size: %s' % len(outputs))
+ self._write_lines(table_writer, outputs)
- table_writer.write(outputs, indices, allow_type_cast=False)
ts3 = time.time()
progress += 1
sum_t0 += (ts1 - ts0)
sum_t1 += (ts2 - ts1)
sum_t2 += (ts3 - ts2)
- except tf.python_io.OutOfRangeException:
- break
- except tf.errors.OutOfRangeError:
+ except self.out_of_range_exception:
break
if progress % 100 == 0:
logging.info('progress: batch_num=%d sample_num=%d' %
@@ -644,7 +561,8 @@ def _parse_value(all_vals):
logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' %
(sum_t0, sum_t1, sum_t2))
table_writer.close()
- logging.info('Predict %s done.' % input_table)
+ self.load_to_table(output_path, slice_num, slice_id)
+ logging.info('Predict %s done.' % input_path)
def predict(self, input_data_dict_list, output_names=None, batch_size=1):
"""Predict input data with loaded model.
diff --git a/easy_rec/python/inference/processor/__init__.py b/easy_rec/python/inference/processor/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/easy_rec/python/inference/processor/test.py b/easy_rec/python/inference/processor/test.py
new file mode 100644
index 000000000..088c93edc
--- /dev/null
+++ b/easy_rec/python/inference/processor/test.py
@@ -0,0 +1,170 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import ctypes
+import glob
+import json
+import logging
+import os
+import subprocess
+import time
+
+import numpy as np
+from google.protobuf import text_format
+
+from easy_rec.python.protos import dataset_pb2
+from easy_rec.python.protos import pipeline_pb2
+from easy_rec.python.protos import tf_predict_pb2
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+PROCESSOR_VERSION = 'LaRec-0.9.5d-b1b1604-TF-2.5.0-Linux'
+PROCESSOR_FILE = PROCESSOR_VERSION + '.tar.gz'
+PROCESSOR_URL = 'http://easyrec.oss-cn-beijing.aliyuncs.com/processor/' + PROCESSOR_FILE
+PROCESSOR_ENTRY_LIB = 'processor/' + PROCESSOR_VERSION + '/larec/libtf_predictor.so'
+
+
+def build_array_proto(array_proto, data, dtype):
+ array_proto.array_shape.dim.append(len(data))
+
+ if dtype == dataset_pb2.DatasetConfig.STRING:
+ array_proto.string_val.extend([x.encode('utf-8') for x in data])
+ array_proto.dtype = tf_predict_pb2.DT_STRING
+ elif dtype == dataset_pb2.DatasetConfig.FLOAT:
+ array_proto.float_val.extend([float(x) for x in data])
+ array_proto.dtype = tf_predict_pb2.DT_FLOAT
+ elif dtype == dataset_pb2.DatasetConfig.DOUBLE:
+ array_proto.double_val.extend([float(x) for x in data])
+ array_proto.dtype = tf_predict_pb2.DT_DOUBLE
+ elif dtype == dataset_pb2.DatasetConfig.INT32:
+ array_proto.int_val.extend([int(x) for x in data])
+ array_proto.dtype = tf_predict_pb2.DT_INT32
+ elif dtype == dataset_pb2.DatasetConfig.INT64:
+ array_proto.int64_val.extend([np.int64(x) for x in data])
+ array_proto.dtype = tf_predict_pb2.DT_INT64
+ else:
+ assert False, 'invalid datatype[%s]' % str(dtype)
+ return array_proto
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input_path', type=str, default=None, help='input data path')
+ parser.add_argument(
+ '--output_path', type=str, default=None, help='output data path')
+ parser.add_argument(
+ '--libc_path',
+ type=str,
+ default='/lib64/libc.so.6',
+ help='libc.so.6 path')
+ parser.add_argument(
+ '--saved_model_dir', type=str, default=None, help='saved model directory')
+ parser.add_argument(
+ '--test_dir', type=str, default=None, help='test directory')
+ args = parser.parse_args()
+
+ if not os.path.exists('processor'):
+ os.mkdir('processor')
+ if not os.path.exists(PROCESSOR_ENTRY_LIB):
+ if not os.path.exists('processor/' + PROCESSOR_FILE):
+ subprocess.check_output(
+ 'wget %s -O processor/%s' % (PROCESSOR_URL, PROCESSOR_FILE),
+ shell=True)
+ subprocess.check_output(
+ 'cd processor && tar -zvxf %s' % PROCESSOR_FILE, shell=True)
+ assert os.path.exists(
+ PROCESSOR_ENTRY_LIB), 'invalid processor path: %s' % PROCESSOR_ENTRY_LIB
+
+ assert os.path.exists(args.libc_path), '%s does not exist' % args.libc_path
+ assert args.saved_model_dir is not None and os.path.isdir(
+ args.saved_model_dir
+ ), '%s is not a valid directory' % args.saved_model_dir
+ assert args.input_path is not None and os.path.exists(
+ args.input_path), '%s does not exist' % args.input_path
+ assert args.output_path is not None, 'output_path is not set'
+
+ pipeline_config = pipeline_pb2.EasyRecConfig()
+ pipeline_config_path = os.path.join(args.saved_model_dir,
+ 'assets/pipeline.config')
+ with open(pipeline_config_path) as fin:
+ config_str = fin.read()
+ text_format.Merge(config_str, pipeline_config)
+
+ data_config = pipeline_config.data_config
+
+ input_fields = [[]
+ for x in data_config.input_fields
+ if x.input_name not in data_config.label_fields]
+
+ with open(args.input_path, 'r') as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ line_toks = line_str.split(data_config.rtp_separator)[-1].split(chr(2))
+ for i, tok in enumerate(line_toks):
+ input_fields[i].append(tok)
+
+ req = tf_predict_pb2.PredictRequest()
+ req.signature_name = 'serving_default'
+ for i in range(len(input_fields)):
+ build_array_proto(req.inputs[data_config.input_fields[i + 1].input_name],
+ input_fields[i],
+ data_config.input_fields[i + 1].input_type)
+
+ tf_predictor = ctypes.cdll.LoadLibrary(PROCESSOR_ENTRY_LIB)
+ tf_predictor.saved_model_init.restype = ctypes.c_void_p
+ handle = tf_predictor.saved_model_init(args.saved_model_dir.encode('utf-8'))
+ logging.info('saved_model handle=%d' % handle)
+
+ num_steps = pipeline_config.train_config.num_steps
+ logging.info('num_steps=%d' % num_steps)
+
+ # last_step could be greater than num_steps for sync_replicas: false
+ train_dir = os.path.dirname(args.saved_model_dir.strip('/'))
+ all_models = glob.glob(
+ os.path.join(args.test_dir, 'train/model.ckpt-*.index'))
+ iters = [int(x.split('-')[-1].replace('.index', '')) for x in all_models]
+ iters.sort()
+ last_step = iters[-1]
+ logging.info('last_step=%d' % last_step)
+
+ sparse_step = ctypes.c_int(0)
+ dense_step = ctypes.c_int(0)
+ start_ts = time.time()
+ while sparse_step.value < last_step or dense_step.value < last_step:
+ tf_predictor.saved_model_step(
+ ctypes.c_void_p(handle), ctypes.byref(sparse_step),
+ ctypes.byref(dense_step))
+ time.sleep(1)
+ if time.time() - start_ts > 300:
+ logging.warning(
+ 'could not reach last_step, sparse_step=%d dense_step=%d' %
+ (sparse_step.value, dense_step.value))
+ break
+
+ data_bin = req.SerializeToString()
+ save_path = os.path.join(args.saved_model_dir, 'req.pb')
+ with open(save_path, 'wb') as fout:
+ fout.write(data_bin)
+ logging.info('save request to %s' % save_path)
+
+ tf_predictor.saved_model_predict.restype = ctypes.c_void_p
+ out_len = ctypes.c_int(0)
+ res_p = tf_predictor.saved_model_predict(
+ ctypes.c_void_p(handle), data_bin, ctypes.c_int32(len(data_bin)),
+ ctypes.byref(out_len))
+ res_bytes = bytearray(ctypes.string_at(res_p, out_len))
+ res = tf_predict_pb2.PredictResponse()
+ res.ParseFromString(res_bytes)
+
+ with open(args.output_path, 'w') as fout:
+ logits = res.outputs['logits'].float_val
+ probs = res.outputs['probs'].float_val
+ for logit, prob in zip(logits, probs):
+ fout.write(json.dumps({'logits': logit, 'probs': prob}) + '\n')
+
+ # free memory
+ tf_predictor.saved_model_release(ctypes.c_void_p(handle))
+ libc = ctypes.cdll.LoadLibrary(args.libc_path)
+ libc.free(ctypes.c_void_p(res_p))
diff --git a/easy_rec/python/inference/vector_retrieve.py b/easy_rec/python/inference/vector_retrieve.py
index 917853484..a8a8122d5 100644
--- a/easy_rec/python/inference/vector_retrieve.py
+++ b/easy_rec/python/inference/vector_retrieve.py
@@ -10,14 +10,14 @@
import common_io
import numpy as np
import tensorflow as tf
+
try:
import graphlearn as gl
-except:
- logging.WARN(
- 'GraphLearn is not installed. You can install it by "pip install http://odps-release.cn-hangzhou.oss-cdn.aliyun-inc.com/graphlearn/tunnel/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl."' # noqa: E501
+except: # noqa: E722
+ logging.warning(
+ 'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl.' # noqa: E501
)
-
if tf.__version__ >= '2.0':
tf = tf.compat.v1
diff --git a/easy_rec/python/input/batch_tfrecord_input.py b/easy_rec/python/input/batch_tfrecord_input.py
index 620b91a59..fb1981f60 100644
--- a/easy_rec/python/input/batch_tfrecord_input.py
+++ b/easy_rec/python/input/batch_tfrecord_input.py
@@ -5,6 +5,7 @@
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.utils.tf_utils import get_tf_type
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -23,9 +24,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(BatchTFRecordInput, self).__init__(data_config, feature_config,
- input_path, task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(BatchTFRecordInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
assert data_config.HasField(
'n_data_batch_tfrecord'), 'Need to set n_data_batch_tfrecord in config.'
self._input_shapes = [x.input_shape for x in data_config.input_fields]
@@ -33,7 +37,7 @@ def __init__(self,
for x, t, d, s in zip(self._input_fields, self._input_field_types,
self._input_field_defaults, self._input_shapes):
d = self.get_type_defaults(t, d)
- t = self.get_tf_type(t)
+ t = get_tf_type(t)
self.feature_desc[x] = tf.io.FixedLenSequenceFeature(
dtype=t, shape=s, allow_missing=True)
@@ -54,7 +58,11 @@ def _parse_tfrecord(self, example):
return features
def _build(self, mode, params):
- file_paths = tf.gfile.Glob(self._input_path)
+ if type(self._input_path) != list:
+ self._input_path = self._input_path.split(',')
+ file_paths = []
+ for x in self._input_path:
+ file_paths.extend(tf.gfile.Glob(x))
assert len(file_paths) > 0, 'match no files with %s' % self._input_path
num_parallel_calls = self._data_config.num_parallel_calls
diff --git a/easy_rec/python/input/criteo_binary_reader.py b/easy_rec/python/input/criteo_binary_reader.py
new file mode 100644
index 000000000..6672165c0
--- /dev/null
+++ b/easy_rec/python/input/criteo_binary_reader.py
@@ -0,0 +1,259 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import concurrent
+import concurrent.futures
+import glob
+import logging
+import os
+import queue
+import time
+
+import numpy as np
+
+
+class BinaryDataset:
+
+ def __init__(
+ self,
+ label_bins,
+ dense_bins,
+ category_bins,
+ batch_size=1,
+ drop_last=False,
+ prefetch=1,
+ global_rank=0,
+ global_size=1,
+ ):
+ total_sample_num = 0
+ self._sample_num_arr = []
+ for label_bin in label_bins:
+ sample_num = os.path.getsize(label_bin) // 4
+ total_sample_num += sample_num
+ self._sample_num_arr.append(sample_num)
+ logging.info('total number samples = %d' % total_sample_num)
+ self._total_sample_num = total_sample_num
+
+ self._batch_size = batch_size
+
+ self._compute_global_start_pos(total_sample_num, batch_size, global_rank,
+ global_size, drop_last)
+
+ self._label_file_arr = [None for _ in self._sample_num_arr]
+ self._dense_file_arr = [None for _ in self._sample_num_arr]
+ self._category_file_arr = [None for _ in self._sample_num_arr]
+
+ for tmp_file_id in range(self._start_file_id, self._end_file_id + 1):
+ self._label_file_arr[tmp_file_id] = os.open(label_bins[tmp_file_id],
+ os.O_RDONLY)
+ self._dense_file_arr[tmp_file_id] = os.open(dense_bins[tmp_file_id],
+ os.O_RDONLY)
+ self._category_file_arr[tmp_file_id] = os.open(category_bins[tmp_file_id],
+ os.O_RDONLY)
+
+ self._prefetch = min(prefetch, self._num_entries)
+ self._prefetch_queue = queue.Queue()
+ self._executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=self._prefetch)
+
+ self._os_close_func = os.close
+
+ def _compute_global_start_pos(self, total_sample_num, batch_size, global_rank,
+ global_size, drop_last):
+ # ensure all workers have the same number of samples
+ avg_sample_num = (total_sample_num // global_size)
+ res_num = (total_sample_num % global_size)
+ self._num_samples = avg_sample_num
+ if res_num > 0:
+ self._num_samples += 1
+ if global_rank < res_num:
+ global_start_pos = (avg_sample_num + 1) * global_rank
+ else:
+ global_start_pos = avg_sample_num * global_rank + res_num - 1
+ else:
+ global_start_pos = avg_sample_num * global_rank
+ # global_end_pos = global_start_pos + self._num_samples
+
+ self._num_entries = self._num_samples // batch_size
+ self._last_batch_size = batch_size
+ if not drop_last and (self._num_samples % batch_size != 0):
+ self._num_entries += 1
+ self._last_batch_size = self._num_samples % batch_size
+ logging.info('num_batches = %d num_samples = %d' %
+ (self._num_entries, self._num_samples))
+
+ start_file_id = 0
+ curr_pos = 0
+ while curr_pos + self._sample_num_arr[start_file_id] <= global_start_pos:
+ start_file_id += 1
+ curr_pos += self._sample_num_arr[start_file_id]
+ self._start_file_id = start_file_id
+ self._start_file_pos = global_start_pos - curr_pos
+
+ logging.info('start_file_id = %d start_file_pos = %d' %
+ (start_file_id, self._start_file_pos))
+
+ # find the start of each batch
+ self._start_pos_arr = np.zeros([self._num_entries, 2], dtype=np.uint32)
+ batch_id = 0
+ tmp_start_pos = self._start_file_pos
+ while batch_id < self._num_entries:
+ self._start_pos_arr[batch_id] = (start_file_id, tmp_start_pos)
+ batch_id += 1
+ # the last batch
+ if batch_id == self._num_entries:
+ tmp_start_pos += self._last_batch_size
+ while start_file_id < len(
+ self._sample_num_arr
+ ) and tmp_start_pos > self._sample_num_arr[start_file_id]:
+ tmp_start_pos -= self._sample_num_arr[start_file_id]
+ start_file_id += 1
+ else:
+ tmp_start_pos += batch_size
+ while start_file_id < len(
+ self._sample_num_arr
+ ) and tmp_start_pos >= self._sample_num_arr[start_file_id]:
+ tmp_start_pos -= self._sample_num_arr[start_file_id]
+ start_file_id += 1
+
+ self._end_file_id = start_file_id
+ self._end_file_pos = tmp_start_pos
+
+ logging.info('end_file_id = %d end_file_pos = %d' %
+ (self._end_file_id, self._end_file_pos))
+
+ def __del__(self):
+ for f in self._label_file_arr:
+ if f is not None:
+ self._os_close_func(f)
+ for f in self._dense_file_arr:
+ if f is not None:
+ self._os_close_func(f)
+ for f in self._category_file_arr:
+ if f is not None:
+ self._os_close_func(f)
+
+ def __len__(self):
+ return self._num_entries
+
+ def __getitem__(self, idx):
+ if idx >= self._num_entries:
+ raise IndexError()
+
+ if self._prefetch <= 1:
+ return self._get(idx)
+
+ if idx == 0:
+ for i in range(self._prefetch):
+ self._prefetch_queue.put(self._executor.submit(self._get, (i)))
+
+ if idx < (self._num_entries - self._prefetch):
+ self._prefetch_queue.put(
+ self._executor.submit(self._get, (idx + self._prefetch)))
+
+ return self._prefetch_queue.get().result()
+
+ def _get(self, idx):
+ curr_file_id = self._start_pos_arr[idx][0]
+ start_read_pos = self._start_pos_arr[idx][1]
+
+ end_read_pos = start_read_pos + self._batch_size
+ total_read_num = 0
+
+ label_read_arr = []
+ dense_read_arr = []
+ cate_read_arr = []
+ while total_read_num < self._batch_size and curr_file_id < len(
+ self._sample_num_arr):
+ tmp_read_num = min(end_read_pos,
+ self._sample_num_arr[curr_file_id]) - start_read_pos
+
+ label_raw_data = os.pread(self._label_file_arr[curr_file_id],
+ 4 * tmp_read_num, start_read_pos * 4)
+ tmp_lbl_np = np.frombuffer(
+ label_raw_data, dtype=np.int32).reshape([tmp_read_num, 1])
+ label_read_arr.append(tmp_lbl_np)
+
+ dense_raw_data = os.pread(self._dense_file_arr[curr_file_id],
+ 52 * tmp_read_num, start_read_pos * 52)
+ part_dense_np = np.frombuffer(
+ dense_raw_data, dtype=np.float32).reshape([tmp_read_num, 13])
+ # part_dense_np = np.log(part_dense_np + 3, dtype=np.float32)
+ dense_read_arr.append(part_dense_np)
+
+ category_raw_data = os.pread(self._category_file_arr[curr_file_id],
+ 104 * tmp_read_num, start_read_pos * 104)
+ part_cate_np = np.frombuffer(
+ category_raw_data, dtype=np.uint32).reshape([tmp_read_num, 26])
+ cate_read_arr.append(part_cate_np)
+
+ curr_file_id += 1
+ start_read_pos = 0
+ total_read_num += tmp_read_num
+
+ if len(label_read_arr) == 1:
+ label = label_read_arr[0]
+ else:
+ label = np.concatenate(label_read_arr, axis=0)
+
+ if len(cate_read_arr) == 1:
+ category = cate_read_arr[0]
+ else:
+ category = np.concatenate(cate_read_arr, axis=0)
+
+ if len(dense_read_arr) == 1:
+ dense = dense_read_arr[0]
+ else:
+ dense = np.concatenate(dense_read_arr, axis=0)
+
+ return dense, category, label
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--batch_size', type=int, default=1024, help='batch_size')
+ parser.add_argument(
+ '--dataset_dir', type=str, default='./', help='dataset_dir')
+ parser.add_argument('--task_num', type=int, default=1, help='task number')
+ parser.add_argument('--task_index', type=int, default=0, help='task index')
+ parser.add_argument(
+ '--prefetch_size', type=int, default=10, help='prefetch size')
+ args = parser.parse_args()
+
+ batch_size = args.batch_size
+ dataset_dir = args.dataset_dir
+ logging.info('batch_size = %d' % batch_size)
+ logging.info('dataset_dir = %s' % dataset_dir)
+
+ label_files = glob.glob(os.path.join(dataset_dir, '*_label.bin'))
+ dense_files = glob.glob(os.path.join(dataset_dir, '*_dense.bin'))
+ category_files = glob.glob(os.path.join(dataset_dir, '*_category.bin'))
+
+ label_files.sort()
+ dense_files.sort()
+ category_files.sort()
+
+ test_dataset = BinaryDataset(
+ label_files,
+ dense_files,
+ category_files,
+ batch_size=batch_size,
+ drop_last=False,
+ prefetch=args.prefetch_size,
+ global_rank=args.task_index,
+ global_size=args.task_num,
+ )
+
+ for step, (dense, category, labels) in enumerate(test_dataset):
+ # if (step % 100 == 0):
+ # print(step, dense.shape, category.shape, labels.shape)
+ if step == 0:
+ logging.info('warmup over!')
+ start_time = time.time()
+ if step == 1000:
+ logging.info('1000 steps time = %.3f' % (time.time() - start_time))
+ logging.info('total_steps = %d total_time = %.3f' %
+ (step + 1, time.time() - start_time))
+ logging.info(
+ 'final step[%d] dense_shape=%s category_shape=%s labels_shape=%s' %
+ (step, dense.shape, category.shape, labels.shape))
diff --git a/easy_rec/python/input/criteo_input.py b/easy_rec/python/input/criteo_input.py
new file mode 100644
index 000000000..0eb3ee595
--- /dev/null
+++ b/easy_rec/python/input/criteo_input.py
@@ -0,0 +1,107 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+
+from easy_rec.python.input.criteo_binary_reader import BinaryDataset
+from easy_rec.python.input.input import Input
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class CriteoInput(Input):
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(CriteoInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
+ all_label_paths = []
+ all_dense_paths = []
+ all_category_paths = []
+
+ if input_path is not None:
+ assert len(input_path.label_path) == len(input_path.dense_path) and \
+ len(input_path.label_path) == len(input_path.category_path), \
+ 'label_path_num(%d), dense_path_num(%d), category_path_num(%d) must be the same' % \
+ (len(input_path.label_path), len(input_path.dense_path), len(input_path.category_path))
+
+ for label_path, dense_path, category_path in zip(
+ input_path.label_path, input_path.dense_path,
+ input_path.category_path):
+ label_paths = tf.gfile.Glob(input_path.label_path)
+ dense_paths = tf.gfile.Glob(input_path.dense_path)
+ category_paths = tf.gfile.Glob(input_path.category_path)
+ assert len(label_paths) == len(dense_paths) and len(label_paths) == \
+ len(category_paths), 'label_path(%s) dense_path(%s) category_path(%s) ' + \
+ 'matched different number of files(%d %d %d)' % (
+ len(label_paths), len(dense_paths), len(category_paths))
+ label_paths.sort()
+ dense_paths.sort()
+ category_paths.sort()
+ all_label_paths.extend(label_paths)
+ all_dense_paths.extend(dense_paths)
+ all_category_paths.extend(category_paths)
+
+ logging.info('total number of input parts: %s' % len(all_label_paths))
+
+ self._binary_reader = BinaryDataset(
+ all_label_paths,
+ all_dense_paths,
+ all_category_paths,
+ self._batch_size,
+ prefetch=self._prefetch_size,
+ global_rank=self._task_index,
+ global_size=self._task_num)
+ else:
+ self._binary_reader = None
+
+ def _sample_generator(self):
+ num_epoch = 0
+ while not self.should_stop(num_epoch):
+ logging.info('start epoch: %d' % num_epoch)
+ for dense, category, labels in self._binary_reader:
+ yield dense, category, labels.reshape([-1])
+ logging.info('finish epoch: %d' % num_epoch)
+ num_epoch += 1
+
+ def _to_fea_dict(self, dense, category, labels):
+ field_dict = {}
+ for fid in range(1, 14):
+ fea_name = 'f%d' % fid
+ field_dict[fea_name] = dense[:, fid - 1]
+
+ for cid in range(1, 27):
+ fea_name = 'c%d' % cid
+ field_dict[fea_name] = category[:, cid - 1]
+ field_dict['label'] = labels
+ return field_dict
+
+ def _build(self, mode, params):
+ dataset = tf.data.Dataset.from_generator(
+ self._sample_generator,
+ output_types=(tf.float32, tf.int32, tf.int32),
+ output_shapes=(tf.TensorShape([None, 13]), tf.TensorShape([None, 26]),
+ tf.TensorShape([None])))
+ num_parallel_calls = self._data_config.num_parallel_calls
+ dataset = dataset.map(
+ self._to_fea_dict, num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+ dataset = dataset.map(
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ dataset = dataset.map(lambda x:
+ (self._get_features(x), self._get_labels(x)))
+ else:
+ dataset = dataset.map(lambda x: (self._get_features(x)))
+ return dataset
diff --git a/easy_rec/python/input/csv_input.py b/easy_rec/python/input/csv_input.py
index 50ecb668a..b3afd1656 100644
--- a/easy_rec/python/input/csv_input.py
+++ b/easy_rec/python/input/csv_input.py
@@ -5,6 +5,7 @@
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.utils.check_utils import check_split
if tf.__version__ >= '2.0':
ignore_errors = tf.data.experimental.ignore_errors()
@@ -20,9 +21,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(CSVInput, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(CSVInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
self._with_header = data_config.with_header
self._field_names = None
@@ -44,24 +48,21 @@ def _parse_csv(self, line):
else:
record_defaults.append('')
- def _check_data(line):
- sep = self._data_config.separator
- if type(sep) != type(str):
- sep = sep.encode('utf-8')
- field_num = len(line[0].split(sep))
- assert field_num == len(record_defaults), \
- 'sep[%s] maybe invalid: field_num=%d, required_num=%d' % \
- (sep, field_num, len(record_defaults))
- return True
-
- check_op = tf.py_func(_check_data, [line], Tout=tf.bool)
- with tf.control_dependencies([check_op]):
+ check_list = [
+ tf.py_func(
+ check_split, [
+ line, self._data_config.separator,
+ len(record_defaults), self._check_mode
+ ],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
fields = tf.decode_csv(
line,
field_delim=self._data_config.separator,
record_defaults=record_defaults,
name='decode_csv')
- if self._field_names:
+ if self._field_names is not None:
fields = [
fields[self._field_names.index(x)] for x in self._input_fields
]
@@ -75,11 +76,22 @@ def _check_data(line):
return inputs
def _build(self, mode, params):
+ if type(self._input_path) != list:
+ self._input_path = self._input_path.split(',')
file_paths = []
- for x in self._input_path.split(','):
- file_paths.extend(tf.gfile.Glob(x))
+ for path in self._input_path:
+ for x in tf.gfile.Glob(path):
+ if not x.endswith('_SUCCESS'):
+ file_paths.append(x)
assert len(file_paths) > 0, 'match no files with %s' % self._input_path
+ assert not file_paths[0].endswith(
+ '.tar.gz'), 'could only support .csv or .gz(not .tar.gz) files.'
+
+ compression_type = 'GZIP' if file_paths[0].endswith('.gz') else ''
+ if compression_type:
+ logging.info('compression_type = %s' % compression_type)
+
if self._with_header:
with tf.gfile.GFile(file_paths[0], 'r') as fin:
for line_str in fin:
@@ -93,32 +105,55 @@ def _build(self, mode, params):
logging.info('train files[%d]: %s' %
(len(file_paths), ','.join(file_paths)))
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ if self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
if self._data_config.shuffle:
# shuffle input files
dataset = dataset.shuffle(len(file_paths))
+
# too many readers read the same file will cause performance issues
# as the same data will be read multiple times
parallel_num = min(num_parallel_calls, len(file_paths))
dataset = dataset.interleave(
- lambda x: tf.data.TextLineDataset(x).skip(int(self._with_header)),
+ lambda x: tf.data.TextLineDataset(
+ x, compression_type=compression_type).skip(
+ int(self._with_header)),
cycle_length=parallel_num,
num_parallel_calls=parallel_num)
- if self._data_config.chief_redundant:
- dataset = dataset.shard(
- max(self._task_num - 1, 1), max(self._task_index - 1, 0))
- else:
- dataset = dataset.shard(self._task_num, self._task_index)
+ if not self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
if self._data_config.shuffle:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
seed=2020,
reshuffle_each_iteration=True)
dataset = dataset.repeat(self.num_epochs)
+ elif self._task_num > 1: # For distribute evaluate
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+ parallel_num = min(num_parallel_calls, len(file_paths))
+ dataset = dataset.interleave(
+ lambda x: tf.data.TextLineDataset(
+ x, compression_type=compression_type).skip(
+ int(self._with_header)),
+ cycle_length=parallel_num,
+ num_parallel_calls=parallel_num)
+ dataset = self._safe_shard(dataset)
+ dataset = dataset.repeat(1)
else:
logging.info('eval files[%d]: %s' %
(len(file_paths), ','.join(file_paths)))
- dataset = tf.data.TextLineDataset(file_paths).skip(int(self._with_header))
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+ parallel_num = min(num_parallel_calls, len(file_paths))
+ dataset = dataset.interleave(
+ lambda x: tf.data.TextLineDataset(
+ x, compression_type=compression_type).skip(
+ int(self._with_header)),
+ cycle_length=parallel_num,
+ num_parallel_calls=parallel_num)
dataset = dataset.repeat(1)
dataset = dataset.batch(self._data_config.batch_size)
diff --git a/easy_rec/python/input/csv_input_ex.py b/easy_rec/python/input/csv_input_ex.py
index 3be5c0d46..d3b506fce 100644
--- a/easy_rec/python/input/csv_input_ex.py
+++ b/easy_rec/python/input/csv_input_ex.py
@@ -4,6 +4,7 @@
import tensorflow as tf
from easy_rec.python.input.csv_input import CSVInput
+from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -16,9 +17,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(CSVInputEx, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(CSVInputEx,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
def _parse_csv(self, line):
record_defaults = [
@@ -36,7 +40,7 @@ def _check_data(line):
(sep, field_num, len(record_defaults))
return True
- fields = tf.string_split(
+ fields = str_split_by_chr(
line, self._data_config.separator, skip_empty=False)
tmp_fields = tf.reshape(fields.values, [-1, len(record_defaults)])
fields = []
diff --git a/easy_rec/python/input/csv_input_v2.py b/easy_rec/python/input/csv_input_v2.py
index 80f80734c..deddc2f06 100644
--- a/easy_rec/python/input/csv_input_v2.py
+++ b/easy_rec/python/input/csv_input_v2.py
@@ -12,14 +12,22 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(CSVInputV2, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(CSVInputV2,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
def _build(self, mode, params):
- if self._input_path.startswith('hdfs://'):
+ if type(self._input_path) != list:
+ self._input_path = self._input_path.split(',')
+ assert len(
+ self._input_path) > 0, 'match no files with %s' % self._input_path
+
+ if self._input_path[0].startswith('hdfs://'):
# support hdfs input
- dataset = tf.data.TextLineDataset([self._input_path])
+ dataset = tf.data.TextLineDataset(self._input_path)
else:
num_epochs = self.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1
is_train = (mode == tf.estimator.ModeKeys.TRAIN)
@@ -28,7 +36,7 @@ def _build(self, mode, params):
for x, v in zip(self._input_field_types, self._input_field_defaults)
]
dataset = tf.data.experimental.make_csv_dataset(
- [self._input_path],
+ self._input_path,
self._data_config.batch_size,
column_names=self._input_fields,
field_delim=self._data_config.separator,
diff --git a/easy_rec/python/input/datahub_input.py b/easy_rec/python/input/datahub_input.py
index 8e86feab7..37e3292d4 100644
--- a/easy_rec/python/input/datahub_input.py
+++ b/easy_rec/python/input/datahub_input.py
@@ -1,126 +1,308 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
import logging
-import time
+import traceback
-import numpy as np
import tensorflow as tf
+from tensorflow.python.framework import dtypes
from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util
+from easy_rec.python.utils.config_util import parse_time
+
+if tf.__version__.startswith('1.'):
+ from tensorflow.python.platform import gfile
+else:
+ import tensorflow.io.gfile as gfile
try:
import common_io
except Exception:
common_io = None
+
try:
from datahub import DataHub
from datahub.exceptions import DatahubException
from datahub.models import RecordType
from datahub.models import CursorType
+ import urllib3
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+ logging.getLogger('datahub.account').setLevel(logging.INFO)
except Exception:
logging.warning(
- 'DataHub is not installed. You can install it by: pip install pydatahub')
+ 'DataHub is not installed[%s]. You can install it by: pip install pydatahub'
+ % traceback.format_exc())
DataHub = None
class DataHubInput(Input):
- """Common IO based interface, could run at local or on data science."""
+ """DataHubInput is used for online train."""
def __init__(self,
data_config,
feature_config,
datahub_config,
task_index=0,
- task_num=1):
- super(DataHubInput, self).__init__(data_config, feature_config, '',
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(DataHubInput,
+ self).__init__(data_config, feature_config, '', task_index, task_num,
+ check_mode, pipeline_config)
if DataHub is None:
logging.error('please install datahub: ',
'pip install pydatahub ;Python 3.6 recommended')
try:
- self._datahub_config = datahub_config
- if self._datahub_config is None:
- pass
- self._datahub = DataHub(self._datahub_config.akId,
- self._datahub_config.akSecret,
- self._datahub_config.region)
self._num_epoch = 0
+ self._datahub_config = datahub_config
+ if self._datahub_config is not None:
+ akId = self._datahub_config.akId
+ akSecret = self._datahub_config.akSecret
+ endpoint = self._datahub_config.endpoint
+ if not isinstance(akId, str):
+ akId = akId.encode('utf-8')
+ akSecret = akSecret.encode('utf-8')
+ endpoint = endpoint.encode('utf-8')
+ self._datahub = DataHub(akId, akSecret, endpoint)
+ else:
+ self._datahub = None
except Exception as ex:
- logging.info('exception in init datahub:', str(ex))
+ logging.info('exception in init datahub: %s' % str(ex))
pass
+ self._offset_dict = {}
+ if datahub_config:
+ shard_result = self._datahub.list_shard(self._datahub_config.project,
+ self._datahub_config.topic)
+ shards = shard_result.shards
+ self._all_shards = shards
+ self._shards = [
+ shards[i] for i in range(len(shards)) if (i % task_num) == task_index
+ ]
+ logging.info('all shards: %s' % str(self._shards))
+
+ offset_type = datahub_config.WhichOneof('offset')
+ if offset_type == 'offset_time':
+ ts = parse_time(datahub_config.offset_time) * 1000
+ for x in self._all_shards:
+ ks = str(x.shard_id)
+ cursor_result = self._datahub.get_cursor(self._datahub_config.project,
+ self._datahub_config.topic,
+ ks, CursorType.SYSTEM_TIME,
+ ts)
+ logging.info('shard[%s] cursor = %s' % (ks, cursor_result))
+ self._offset_dict[ks] = cursor_result.cursor
+ elif offset_type == 'offset_info':
+ self._offset_dict = json.loads(self._datahub_config.offset_info)
+ else:
+ self._offset_dict = {}
+
+ self._dh_field_names = []
+ self._dh_field_types = []
+ topic_info = self._datahub.get_topic(
+ project_name=self._datahub_config.project,
+ topic_name=self._datahub_config.topic)
+ for field in topic_info.record_schema.field_list:
+ self._dh_field_names.append(field.name)
+ self._dh_field_types.append(field.type.value)
+
+ assert len(
+ self._feature_fields) > 0, 'data_config.feature_fields are not set.'
+
+ for x in self._feature_fields:
+ assert x in self._dh_field_names, 'feature_field[%s] is not in datahub' % x
+
+ # feature column ids in datahub schema
+ self._dh_fea_ids = [
+ self._dh_field_names.index(x) for x in self._feature_fields
+ ]
+
+ for x in self._label_fields:
+ assert x in self._dh_field_names, 'label_field[%s] is not in datahub' % x
+
+ if self._data_config.HasField('sample_weight'):
+ x = self._data_config.sample_weight
+ assert x in self._dh_field_names, 'sample_weight[%s] is not in datahub' % x
+
+ self._read_cnt = 32
+
+ if len(self._dh_fea_ids) > 1:
+ self._filter_fea_func = lambda record: ''.join(
+ [record.values[x]
+ for x in self._dh_fea_ids]).split(chr(2))[1] == '-1024'
+ else:
+ dh_fea_id = self._dh_fea_ids[0]
+ self._filter_fea_func = lambda record: record.values[dh_fea_id].split(
+ self._data_config.separator)[1] == '-1024'
def _parse_record(self, *fields):
+ field_dict = {}
fields = list(fields)
- inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
- for x in self._label_fids:
- inputs[self._input_fields[x]] = fields[x]
- return inputs
+
+ def _dump_offsets():
+ all_offsets = {
+ x.shard_id: self._offset_dict[x.shard_id]
+ for x in self._shards
+ if x.shard_id in self._offset_dict
+ }
+ return json.dumps(all_offsets)
+
+ field_dict[Input.DATA_OFFSET] = tf.py_func(_dump_offsets, [], dtypes.string)
+
+ for x in self._label_fields:
+ dh_id = self._dh_field_names.index(x)
+ field_dict[x] = fields[dh_id]
+
+ feature_inputs = self.get_feature_input_fields()
+ # only for features, labels and sample_weight excluded
+ record_types = [
+ t for x, t in zip(self._input_fields, self._input_field_types)
+ if x in feature_inputs
+ ]
+ feature_num = len(record_types)
+
+ feature_fields = [
+ fields[self._dh_field_names.index(x)] for x in self._feature_fields
+ ]
+ feature = feature_fields[0]
+ for fea_id in range(1, len(feature_fields)):
+ feature = feature + self._data_config.separator + feature_fields[fea_id]
+
+ feature = tf.string_split(
+ feature, self._data_config.separator, skip_empty=False)
+
+ fields = tf.reshape(feature.values, [-1, feature_num])
+
+ for fid in range(feature_num):
+ field_dict[feature_inputs[fid]] = fields[:, fid]
+ return field_dict
+
+ def _preprocess(self, field_dict):
+ output_dict = super(DataHubInput, self)._preprocess(field_dict)
+
+ # append offset fields
+ if Input.DATA_OFFSET in field_dict:
+ output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
+
+ # for _get_features to include DATA_OFFSET
+ if Input.DATA_OFFSET not in self._appended_fields:
+ self._appended_fields.append(Input.DATA_OFFSET)
+
+ return output_dict
+
+ def restore(self, checkpoint_path):
+ if checkpoint_path is None:
+ return
+
+ offset_path = checkpoint_path + '.offset'
+ if not gfile.Exists(offset_path):
+ return
+
+ logging.info('will restore datahub offset from %s' % offset_path)
+ with gfile.GFile(offset_path, 'r') as fin:
+ offset_dict = json.load(fin)
+ for k in offset_dict:
+ v = offset_dict[k]
+ ks = str(k)
+ if ks not in self._offset_dict or v > self._offset_dict[ks]:
+ self._offset_dict[ks] = v
+
+ def _is_data_empty(self, record):
+ is_empty = True
+ for fid in self._dh_fea_ids:
+ if record.values[fid] is not None and len(record.values[fid]) > 0:
+ is_empty = False
+ break
+ return is_empty
+
+ def _dump_record(self, record):
+ feas = []
+ for fid in range(len(record.values)):
+ if fid not in self._dh_fea_ids:
+ feas.append(self._dh_field_names[fid] + ':' + str(record.values[fid]))
+ return ';'.join(feas)
def _datahub_generator(self):
logging.info('start epoch[%d]' % self._num_epoch)
self._num_epoch += 1
- odps_util.check_input_field_and_types(self._data_config)
- record_defaults = [
- self.get_type_defaults(x, v)
- for x, v in zip(self._input_field_types, self._input_field_defaults)
- ]
- batch_defaults = [
- np.array([x] * self._data_config.batch_size) for x in record_defaults
- ]
+
try:
self._datahub.wait_shards_ready(self._datahub_config.project,
self._datahub_config.topic)
topic_result = self._datahub.get_topic(self._datahub_config.project,
self._datahub_config.topic)
if topic_result.record_type != RecordType.TUPLE:
- logging.error('topic type illegal !')
+ logging.error('datahub topic type(%s) illegal' %
+ str(topic_result.record_type))
record_schema = topic_result.record_schema
- shard_result = self._datahub.list_shard(self._datahub_config.project,
- self._datahub_config.topic)
- shards = shard_result.shards
- for shard in shards:
- shard_id = shard._shard_id
- cursor_result = self._datahub.get_cursor(self._datahub_config.project,
- self._datahub_config.topic,
- shard_id, CursorType.OLDEST)
- cursor = cursor_result.cursor
- limit = self._data_config.batch_size
- while True:
- get_result = self._datahub.get_tuple_records(
- self._datahub_config.project, self._datahub_config.topic,
- shard_id, record_schema, cursor, limit)
- batch_data_np = [x.copy() for x in batch_defaults]
- for row_id, record in enumerate(get_result.records):
- for col_id in range(len(record_defaults)):
- if record.values[col_id] not in ['', 'Null', None]:
- batch_data_np[col_id][row_id] = record.values[col_id]
- yield tuple(batch_data_np)
- if 0 == get_result.record_count:
- time.sleep(1)
- cursor = get_result.next_cursor
- except DatahubException as e:
- logging.error(e)
+
+ tid = 0
+ while True:
+ shard_id = self._shards[tid].shard_id
+ tid += 1
+ if tid >= len(self._shards):
+ tid = 0
+
+ if shard_id not in self._offset_dict:
+ cursor_result = self._datahub.get_cursor(self._datahub_config.project,
+ self._datahub_config.topic,
+ shard_id, CursorType.OLDEST)
+ cursor = cursor_result.cursor
+ else:
+ cursor = self._offset_dict[shard_id]
+
+ get_result = self._datahub.get_tuple_records(
+ self._datahub_config.project, self._datahub_config.topic, shard_id,
+ record_schema, cursor, self._read_cnt)
+ count = get_result.record_count
+ if count == 0:
+ continue
+ for row_id, record in enumerate(get_result.records):
+ if self._is_data_empty(record):
+ logging.warning('skip empty data record: %s' %
+ self._dump_record(record))
+ continue
+ if self._filter_fea_func is not None:
+ if self._filter_fea_func(record):
+ logging.warning('filter data record: %s' %
+ self._dump_record(record))
+ continue
+ yield tuple(list(record.values))
+ if shard_id not in self._offset_dict or get_result.next_cursor > self._offset_dict[
+ shard_id]:
+ self._offset_dict[shard_id] = get_result.next_cursor
+ except DatahubException as ex:
+ logging.error('DatahubException: %s' % str(ex))
def _build(self, mode, params):
- # get input type
- list_type = [self.get_tf_type(x) for x in self._input_field_types]
- list_type = tuple(list_type)
- list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ assert self._datahub is not None, 'datahub_train_input is not set'
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ assert self._datahub is not None, 'datahub_eval_input is not set'
+
+ # get input types
+ list_types = [
+ odps_util.odps_type_2_tf_type(x) for x in self._dh_field_types
+ ]
+ list_types = tuple(list_types)
+ list_shapes = [
+ tf.TensorShape([]) for x in range(0, len(self._dh_field_types))
+ ]
list_shapes = tuple(list_shapes)
# read datahub
dataset = tf.data.Dataset.from_generator(
self._datahub_generator,
- output_types=list_type,
+ output_types=list_types,
output_shapes=list_shapes)
if mode == tf.estimator.ModeKeys.TRAIN:
- dataset = dataset.shuffle(
- self._data_config.shuffle_buffer_size,
- seed=2020,
- reshuffle_each_iteration=True)
- dataset = dataset.repeat(self.num_epochs)
- else:
- dataset = dataset.repeat(1)
+ if self._data_config.shuffle:
+ dataset = dataset.shuffle(
+ self._data_config.shuffle_buffer_size,
+ seed=2020,
+ reshuffle_each_iteration=True)
+
+ dataset = dataset.batch(self._data_config.batch_size)
+
dataset = dataset.map(
self._parse_record,
num_parallel_calls=self._data_config.num_parallel_calls)
diff --git a/easy_rec/python/input/dummy_input.py b/easy_rec/python/input/dummy_input.py
index bc95be436..f556a3686 100644
--- a/easy_rec/python/input/dummy_input.py
+++ b/easy_rec/python/input/dummy_input.py
@@ -4,6 +4,7 @@
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.utils.tf_utils import get_tf_type
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -21,9 +22,12 @@ def __init__(self,
input_path,
task_index=0,
task_num=1,
+ check_mode=False,
+ pipeline_config=None,
input_vals={}):
- super(DummyInput, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ super(DummyInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
self._input_vals = input_vals
def _build(self, mode, params):
@@ -41,12 +45,14 @@ def _build(self, mode, params):
for field, field_type, def_val in zip(self._input_fields,
self._input_field_types,
self._input_field_defaults):
- tf_type = self.get_tf_type(field_type)
+ tf_type = get_tf_type(field_type)
def_val = self.get_type_defaults(field_type, default_val=def_val)
+
if field in self._input_vals:
tensor = self._input_vals[field]
else:
tensor = tf.constant([def_val] * self._batch_size, dtype=tf_type)
+
features[field] = tensor
parse_dict = self._preprocess(features)
return self._get_features(parse_dict), self._get_labels(parse_dict)
diff --git a/easy_rec/python/input/hive_input.py b/easy_rec/python/input/hive_input.py
new file mode 100644
index 000000000..f5c8735af
--- /dev/null
+++ b/easy_rec/python/input/hive_input.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+import logging
+import os
+
+import tensorflow as tf
+
+from easy_rec.python.input.input import Input
+from easy_rec.python.utils.hive_utils import HiveUtils
+
+
+class HiveInput(Input):
+ """Common IO based interface, could run at local or on data science."""
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(HiveInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
+ if input_path is None:
+ return
+ self._data_config = data_config
+ self._feature_config = feature_config
+ self._hive_config = input_path
+
+ hive_util = HiveUtils(
+ data_config=self._data_config, hive_config=self._hive_config)
+ self._input_hdfs_path = hive_util.get_table_location(
+ self._hive_config.table_name)
+ self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols(
+ self._hive_config.table_name)
+
+ def _parse_csv(self, line):
+ record_defaults = []
+ for field_name in self._input_table_col_names:
+ if field_name in self._input_fields:
+ tid = self._input_fields.index(field_name)
+ record_defaults.append(
+ self.get_type_defaults(self._input_field_types[tid],
+ self._input_field_defaults[tid]))
+ else:
+ record_defaults.append('')
+
+ tmp_fields = tf.decode_csv(
+ line,
+ field_delim=self._data_config.separator,
+ record_defaults=record_defaults,
+ name='decode_csv')
+
+ fields = []
+ for x in self._input_fields:
+ assert x in self._input_table_col_names, 'Column %s not in Table %s.' % (
+ x, self._hive_config.table_name)
+ fields.append(tmp_fields[self._input_table_col_names.index(x)])
+
+ # filter only valid fields
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
+ for x in self._label_fids:
+ inputs[self._input_fields[x]] = fields[x]
+ return inputs
+
+ def _build(self, mode, params):
+ file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*'))
+ assert len(
+ file_paths) > 0, 'match no files with %s' % self._hive_config.table_name
+
+ num_parallel_calls = self._data_config.num_parallel_calls
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ logging.info('train files[%d]: %s' %
+ (len(file_paths), ','.join(file_paths)))
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ if self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
+ if self._data_config.shuffle:
+ # shuffle input files
+ dataset = dataset.shuffle(len(file_paths))
+
+ # too many readers read the same file will cause performance issues
+ # as the same data will be read multiple times
+ parallel_num = min(num_parallel_calls, len(file_paths))
+ dataset = dataset.interleave(
+ lambda x: tf.data.TextLineDataset(x),
+ cycle_length=parallel_num,
+ num_parallel_calls=parallel_num)
+
+ if not self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
+ if self._data_config.shuffle:
+ dataset = dataset.shuffle(
+ self._data_config.shuffle_buffer_size,
+ seed=2020,
+ reshuffle_each_iteration=True)
+ dataset = dataset.repeat(self.num_epochs)
+ else:
+ logging.info('eval files[%d]: %s' %
+ (len(file_paths), ','.join(file_paths)))
+ dataset = tf.data.TextLineDataset(file_paths)
+ dataset = dataset.repeat(1)
+
+ dataset = dataset.batch(self._data_config.batch_size)
+ dataset = dataset.map(
+ self._parse_csv, num_parallel_calls=num_parallel_calls)
+
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+ dataset = dataset.map(
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
+
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ dataset = dataset.map(lambda x:
+ (self._get_features(x), self._get_labels(x)))
+ else:
+ dataset = dataset.map(lambda x: (self._get_features(x)))
+ return dataset
diff --git a/easy_rec/python/input/hive_parquet_input.py b/easy_rec/python/input/hive_parquet_input.py
new file mode 100644
index 000000000..7bb42dc42
--- /dev/null
+++ b/easy_rec/python/input/hive_parquet_input.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+import logging
+import os
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+
+from easy_rec.python.input.input import Input
+from easy_rec.python.utils.hive_utils import HiveUtils
+from easy_rec.python.utils.tf_utils import get_tf_type
+
+
+class HiveParquetInput(Input):
+ """Common IO based interface, could run at local or on data science."""
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(HiveParquetInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
+ if input_path is None:
+ return
+ self._data_config = data_config
+ self._feature_config = feature_config
+ self._hive_config = input_path
+
+ hive_util = HiveUtils(
+ data_config=self._data_config, hive_config=self._hive_config)
+ input_hdfs_path = hive_util.get_table_location(self._hive_config.table_name)
+ self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols(
+ self._hive_config.table_name)
+ self._all_hdfs_path = tf.gfile.Glob(os.path.join(input_hdfs_path, '*'))
+
+ for x in self._input_fields:
+ assert x in self._input_table_col_names, 'Column %s not in Table %s.' % (
+ x, self._hive_config.table_name)
+
+ self._record_defaults = [
+ self.get_type_defaults(t, v)
+ for t, v in zip(self._input_field_types, self._input_field_defaults)
+ ]
+
+ def _file_shard(self, file_paths, task_num, task_index):
+ if self._data_config.chief_redundant:
+ task_num = max(task_num - 1, 1)
+ task_index = max(task_index - 1, 0)
+ task_file_paths = []
+ for idx in range(task_index, len(file_paths), task_num):
+ task_file_paths.append(file_paths[idx])
+ return task_file_paths
+
+ def _parquet_read(self):
+ for input_path in self._input_hdfs_path:
+ if input_path.endswith('SUCCESS'):
+ continue
+ df = pd.read_parquet(input_path, engine='pyarrow')
+ df = df[self._input_fields]
+ df.replace('', np.nan, inplace=True)
+ df.replace('NULL', np.nan, inplace=True)
+ total_records_num = len(df)
+
+ for k, v in zip(self._input_fields, self._record_defaults):
+ df[k].fillna(v, inplace=True)
+
+ for start_idx in range(0, total_records_num,
+ self._data_config.batch_size):
+ end_idx = min(total_records_num,
+ start_idx + self._data_config.batch_size)
+ batch_data = df[start_idx:end_idx]
+ inputs = []
+ for k in self._input_fields:
+ inputs.append(batch_data[k].to_numpy())
+ yield tuple(inputs)
+
+ def _parse_csv(self, *fields):
+ # filter only valid fields
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
+ # filter only valid labels
+ for x in self._label_fids:
+ inputs[self._input_fields[x]] = fields[x]
+ return inputs
+
+ def _build(self, mode, params):
+ # get input type
+ list_type = [get_tf_type(x) for x in self._input_field_types]
+ list_type = tuple(list_type)
+ list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
+ list_shapes = tuple(list_shapes)
+
+ if len(self._all_hdfs_path) >= 2 * self._task_num:
+ file_shard = True
+ self._input_hdfs_path = self._file_shard(self._all_hdfs_path,
+ self._task_num, self._task_index)
+ else:
+ file_shard = False
+ self._input_hdfs_path = self._all_hdfs_path
+ logging.info('input path: %s' % self._input_hdfs_path)
+ assert len(self._input_hdfs_path
+ ) > 0, 'match no files with %s' % self._hive_config.table_name
+
+ dataset = tf.data.Dataset.from_generator(
+ self._parquet_read, output_types=list_type, output_shapes=list_shapes)
+
+ if not file_shard:
+ dataset = self._safe_shard(dataset)
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ dataset = dataset.shuffle(
+ self._data_config.shuffle_buffer_size,
+ seed=2020,
+ reshuffle_each_iteration=True)
+ dataset = dataset.repeat(self.num_epochs)
+ else:
+ dataset = dataset.repeat(1)
+
+ dataset = dataset.map(
+ self._parse_csv,
+ num_parallel_calls=self._data_config.num_parallel_calls)
+
+ # preprocess is necessary to transform data
+ # so that they could be feed into FeatureColumns
+ dataset = dataset.map(
+ map_func=self._preprocess,
+ num_parallel_calls=self._data_config.num_parallel_calls)
+
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ dataset = dataset.map(lambda x:
+ (self._get_features(x), self._get_labels(x)))
+ else:
+ dataset = dataset.map(lambda x: (self._get_features(x)))
+ return dataset
diff --git a/easy_rec/python/input/hive_rtp_input.py b/easy_rec/python/input/hive_rtp_input.py
new file mode 100644
index 000000000..b7bbf2148
--- /dev/null
+++ b/easy_rec/python/input/hive_rtp_input.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+import logging
+import os
+
+import tensorflow as tf
+
+from easy_rec.python.input.input import Input
+from easy_rec.python.utils.check_utils import check_split
+from easy_rec.python.utils.hive_utils import HiveUtils
+from easy_rec.python.utils.input_utils import string_to_number
+
+
+class HiveRTPInput(Input):
+ """Common IO based interface, could run at local or on data science."""
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(HiveRTPInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
+ if input_path is None:
+ return
+ self._data_config = data_config
+ self._feature_config = feature_config
+ self._hive_config = input_path
+
+ logging.info('input_fields: %s label_fields: %s' %
+ (','.join(self._input_fields), ','.join(self._label_fields)))
+
+ self._rtp_separator = self._data_config.rtp_separator
+ if not isinstance(self._rtp_separator, str):
+ self._rtp_separator = self._rtp_separator.encode('utf-8')
+ logging.info('rtp separator = %s' % self._rtp_separator)
+ self._selected_cols = [c.strip() for c in self._data_config.selected_cols.split(',')] \
+ if self._data_config.selected_cols else None
+ logging.info('select cols: %s' % self._selected_cols)
+ hive_util = HiveUtils(
+ data_config=self._data_config, hive_config=self._hive_config)
+ self._input_hdfs_path = hive_util.get_table_location(
+ self._hive_config.table_name)
+ self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols(
+ self._hive_config.table_name)
+
+ def _parse_csv(self, line):
+ non_feature_cols = self._label_fields
+ if self._selected_cols:
+ non_feature_cols = self._selected_cols[:-1]
+ record_defaults = []
+ for tid, field_name in enumerate(self._input_table_col_names):
+ if field_name in self._selected_cols[:-1]:
+ idx = self._input_fields.index(field_name)
+ record_defaults.append(
+ self.get_type_defaults(self._input_field_types[idx],
+ self._input_field_defaults[idx]))
+ else:
+ record_defaults.append('')
+ print('record_defaults: ', record_defaults)
+ tmp_fields = tf.decode_csv(
+ line,
+ field_delim=self._rtp_separator,
+ record_defaults=record_defaults,
+ name='decode_csv')
+ print('tmp_fields: ', tmp_fields)
+
+ fields = []
+ if self._selected_cols:
+ for idx, field_name in enumerate(self._input_table_col_names):
+ if field_name in self._selected_cols:
+ fields.append(tmp_fields[idx])
+ print('fields: ', fields)
+ labels = fields[:-1]
+
+ # only for features, labels and sample_weight excluded
+ record_types = [
+ t for x, t in zip(self._input_fields, self._input_field_types)
+ if x not in non_feature_cols
+ ]
+ feature_num = len(record_types)
+
+ check_list = [
+ tf.py_func(
+ check_split,
+ [fields[-1], self._data_config.separator,
+ len(record_types)],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ fields = tf.string_split(
+ fields[-1], self._data_config.separator, skip_empty=False)
+ tmp_fields = tf.reshape(fields.values, [-1, feature_num])
+
+ rtp_record_defaults = [
+ str(self.get_type_defaults(t, v))
+ for x, t, v in zip(self._input_fields, self._input_field_types,
+ self._input_field_defaults)
+ if x not in non_feature_cols
+ ]
+ fields = labels[len(self._label_fields):]
+ for i in range(feature_num):
+ field = string_to_number(tmp_fields[:, i], record_types[i],
+ rtp_record_defaults[i], i)
+ fields.append(field)
+
+ field_keys = [x for x in self._input_fields if x not in self._label_fields]
+ effective_fids = [field_keys.index(x) for x in self._effective_fields]
+ inputs = {field_keys[x]: fields[x] for x in effective_fids}
+
+ for x in range(len(self._label_fields)):
+ inputs[self._label_fields[x]] = labels[x]
+ return inputs
+
+ def _build(self, mode, params):
+ file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*'))
+ assert len(
+ file_paths) > 0, 'match no files with %s' % self._hive_config.table_name
+
+ num_parallel_calls = self._data_config.num_parallel_calls
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ logging.info('train files[%d]: %s' %
+ (len(file_paths), ','.join(file_paths)))
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ if self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
+ if self._data_config.shuffle:
+ # shuffle input files
+ dataset = dataset.shuffle(len(file_paths))
+
+ # too many readers read the same file will cause performance issues
+ # as the same data will be read multiple times
+ parallel_num = min(num_parallel_calls, len(file_paths))
+ dataset = dataset.interleave(
+ lambda x: tf.data.TextLineDataset(x),
+ cycle_length=parallel_num,
+ num_parallel_calls=parallel_num)
+
+ if not self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
+ if self._data_config.shuffle:
+ dataset = dataset.shuffle(
+ self._data_config.shuffle_buffer_size,
+ seed=2020,
+ reshuffle_each_iteration=True)
+ dataset = dataset.repeat(self.num_epochs)
+ else:
+ logging.info('eval files[%d]: %s' %
+ (len(file_paths), ','.join(file_paths)))
+ dataset = tf.data.TextLineDataset(file_paths)
+ dataset = dataset.repeat(1)
+
+ dataset = dataset.batch(self._data_config.batch_size)
+ dataset = dataset.map(
+ self._parse_csv, num_parallel_calls=num_parallel_calls)
+
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+ dataset = dataset.map(
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
+
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ dataset = dataset.map(lambda x:
+ (self._get_features(x), self._get_labels(x)))
+ else:
+ dataset = dataset.map(lambda x: (self._get_features(x)))
+ return dataset
diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py
index 500c6ed95..f53c4ee45 100644
--- a/easy_rec/python/input/input.py
+++ b/easy_rec/python/input/input.py
@@ -1,18 +1,30 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
+import os
from abc import abstractmethod
from collections import OrderedDict
import six
import tensorflow as tf
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import gfile
from easy_rec.python.core import sampler as sampler_lib
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import conditional
from easy_rec.python.utils import config_util
from easy_rec.python.utils import constant
+from easy_rec.python.utils.check_utils import check_split
+from easy_rec.python.utils.check_utils import check_string_to_number
+from easy_rec.python.utils.expr_util import get_expression
from easy_rec.python.utils.input_utils import get_type_defaults
from easy_rec.python.utils.load_class import get_register_class_meta
+from easy_rec.python.utils.load_class import load_by_path
+from easy_rec.python.utils.tf_utils import get_tf_type
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -23,17 +35,29 @@
class Input(six.with_metaclass(_meta_type, object)):
+ DATA_OFFSET = 'DATA_OFFSET'
+
def __init__(self,
data_config,
feature_configs,
input_path,
task_index=0,
- task_num=1):
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None,
+ **kwargs):
+ self._pipeline_config = pipeline_config
self._data_config = data_config
-
+ self._check_mode = check_mode
+ logging.info('check_mode: %s ' % self._check_mode)
# tf.estimator.ModeKeys.*, only available before
# calling self._build
self._mode = None
+ if pipeline_config is not None and pipeline_config.model_config.HasField(
+ 'ev_params'):
+ self._has_ev = True
+ else:
+ self._has_ev = False
if self._data_config.auto_expand_input_fields:
input_fields = [x for x in self._data_config.input_fields]
@@ -54,12 +78,18 @@ def __init__(self,
x.default_val for x in data_config.input_fields
]
self._label_fields = list(data_config.label_fields)
+ self._feature_fields = list(data_config.feature_fields)
self._label_sep = list(data_config.label_sep)
self._label_dim = list(data_config.label_dim)
if len(self._label_dim) < len(self._label_fields):
for x in range(len(self._label_fields) - len(self._label_dim)):
self._label_dim.append(1)
+ self._label_udf_map = {}
+ for config in self._data_config.input_fields:
+ if config.HasField('user_define_fn'):
+ self._label_udf_map[config.input_name] = self._load_label_fn(config)
+
self._batch_size = data_config.batch_size
self._prefetch_size = data_config.prefetch_size
self._feature_configs = list(feature_configs)
@@ -75,7 +105,9 @@ def __init__(self,
# from the types defined in input_fields
# it is used in create_multi_placeholders
self._multi_value_types = {}
+ self._multi_value_fields = set()
+ self._normalizer_fn = {}
for fc in self._feature_configs:
for input_name in fc.input_names:
assert input_name in self._input_fields, 'invalid input_name in %s' % str(
@@ -84,20 +116,64 @@ def __init__(self,
self._effective_fields.append(input_name)
if fc.feature_type in [fc.TagFeature, fc.SequenceFeature]:
- if fc.hash_bucket_size > 0:
+ if fc.hash_bucket_size > 0 or len(
+ fc.vocab_list) > 0 or fc.HasField('vocab_file'):
self._multi_value_types[fc.input_names[0]] = tf.string
+ self._multi_value_fields.add(fc.input_names[0])
else:
self._multi_value_types[fc.input_names[0]] = tf.int64
+ self._multi_value_fields.add(fc.input_names[0])
if len(fc.input_names) > 1:
self._multi_value_types[fc.input_names[1]] = tf.float32
+ self._multi_value_fields.add(fc.input_names[1])
- if fc.feature_type == fc.RawFeature:
+ if fc.feature_type == fc.RawFeature and fc.raw_input_dim > 1:
self._multi_value_types[fc.input_names[0]] = tf.float32
+ self._multi_value_fields.add(fc.input_names[0])
+
+ if fc.HasField('normalizer_fn'):
+ feature_name = fc.feature_name if fc.HasField(
+ 'feature_name') else fc.input_names[0]
+ self._normalizer_fn[feature_name] = load_by_path(fc.normalizer_fn)
# add sample weight to effective fields
if self._data_config.HasField('sample_weight'):
self._effective_fields.append(self._data_config.sample_weight)
+ # add uid_field of GAUC and session_fields of SessionAUC
+ if self._pipeline_config is not None:
+ metrics = self._pipeline_config.eval_config.metrics_set
+ for metric in metrics:
+ metric_name = metric.WhichOneof('metric')
+ if metric_name == 'gauc':
+ uid = metric.gauc.uid_field
+ if uid not in self._effective_fields:
+ self._effective_fields.append(uid)
+ elif metric_name == 'session_auc':
+ sid = metric.session_auc.session_id_field
+ if sid not in self._effective_fields:
+ self._effective_fields.append(sid)
+
+ # check multi task model's metrics
+ model_config = self._pipeline_config.model_config
+ model_name = model_config.WhichOneof('model')
+ if model_name in {'mmoe', 'esmm', 'dbmtl', 'simple_multi_task', 'ple'}:
+ model = getattr(model_config, model_name)
+ towers = [model.ctr_tower, model.cvr_tower
+ ] if model_name == 'esmm' else model.task_towers
+ for tower in towers:
+ metrics = tower.metrics_set
+ for metric in metrics:
+ metric_name = metric.WhichOneof('metric')
+ if metric_name == 'gauc':
+ uid = metric.gauc.uid_field
+ if uid not in self._effective_fields:
+ self._effective_fields.append(uid)
+ elif metric_name == 'session_auc':
+ sid = metric.session_auc.session_id_field
+ if sid not in self._effective_fields:
+ self._effective_fields.append(sid)
+
self._effective_fids = [
self._input_fields.index(x) for x in self._effective_fields
]
@@ -121,6 +197,32 @@ def __init__(self,
self.get_type_defaults = get_type_defaults
+ def _load_label_fn(self, config):
+ udf_class = config.user_define_fn
+ udf_path = config.user_define_fn_path if config.HasField(
+ 'user_define_fn_path') else None
+ dtype = config.user_define_fn_res_type if config.HasField(
+ 'user_define_fn_res_type') else None
+
+ if udf_path:
+ if udf_path.startswith('oss://') or udf_path.startswith('hdfs://'):
+ with gfile.GFile(udf_path, 'r') as fin:
+ udf_content = fin.read()
+ final_udf_tmp_path = '/udf/'
+ final_udf_path = final_udf_tmp_path + udf_path.split('/')[-1]
+ logging.info('final udf path %s' % final_udf_path)
+ logging.info('udf content: %s' % udf_content)
+ if not gfile.Exists(final_udf_tmp_path):
+ gfile.MkDir(final_udf_tmp_path)
+ with gfile.GFile(final_udf_path, 'w') as fin:
+ fin.write(udf_content)
+ else:
+ final_udf_path = udf_path
+ final_udf_path = final_udf_path[:-3].replace('/', '.')
+ udf_class = final_udf_path + '.' + udf_class
+ logging.info('apply udf %s' % udf_class)
+ return load_by_path(udf_class), udf_class, dtype
+
@property
def num_epochs(self):
if self._data_config.num_epochs > 0:
@@ -128,78 +230,105 @@ def num_epochs(self):
else:
return None
- def get_tf_type(self, field_type):
- type_map = {
- DatasetConfig.INT32: tf.int32,
- DatasetConfig.INT64: tf.int64,
- DatasetConfig.STRING: tf.string,
- DatasetConfig.BOOL: tf.bool,
- DatasetConfig.FLOAT: tf.float32,
- DatasetConfig.DOUBLE: tf.double
- }
- assert field_type in type_map, 'invalid type: %s' % field_type
- return type_map[field_type]
-
- def create_multi_placeholders(self,
- placeholder_named_by_input,
- export_fields_name=None):
- """Create multiply placeholders on export.
+ def get_feature_input_fields(self):
+ return [
+ x for x in self._input_fields
+ if x not in self._label_fields and x != self._data_config.sample_weight
+ ]
+
+ def should_stop(self, curr_epoch):
+ """Check whether have run enough num epochs."""
+ total_epoch = self.num_epochs
+ if self._mode != tf.estimator.ModeKeys.TRAIN:
+ total_epoch = 1
+ return total_epoch is not None and curr_epoch >= total_epoch
+
+ def create_multi_placeholders(self, export_config):
+ """Create multiply placeholders on export, one for each feature.
Args:
- placeholder_named_by_input: If it is true, placeholder is named by the input feature,
- otherwise the placeholder name if input_XX. Default: false.
- export_fields_name: TagFeature / SeqFeature list that needs to be converted into
- 2D placeholders when exporting.
+ export_config: ExportConfig instance.
"""
self._mode = tf.estimator.ModeKeys.PREDICT
- effective_fids = list(self._effective_fids)
+
+ if export_config.auto_multi_value:
+ export_fields_name = self._multi_value_fields
+ elif export_config.multi_value_fields:
+ export_fields_name = export_config.multi_value_fields.input_name
+ else:
+ export_fields_name = None
+ placeholder_named_by_input = export_config.placeholder_named_by_input
+
+ sample_weight_field = ''
if self._data_config.HasField('sample_weight'):
- effective_fids = effective_fids[:-1]
- inputs = {}
+ sample_weight_field = self._data_config.sample_weight
+ if export_config.filter_inputs:
+ effective_fids = list(self._effective_fids)
+ else:
+ effective_fids = [
+ fid for fid in range(len(self._input_fields))
+ if self._input_fields[fid] not in self._label_fields and
+ self._input_fields[fid] != sample_weight_field
+ ]
+
+ inputs = {}
for fid in effective_fids:
input_name = self._input_fields[fid]
+ if input_name == sample_weight_field:
+ continue
if placeholder_named_by_input:
placeholder_name = input_name
else:
placeholder_name = 'input_%d' % fid
if input_name in export_fields_name:
- tf_type = self._multi_value_types[input_name]
+ tf_type = self._multi_value_types[input_name] if input_name in self._multi_value_types \
+ else get_tf_type(self._input_field_types[fid])
logging.info('multi value input_name: %s, dtype: %s' %
(input_name, tf_type))
- finput = tf.placeholder(tf_type, [None, None], name=placeholder_name)
+ finput = array_ops.placeholder(
+ tf_type, [None, None], name=placeholder_name)
else:
ftype = self._input_field_types[fid]
- tf_type = self.get_tf_type(ftype)
+ tf_type = get_tf_type(ftype)
logging.info('input_name: %s, dtype: %s' % (input_name, tf_type))
- finput = tf.placeholder(tf_type, [None], name=placeholder_name)
+ finput = array_ops.placeholder(tf_type, [None], name=placeholder_name)
inputs[input_name] = finput
features = {x: inputs[x] for x in inputs}
features = self._preprocess(features)
- return inputs, features
+ return inputs, features['feature']
def create_placeholders(self, export_config):
self._mode = tf.estimator.ModeKeys.PREDICT
- inputs_placeholder = tf.placeholder(tf.string, [None], name='features')
+ inputs_placeholder = array_ops.placeholder(
+ tf.string, [None], name='features')
input_vals = tf.string_split(
inputs_placeholder, self._data_config.separator,
skip_empty=False).values
+
+ sample_weight_field = ''
+ if self._data_config.HasField('sample_weight'):
+ sample_weight_field = self._data_config.sample_weight
+
if export_config.filter_inputs:
effective_fids = list(self._effective_fids)
logging.info('number of effective inputs:%d, total number inputs: %d' %
(len(effective_fids), len(self._input_fields)))
else:
- effective_fids = list(range(1, len(self._input_fields)))
- logging.info('will not filter any input, total number inputs:%d' %
- len(effective_fids))
- if self._data_config.HasField('sample_weight'):
- effective_fids = effective_fids[:-1]
+ effective_fids = [
+ fid for fid in range(len(self._input_fields))
+ if self._input_fields[fid] not in self._label_fields and
+ self._input_fields[fid] != sample_weight_field
+ ]
+ logging.info(
+ 'will not filter any input[except labels], total number inputs:%d' %
+ len(effective_fids))
input_vals = tf.reshape(
input_vals, [-1, len(effective_fids)], name='input_reshape')
features = {}
for tmp_id, fid in enumerate(effective_fids):
ftype = self._input_field_types[fid]
- tf_type = self.get_tf_type(ftype)
+ tf_type = get_tf_type(ftype)
input_name = self._input_fields[fid]
if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]:
features[input_name] = tf.string_to_number(
@@ -212,25 +341,468 @@ def create_placeholders(self, export_config):
(ftype, tf_type))
features[input_name] = input_vals[:, tmp_id]
features = self._preprocess(features)
- return {'features': inputs_placeholder}, features
+ return {'features': inputs_placeholder}, features['feature']
def _get_features(self, fields):
- field_dict = {x: fields[x] for x in self._effective_fields if x in fields}
- for k in self._appended_fields:
- field_dict[k] = fields[k]
- if constant.SAMPLE_WEIGHT in fields:
- logging.info('will use field %s as sample weight' %
- self._data_config.sample_weight)
- field_dict[constant.SAMPLE_WEIGHT] = fields[constant.SAMPLE_WEIGHT]
- return field_dict
+ return fields['feature']
def _get_labels(self, fields):
+ labels = fields['label']
return OrderedDict([
- (x, tf.squeeze(fields[x], axis=1) if len(fields[x].get_shape()) == 2 and
- fields[x].get_shape()[1] == 1 else fields[x])
- for x in self._label_fields
+ (x, tf.squeeze(labels[x], axis=1) if len(labels[x].get_shape()) == 2 and
+ labels[x].get_shape()[1] == 1 else labels[x]) for x in labels
])
+ def _as_string(self, field, fc):
+ if field.dtype == tf.string:
+ return field
+ if field.dtype in [tf.float32, tf.double]:
+ feature_name = fc.feature_name if fc.HasField(
+ 'feature_name') else fc.input_names[0]
+ assert fc.precision > 0, 'fc.precision not set for feature[%s], it is dangerous to convert ' \
+ 'float or double to string due to precision problem, it is suggested ' \
+ ' to convert them into string format before using EasyRec; ' \
+ 'if you really need to do so, please set precision (the number of ' \
+ 'decimal digits) carefully.' % feature_name
+ precision = None
+ if field.dtype in [tf.float32, tf.double]:
+ if fc.precision > 0:
+ precision = fc.precision
+
+ # convert to string
+ if 'as_string' in dir(tf.strings):
+ return tf.strings.as_string(field, precision=precision)
+ else:
+ return tf.as_string(field, precision=precision)
+
+ def _parse_combo_feature(self, fc, parsed_dict, field_dict):
+ # for compatibility with existing implementations
+ feature_name = fc.feature_name if fc.HasField(
+ 'feature_name') else fc.input_names[0]
+
+ if len(fc.combo_input_seps) > 0:
+ assert len(fc.combo_input_seps) == len(fc.input_names), \
+ 'len(combo_separator)[%d] != len(fc.input_names)[%d]' % (
+ len(fc.combo_input_seps), len(fc.input_names))
+
+ def _get_input_sep(input_id):
+ if input_id < len(fc.combo_input_seps):
+ return fc.combo_input_seps[input_id]
+ else:
+ return ''
+
+ if len(fc.combo_join_sep) == 0:
+ for input_id, input_name in enumerate(fc.input_names):
+ if input_id > 0:
+ key = feature_name + '_' + str(input_id)
+ else:
+ key = feature_name
+ input_sep = _get_input_sep(input_id)
+ if input_sep != '':
+ assert field_dict[
+ input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
+ input_name, field_dict[input_name].dtype)
+ parsed_dict[key] = tf.string_split(field_dict[input_name], input_sep)
+ else:
+ parsed_dict[key] = self._as_string(field_dict[input_name], fc)
+ else:
+ if len(fc.combo_input_seps) > 0:
+ split_inputs = []
+ for input_id, input_name in enumerate(fc.input_names):
+ input_sep = fc.combo_input_seps[input_id]
+ if len(input_sep) > 0:
+ assert field_dict[
+ input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
+ input_name, field_dict[input_name].dtype)
+ split_inputs.append(
+ tf.string_split(field_dict[input_name],
+ fc.combo_input_seps[input_id]))
+ else:
+ split_inputs.append(tf.reshape(field_dict[input_name], [-1, 1]))
+ parsed_dict[feature_name] = sparse_ops.sparse_cross(
+ split_inputs, fc.combo_join_sep)
+ else:
+ inputs = [
+ self._as_string(field_dict[input_name], fc)
+ for input_name in fc.input_names
+ ]
+ parsed_dict[feature_name] = string_ops.string_join(
+ inputs, fc.combo_join_sep)
+
+ def _parse_tag_feature(self, fc, parsed_dict, field_dict):
+ input_0 = fc.input_names[0]
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
+ field = field_dict[input_0]
+ # Construct the output of TagFeature according to the dimension of field_dict.
+ # When the input field exceeds 2 dimensions, convert TagFeature to 2D output.
+ if len(field.get_shape()) < 2 or field.get_shape()[-1] == 1:
+ if len(field.get_shape()) == 0:
+ field = tf.expand_dims(field, axis=0)
+ elif len(field.get_shape()) == 2:
+ field = tf.squeeze(field, axis=-1)
+ if fc.HasField('kv_separator') and len(fc.input_names) > 1:
+ assert False, 'Tag Feature Error, ' \
+ 'Cannot set kv_separator and multi input_names in one feature config. Feature: %s.' % input_0
+ parsed_dict[feature_name] = tf.string_split(field, fc.separator)
+ if fc.HasField('kv_separator'):
+ indices = parsed_dict[feature_name].indices
+ tmp_kvs = parsed_dict[feature_name].values
+ tmp_kvs = tf.string_split(tmp_kvs, fc.kv_separator, skip_empty=False)
+ tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2])
+ tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1]
+
+ check_list = [
+ tf.py_func(check_string_to_number, [tmp_vs, input_0], Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ tmp_vs = tf.string_to_number(
+ tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0)
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
+ indices, tmp_ks, parsed_dict[feature_name].dense_shape)
+ parsed_dict[feature_name + '_w'] = tf.sparse.SparseTensor(
+ indices, tmp_vs, parsed_dict[feature_name].dense_shape)
+ if not fc.HasField('hash_bucket_size') and fc.num_buckets > 0:
+ check_list = [
+ tf.py_func(
+ check_string_to_number,
+ [parsed_dict[feature_name].values, input_0],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ vals = tf.string_to_number(
+ parsed_dict[feature_name].values,
+ tf.int32,
+ name='tag_fea_%s' % input_0)
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
+ parsed_dict[feature_name].indices, vals,
+ parsed_dict[feature_name].dense_shape)
+ if len(fc.input_names) > 1:
+ input_1 = fc.input_names[1]
+ field = field_dict[input_1]
+ if len(field.get_shape()) == 0:
+ field = tf.expand_dims(field, axis=0)
+ field = tf.string_split(field, fc.separator)
+ check_list = [
+ tf.py_func(
+ check_string_to_number, [field.values, input_1], Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ field_vals = tf.string_to_number(
+ field.values, tf.float32, name='tag_wgt_str_2_flt_%s' % input_1)
+ assert_op = tf.assert_equal(
+ tf.shape(field_vals)[0],
+ tf.shape(parsed_dict[feature_name].values)[0],
+ message='TagFeature Error: The size of %s not equal to the size of %s. Please check input: %s and %s.'
+ % (input_0, input_1, input_0, input_1))
+ with tf.control_dependencies([assert_op]):
+ field = tf.sparse.SparseTensor(field.indices, tf.identity(field_vals),
+ field.dense_shape)
+ parsed_dict[feature_name + '_w'] = field
+ else:
+ parsed_dict[feature_name] = field_dict[input_0]
+ if len(fc.input_names) > 1:
+ input_1 = fc.input_names[1]
+ parsed_dict[feature_name + '_w'] = field_dict[input_1]
+
+ def _parse_expr_feature(self, fc, parsed_dict, field_dict):
+ fea_name = fc.feature_name
+ prefix = 'expr_'
+ for input_name in fc.input_names:
+ new_input_name = prefix + input_name
+ if field_dict[input_name].dtype == tf.string:
+ check_list = [
+ tf.py_func(
+ check_string_to_number, [field_dict[input_name], input_name],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ parsed_dict[new_input_name] = tf.string_to_number(
+ field_dict[input_name],
+ tf.float64,
+ name='%s_str_2_int_for_expr' % new_input_name)
+ elif field_dict[input_name].dtype in [
+ tf.int32, tf.int64, tf.double, tf.float32
+ ]:
+ parsed_dict[new_input_name] = tf.cast(field_dict[input_name],
+ tf.float64)
+ else:
+ assert False, 'invalid input dtype[%s] for expr feature' % str(
+ field_dict[input_name].dtype)
+
+ expression = get_expression(fc.expression, fc.input_names, prefix=prefix)
+ logging.info('expression: %s' % expression)
+ parsed_dict[fea_name] = eval(expression)
+ self._appended_fields.append(fea_name)
+
+ def _parse_id_feature(self, fc, parsed_dict, field_dict):
+ input_0 = fc.input_names[0]
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
+ parsed_dict[feature_name] = field_dict[input_0]
+ if fc.HasField('hash_bucket_size'):
+ if field_dict[input_0].dtype != tf.string:
+ parsed_dict[feature_name] = self._as_string(field_dict[input_0], fc)
+ elif fc.num_buckets > 0:
+ if parsed_dict[feature_name].dtype == tf.string:
+ check_list = [
+ tf.py_func(
+ check_string_to_number, [parsed_dict[feature_name], input_0],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ parsed_dict[feature_name] = tf.string_to_number(
+ parsed_dict[feature_name],
+ tf.int32,
+ name='%s_str_2_int' % input_0)
+
+ def _parse_raw_feature(self, fc, parsed_dict, field_dict):
+ input_0 = fc.input_names[0]
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
+ if field_dict[input_0].dtype == tf.string:
+ if fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
+ fea = tf.string_split(field_dict[input_0], fc.seq_multi_sep)
+ segment_ids = fea.indices[:, 0]
+ vals = fea.values
+ else:
+ vals = field_dict[input_0]
+ segment_ids = tf.range(0, tf.shape(vals)[0])
+ if fc.raw_input_dim > 1:
+ check_list = [
+ tf.py_func(
+ check_split, [vals, fc.separator, fc.raw_input_dim, input_0],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ tmp_fea = tf.string_split(vals, fc.separator)
+ check_list = [
+ tf.py_func(
+ check_string_to_number, [tmp_fea.values, input_0], Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ tmp_vals = tf.string_to_number(
+ tmp_fea.values,
+ tf.float32,
+ name='multi_raw_fea_to_flt_%s' % input_0)
+ if fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
+ emb = tf.reshape(tmp_vals, [-1, fc.raw_input_dim])
+ if fc.combiner == 'max':
+ emb = tf.segment_max(emb, segment_ids)
+ elif fc.combiner == 'sum':
+ emb = tf.segment_sum(emb, segment_ids)
+ elif fc.combiner == 'min':
+ emb = tf.segment_min(emb, segment_ids)
+ elif fc.combiner == 'mean':
+ emb = tf.segment_mean(emb, segment_ids)
+ else:
+ assert False, 'unsupported combine operator: ' + fc.combiner
+ parsed_dict[feature_name] = emb
+ else:
+ parsed_dict[feature_name] = tf.sparse_to_dense(
+ tmp_fea.indices,
+ [tf.shape(field_dict[input_0])[0], fc.raw_input_dim],
+ tmp_vals,
+ default_value=0)
+ elif fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
+ check_list = [
+ tf.py_func(check_string_to_number, [vals, input_0], Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ emb = tf.string_to_number(
+ vals, tf.float32, name='raw_fea_to_flt_%s' % input_0)
+ if fc.combiner == 'max':
+ emb = tf.segment_max(emb, segment_ids)
+ elif fc.combiner == 'sum':
+ emb = tf.segment_sum(emb, segment_ids)
+ elif fc.combiner == 'min':
+ emb = tf.segment_min(emb, segment_ids)
+ elif fc.combiner == 'mean':
+ emb = tf.segment_mean(emb, segment_ids)
+ else:
+ assert False, 'unsupported combine operator: ' + fc.combiner
+ parsed_dict[feature_name] = emb
+ else:
+ check_list = [
+ tf.py_func(
+ check_string_to_number, [field_dict[input_0], input_0],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ parsed_dict[feature_name] = tf.string_to_number(
+ field_dict[input_0], tf.float32)
+ elif field_dict[input_0].dtype in [
+ tf.int32, tf.int64, tf.double, tf.float32
+ ]:
+ parsed_dict[feature_name] = tf.to_float(field_dict[input_0])
+ else:
+ assert False, 'invalid dtype[%s] for raw feature' % str(
+ field_dict[input_0].dtype)
+ if fc.max_val > fc.min_val:
+ parsed_dict[feature_name] = (parsed_dict[feature_name] - fc.min_val) / (
+ fc.max_val - fc.min_val)
+
+ if fc.HasField('normalizer_fn'):
+ logging.info('apply normalizer_fn %s to `%s`' %
+ (fc.normalizer_fn, feature_name))
+ parsed_dict[feature_name] = self._normalizer_fn[feature_name](
+ parsed_dict[feature_name])
+
+ if not fc.boundaries and fc.num_buckets <= 1 and \
+ fc.embedding_dim > 0 and \
+ self._data_config.sample_weight != input_0:
+ # may need by wide model and deep model to project
+ # raw values to a vector, it maybe better implemented
+ # by a ProjectionColumn later
+ sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
+ indices_0 = tf.range(sample_num, dtype=tf.int64)
+ indices_1 = tf.range(fc.raw_input_dim, dtype=tf.int64)
+ indices_0 = indices_0[:, None]
+ indices_1 = indices_1[None, :]
+ indices_0 = tf.tile(indices_0, [1, fc.raw_input_dim])
+ indices_1 = tf.tile(indices_1, [sample_num, 1])
+ indices_0 = tf.reshape(indices_0, [-1, 1])
+ indices_1 = tf.reshape(indices_1, [-1, 1])
+ indices = tf.concat([indices_0, indices_1], axis=1)
+
+ tmp_parsed = parsed_dict[feature_name]
+ parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
+ indices=indices,
+ values=indices_1[:, 0],
+ dense_shape=[sample_num, fc.raw_input_dim])
+ parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
+ indices=indices,
+ values=tf.reshape(tmp_parsed, [-1]),
+ dense_shape=[sample_num, fc.raw_input_dim])
+ # self._appended_fields.append(input_0 + '_raw_proj_id')
+ # self._appended_fields.append(input_0 + '_raw_proj_val')
+
+ def _parse_seq_feature(self, fc, parsed_dict, field_dict):
+ input_0 = fc.input_names[0]
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
+ field = field_dict[input_0]
+ sub_feature_type = fc.sub_feature_type
+ # Construct the output of SeqFeature according to the dimension of field_dict.
+ # When the input field exceeds 2 dimensions, convert SeqFeature to 2D output.
+ if len(field.get_shape()) < 2:
+ parsed_dict[feature_name] = tf.strings.split(field, fc.separator)
+ if fc.HasField('seq_multi_sep'):
+ indices = parsed_dict[feature_name].indices
+ values = parsed_dict[feature_name].values
+ multi_vals = tf.string_split(values, fc.seq_multi_sep)
+ indices_1 = multi_vals.indices
+ indices = tf.gather(indices, indices_1[:, 0])
+ out_indices = tf.concat([indices, indices_1[:, 1:]], axis=1)
+ # 3 dimensional sparse tensor
+ out_shape = tf.concat(
+ [parsed_dict[feature_name].dense_shape, multi_vals.dense_shape[1:]],
+ axis=0)
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
+ out_indices, multi_vals.values, out_shape)
+ if (fc.num_buckets > 1 and fc.max_val == fc.min_val):
+ check_list = [
+ tf.py_func(
+ check_string_to_number,
+ [parsed_dict[feature_name].values, input_0],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
+ parsed_dict[feature_name].indices,
+ tf.string_to_number(
+ parsed_dict[feature_name].values,
+ tf.int64,
+ name='sequence_str_2_int_%s' % input_0),
+ parsed_dict[feature_name].dense_shape)
+ elif sub_feature_type == fc.RawFeature:
+ check_list = [
+ tf.py_func(
+ check_string_to_number,
+ [parsed_dict[feature_name].values, input_0],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
+ parsed_dict[feature_name].indices,
+ tf.string_to_number(
+ parsed_dict[feature_name].values,
+ tf.float32,
+ name='sequence_str_2_float_%s' % input_0),
+ parsed_dict[feature_name].dense_shape)
+ if fc.num_buckets > 1 and fc.max_val > fc.min_val:
+ normalized_values = (parsed_dict[feature_name].values - fc.min_val) / (
+ fc.max_val - fc.min_val)
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
+ parsed_dict[feature_name].indices, normalized_values,
+ parsed_dict[feature_name].dense_shape)
+ else:
+ parsed_dict[feature_name] = field
+ if not fc.boundaries and fc.num_buckets <= 1 and\
+ self._data_config.sample_weight != input_0 and\
+ sub_feature_type == fc.RawFeature and\
+ fc.raw_input_dim == 1:
+ logging.info(
+ 'Not set boundaries or num_buckets or hash_bucket_size, %s will process as two dimension sequence raw feature'
+ % feature_name)
+ parsed_dict[feature_name] = tf.sparse_to_dense(
+ parsed_dict[feature_name].indices,
+ [tf.shape(parsed_dict[feature_name])[0], fc.sequence_length],
+ parsed_dict[feature_name].values)
+ sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
+ indices_0 = tf.range(sample_num, dtype=tf.int64)
+ indices_1 = tf.range(fc.sequence_length, dtype=tf.int64)
+ indices_0 = indices_0[:, None]
+ indices_1 = indices_1[None, :]
+ indices_0 = tf.tile(indices_0, [1, fc.sequence_length])
+ indices_1 = tf.tile(indices_1, [sample_num, 1])
+ indices_0 = tf.reshape(indices_0, [-1, 1])
+ indices_1 = tf.reshape(indices_1, [-1, 1])
+ indices = tf.concat([indices_0, indices_1], axis=1)
+ tmp_parsed = parsed_dict[feature_name]
+ parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
+ indices=indices,
+ values=indices_1[:, 0],
+ dense_shape=[sample_num, fc.sequence_length])
+ parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
+ indices=indices,
+ values=tf.reshape(tmp_parsed, [-1]),
+ dense_shape=[sample_num, fc.sequence_length])
+ elif (not fc.boundaries and fc.num_buckets <= 1 and
+ self._data_config.sample_weight != input_0 and
+ sub_feature_type == fc.RawFeature and fc.raw_input_dim > 1):
+ # for 3 dimension sequence feature input.
+ logging.info('Not set boundaries or num_buckets or hash_bucket_size,'
+ ' %s will process as three dimension sequence raw feature' %
+ feature_name)
+ parsed_dict[feature_name] = tf.sparse_to_dense(
+ parsed_dict[feature_name].indices, [
+ tf.shape(parsed_dict[feature_name])[0], fc.sequence_length,
+ fc.raw_input_dim
+ ], parsed_dict[feature_name].values)
+ sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
+ indices_0 = tf.range(sample_num, dtype=tf.int64)
+ indices_1 = tf.range(fc.sequence_length, dtype=tf.int64)
+ indices_2 = tf.range(fc.raw_input_dim, dtype=tf.int64)
+ indices_0 = indices_0[:, None, None]
+ indices_1 = indices_1[None, :, None]
+ indices_2 = indices_2[None, None, :]
+ indices_0 = tf.tile(indices_0, [1, fc.sequence_length, fc.raw_input_dim])
+ indices_1 = tf.tile(indices_1, [sample_num, 1, fc.raw_input_dim])
+ indices_2 = tf.tile(indices_2, [sample_num, fc.sequence_length, 1])
+ indices_0 = tf.reshape(indices_0, [-1, 1])
+ indices_1 = tf.reshape(indices_1, [-1, 1])
+ indices_2 = tf.reshape(indices_2, [-1, 1])
+ indices = tf.concat([indices_0, indices_1, indices_2], axis=1)
+
+ tmp_parsed = parsed_dict[feature_name]
+ parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
+ indices=indices,
+ values=indices_1[:, 0],
+ dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim])
+ parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
+ indices=indices,
+ values=tf.reshape(parsed_dict[feature_name], [-1]),
+ dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim])
+ # self._appended_fields.append(input_0 + '_raw_proj_id')
+ # self._appended_fields.append(input_0 + '_raw_proj_val')
+
def _preprocess(self, field_dict):
"""Preprocess the feature columns.
@@ -248,11 +820,13 @@ def _preprocess(self, field_dict):
"""
parsed_dict = {}
- if self._sampler is not None:
+ if self._sampler is not None and self._mode != tf.estimator.ModeKeys.PREDICT:
+ if self._mode != tf.estimator.ModeKeys.TRAIN:
+ self._sampler.set_eval_num_sample()
sampler_type = self._data_config.WhichOneof('sampler')
sampler_config = getattr(self._data_config, sampler_type)
item_ids = field_dict[sampler_config.item_id_field]
- if sampler_type == 'negative_sampler':
+ if sampler_type in ['negative_sampler', 'negative_sampler_in_memory']:
sampled = self._sampler.get(item_ids)
elif sampler_type == 'negative_sampler_v2':
user_ids = field_dict[sampler_config.user_id_field]
@@ -266,212 +840,103 @@ def _preprocess(self, field_dict):
if k in field_dict:
field_dict[k] = tf.concat([field_dict[k], v], axis=0)
else:
+ print('appended fields: %s' % k)
parsed_dict[k] = v
self._appended_fields.append(k)
for fc in self._feature_configs:
feature_name = fc.feature_name
feature_type = fc.feature_type
- input_0 = fc.input_names[0]
if feature_type == fc.TagFeature:
- input_0 = fc.input_names[0]
- field = field_dict[input_0]
- # Construct the output of TagFeature according to the dimension of field_dict.
- # When the input field exceeds 2 dimensions, convert TagFeature to 2D output.
- if len(field.get_shape()) < 2 or field.get_shape()[-1] == 1:
- if len(field.get_shape()) == 0:
- field = tf.expand_dims(field, axis=0)
- elif len(field.get_shape()) == 2:
- field = tf.squeeze(field, axis=-1)
- parsed_dict[input_0] = tf.string_split(field, fc.separator)
- if fc.HasField('kv_separator'):
- indices = parsed_dict[input_0].indices
- tmp_kvs = parsed_dict[input_0].values
- tmp_kvs = tf.string_split(
- tmp_kvs, fc.kv_separator, skip_empty=False)
- tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2])
- tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1]
- tmp_vs = tf.string_to_number(
- tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0)
- parsed_dict[input_0] = tf.sparse.SparseTensor(
- indices, tmp_ks, parsed_dict[input_0].dense_shape)
- input_wgt = input_0 + '_WEIGHT'
- parsed_dict[input_wgt] = tf.sparse.SparseTensor(
- indices, tmp_vs, parsed_dict[input_0].dense_shape)
- self._appended_fields.append(input_wgt)
- if not fc.HasField('hash_bucket_size'):
- vals = tf.string_to_number(
- parsed_dict[input_0].values,
- tf.int32,
- name='tag_fea_%s' % input_0)
- parsed_dict[input_0] = tf.sparse.SparseTensor(
- parsed_dict[input_0].indices, vals,
- parsed_dict[input_0].dense_shape)
- if len(fc.input_names) > 1:
- input_1 = fc.input_names[1]
- field = field_dict[input_1]
- if len(field.get_shape()) == 0:
- field = tf.expand_dims(field, axis=0)
- field = tf.string_split(field, fc.separator)
- field_vals = tf.string_to_number(
- field.values, tf.float32, name='tag_wgt_str_2_flt_%s' % input_1)
- assert_op = tf.assert_equal(
- tf.shape(field_vals)[0],
- tf.shape(parsed_dict[input_0].values)[0],
- message='tag_feature_kv_size_not_eq_%s' % input_0)
- with tf.control_dependencies([assert_op]):
- field = tf.sparse.SparseTensor(field.indices,
- tf.identity(field_vals),
- field.dense_shape)
- parsed_dict[input_1] = field
- else:
- parsed_dict[input_0] = field_dict[input_0]
- if len(fc.input_names) > 1:
- input_1 = fc.input_names[1]
- parsed_dict[input_1] = field_dict[input_1]
+ self._parse_tag_feature(fc, parsed_dict, field_dict)
elif feature_type == fc.LookupFeature:
assert feature_name is not None and feature_name != ''
assert len(fc.input_names) == 2
parsed_dict[feature_name] = self._lookup_preprocess(fc, field_dict)
elif feature_type == fc.SequenceFeature:
- input_0 = fc.input_names[0]
- field = field_dict[input_0]
- # Construct the output of SeqFeature according to the dimension of field_dict.
- # When the input field exceeds 2 dimensions, convert SeqFeature to 2D output.
- if len(field.get_shape()) < 2:
- parsed_dict[input_0] = tf.strings.split(field, fc.separator)
- if fc.HasField('seq_multi_sep'):
- indices = parsed_dict[input_0].indices
- values = parsed_dict[input_0].values
- multi_vals = tf.string_split(values, fc.seq_multi_sep)
- indices_1 = multi_vals.indices
- indices = tf.gather(indices, indices_1[:, 0])
- out_indices = tf.concat([indices, indices_1[:, 1:]], axis=1)
- # 3 dimensional sparse tensor
- out_shape = tf.concat(
- [parsed_dict[input_0].dense_shape, multi_vals.dense_shape[1:]],
- axis=0)
- parsed_dict[input_0] = tf.sparse.SparseTensor(
- out_indices, multi_vals.values, out_shape)
- if fc.num_buckets > 0:
- parsed_dict[input_0] = tf.sparse.SparseTensor(
- parsed_dict[input_0].indices,
- tf.string_to_number(
- parsed_dict[input_0].values,
- tf.int64,
- name='sequence_str_2_int_%s' % input_0),
- parsed_dict[input_0].dense_shape)
- else:
- parsed_dict[input_0] = field
+ self._parse_seq_feature(fc, parsed_dict, field_dict)
elif feature_type == fc.RawFeature:
- input_0 = fc.input_names[0]
- if field_dict[input_0].dtype == tf.string:
- if fc.raw_input_dim > 1:
- tmp_fea = tf.string_split(field_dict[input_0], fc.separator)
- tmp_vals = tf.string_to_number(
- tmp_fea.values,
- tf.float32,
- name='multi_raw_fea_to_flt_%s' % input_0)
- parsed_dict[input_0] = tf.sparse_to_dense(
- tmp_fea.indices,
- [tf.shape(field_dict[input_0])[0], fc.raw_input_dim],
- tmp_vals,
- default_value=0)
- else:
- parsed_dict[input_0] = tf.string_to_number(field_dict[input_0],
- tf.float32)
- elif field_dict[input_0].dtype in [
- tf.int32, tf.int64, tf.double, tf.float32
- ]:
- parsed_dict[input_0] = tf.to_float(field_dict[input_0])
- else:
- assert False, 'invalid dtype[%s] for raw feature' % str(
- field_dict[input_0].dtype)
- if fc.max_val > fc.min_val:
- parsed_dict[input_0] = (parsed_dict[input_0] - fc.min_val) /\
- (fc.max_val - fc.min_val)
- if not fc.boundaries and fc.num_buckets <= 1 and \
- self._data_config.sample_weight != input_0:
- # may need by wide model and deep model to project
- # raw values to a vector, it maybe better implemented
- # by a ProjectionColumn later
- sample_num = tf.to_int64(tf.shape(parsed_dict[input_0])[0])
- indices_0 = tf.range(sample_num, dtype=tf.int64)
- indices_1 = tf.range(fc.raw_input_dim, dtype=tf.int64)
- indices_0 = indices_0[:, None]
- indices_1 = indices_1[None, :]
- indices_0 = tf.tile(indices_0, [1, fc.raw_input_dim])
- indices_1 = tf.tile(indices_1, [sample_num, 1])
- indices_0 = tf.reshape(indices_0, [-1, 1])
- indices_1 = tf.reshape(indices_1, [-1, 1])
- indices = tf.concat([indices_0, indices_1], axis=1)
-
- parsed_dict[input_0 + '_raw_proj_id'] = tf.SparseTensor(
- indices=indices,
- values=indices_1[:, 0],
- dense_shape=[sample_num, fc.raw_input_dim])
- parsed_dict[input_0 + '_raw_proj_val'] = tf.SparseTensor(
- indices=indices,
- values=tf.reshape(parsed_dict[input_0], [-1]),
- dense_shape=[sample_num, fc.raw_input_dim])
- self._appended_fields.append(input_0 + '_raw_proj_id')
- self._appended_fields.append(input_0 + '_raw_proj_val')
+ self._parse_raw_feature(fc, parsed_dict, field_dict)
elif feature_type == fc.IdFeature:
- input_0 = fc.input_names[0]
- parsed_dict[input_0] = field_dict[input_0]
- if fc.HasField('hash_bucket_size'):
- if field_dict[input_0].dtype != tf.string:
- if field_dict[input_0].dtype in [tf.float32, tf.double]:
- assert fc.precision > 0, 'it is dangerous to convert float or double to string due to ' \
- 'precision problem, it is suggested to convert them into string ' \
- 'format during feature generalization before using EasyRec; ' \
- 'if you really need to do so, please set precision (the number of ' \
- 'decimal digits) carefully.'
- precision = None
- if field_dict[input_0].dtype in [tf.float32, tf.double]:
- if fc.precision > 0:
- precision = fc.precision
- # convert to string
- if 'as_string' in dir(tf.strings):
- parsed_dict[input_0] = tf.strings.as_string(
- field_dict[input_0], precision=precision)
- else:
- parsed_dict[input_0] = tf.as_string(
- field_dict[input_0], precision=precision)
- elif fc.num_buckets > 0:
- if parsed_dict[input_0].dtype == tf.string:
- parsed_dict[input_0] = tf.string_to_number(
- parsed_dict[input_0], tf.int32, name='%s_str_2_int' % input_0)
+ self._parse_id_feature(fc, parsed_dict, field_dict)
+ elif feature_type == fc.ExprFeature:
+ self._parse_expr_feature(fc, parsed_dict, field_dict)
+ elif feature_type == fc.ComboFeature:
+ self._parse_combo_feature(fc, parsed_dict, field_dict)
else:
- for input_name in fc.input_names:
- parsed_dict[input_name] = field_dict[input_name]
+ feature_name = fc.feature_name if fc.HasField(
+ 'feature_name') else fc.input_names[0]
+ for input_id, input_name in enumerate(fc.input_names):
+ if input_id > 0:
+ key = feature_name + '_' + str(input_id)
+ else:
+ key = feature_name
+ parsed_dict[key] = field_dict[input_name]
+ label_dict = {}
for input_id, input_name in enumerate(self._label_fields):
if input_name not in field_dict:
continue
+ if input_name in self._label_udf_map:
+ udf, udf_class, dtype = self._label_udf_map[input_name]
+ if dtype is None or dtype == '':
+ logging.info('apply tensorflow function transform: %s' % udf_class)
+ field_dict[input_name] = udf(field_dict[input_name])
+ else:
+ assert dtype is not None, 'must set user_define_fn_res_type'
+ logging.info('apply py_func transform: %s' % udf_class)
+ field_dict[input_name] = tf.py_func(
+ udf, [field_dict[input_name]], Tout=get_tf_type(dtype))
+ field_dict[input_name].set_shape(tf.TensorShape([None]))
+
if field_dict[input_name].dtype == tf.string:
if self._label_dim[input_id] > 1:
logging.info('will split labels[%d]=%s' % (input_id, input_name))
- parsed_dict[input_name] = tf.string_split(
- field_dict[input_name], self._label_sep[input_id]).values
- parsed_dict[input_name] = tf.reshape(parsed_dict[input_name],
- [-1, self._label_dim[input_id]])
+ check_list = [
+ tf.py_func(
+ check_split, [
+ field_dict[input_name], self._label_sep[input_id],
+ self._label_dim[input_id], input_name
+ ],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ label_dict[input_name] = tf.string_split(
+ field_dict[input_name], self._label_sep[input_id]).values
+ label_dict[input_name] = tf.reshape(label_dict[input_name],
+ [-1, self._label_dim[input_id]])
else:
- parsed_dict[input_name] = field_dict[input_name]
- parsed_dict[input_name] = tf.string_to_number(
- parsed_dict[input_name], tf.float32, name=input_name)
+ label_dict[input_name] = field_dict[input_name]
+ check_list = [
+ tf.py_func(
+ check_string_to_number, [label_dict[input_name], input_name],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ label_dict[input_name] = tf.string_to_number(
+ label_dict[input_name], tf.float32, name=input_name)
else:
assert field_dict[input_name].dtype in [
tf.float32, tf.double, tf.int32, tf.int64
], 'invalid label dtype: %s' % str(field_dict[input_name].dtype)
- parsed_dict[input_name] = field_dict[input_name]
+ label_dict[input_name] = field_dict[input_name]
- if self._data_config.HasField('sample_weight'):
- if self._mode != tf.estimator.ModeKeys.PREDICT:
+ if self._mode != tf.estimator.ModeKeys.PREDICT:
+ for func_config in self._data_config.extra_label_func:
+ lbl_name = func_config.label_name
+ func_name = func_config.label_func
+ logging.info('generating new label `%s` by transform: %s' %
+ (lbl_name, func_name))
+ lbl_fn = load_by_path(func_name)
+ label_dict[lbl_name] = lbl_fn(label_dict)
+
+ if self._data_config.HasField('sample_weight'):
parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
self._data_config.sample_weight]
- return parsed_dict
+ if Input.DATA_OFFSET in field_dict:
+ parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
+ return {'feature': parsed_dict, 'label': label_dict}
def _lookup_preprocess(self, fc, field_dict):
"""Preprocess function for lookup features.
@@ -541,6 +1006,22 @@ def _lookup(args, pad=True):
def _build(self, mode, params):
raise NotImplementedError
+ def _pre_build(self, mode, params):
+ pass
+
+ def restore(self, checkpoint_path):
+ pass
+
+ def stop(self):
+ pass
+
+ def _safe_shard(self, dataset):
+ if self._data_config.chief_redundant:
+ return dataset.shard(
+ max(self._task_num - 1, 1), max(self._task_index - 1, 0))
+ else:
+ return dataset.shard(self._task_num, self._task_index)
+
def create_input(self, export_config=None):
def _input_fn(mode=None, params=None, config=None):
@@ -548,8 +1029,8 @@ def _input_fn(mode=None, params=None, config=None):
Args:
mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT)
- params: `dict` of hyper parameters, from Estimator
- config: tf.estimator.RunConfig instance
+ params: `dict` of hyper parameters, from Estimator
+ config: tf.estimator.RunConfig instance
Return:
if mode is not None, return:
@@ -558,6 +1039,7 @@ def _input_fn(mode=None, params=None, config=None):
else, return:
tf.estimator.export.ServingInputReceiver instance
"""
+ self._pre_build(mode, params)
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT):
# build dataset from self._config.input_path
@@ -565,17 +1047,18 @@ def _input_fn(mode=None, params=None, config=None):
dataset = self._build(mode, params)
return dataset
elif mode is None: # serving_input_receiver_fn for export SavedModel
+ place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
if export_config.multi_placeholder:
- if export_config.multi_value_fields:
- export_fields_name = export_config.multi_value_fields.input_name
- else:
- export_fields_name = None
- placeholder_named_by_input = export_config.placeholder_named_by_input
- inputs, features = self.create_multi_placeholders(
- placeholder_named_by_input, export_fields_name)
+ with conditional(place_on_cpu, ops.device('/CPU:0')):
+ inputs, features = self.create_multi_placeholders(export_config)
return tf.estimator.export.ServingInputReceiver(features, inputs)
else:
- inputs, features = self.create_placeholders(export_config)
+ with conditional(place_on_cpu, ops.device('/CPU:0')):
+ inputs, features = self.create_placeholders(export_config)
+ print('built feature placeholders. features: {}'.format(
+ features.keys()))
return tf.estimator.export.ServingInputReceiver(features, inputs)
+ _input_fn.input_creator = self
return _input_fn
diff --git a/easy_rec/python/input/kafka_dataset.py b/easy_rec/python/input/kafka_dataset.py
new file mode 100644
index 000000000..22ae45b90
--- /dev/null
+++ b/easy_rec/python/input/kafka_dataset.py
@@ -0,0 +1,144 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kafka Dataset."""
+
+import logging
+import traceback
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+try:
+ from easy_rec.python.ops import gen_kafka_ops
+except ImportError:
+ logging.warning('failed to import gen_kafka_ops: %s' % traceback.format_exc())
+
+
+class KafkaDataset(dataset_ops.Dataset):
+ """A Kafka Dataset that consumes the message."""
+
+ def __init__(self,
+ topics,
+ servers='localhost',
+ group='',
+ eof=False,
+ timeout=1000,
+ config_global=None,
+ config_topic=None,
+ message_key=False,
+ message_offset=False):
+ """Create a KafkaReader.
+
+ Args:
+ topics: A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset:length],
+ by default length is -1 for unlimited.
+ servers: A list of bootstrap servers.
+ group: The consumer group id.
+ eof: If True, the kafka reader will stop on EOF.
+ timeout: The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+ config_global: A `tf.string` tensor containing global configuration
+ properties in [Key=Value] format,
+ eg. ["enable.auto.commit=false",
+ "heartbeat.interval.ms=2000"],
+ please refer to 'Global configuration properties'
+ in librdkafka doc.
+ config_topic: A `tf.string` tensor containing topic configuration
+ properties in [Key=Value] format,
+ eg. ["auto.offset.reset=earliest"],
+ please refer to 'Topic configuration properties'
+ in librdkafka doc.
+ message_key: If True, the kafka will output both message value and key.
+ message_offset: If True, the kafka will output both message value and offset.
+ """
+ self._topics = ops.convert_to_tensor(
+ topics, dtype=dtypes.string, name='topics')
+ self._servers = ops.convert_to_tensor(
+ servers, dtype=dtypes.string, name='servers')
+ self._group = ops.convert_to_tensor(
+ group, dtype=dtypes.string, name='group')
+ self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name='eof')
+ self._timeout = ops.convert_to_tensor(
+ timeout, dtype=dtypes.int64, name='timeout')
+ config_global = config_global if config_global else []
+ self._config_global = ops.convert_to_tensor(
+ config_global, dtype=dtypes.string, name='config_global')
+ config_topic = config_topic if config_topic else []
+ self._config_topic = ops.convert_to_tensor(
+ config_topic, dtype=dtypes.string, name='config_topic')
+ self._message_key = message_key
+ self._message_offset = message_offset
+ super(KafkaDataset, self).__init__()
+
+ def _inputs(self):
+ return []
+
+ def _as_variant_tensor(self):
+ return gen_kafka_ops.io_kafka_dataset_v2(
+ self._topics,
+ self._servers,
+ self._group,
+ self._eof,
+ self._timeout,
+ self._config_global,
+ self._config_topic,
+ self._message_key,
+ self._message_offset,
+ )
+
+ @property
+ def output_classes(self):
+ if self._message_key ^ self._message_offset:
+ return (ops.Tensor, ops.Tensor)
+ elif self._message_key and self._message_offset:
+ return (ops.Tensor, ops.Tensor, ops.Tensor)
+ return (ops.Tensor)
+
+ @property
+ def output_shapes(self):
+ if self._message_key ^ self._message_offset:
+ return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([])))
+ elif self._message_key and self._message_offset:
+ return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
+ tensor_shape.TensorShape([])))
+ return ((tensor_shape.TensorShape([])))
+
+ @property
+ def output_types(self):
+ if self._message_key ^ self._message_offset:
+ return ((dtypes.string, dtypes.string))
+ elif self._message_key and self._message_offset:
+ return ((dtypes.string, dtypes.string, dtypes.string))
+ return ((dtypes.string))
+
+
+def write_kafka_v2(message, topic, servers='localhost', name=None):
+ """Write kafka.
+
+ Args:
+ message: A `Tensor` of type `string`. 0-D.
+ topic: A `tf.string` tensor containing one subscription,
+ in the format of topic:partition.
+ servers: A list of bootstrap servers.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `string`. 0-D.
+ """
+ return gen_kafka_ops.io_write_kafka_v2(
+ message=message, topic=topic, servers=servers, name=name)
diff --git a/easy_rec/python/input/kafka_input.py b/easy_rec/python/input/kafka_input.py
index 63bf5a4d2..38eab121f 100644
--- a/easy_rec/python/input/kafka_input.py
+++ b/easy_rec/python/input/kafka_input.py
@@ -1,117 +1,226 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
import logging
-import sys
+import traceback
+import six
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.input.kafka_dataset import KafkaDataset
+from easy_rec.python.utils.config_util import parse_time
+
+if tf.__version__.startswith('1.'):
+ from tensorflow.python.platform import gfile
+else:
+ import tensorflow.io.gfile as gfile
+
+try:
+ from kafka import KafkaConsumer, TopicPartition
+except ImportError:
+ logging.warning(
+ 'kafka-python is not installed[%s]. You can install it by: pip install kafka-python'
+ % traceback.format_exc())
if tf.__version__ >= '2.0':
+ ignore_errors = tf.data.experimental.ignore_errors()
tf = tf.compat.v1
+else:
+ ignore_errors = tf.contrib.data.ignore_errors()
class KafkaInput(Input):
+ DATA_OFFSET = 'DATA_OFFSET'
+
def __init__(self,
data_config,
feature_config,
kafka_config,
task_index=0,
- task_num=1):
- super(KafkaInput, self).__init__(data_config, feature_config, '',
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(KafkaInput,
+ self).__init__(data_config, feature_config, '', task_index, task_num,
+ check_mode, pipeline_config)
self._kafka = kafka_config
-
- def _parse_csv(self, line):
+ self._offset_dict = {}
+ if self._kafka is not None:
+ consumer = KafkaConsumer(
+ group_id='kafka_dataset_consumer',
+ bootstrap_servers=[self._kafka.server],
+ api_version_auto_timeout_ms=60000) # in miliseconds
+ partitions = consumer.partitions_for_topic(self._kafka.topic)
+ self._num_partition = len(partitions)
+ logging.info('all partitions[%d]: %s' % (self._num_partition, partitions))
+
+ # determine kafka offsets for each partition
+ offset_type = self._kafka.WhichOneof('offset')
+ if offset_type is not None:
+ if offset_type == 'offset_time':
+ ts = parse_time(self._kafka.offset_time)
+ input_map = {
+ TopicPartition(partition=part_id, topic=self._kafka.topic):
+ ts * 1000 for part_id in partitions
+ }
+ part_offsets = consumer.offsets_for_times(input_map)
+ # part_offsets is a dictionary:
+ # {
+ # TopicPartition(topic=u'kafka_data_20220408', partition=0):
+ # OffsetAndTimestamp(offset=2, timestamp=1650611437895)
+ # }
+ for part in part_offsets:
+ self._offset_dict[part.partition] = part_offsets[part].offset
+ logging.info(
+ 'Find offset by time, topic[%s], partition[%d], timestamp[%ss], offset[%d], offset_timestamp[%dms]'
+ % (self._kafka.topic, part.partition, ts,
+ part_offsets[part].offset, part_offsets[part].timestamp))
+ elif offset_type == 'offset_info':
+ offset_dict = json.loads(self._kafka.offset_info)
+ for part in offset_dict:
+ part_id = int(part)
+ self._offset_dict[part_id] = offset_dict[part]
+ else:
+ assert 'invalid offset_type: %s' % offset_type
+ self._task_offset_dict = {}
+
+ def _preprocess(self, field_dict):
+ output_dict = super(KafkaInput, self)._preprocess(field_dict)
+
+ # append offset fields
+ if Input.DATA_OFFSET in field_dict:
+ output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
+
+ # for _get_features to include DATA_OFFSET
+ if Input.DATA_OFFSET not in self._appended_fields:
+ self._appended_fields.append(Input.DATA_OFFSET)
+
+ return output_dict
+
+ def _parse_csv(self, line, message_key, message_offset):
record_defaults = [
self.get_type_defaults(t, v)
for t, v in zip(self._input_field_types, self._input_field_defaults)
]
- def _check_data(line):
- sep = self._data_config.separator
- if type(sep) != type(str):
- sep = sep.encode('utf-8')
- field_num = len(line[0].split(sep))
- assert field_num == len(record_defaults),\
- 'sep[%s] maybe invalid: field_num=%d, required_num=%d' % (sep, field_num, len(record_defaults))
- return True
-
- check_op = tf.py_func(_check_data, [line], Tout=tf.bool)
- with tf.control_dependencies([check_op]):
- fields = tf.decode_csv(
- line,
- field_delim=self._data_config.separator,
- record_defaults=record_defaults,
- name='decode_csv')
+ fields = tf.decode_csv(
+ line,
+ use_quote_delim=False,
+ field_delim=self._data_config.separator,
+ record_defaults=record_defaults,
+ name='decode_csv')
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
for x in self._label_fids:
inputs[self._input_fields[x]] = fields[x]
+
+ # record current offset
+ def _parse_offset(message_offset):
+ for kv in message_offset:
+ if six.PY3:
+ kv = kv.decode('utf-8')
+ k, v = kv.split(':')
+ k = int(k)
+ v = int(v)
+ if k not in self._task_offset_dict or v > self._task_offset_dict[k]:
+ self._task_offset_dict[k] = v
+ return json.dumps(self._task_offset_dict)
+
+ inputs[Input.DATA_OFFSET] = tf.py_func(_parse_offset, [message_offset],
+ tf.string)
return inputs
- def _build(self, mode, params):
- try:
- import tensorflow_io.kafka as kafka_io
- except ImportError:
- logging.error(
- 'Please install tensorflow-io, '
- 'version compatibility can refer to https://github.com/tensorflow/io#tensorflow-version-compatibility'
- )
+ def restore(self, checkpoint_path):
+ if checkpoint_path is None:
+ return
+
+ offset_path = checkpoint_path + '.offset'
+ if not gfile.Exists(offset_path):
+ return
+
+ logging.info('will restore kafka offset from %s' % offset_path)
+ with gfile.GFile(offset_path, 'r') as fin:
+ offset_dict = json.load(fin)
+ self._offset_dict = {}
+ for k in offset_dict:
+ v = offset_dict[k]
+ k = int(k)
+ if k not in self._offset_dict or v > self._offset_dict[k]:
+ self._offset_dict[k] = v
+
+ def _get_topics(self):
+ task_num = self._task_num
+ task_index = self._task_index
+ if self._data_config.chief_redundant and self._mode == tf.estimator.ModeKeys.TRAIN:
+ task_index = max(task_index - 1, 0)
+ task_num = max(task_num - 1, 1)
+
+ topics = []
+ self._task_offset_dict = {}
+ for part_id in range(self._num_partition):
+ if (part_id % task_num) == task_index:
+ offset = self._offset_dict.get(part_id, 0)
+ topics.append('%s:%d:%d' % (self._kafka.topic, part_id, offset))
+ self._task_offset_dict[part_id] = offset
+ logging.info('assigned topic partitions: %s' % (','.join(topics)))
+ assert len(
+ topics) > 0, 'no partitions are assigned for this task(%d/%d)' % (
+ self._task_index, self._task_num)
+ return topics
+ def _build(self, mode, params):
num_parallel_calls = self._data_config.num_parallel_calls
+ task_topics = self._get_topics()
if mode == tf.estimator.ModeKeys.TRAIN:
- train = self._kafka
- topics = []
- i = self._task_index
- assert len(train.offset) == 1 or len(train.offset) == train.partitions, \
- 'number of train.offset must be 1 or train.partitions'
- while i < train.partitions:
- offset_i = train.offset[i] if i < len(
- train.offset) else train.offset[-1]
- topics.append(train.topic + ':' + str(i) + ':' + str(offset_i) + ':-1')
- i = i + self._task_num
-
+ assert self._kafka is not None, 'kafka_train_input is not set.'
+ train_kafka = self._kafka
logging.info(
'train kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
- %
- (train.server, train.topic, self._task_num, self._task_index, topics))
- if len(topics) == 0:
- logging.info('train kafka topic is empty')
- sys.exit(1)
-
- dataset = kafka_io.KafkaDataset(
- topics, servers=train.server, group=train.group, eof=False)
- dataset = dataset.repeat(1)
+ % (train_kafka.server, train_kafka.topic, self._task_num,
+ self._task_index, task_topics))
+
+ dataset = KafkaDataset(
+ task_topics,
+ servers=train_kafka.server,
+ group=train_kafka.group,
+ eof=False,
+ config_global=list(self._kafka.config_global),
+ config_topic=list(self._kafka.config_topic),
+ message_key=True,
+ message_offset=True)
+
+ if self._data_config.shuffle:
+ dataset = dataset.shuffle(
+ self._data_config.shuffle_buffer_size,
+ seed=2020,
+ reshuffle_each_iteration=True)
else:
- eval = self._kafka
- topics = []
- i = 0
- assert len(eval.offset) == 1 or len(eval.offset) == eval.partitions, \
- 'number of eval.offset must be 1 or eval.partitions'
- while i < eval.partitions:
- offset_i = eval.offset[i] if i < len(eval.offset) else eval.offset[-1]
- topics.append(eval.topic + ':' + str(i) + ':' + str(eval.offset) +
- ':-1')
- i = i + 1
+ eval_kafka = self._kafka
+ assert self._kafka is not None, 'kafka_eval_input is not set.'
logging.info(
'eval kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
- % (eval.server, eval.topic, self._task_num, self._task_index, topics))
-
- if len(topics) == 0:
- logging.info('eval kafka topic is empty')
- sys.exit(1)
-
- dataset = kafka_io.KafkaDataset(
- topics, servers=eval.server, group=eval.group, eof=False)
- dataset = dataset.repeat(1)
+ % (eval_kafka.server, eval_kafka.topic, self._task_num,
+ self._task_index, task_topics))
+
+ dataset = KafkaDataset(
+ task_topics,
+ servers=self._kafka.server,
+ group=eval_kafka.group,
+ eof=False,
+ config_global=list(self._kafka.config_global),
+ config_topic=list(self._kafka.config_topic),
+ message_key=True,
+ message_offset=True)
dataset = dataset.batch(self._data_config.batch_size)
dataset = dataset.map(
self._parse_csv, num_parallel_calls=num_parallel_calls)
+ if self._data_config.ignore_error:
+ dataset = dataset.apply(ignore_errors)
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
dataset = dataset.map(
map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
diff --git a/easy_rec/python/input/load_parquet.py b/easy_rec/python/input/load_parquet.py
new file mode 100644
index 000000000..a798efb58
--- /dev/null
+++ b/easy_rec/python/input/load_parquet.py
@@ -0,0 +1,317 @@
+import logging
+import multiprocessing
+import queue
+
+import numpy as np
+import pandas as pd
+
+
+def start_data_proc(task_index,
+ task_num,
+ num_proc,
+ file_que,
+ data_que,
+ proc_start_que,
+ proc_stop_que,
+ batch_size,
+ label_fields,
+ sparse_fea_names,
+ dense_fea_names,
+ dense_fea_cfgs,
+ reserve_fields,
+ drop_remainder,
+ need_pack=True):
+ mp_ctxt = multiprocessing.get_context('spawn')
+ proc_arr = []
+ for proc_id in range(num_proc):
+ proc = mp_ctxt.Process(
+ target=load_data_proc,
+ args=(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
+ batch_size, label_fields, sparse_fea_names, dense_fea_names,
+ dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
+ task_num, need_pack),
+ name='task_%d_data_proc_%d' % (task_index, proc_id))
+ proc.daemon = True
+ proc.start()
+ proc_arr.append(proc)
+ return proc_arr
+
+
+def _should_stop(proc_stop_que):
+ try:
+ proc_stop_que.get(block=False)
+ logging.info('data_proc stop signal received')
+ proc_stop_que.close()
+ return True
+ except queue.Empty:
+ return False
+ except ValueError:
+ return True
+ except AssertionError:
+ return True
+
+
+def _add_to_que(data_dict, data_que, proc_stop_que):
+ while True:
+ try:
+ data_que.put(data_dict, timeout=5)
+ return True
+ except queue.Full:
+ logging.warning('data_que is full')
+ if _should_stop(proc_stop_que):
+ return False
+ except ValueError:
+ logging.warning('data_que is closed')
+ return False
+ except AssertionError:
+ logging.warning('data_que is closed')
+ return False
+
+
+def _get_one_file(file_que, proc_stop_que):
+ while True:
+ try:
+ input_file = file_que.get(timeout=1)
+ return input_file
+ except queue.Empty:
+ pass
+ return None
+
+
+def _pack_sparse_feas(data_dict, sparse_fea_names):
+ fea_val_arr = []
+ fea_len_arr = []
+ for fea_name in sparse_fea_names:
+ fea_len_arr.append(data_dict[fea_name][0])
+ fea_val_arr.append(data_dict[fea_name][1])
+ del data_dict[fea_name]
+ fea_lens = np.concatenate(fea_len_arr, axis=0)
+ fea_vals = np.concatenate(fea_val_arr, axis=0)
+ data_dict['sparse_fea'] = (fea_lens, fea_vals)
+
+
+def _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
+ fea_val_arr = []
+ for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
+ fea_val_arr.append(data_dict[fea_name].reshape([-1, fea_cfg.raw_input_dim]))
+ del data_dict[fea_name]
+ fea_vals = np.concatenate(fea_val_arr, axis=1)
+ data_dict['dense_fea'] = fea_vals
+
+
+def _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
+ for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
+ data_dict[fea_name] = data_dict[fea_name].reshape(
+ [-1, fea_cfg.raw_input_dim])
+
+
+def _load_dense(input_data, field_names, sid, eid, dense_dict):
+ for k in field_names:
+ if isinstance(input_data[k][0], np.ndarray):
+ np_dtype = type(input_data[k][sid][0])
+ dense_dict[k] = np.array([x[0] for x in input_data[k][sid:eid]],
+ dtype=np_dtype)
+ else:
+ dense_dict[k] = input_data[k][sid:eid].to_numpy()
+
+
+def _load_and_pad_dense(input_data, field_names, sid, dense_dict,
+ part_dense_dict, part_dense_dict_n, batch_size):
+ for k in field_names:
+ if isinstance(input_data[k][0], np.ndarray):
+ np_dtype = type(input_data[k][sid][0])
+ tmp_lbls = np.array([x[0] for x in input_data[k][sid:]], dtype=np_dtype)
+ else:
+ tmp_lbls = input_data[k][sid:].to_numpy()
+ if part_dense_dict is not None and k in part_dense_dict:
+ tmp_lbls = np.concatenate([part_dense_dict[k], tmp_lbls], axis=0)
+ if len(tmp_lbls) > batch_size:
+ dense_dict[k] = tmp_lbls[:batch_size]
+ part_dense_dict_n[k] = tmp_lbls[batch_size:]
+ elif len(tmp_lbls) == batch_size:
+ dense_dict[k] = tmp_lbls
+ else:
+ part_dense_dict_n[k] = tmp_lbls
+ else:
+ part_dense_dict_n[k] = tmp_lbls
+
+
+def load_data_proc(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
+ batch_size, label_fields, sparse_fea_names, dense_fea_names,
+ dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
+ task_num, need_pack):
+ logging.info('data proc %d start, proc_start_que=%s' %
+ (proc_id, proc_start_que.qsize()))
+ proc_start_que.get()
+ effective_fields = sparse_fea_names + dense_fea_names
+ all_fields = effective_fields
+ if label_fields is not None:
+ all_fields = all_fields + label_fields
+ if reserve_fields is not None:
+ for tmp in reserve_fields:
+ if tmp not in all_fields:
+ all_fields.append(tmp)
+ logging.info('data proc %d start, file_que.qsize=%d' %
+ (proc_id, file_que.qsize()))
+ num_files = 0
+ part_data_dict = {}
+
+ is_good = True
+ total_batch_cnt = 0
+ total_sample_cnt = 0
+ while is_good:
+ if _should_stop(proc_stop_que):
+ is_good = False
+ break
+ input_file = _get_one_file(file_que, proc_stop_que)
+ if input_file is None:
+ break
+ num_files += 1
+ input_data = pd.read_parquet(input_file, columns=all_fields)
+ data_len = len(input_data[all_fields[0]])
+ total_sample_cnt += data_len
+ batch_num = int(data_len / batch_size)
+ res_num = data_len % batch_size
+
+ sid = 0
+ for batch_id in range(batch_num):
+ eid = sid + batch_size
+ data_dict = {}
+
+ if label_fields is not None and len(label_fields) > 0:
+ _load_dense(input_data, label_fields, sid, eid, data_dict)
+
+ if reserve_fields is not None and len(reserve_fields) > 0:
+ data_dict['reserve'] = {}
+ _load_dense(input_data, reserve_fields, sid, eid, data_dict['reserve'])
+
+ if len(sparse_fea_names) > 0:
+ for k in sparse_fea_names:
+ val = input_data[k][sid:eid]
+ if isinstance(input_data[k][sid], np.ndarray):
+ all_lens = np.array([len(x) for x in val], dtype=np.int32)
+ all_vals = np.concatenate(val.to_numpy())
+ else:
+ all_lens = np.ones([len(val)], dtype=np.int32)
+ all_vals = val.to_numpy()
+ assert np.sum(all_lens) == len(
+ all_vals), 'len(all_vals)=%d np.sum(all_lens)=%d' % (
+ len(all_vals), np.sum(all_lens))
+ data_dict[k] = (all_lens, all_vals)
+
+ if len(dense_fea_names) > 0:
+ _load_dense(input_data, dense_fea_names, sid, eid, data_dict)
+
+ if need_pack:
+ if len(sparse_fea_names) > 0:
+ _pack_sparse_feas(data_dict, sparse_fea_names)
+ if len(dense_fea_names) > 0:
+ _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
+ else:
+ if len(dense_fea_names) > 0:
+ _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
+ # logging.info('task_index=%d sid=%d eid=%d total_len=%d' % (task_index, sid, eid,
+ # len(data_dict['sparse_fea'][1])))
+ if not _add_to_que(data_dict, data_que, proc_stop_que):
+ logging.info('add to que failed')
+ is_good = False
+ break
+ total_batch_cnt += 1
+ sid += batch_size
+
+ if res_num > 0 and is_good:
+ data_dict = {}
+ part_data_dict_n = {}
+
+ if label_fields is not None and len(label_fields) > 0:
+ _load_and_pad_dense(input_data, label_fields, sid, data_dict,
+ part_data_dict, part_data_dict_n, batch_size)
+
+ if reserve_fields is not None and len(reserve_fields) > 0:
+ data_dict['reserve'] = {}
+ part_data_dict_n['reserve'] = {}
+ _load_and_pad_dense(input_data, label_fields, sid, data_dict['reserve'],
+ part_data_dict['reserve'],
+ part_data_dict_n['reserve'], batch_size)
+
+ if len(dense_fea_names) > 0:
+ _load_and_pad_dense(input_data, dense_fea_names, sid, data_dict,
+ part_data_dict, part_data_dict_n, batch_size)
+
+ if len(sparse_fea_names) > 0:
+ for k in sparse_fea_names:
+ val = input_data[k][sid:]
+
+ if isinstance(input_data[k][sid], np.ndarray):
+ all_lens = np.array([len(x) for x in val], dtype=np.int32)
+ all_vals = np.concatenate(val.to_numpy())
+ else:
+ all_lens = np.ones([len(val)], dtype=np.int32)
+ all_vals = val.to_numpy()
+
+ if part_data_dict is not None and k in part_data_dict:
+ tmp_lens = np.concatenate([part_data_dict[k][0], all_lens], axis=0)
+ tmp_vals = np.concatenate([part_data_dict[k][1], all_vals], axis=0)
+ if len(tmp_lens) > batch_size:
+ tmp_res_lens = tmp_lens[batch_size:]
+ tmp_lens = tmp_lens[:batch_size]
+ tmp_num_elems = np.sum(tmp_lens)
+ tmp_res_vals = tmp_vals[tmp_num_elems:]
+ tmp_vals = tmp_vals[:tmp_num_elems]
+ part_data_dict_n[k] = (tmp_res_lens, tmp_res_vals)
+ data_dict[k] = (tmp_lens, tmp_vals)
+ elif len(tmp_lens) == batch_size:
+ data_dict[k] = (tmp_lens, tmp_vals)
+ else:
+ part_data_dict_n[k] = (tmp_lens, tmp_vals)
+ else:
+ part_data_dict_n[k] = (all_lens, all_vals)
+
+ if effective_fields[0] in data_dict:
+ if need_pack:
+ if len(sparse_fea_names) > 0:
+ _pack_sparse_feas(data_dict, sparse_fea_names)
+ if len(dense_fea_names) > 0:
+ _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
+ else:
+ if len(dense_fea_names) > 0:
+ _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
+ if not _add_to_que(data_dict, data_que, proc_stop_que):
+ logging.info('add to que failed')
+ is_good = False
+ break
+ total_batch_cnt += 1
+ part_data_dict = part_data_dict_n
+ if len(part_data_dict) > 0 and is_good:
+ batch_len = len(part_data_dict[effective_fields[0]][0])
+ if not drop_remainder:
+ if need_pack:
+ if len(sparse_fea_names) > 0:
+ _pack_sparse_feas(part_data_dict, sparse_fea_names)
+ if len(dense_fea_names) > 0:
+ _pack_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
+ else:
+ if len(dense_fea_names) > 0:
+ _reshape_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
+ logging.info('remainder batch: %s sample_num=%d' %
+ (','.join(part_data_dict.keys()), batch_len))
+ _add_to_que(part_data_dict, data_que, proc_stop_que)
+ total_batch_cnt += 1
+ else:
+ logging.warning('drop remain %d samples as drop_remainder is set' %
+ batch_len)
+ if is_good:
+ is_good = _add_to_que(None, data_que, proc_stop_que)
+ logging.info(
+ 'data_proc_id[%d]: is_good = %s, total_batch_cnt=%d, total_sample_cnt=%d'
+ % (proc_id, is_good, total_batch_cnt, total_sample_cnt))
+ data_que.close(wait_send_finish=is_good)
+
+ while not is_good:
+ try:
+ if file_que.get(timeout=1) is None:
+ break
+ except queue.Empty:
+ pass
+ file_que.close()
+ logging.info('data proc %d done, file_num=%d' % (proc_id, num_files))
diff --git a/easy_rec/python/input/odps_input.py b/easy_rec/python/input/odps_input.py
index d70cbd42d..bbd08bc39 100644
--- a/easy_rec/python/input/odps_input.py
+++ b/easy_rec/python/input/odps_input.py
@@ -18,9 +18,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(OdpsInput, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(OdpsInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
def _build(self, mode, params):
# check data_config are consistent with odps tables
@@ -42,7 +45,10 @@ def _build(self, mode, params):
slice_id=self._task_index)
if type(self._input_path) != list:
- self._input_path = [x for x in self._input_path.split(',')]
+ self._input_path = self._input_path.split(',')
+ assert len(
+ self._input_path) > 0, 'match no files with %s' % self._input_path
+
if mode == tf.estimator.ModeKeys.TRAIN:
if self._data_config.pai_worker_queue:
work_queue = pai.data.WorkQueue(
diff --git a/easy_rec/python/input/odps_input_v2.py b/easy_rec/python/input/odps_input_v2.py
index 60a2ae080..bd58de390 100644
--- a/easy_rec/python/input/odps_input_v2.py
+++ b/easy_rec/python/input/odps_input_v2.py
@@ -20,9 +20,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(OdpsInputV2, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(OdpsInputV2,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
def _parse_table(self, *fields):
fields = list(fields)
@@ -33,8 +36,9 @@ def _parse_table(self, *fields):
def _build(self, mode, params):
if type(self._input_path) != list:
- self._input_path = [x for x in self._input_path.split(',')]
-
+ self._input_path = self._input_path.split(',')
+ assert len(
+ self._input_path) > 0, 'match no files with %s' % self._input_path
# check data_config are consistent with odps tables
odps_util.check_input_field_and_types(self._data_config)
diff --git a/easy_rec/python/input/odps_input_v3.py b/easy_rec/python/input/odps_input_v3.py
index 4a67049b2..6ec737ca1 100644
--- a/easy_rec/python/input/odps_input_v3.py
+++ b/easy_rec/python/input/odps_input_v3.py
@@ -4,11 +4,11 @@
import logging
import sys
-import numpy as np
import tensorflow as tf
from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util
+from easy_rec.python.utils.tf_utils import get_tf_type
try:
import common_io
@@ -24,13 +24,17 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(OdpsInputV3, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(OdpsInputV3,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
self._num_epoch = 0
if common_io is None:
- logging.error("""please install common_io pip install
- https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/common_io-0.1.0-cp37-cp37m-linux_x86_64.whl"""
+ logging.error('''
+ please install common_io pip install
+ https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/common_io-0.4.2%2Btunnel-py2.py3-none-any.whl'''
)
sys.exit(1)
@@ -45,7 +49,9 @@ def _odps_read(self):
logging.info('start epoch[%d]' % self._num_epoch)
self._num_epoch += 1
if type(self._input_path) != list:
- self._input_path = [x for x in self._input_path.split(',')]
+ self._input_path = self._input_path.split(',')
+ assert len(
+ self._input_path) > 0, 'match no files with %s' % self._input_path
# check data_config are consistent with odps tables
odps_util.check_input_field_and_types(self._data_config)
@@ -66,7 +72,7 @@ def _odps_read(self):
batch_num = int(total_records_num / self._data_config.batch_size)
res_num = total_records_num - batch_num * self._data_config.batch_size
batch_defaults = [
- np.array([x] * self._data_config.batch_size) for x in record_defaults
+ [x] * self._data_config.batch_size for x in record_defaults
]
for batch_id in range(batch_num):
batch_data_np = [x.copy() for x in batch_defaults]
@@ -88,7 +94,7 @@ def _odps_read(self):
def _build(self, mode, params):
# get input type
- list_type = [self.get_tf_type(x) for x in self._input_field_types]
+ list_type = [get_tf_type(x) for x in self._input_field_types]
list_type = tuple(list_type)
list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
list_shapes = tuple(list_shapes)
diff --git a/easy_rec/python/input/odps_rtp_input.py b/easy_rec/python/input/odps_rtp_input.py
index bbcf70cc2..6ae6096e0 100644
--- a/easy_rec/python/input/odps_rtp_input.py
+++ b/easy_rec/python/input/odps_rtp_input.py
@@ -5,6 +5,8 @@
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
+from easy_rec.python.utils.check_utils import check_split
from easy_rec.python.utils.input_utils import string_to_number
try:
@@ -32,9 +34,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(OdpsRTPInput, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(OdpsRTPInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
logging.info('input_fields: %s label_fields: %s' %
(','.join(self._input_fields), ','.join(self._label_fields)))
@@ -42,20 +47,46 @@ def _parse_table(self, *fields):
fields = list(fields)
labels = fields[:-1]
- # only for features, labels excluded
+ selected_cols = self._data_config.selected_cols \
+ if self._data_config.selected_cols else None
+ non_feature_cols = self._label_fields
+ if selected_cols:
+ cols = [c.strip() for c in selected_cols.split(',')]
+ non_feature_cols = cols[:-1]
+ # only for features, labels and sample_weight excluded
record_types = [
t for x, t in zip(self._input_fields, self._input_field_types)
- if x not in self._label_fields
+ if x not in non_feature_cols
]
+ record_defaults = [
+ self.get_type_defaults(t, v)
+ for x, t, v in zip(self._input_fields, self._input_field_types,
+ self._input_field_defaults)
+ if x not in non_feature_cols
+ ]
+
+ feature_num = len(record_types)
# assume that the last field is the generated feature column
- print('field_delim = %s, input_field_name = %d' %
- (self._data_config.separator, len(record_types)))
- fields = tf.string_split(
- fields[-1], self._data_config.separator, skip_empty=False)
- tmp_fields = tf.reshape(fields.values, [-1, len(record_types)])
- fields = []
- for i in range(len(record_types)):
- field = string_to_number(tmp_fields[:, i], record_types[i], i)
+ print('field_delim = %s, feature_num = %d' %
+ (self._data_config.separator, feature_num))
+ logging.info('field_delim = %s, input_field_name = %d' %
+ (self._data_config.separator, len(record_types)))
+
+ check_list = [
+ tf.py_func(
+ check_split,
+ [fields[-1], self._data_config.separator,
+ len(record_types)],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ fields = str_split_by_chr(
+ fields[-1], self._data_config.separator, skip_empty=False)
+ tmp_fields = tf.reshape(fields.values, [-1, feature_num])
+ fields = labels[len(self._label_fields):]
+ for i in range(feature_num):
+ field = string_to_number(tmp_fields[:, i], record_types[i],
+ record_defaults[i], i)
fields.append(field)
field_keys = [x for x in self._input_fields if x not in self._label_fields]
@@ -64,18 +95,35 @@ def _parse_table(self, *fields):
for x in range(len(self._label_fields)):
inputs[self._label_fields[x]] = labels[x]
+ print('effective field num = %d, input_num = %d' %
+ (len(fields), len(inputs)))
return inputs
def _build(self, mode, params):
if type(self._input_path) != list:
- self._input_path = [x for x in self._input_path.split(',')]
+ self._input_path = self._input_path.split(',')
+ assert len(
+ self._input_path) > 0, 'match no files with %s' % self._input_path
- record_defaults = [
- self.get_type_defaults(t, v)
- for x, t, v in zip(self._input_fields, self._input_field_types,
- self._input_field_defaults)
- if x in self._label_fields
- ]
+ selected_cols = self._data_config.selected_cols \
+ if self._data_config.selected_cols else None
+ if selected_cols:
+ cols = [c.strip() for c in selected_cols.split(',')]
+ record_defaults = [
+ self.get_type_defaults(t, v)
+ for x, t, v in zip(self._input_fields, self._input_field_types,
+ self._input_field_defaults)
+ if x in cols[:-1]
+ ]
+ print('selected_cols: %s; defaults num: %d' %
+ (','.join(cols), len(record_defaults)))
+ else:
+ record_defaults = [
+ self.get_type_defaults(t, v)
+ for x, t, v in zip(self._input_fields, self._input_field_types,
+ self._input_field_defaults)
+ if x in self._label_fields
+ ]
# the actual features are in one single column
record_defaults.append(
self._data_config.separator.join([
@@ -84,8 +132,6 @@ def _build(self, mode, params):
self._input_field_defaults)
if x not in self._label_fields
]))
- selected_cols = self._data_config.selected_cols \
- if self._data_config.selected_cols else None
if self._data_config.pai_worker_queue and \
mode == tf.estimator.ModeKeys.TRAIN:
diff --git a/easy_rec/python/input/odps_rtp_input_v2.py b/easy_rec/python/input/odps_rtp_input_v2.py
new file mode 100644
index 000000000..77edb46c1
--- /dev/null
+++ b/easy_rec/python/input/odps_rtp_input_v2.py
@@ -0,0 +1,104 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
+import logging
+
+import tensorflow as tf
+
+from easy_rec.python.input.odps_rtp_input import OdpsRTPInput
+
+if tf.__version__.startswith('1.'):
+ from tensorflow.python.platform import gfile
+else:
+ import tensorflow.io.gfile as gfile
+try:
+ import pai
+ import rtp_fg
+except Exception:
+ pai = None
+ rtp_fg = None
+
+
+class OdpsRTPInputV2(OdpsRTPInput):
+ """RTPInput for parsing rtp fg new input format on odps.
+
+ Our new format(csv in table) of rtp output:
+ label0, item_id, ..., user_id, features
+ Where features is in default RTP-tensorflow format.
+ The features column and labels are specified by data_config.selected_cols,
+ columns are selected by names in the table
+ such as: clk,features, the last selected column is features, the first
+ selected columns are labels
+ """
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ fg_json_path=None,
+ pipeline_config=None):
+ super(OdpsRTPInputV2,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
+ if fg_json_path.startswith('!'):
+ fg_json_path = fg_json_path[1:]
+ self._fg_config_path = fg_json_path
+ logging.info('fg config path: {}'.format(self._fg_config_path))
+ if self._fg_config_path is None:
+ raise ValueError('fg_json_path is not set')
+ with gfile.GFile(self._fg_config_path, 'r') as f:
+ self._fg_config = json.load(f)
+
+ def _parse_table(self, *fields):
+ self.check_rtp()
+
+ fields = list(fields)
+ labels = fields[:-1]
+
+ # assume that the last field is the generated feature column
+ features = rtp_fg.parse_genreated_fg(self._fg_config, fields[-1])
+
+ field_keys = [x for x in self._input_fields if x not in self._label_fields]
+ for feature_key in features:
+ if feature_key not in field_keys or feature_key not in self._effective_fields:
+ del features[feature_key]
+ inputs = {x: features[x] for x in features.keys()}
+
+ for x in range(len(self._label_fields)):
+ inputs[self._label_fields[x]] = labels[x]
+ return inputs
+
+ def create_placeholders(self, *args, **kwargs):
+ """Create serving placeholders with rtp_fg."""
+ self.check_rtp()
+ self._mode = tf.estimator.ModeKeys.PREDICT
+ inputs_placeholder = tf.placeholder(tf.string, [None], name='features')
+ print('[OdpsRTPInputV2] building placeholders.')
+ print('[OdpsRTPInputV2] fg_config: {}'.format(self._fg_config))
+ features = rtp_fg.parse_genreated_fg(self._fg_config, inputs_placeholder)
+ print('[OdpsRTPInputV2] built features: {}'.format(features.keys()))
+ features = self._preprocess(features)
+ print('[OdpsRTPInputV2] processed features: {}'.format(features.keys()))
+ return {'features': inputs_placeholder}, features['feature']
+
+ def create_multi_placeholders(self, *args, **kwargs):
+ """Create serving multi-placeholders with rtp_fg."""
+ raise NotImplementedError(
+ 'create_multi_placeholders is not supported for OdpsRTPInputV2')
+
+ def check_rtp(self):
+ if rtp_fg is None:
+ raise NotImplementedError(
+ 'OdpsRTPInputV2 cannot run without rtp_fg, which is not installed')
+
+ def _pre_build(self, mode, params):
+ try:
+ # Prevent TF from replacing the shape tensor to a constant tensor. This will
+ # cause the batch size being fixed. And RTP will be not able to recognize
+ # the input shape.
+ tf.get_default_graph().set_shape_optimize(False)
+ except AttributeError as e:
+ logging.warning('failed to disable shape optimization:', e)
diff --git a/easy_rec/python/input/parquet_input.py b/easy_rec/python/input/parquet_input.py
new file mode 100644
index 000000000..dcc6e867f
--- /dev/null
+++ b/easy_rec/python/input/parquet_input.py
@@ -0,0 +1,397 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+import multiprocessing
+import queue
+import time
+
+import tensorflow as tf
+from tensorflow.python.ops import array_ops
+
+from easy_rec.python.compat import queues
+from easy_rec.python.input import load_parquet
+from easy_rec.python.input.input import Input
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class ParquetInput(Input):
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None,
+ **kwargs):
+ super(ParquetInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config, **kwargs)
+ self._need_pack = True
+ if input_path is None:
+ return
+
+ self._input_files = []
+ for sub_path in input_path.strip().split(','):
+ self._input_files.extend(tf.gfile.Glob(sub_path))
+ logging.info('parquet input_path=%s file_num=%d' %
+ (input_path, len(self._input_files)))
+ mp_ctxt = multiprocessing.get_context('spawn')
+ self._data_que = queues.Queue(
+ name='data_que', ctx=mp_ctxt, maxsize=self._data_config.prefetch_size)
+
+ file_num = len(self._input_files)
+ logging.info('[task_index=%d] total_file_num=%d task_num=%d' %
+ (task_index, file_num, task_num))
+
+ self._my_files = []
+ for file_id in range(file_num):
+ if (file_id % task_num) == task_index:
+ self._my_files.append(self._input_files[file_id])
+ # self._my_files = self._input_files
+
+ logging.info('[task_index=%d] task_file_num=%d' %
+ (task_index, len(self._my_files)))
+ self._file_que = queues.Queue(name='file_que', ctx=mp_ctxt)
+
+ self._num_proc = 8
+ if file_num < self._num_proc:
+ self._num_proc = file_num
+
+ self._proc_start = False
+ self._proc_start_que = queues.Queue(name='proc_start_que', ctx=mp_ctxt)
+ self._proc_stop = False
+ self._proc_stop_que = queues.Queue(name='proc_stop_que', ctx=mp_ctxt)
+
+ self._reserve_fields = None
+ self._reserve_types = None
+ if 'reserve_fields' in kwargs and 'reserve_types' in kwargs:
+ self._reserve_fields = kwargs['reserve_fields']
+ self._reserve_types = kwargs['reserve_types']
+
+ # indicator whether is called from Predictor, do not go pass
+ if 'is_predictor' in kwargs:
+ self._is_predictor = kwargs['is_predictor']
+ else:
+ self._is_predictor = False
+
+ self._proc_arr = None
+
+ self._sparse_fea_names = []
+ self._dense_fea_names = []
+ self._dense_fea_cfgs = []
+ self._total_dense_fea_dim = 0
+ for fc in self._feature_configs:
+ feature_type = fc.feature_type
+ if feature_type in [fc.IdFeature, fc.TagFeature]:
+ input_name0 = fc.input_names[0]
+ self._sparse_fea_names.append(input_name0)
+ elif feature_type in [fc.RawFeature]:
+ input_name0 = fc.input_names[0]
+ self._dense_fea_names.append(input_name0)
+ self._dense_fea_cfgs.append(fc)
+ self._total_dense_fea_dim += fc.raw_input_dim
+ else:
+ assert False, 'feature_type[%s] not supported' % str(feature_type)
+
+ def _rebuild_que(self):
+ mp_ctxt = multiprocessing.get_context('spawn')
+ self._data_que = queues.Queue(
+ name='data_que', ctx=mp_ctxt, maxsize=self._data_config.prefetch_size)
+ self._file_que = queues.Queue(name='file_que', ctx=mp_ctxt)
+ self._proc_start_que = queues.Queue(name='proc_start_que', ctx=mp_ctxt)
+ self._proc_stop_que = queues.Queue(name='proc_stop_que', ctx=mp_ctxt)
+
+ def _sample_generator(self):
+ if not self._proc_start:
+ self._proc_start = True
+ for proc in (self._proc_arr):
+ self._proc_start_que.put(True)
+ logging.info('task[%s] data_proc=%s is_alive=%s' %
+ (self._task_index, proc, proc.is_alive()))
+
+ done_proc_cnt = 0
+ fetch_timeout_cnt = 0
+
+ # # for mock purpose
+ # all_samples = []
+ # while len(all_samples) < 64:
+ # try:
+ # sample = self._data_que.get(block=False)
+ # all_samples.append(sample)
+ # except queue.Empty:
+ # continue
+ # sid = 0
+ # while True:
+ # yield all_samples[sid]
+ # sid += 1
+ # if sid >= len(all_samples):
+ # sid = 0
+
+ fetch_good_cnt = 0
+ while True:
+ try:
+ sample = self._data_que.get(timeout=1)
+ if sample is None:
+ done_proc_cnt += 1
+ else:
+ fetch_good_cnt += 1
+ yield sample
+ if fetch_good_cnt % 200 == 0:
+ logging.info(
+ 'task[%d] fetch_batch_cnt=%d, fetch_timeout_cnt=%d, qsize=%d' %
+ (self._task_index, fetch_good_cnt, fetch_timeout_cnt,
+ self._data_que.qsize()))
+ except queue.Empty:
+ fetch_timeout_cnt += 1
+ if done_proc_cnt >= len(self._proc_arr):
+ logging.info('all sample finished, fetch_timeout_cnt=%d' %
+ fetch_timeout_cnt)
+ break
+ except Exception as ex:
+ logging.warning('task[%d] get from data_que exception: %s' %
+ (self._task_index, str(ex)))
+ break
+ logging.info('task[%d] sample_generator: total_batches=%d' %
+ (self._task_index, fetch_good_cnt))
+
+ def stop(self):
+ if self._proc_arr is None or len(self._proc_arr) == 0:
+ return
+ logging.info('task[%d] will stop dataset procs, proc_num=%d' %
+ (self._task_index, len(self._proc_arr)))
+ self._file_que.close()
+ if self._proc_start:
+ logging.info('try close data que')
+ for _ in range(len(self._proc_arr)):
+ self._proc_stop_que.put(1)
+ self._proc_stop_que.close()
+
+ def _any_alive():
+ for proc in self._proc_arr:
+ if proc.is_alive():
+ return True
+ return False
+
+ # to ensure the sender part of the python Queue could exit
+ while _any_alive():
+ try:
+ self._data_que.get(timeout=1)
+ except Exception:
+ pass
+ time.sleep(1)
+ self._data_que.close()
+ logging.info('data que closed')
+ # import time
+ # time.sleep(10)
+ for proc in self._proc_arr:
+ # proc.terminate()
+ proc.join()
+ logging.info('join proc done')
+
+ # rebuild for next run, which is necessary for evaluation
+ self._rebuild_que()
+ self._proc_arr = None
+ self._proc_start = False
+ self._proc_stop = False
+
+ def _to_fea_dict(self, input_dict):
+ fea_dict = {}
+
+ if len(self._sparse_fea_names) > 0:
+ if self._has_ev:
+ tmp_vals, tmp_lens = input_dict['sparse_fea'][1], input_dict[
+ 'sparse_fea'][0]
+
+ fea_dict['sparse_fea'] = (tmp_vals, tmp_lens)
+ else:
+ tmp_vals, tmp_lens = input_dict['sparse_fea'][1], input_dict[
+ 'sparse_fea'][0]
+ num_buckets = -1
+ for fc in self._feature_configs:
+ if fc.num_buckets > 0:
+ if num_buckets < 0:
+ num_buckets = fc.num_buckets
+ else:
+ assert num_buckets == fc.num_buckets, 'all features must share the same buckets, but are %d and %s' % (
+ num_buckets, str(fc))
+ fea_dict['sparse_fea'] = (tmp_vals % num_buckets, tmp_lens)
+
+ if len(self._dense_fea_names) > 0:
+ fea_dict['dense_fea'] = input_dict['dense_fea']
+
+ output_dict = {'feature': fea_dict}
+
+ lbl_dict = {}
+ for lbl_name in self._label_fields:
+ if lbl_name in input_dict:
+ lbl_dict[lbl_name] = input_dict[lbl_name]
+
+ if len(lbl_dict) > 0:
+ output_dict['label'] = lbl_dict
+
+ if self._reserve_fields is not None:
+ output_dict['reserve'] = input_dict['reserve']
+
+ return output_dict
+
+ def add_fea_type_and_shape(self, out_types, out_shapes):
+ # all features are packed into one tuple sparse_fea
+ # first field: field lengths
+ # second field: field values
+ if len(self._sparse_fea_names) > 0:
+ out_types['sparse_fea'] = (tf.int32, tf.int64)
+ out_shapes['sparse_fea'] = (tf.TensorShape([None]), tf.TensorShape([None
+ ]))
+ if len(self._dense_fea_names) > 0:
+ out_types['dense_fea'] = tf.float32
+ out_shapes['dense_fea'] = tf.TensorShape(
+ [None, self._total_dense_fea_dim])
+
+ def _build(self, mode, params):
+ if mode == tf.estimator.ModeKeys.TRAIN and self._data_config.num_epochs > 1:
+ logging.info('will repeat train data for %d epochs' %
+ self._data_config.num_epochs)
+ my_files = self._my_files * self._data_config.num_epochs
+ else:
+ my_files = self._my_files
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ drop_remainder = self._data_config.drop_remainder
+ lbl_fields = self._label_fields
+ else:
+ lbl_fields = self._label_fields
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ lbl_fields = None
+ drop_remainder = False
+ self._proc_arr = load_parquet.start_data_proc(
+ self._task_index,
+ self._task_num,
+ self._num_proc,
+ self._file_que,
+ self._data_que,
+ self._proc_start_que,
+ self._proc_stop_que,
+ self._batch_size,
+ lbl_fields,
+ # self._effective_fields,
+ self._sparse_fea_names,
+ self._dense_fea_names,
+ self._dense_fea_cfgs,
+ self._reserve_fields,
+ drop_remainder,
+ need_pack=self._need_pack)
+
+ for input_file in my_files:
+ self._file_que.put(input_file)
+
+ # add end signal
+ for proc in self._proc_arr:
+ self._file_que.put(None)
+ logging.info('add input_files to file_que, qsize=%d' %
+ self._file_que.qsize())
+
+ out_types = {}
+ out_shapes = {}
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ for k in self._label_fields:
+ out_types[k] = tf.int32
+ out_shapes[k] = tf.TensorShape([None])
+
+ if self._reserve_fields is not None:
+ out_types['reserve'] = {}
+ out_shapes['reserve'] = {}
+ for k, t in zip(self._reserve_fields, self._reserve_types):
+ out_types['reserve'][k] = t
+ out_shapes['reserve'][k] = tf.TensorShape([None])
+
+ self.add_fea_type_and_shape(out_types, out_shapes)
+
+ dataset = tf.data.Dataset.from_generator(
+ self._sample_generator,
+ output_types=out_types,
+ output_shapes=out_shapes)
+ num_parallel_calls = self._data_config.num_parallel_calls
+ dataset = dataset.map(
+ self._to_fea_dict, num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+
+ # Note: Input._preprocess is currently not supported as all features
+ # are concatenated together
+ # dataset = dataset.map(
+ # map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ dataset = dataset.map(lambda x:
+ (self._get_features(x), self._get_labels(x)))
+ # initial test show that prefetch to gpu has no performance gain
+ # dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/gpu:0'))
+ else:
+ if self._is_predictor:
+ dataset = dataset.map(self._get_for_predictor)
+ else:
+ dataset = dataset.map(lambda x: self._get_features(x))
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+ return dataset
+
+ def _get_for_predictor(self, fea_dict):
+ out_dict = {
+ 'feature': {
+ 'ragged_ids': fea_dict['feature']['sparse_fea'][0],
+ 'ragged_lens': fea_dict['feature']['sparse_fea'][1]
+ }
+ }
+ if self._is_predictor and self._reserve_fields is not None:
+ out_dict['reserve'] = fea_dict['reserve']
+ return out_dict
+
+ def create_input(self, export_config=None):
+
+ def _input_fn(mode=None, params=None, config=None):
+ """Build input_fn for estimator.
+
+ Args:
+ mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT)
+ params: `dict` of hyper parameters, from Estimator
+ config: tf.estimator.RunConfig instance
+
+ Return:
+ if mode is not None, return:
+ features: inputs to the model.
+ labels: groundtruth
+ else, return:
+ tf.estimator.export.ServingInputReceiver instance
+ """
+ if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
+ tf.estimator.ModeKeys.PREDICT):
+ # build dataset from self._config.input_path
+ self._mode = mode
+ dataset = self._build(mode, params)
+ return dataset
+ elif mode is None: # serving_input_receiver_fn for export SavedModel
+ inputs, features = {}, {}
+ if len(self._sparse_fea_names) > 0:
+ ragged_ids = array_ops.placeholder(
+ tf.int64, [None], name='ragged_ids')
+ ragged_lens = array_ops.placeholder(
+ tf.int32, [None], name='ragged_lens')
+ inputs = {'ragged_ids': ragged_ids, 'ragged_lens': ragged_lens}
+ if self._has_ev:
+ features = {'ragged_ids': ragged_ids, 'ragged_lens': ragged_lens}
+ else:
+ features = {
+ 'ragged_ids': ragged_ids % self._feature_configs[0].num_buckets,
+ 'ragged_lens': ragged_lens
+ }
+ if len(self._dense_fea_names) > 0:
+ inputs['dense_fea'] = array_ops.placeholder(
+ tf.float32, [None, self._total_dense_fea_dim], name='dense_fea')
+ features['dense_fea'] = inputs['dense_fea']
+ return tf.estimator.export.ServingInputReceiver(features, inputs)
+
+ _input_fn.input_creator = self
+ return _input_fn
diff --git a/easy_rec/python/input/parquet_input_v2.py b/easy_rec/python/input/parquet_input_v2.py
new file mode 100644
index 000000000..dba54498c
--- /dev/null
+++ b/easy_rec/python/input/parquet_input_v2.py
@@ -0,0 +1,180 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# import logging
+import os
+
+# import numpy as np
+# import pandas as pd
+import tensorflow as tf
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+# from tensorflow.python.ops import math_ops
+# from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import string_ops
+
+from easy_rec.python.input.parquet_input import ParquetInput
+from easy_rec.python.utils import conditional
+
+# from easy_rec.python.utils.tf_utils import get_tf_type
+
+
+class ParquetInputV2(ParquetInput):
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None,
+ **kwargs):
+ super(ParquetInputV2,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config, **kwargs)
+ self._need_pack = False
+
+ def _predictor_preprocess(self, input_dict):
+ # when the ParquetInputV2 is build from ParquetPredictorV2
+ # the feature preprocess stage will be skipped.
+ fea_dict = {}
+ for k in input_dict:
+ vals = input_dict[k]
+ if isinstance(vals, tuple) and len(vals) == 2 and k != 'reserve':
+ fea_dict[k + '/lens'] = vals[0]
+ fea_dict[k + '/ids'] = vals[1]
+ else:
+ fea_dict[k] = vals
+ return fea_dict
+
+ def _to_fea_dict(self, input_dict):
+ if self._is_predictor:
+ fea_dict = self._predictor_preprocess(input_dict)
+ else:
+ fea_dict = self._preprocess(input_dict)
+
+ output_dict = {'feature': fea_dict}
+
+ lbl_dict = {}
+ for lbl_name in self._label_fields:
+ if lbl_name in input_dict:
+ lbl_dict[lbl_name] = input_dict[lbl_name]
+
+ if len(lbl_dict) > 0:
+ output_dict['label'] = lbl_dict
+
+ if self._reserve_fields is not None:
+ output_dict['reserve'] = input_dict['reserve']
+
+ return output_dict
+
+ def add_fea_type_and_shape(self, out_types, out_shapes):
+ # overload ParquetInput.build_type_and_shape
+ for k in self._sparse_fea_names:
+ out_types[k] = (tf.int32, tf.int64)
+ out_shapes[k] = (tf.TensorShape([None]), tf.TensorShape([None]))
+ for fc in self._dense_fea_cfgs:
+ k = fc.input_names[0]
+ out_types[k] = tf.float32
+ out_shapes[k] = tf.TensorShape([None, fc.raw_input_dim])
+
+ def _preprocess(self, inputs=None):
+ features = {}
+ placeholders = {}
+ for fc in self._feature_configs:
+ feature_name = fc.feature_name if fc.feature_name != '' else fc.input_names[
+ 0]
+ feature_type = fc.feature_type
+ if feature_type in [fc.IdFeature, fc.TagFeature]:
+ input_name0 = fc.input_names[0]
+ if inputs is not None:
+ input_lens, input_vals = inputs[input_name0]
+ else:
+ if input_name0 in placeholders:
+ input_lens, input_vals = placeholders[input_name0]
+ else:
+ input_vals = array_ops.placeholder(
+ dtypes.int64, [None], name=input_name0 + '/ids')
+ input_lens = array_ops.placeholder(
+ dtypes.int64, [None], name=input_name0 + '/lens')
+ placeholders[input_name0] = (input_lens, input_vals)
+ if not self._has_ev:
+ if fc.num_buckets > 0:
+ input_vals = input_vals % fc.num_buckets
+ else:
+ input_vals = string_ops.as_string(input_vals)
+ features[feature_name] = tf.RaggedTensor.from_row_lengths(
+ values=input_vals, row_lengths=input_lens)
+ elif feature_type in [fc.RawFeature]:
+ input_name0 = fc.input_names[0]
+ if inputs is not None:
+ input_vals = inputs[input_name0]
+ else:
+ if input_name0 in placeholders:
+ input_vals = placeholders[input_name0]
+ else:
+ if fc.raw_input_dim > 1:
+ input_vals = array_ops.placeholder(
+ dtypes.float32, [None, fc.raw_input_dim], name=input_name0)
+ else:
+ input_vals = array_ops.placeholder(
+ dtypes.float32, [None], name=input_name0)
+ placeholders[input_name0] = input_vals
+ features[feature_name] = input_vals
+ else:
+ assert False, 'feature_type[%s] not supported' % str(feature_type)
+
+ if inputs is not None:
+ return features
+ else:
+ inputs = {}
+ for key in placeholders:
+ vals = placeholders[key]
+ if isinstance(vals, tuple):
+ inputs[key + '/lens'] = vals[0]
+ inputs[key + '/ids'] = vals[1]
+ else:
+ inputs[key] = vals
+ return features, inputs
+
+ def _get_for_predictor(self, fea_dict):
+ # called by ParquetInputV2._build, format:
+ # {
+ # "feature": {"user_id/ids":..., "user_id/lens":..., ... },
+ # "reserve": {"sample_id":..., ...}
+ # }
+ return fea_dict
+
+ def create_input(self, export_config=None):
+
+ def _input_fn(mode=None, params=None, config=None):
+ """Build input_fn for estimator.
+
+ Args:
+ mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT)
+ params: `dict` of hyper parameters, from Estimator
+ config: tf.estimator.RunConfig instance
+
+ Return:
+ if mode is not None, return:
+ features: inputs to the model.
+ labels: groundtruth
+ else, return:
+ tf.estimator.export.ServingInputReceiver instance
+ """
+ if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
+ tf.estimator.ModeKeys.PREDICT):
+ # build dataset from self._config.input_path
+ self._mode = mode
+ dataset = self._build(mode, params)
+ return dataset
+ elif mode is None: # serving_input_receiver_fn for export SavedModel
+ place_on_cpu = os.getenv('place_embedding_on_cpu')
+ place_on_cpu = bool(place_on_cpu) if place_on_cpu else False
+ with conditional(place_on_cpu, ops.device('/CPU:0')):
+ features, inputs = self._preprocess()
+ return tf.estimator.export.ServingInputReceiver(features, inputs)
+
+ _input_fn.input_creator = self
+ return _input_fn
diff --git a/easy_rec/python/input/parquet_input_v3.py b/easy_rec/python/input/parquet_input_v3.py
new file mode 100644
index 000000000..300a5bd1e
--- /dev/null
+++ b/easy_rec/python/input/parquet_input_v3.py
@@ -0,0 +1,203 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+
+from easy_rec.python.input.input import Input
+from easy_rec.python.utils.input_utils import get_type_defaults
+
+try:
+ from tensorflow.python.data.experimental.ops import parquet_dataset_ops
+ from tensorflow.python.data.experimental.ops import parquet_pybind
+ from tensorflow.python.data.experimental.ops import dataframe
+ from tensorflow.python.ops import gen_ragged_conversion_ops
+ from tensorflow.python.ops.work_queue import WorkQueue
+ _has_deep_rec = True
+except Exception:
+ _has_deep_rec = False
+ pass
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class ParquetInputV3(Input):
+
+ def __init__(self,
+ data_config,
+ feature_config,
+ input_path,
+ task_index=0,
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None,
+ **kwargs):
+ if not _has_deep_rec:
+ raise RuntimeError('You should install DeepRec first.')
+ super(ParquetInputV3,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
+
+ self._ignore_val_dict = {}
+ for f in data_config.input_fields:
+ if f.HasField('ignore_val'):
+ self._ignore_val_dict[f.input_name] = get_type_defaults(
+ f.input_type, f.ignore_val)
+
+ self._true_type_dict = {}
+ for fc in self._feature_configs:
+ if fc.feature_type in [fc.IdFeature, fc.TagFeature, fc.SequenceFeature]:
+ if fc.hash_bucket_size > 0 or len(
+ fc.vocab_list) > 0 or fc.HasField('vocab_file'):
+ self._true_type_dict[fc.input_names[0]] = tf.string
+ else:
+ self._true_type_dict[fc.input_names[0]] = tf.int64
+ if len(fc.input_names) > 1:
+ self._true_type_dict[fc.input_names[1]] = tf.float32
+ if fc.feature_type == fc.RawFeature:
+ self._true_type_dict[fc.input_names[0]] = tf.float32
+
+ self._reserve_fields = None
+ self._reserve_types = None
+ if 'reserve_fields' in kwargs and 'reserve_types' in kwargs:
+ self._reserve_fields = kwargs['reserve_fields']
+ self._reserve_types = kwargs['reserve_types']
+
+ # In ParquetDataset multi_value use input type
+ self._multi_value_types = {}
+
+ def _ignore_and_cast(self, name, value):
+ ignore_value = self._ignore_val_dict.get(name, None)
+ if ignore_value:
+ if isinstance(value, tf.SparseTensor):
+ indices = tf.where(tf.equal(value.values, ignore_value))
+ value = tf.SparseTensor(
+ tf.gather_nd(value.indices, indices),
+ tf.gather_nd(value.values, indices), value.dense_shape)
+ elif isinstance(value, tf.Tensor):
+ indices = tf.where(tf.not_equal(value, ignore_value), name='indices')
+ value = tf.SparseTensor(
+ indices=indices,
+ values=tf.gather_nd(value, indices),
+ dense_shape=tf.shape(value, out_type=tf.int64))
+ dtype = self._true_type_dict.get(name, None)
+ if dtype:
+ value = tf.cast(value, dtype)
+ return value
+
+ def _parse_dataframe_value(self, value):
+ if len(value.nested_row_splits) == 0:
+ return value.values
+ value.values.set_shape([None])
+ sparse_value = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
+ value.nested_row_splits, value.values)
+ return tf.SparseTensor(sparse_value.sparse_indices,
+ sparse_value.sparse_values,
+ sparse_value.sparse_dense_shape)
+
+ def _parse_dataframe(self, df):
+ inputs = {}
+ for k, v in df.items():
+ if k in self._effective_fields:
+ if isinstance(v, dataframe.DataFrame.Value):
+ v = self._parse_dataframe_value(v)
+ elif k in self._label_fields:
+ if isinstance(v, dataframe.DataFrame.Value):
+ v = v.values
+ elif k in self._reserve_fields:
+ if isinstance(v, dataframe.DataFrame.Value):
+ v = v.values
+ else:
+ continue
+ inputs[k] = v
+ return inputs
+
+ def _build(self, mode, params):
+ input_files = []
+ for sub_path in self._input_path.strip().split(','):
+ input_files.extend(tf.gfile.Glob(sub_path))
+ file_num = len(input_files)
+ logging.info('[task_index=%d] total_file_num=%d task_num=%d' %
+ (self._task_index, file_num, self._task_num))
+
+ task_index = self._task_index
+ task_num = self._task_num
+ if self._data_config.chief_redundant:
+ task_index = max(self._task_index - 1, 0)
+ task_num = max(self._task_num - 1, 1)
+
+ if self._data_config.pai_worker_queue and \
+ mode == tf.estimator.ModeKeys.TRAIN:
+ work_queue = WorkQueue(
+ input_files,
+ num_epochs=self.num_epochs,
+ shuffle=self._data_config.shuffle)
+ my_files = work_queue.input_dataset()
+ else:
+ my_files = []
+ for file_id in range(file_num):
+ if (file_id % task_num) == task_index:
+ my_files.append(input_files[file_id])
+
+ parquet_fields = parquet_pybind.parquet_fields(input_files[0])
+ parquet_input_fields = []
+ for f in parquet_fields:
+ if f.name in self._input_fields:
+ parquet_input_fields.append(f)
+
+ all_fields = set(self._effective_fields)
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ all_fields |= set(self._label_fields)
+ if self._reserve_fields:
+ all_fields |= set(self._reserve_fields)
+
+ selected_fields = []
+ for f in parquet_input_fields:
+ if f.name in all_fields:
+ selected_fields.append(f)
+
+ num_parallel_reads = min(self._data_config.num_parallel_calls,
+ len(input_files) // task_num)
+ dataset = parquet_dataset_ops.ParquetDataset(
+ my_files,
+ batch_size=self._batch_size,
+ fields=selected_fields,
+ drop_remainder=self._data_config.drop_remainder,
+ num_parallel_reads=num_parallel_reads)
+ # partition_count=task_num,
+ # partition_index=task_index)
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ if self._data_config.shuffle:
+ dataset = dataset.shuffle(
+ self._data_config.shuffle_buffer_size,
+ seed=2020,
+ reshuffle_each_iteration=True)
+ dataset = dataset.repeat(self.num_epochs)
+ else:
+ dataset = dataset.repeat(1)
+
+ dataset = dataset.map(
+ self._parse_dataframe,
+ num_parallel_calls=self._data_config.num_parallel_calls)
+
+ # preprocess is necessary to transform data
+ # so that they could be feed into FeatureColumns
+ dataset = dataset.map(
+ map_func=self._preprocess,
+ num_parallel_calls=self._data_config.num_parallel_calls)
+
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ dataset = dataset.map(lambda x:
+ (self._get_features(x), self._get_labels(x)))
+ else:
+ dataset = dataset.map(lambda x: (self._get_features(x)))
+ return dataset
+
+ def _preprocess(self, field_dict):
+ for k, v in field_dict.items():
+ field_dict[k] = self._ignore_and_cast(k, v)
+ return super(ParquetInputV3, self)._preprocess(field_dict)
diff --git a/easy_rec/python/input/rtp_input.py b/easy_rec/python/input/rtp_input.py
index 00081962b..9f9679b9e 100644
--- a/easy_rec/python/input/rtp_input.py
+++ b/easy_rec/python/input/rtp_input.py
@@ -5,7 +5,11 @@
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
+from easy_rec.python.utils.check_utils import check_split
+from easy_rec.python.utils.check_utils import check_string_to_number
from easy_rec.python.utils.input_utils import string_to_number
+from easy_rec.python.utils.tf_utils import get_tf_type
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -31,9 +35,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(RTPInput, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(RTPInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
logging.info('input_fields: %s label_fields: %s' %
(','.join(self._input_fields), ','.join(self._label_fields)))
self._rtp_separator = self._data_config.rtp_separator
@@ -48,13 +55,6 @@ def __init__(self,
def _parse_csv(self, line):
record_defaults = ['' for i in range(self._num_cols)]
- lbl_id = 0
- for x, t, v in zip(self._input_fields, self._input_field_types,
- self._input_field_defaults):
- if x not in self._label_fields:
- continue
- record_defaults[self._selected_cols[lbl_id]] = self.get_type_defaults(
- t, v)
# the actual features are in one single column
record_defaults[self._feature_col_id] = self._data_config.separator.join([
@@ -64,9 +64,30 @@ def _parse_csv(self, line):
if x not in self._label_fields
])
- fields = tf.string_split(line, self._rtp_separator, skip_empty=False)
+ check_list = [
+ tf.py_func(
+ check_split, [line, self._rtp_separator,
+ len(record_defaults)],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ fields = tf.string_split(line, self._rtp_separator, skip_empty=False)
+
fields = tf.reshape(fields.values, [-1, len(record_defaults)])
- labels = [fields[:, x] for x in self._selected_cols[:-1]]
+
+ labels = []
+ for idx, x in enumerate(self._selected_cols[:-1]):
+ field = fields[:, x]
+ fname = self._input_fields[idx]
+ ftype = self._input_field_types[idx]
+ tf_type = get_tf_type(ftype)
+ if field.dtype in [tf.string]:
+ check_list = [
+ tf.py_func(check_string_to_number, [field, fname], Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ field = tf.string_to_number(field, tf_type)
+ labels.append(field)
# only for features, labels excluded
record_types = [
@@ -75,14 +96,28 @@ def _parse_csv(self, line):
]
# assume that the last field is the generated feature column
print('field_delim = %s' % self._data_config.separator)
- fields = tf.string_split(
- fields[:, self._feature_col_id],
- self._data_config.separator,
- skip_empty=False)
+ feature_str = fields[:, self._feature_col_id]
+ check_list = [
+ tf.py_func(
+ check_split,
+ [feature_str, self._data_config.separator,
+ len(record_types)],
+ Tout=tf.bool)
+ ] if self._check_mode else []
+ with tf.control_dependencies(check_list):
+ fields = str_split_by_chr(
+ feature_str, self._data_config.separator, skip_empty=False)
tmp_fields = tf.reshape(fields.values, [-1, len(record_types)])
+ rtp_record_defaults = [
+ str(self.get_type_defaults(t, v))
+ for x, t, v in zip(self._input_fields, self._input_field_types,
+ self._input_field_defaults)
+ if x not in self._label_fields
+ ]
fields = []
for i in range(len(record_types)):
- field = string_to_number(tmp_fields[:, i], record_types[i], i)
+ field = string_to_number(tmp_fields[:, i], record_types[i],
+ rtp_record_defaults[i], i)
fields.append(field)
field_keys = [x for x in self._input_fields if x not in self._label_fields]
@@ -94,7 +129,11 @@ def _parse_csv(self, line):
return inputs
def _build(self, mode, params):
- file_paths = tf.gfile.Glob(self._input_path)
+ if type(self._input_path) != list:
+ self._input_path = self._input_path.split(',')
+ file_paths = []
+ for x in self._input_path:
+ file_paths.extend(tf.gfile.Glob(x))
assert len(file_paths) > 0, 'match no files with %s' % self._input_path
# try to figure out number of fields from one file
@@ -103,7 +142,9 @@ def _build(self, mode, params):
for line_str in fin:
line_tok = line_str.strip().split(self._rtp_separator)
if self._num_cols != -1:
- assert self._num_cols == len(line_tok)
+ assert self._num_cols == len(line_tok), \
+ 'num selected cols is %d, not equal to %d, current line is: %s, please check rtp_separator and data.' % \
+ (self._num_cols, len(line_tok), line_str)
self._num_cols = len(line_tok)
num_lines += 1
if num_lines > 10:
@@ -131,9 +172,14 @@ def _build(self, mode, params):
logging.info('train files[%d]: %s' %
(len(file_paths), ','.join(file_paths)))
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ if self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
if self._data_config.shuffle:
# shuffle input files
dataset = dataset.shuffle(len(file_paths))
+
# too many readers read the same file will cause performance issues
# as the same data will be read multiple times
parallel_num = min(num_parallel_calls, len(file_paths))
@@ -141,11 +187,10 @@ def _build(self, mode, params):
tf.data.TextLineDataset,
cycle_length=parallel_num,
num_parallel_calls=parallel_num)
- if self._data_config.chief_redundant:
- dataset = dataset.shard(
- max(self._task_num - 1, 1), max(self._task_index - 1, 0))
- else:
- dataset = dataset.shard(self._task_num, self._task_index)
+
+ if not self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
if self._data_config.shuffle:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
diff --git a/easy_rec/python/input/rtp_input_v2.py b/easy_rec/python/input/rtp_input_v2.py
index 1635c623a..32e841f8d 100644
--- a/easy_rec/python/input/rtp_input_v2.py
+++ b/easy_rec/python/input/rtp_input_v2.py
@@ -22,9 +22,12 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(RTPInputV2, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(RTPInputV2,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
def _parse_rtp(self, lines):
tf_types = [tf.string for x in self._input_field_types]
@@ -82,16 +85,26 @@ def _convert(x, target_type, name):
return inputs
def _build(self, mode, params):
- file_paths = tf.gfile.Glob(self._input_path)
+ if type(self._input_path) != list:
+ self._input_path = self._input_path.split(',')
+ file_paths = []
+ for x in self._input_path:
+ file_paths.extend(tf.gfile.Glob(x))
+ assert len(file_paths) > 0, 'match no files with %s' % self._input_path
num_parallel_calls = self._data_config.num_parallel_calls
if mode == tf.estimator.ModeKeys.TRAIN:
logging.info('train files[%d]: %s' %
(len(file_paths), ','.join(file_paths)))
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ if self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
if self._data_config.shuffle:
# shuffle input files
dataset = dataset.shuffle(len(file_paths))
+
# too many readers read the same file will cause performance issues
# as the same data will be read multiple times
parallel_num = min(num_parallel_calls, len(file_paths))
@@ -99,11 +112,10 @@ def _build(self, mode, params):
tf.data.TextLineDataset,
cycle_length=parallel_num,
num_parallel_calls=parallel_num)
- if self._data_config.chief_redundant:
- dataset = dataset.shard(
- max(self._task_num - 1, 1), max(self._task_index - 1, 0))
- else:
- dataset = dataset.shard(self._task_num, self._task_index)
+
+ if not self._data_config.file_shard:
+ dataset = self._safe_shard(dataset)
+
if self._data_config.shuffle:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
diff --git a/easy_rec/python/input/tfrecord_input.py b/easy_rec/python/input/tfrecord_input.py
index c3d9e228e..b9e25eef0 100644
--- a/easy_rec/python/input/tfrecord_input.py
+++ b/easy_rec/python/input/tfrecord_input.py
@@ -5,6 +5,7 @@
import tensorflow as tf
from easy_rec.python.input.input import Input
+from easy_rec.python.utils.tf_utils import get_tf_type
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -17,15 +18,18 @@ def __init__(self,
feature_config,
input_path,
task_index=0,
- task_num=1):
- super(TFRecordInput, self).__init__(data_config, feature_config, input_path,
- task_index, task_num)
+ task_num=1,
+ check_mode=False,
+ pipeline_config=None):
+ super(TFRecordInput,
+ self).__init__(data_config, feature_config, input_path, task_index,
+ task_num, check_mode, pipeline_config)
self.feature_desc = {}
for x, t, d in zip(self._input_fields, self._input_field_types,
self._input_field_defaults):
d = self.get_type_defaults(t, d)
- t = self.get_tf_type(t)
+ t = get_tf_type(t)
self.feature_desc[x] = tf.FixedLenFeature(
dtype=t, shape=1, default_value=d)
@@ -37,7 +41,11 @@ def _parse_tfrecord(self, example):
return inputs
def _build(self, mode, params):
- file_paths = tf.gfile.Glob(self._input_path)
+ if type(self._input_path) != list:
+ self._input_path = self._input_path.split(',')
+ file_paths = []
+ for x in self._input_path:
+ file_paths.extend(tf.gfile.Glob(x))
assert len(file_paths) > 0, 'match no files with %s' % self._input_path
num_parallel_calls = self._data_config.num_parallel_calls
diff --git a/easy_rec/python/layers/backbone.py b/easy_rec/python/layers/backbone.py
new file mode 100644
index 000000000..e77ea1da5
--- /dev/null
+++ b/easy_rec/python/layers/backbone.py
@@ -0,0 +1,571 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import six
+import tensorflow as tf
+from google.protobuf import struct_pb2
+
+from easy_rec.python.layers.common_layers import EnhancedInputLayer
+from easy_rec.python.layers.keras import MLP
+from easy_rec.python.layers.keras import EmbeddingLayer
+from easy_rec.python.layers.utils import Parameter
+from easy_rec.python.protos import backbone_pb2
+from easy_rec.python.utils.dag import DAG
+from easy_rec.python.utils.load_class import load_keras_layer
+from easy_rec.python.utils.tf_utils import add_elements_to_collection
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class Package(object):
+ """A sub DAG of tf ops for reuse."""
+ __packages = {}
+
+ @staticmethod
+ def has_backbone_block(name):
+ if 'backbone' not in Package.__packages:
+ return False
+ backbone = Package.__packages['backbone']
+ return backbone.has_block(name)
+
+ @staticmethod
+ def backbone_block_outputs(name):
+ if 'backbone' not in Package.__packages:
+ return None
+ backbone = Package.__packages['backbone']
+ return backbone.block_outputs(name)
+
+ def __init__(self, config, features, input_layer, l2_reg=None):
+ self._config = config
+ self._features = features
+ self._input_layer = input_layer
+ self._l2_reg = l2_reg
+ self._dag = DAG()
+ self._name_to_blocks = {}
+ self._name_to_layer = {}
+ self.reset_input_config(None)
+ self._block_outputs = {}
+ self._package_input = None
+ self._feature_group_inputs = {}
+ reuse = None if config.name == 'backbone' else tf.AUTO_REUSE
+ input_feature_groups = self._feature_group_inputs
+
+ for block in config.blocks:
+ if len(block.inputs) == 0:
+ raise ValueError('block takes at least one input: %s' % block.name)
+ self._dag.add_node(block.name)
+ self._name_to_blocks[block.name] = block
+ layer = block.WhichOneof('layer')
+ if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
+ if len(block.inputs) != 1:
+ raise ValueError('input layer `%s` takes only one input' % block.name)
+ one_input = block.inputs[0]
+ name = one_input.WhichOneof('name')
+ if name != 'feature_group_name':
+ raise KeyError(
+ '`feature_group_name` should be set for input layer: ' +
+ block.name)
+ group = one_input.feature_group_name
+ if not input_layer.has_group(group):
+ raise KeyError('invalid feature group name: ' + group)
+ if group in input_feature_groups:
+ if layer == input_layer:
+ logging.warning('input `%s` already exists in other block' % group)
+ elif layer == 'raw_input':
+ input_fn = input_feature_groups[group]
+ self._name_to_layer[block.name] = input_fn
+ elif layer == 'embedding_layer':
+ inputs, vocab, weights = input_feature_groups[group]
+ block.embedding_layer.vocab_size = vocab
+ params = Parameter.make_from_pb(block.embedding_layer)
+ input_fn = EmbeddingLayer(params, block.name)
+ self._name_to_layer[block.name] = input_fn
+ else:
+ if layer == 'input_layer':
+ input_fn = EnhancedInputLayer(self._input_layer, self._features,
+ group, reuse)
+ input_feature_groups[group] = input_fn
+ elif layer == 'raw_input':
+ input_fn = self._input_layer.get_raw_features(self._features, group)
+ input_feature_groups[group] = input_fn
+ else: # embedding_layer
+ inputs, vocab, weights = self._input_layer.get_bucketized_features(
+ self._features, group)
+ block.embedding_layer.vocab_size = vocab
+ params = Parameter.make_from_pb(block.embedding_layer)
+ input_fn = EmbeddingLayer(params, block.name)
+ input_feature_groups[group] = (inputs, vocab, weights)
+ logging.info('add an embedding layer %s with vocab size %d',
+ block.name, vocab)
+ self._name_to_layer[block.name] = input_fn
+ else:
+ self.define_layers(layer, block, block.name, reuse)
+
+ # sequential layers
+ for i, layer_cnf in enumerate(block.layers):
+ layer = layer_cnf.WhichOneof('layer')
+ name_i = '%s_l%d' % (block.name, i)
+ self.define_layers(layer, layer_cnf, name_i, reuse)
+
+ num_groups = len(input_feature_groups)
+ num_blocks = len(self._name_to_blocks) - num_groups
+ assert num_blocks > 0, 'there must be at least one block in backbone'
+
+ num_pkg_input = 0
+ for block in config.blocks:
+ layer = block.WhichOneof('layer')
+ if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
+ continue
+ name = block.name
+ if name in input_feature_groups:
+ raise KeyError('block name can not be one of feature groups:' + name)
+ for input_node in block.inputs:
+ input_type = input_node.WhichOneof('name')
+ input_name = getattr(input_node, input_type)
+ if input_type == 'use_package_input':
+ assert input_name, 'use_package_input can not set false'
+ num_pkg_input += 1
+ continue
+ if input_type == 'package_name':
+ num_pkg_input += 1
+ self._dag.add_node_if_not_exists(input_name)
+ self._dag.add_edge(input_name, name)
+ if input_node.HasField('package_input'):
+ pkg_input_name = input_node.package_input
+ self._dag.add_node_if_not_exists(pkg_input_name)
+ self._dag.add_edge(pkg_input_name, input_name)
+ continue
+ iname = input_name
+ if iname in self._name_to_blocks:
+ assert iname != name, 'input name can not equal to block name:' + iname
+ self._dag.add_edge(iname, name)
+ else:
+ is_fea_group = input_type == 'feature_group_name'
+ if is_fea_group and input_layer.has_group(iname):
+ logging.info('adding an input_layer block: ' + iname)
+ new_block = backbone_pb2.Block()
+ new_block.name = iname
+ input_cfg = backbone_pb2.Input()
+ input_cfg.feature_group_name = iname
+ new_block.inputs.append(input_cfg)
+ new_block.input_layer.CopyFrom(backbone_pb2.InputLayer())
+ self._name_to_blocks[iname] = new_block
+ self._dag.add_node(iname)
+ self._dag.add_edge(iname, name)
+ if iname in input_feature_groups:
+ fn = input_feature_groups[iname]
+ else:
+ fn = EnhancedInputLayer(self._input_layer, self._features, iname)
+ input_feature_groups[iname] = fn
+ self._name_to_layer[iname] = fn
+ elif Package.has_backbone_block(iname):
+ backbone = Package.__packages['backbone']
+ backbone._dag.add_node_if_not_exists(self._config.name)
+ backbone._dag.add_edge(iname, self._config.name)
+ num_pkg_input += 1
+ else:
+ raise KeyError(
+ 'invalid input name `%s`, must be the name of either a feature group or an another block'
+ % iname)
+ num_groups = len(input_feature_groups)
+ assert num_pkg_input > 0 or num_groups > 0, 'there must be at least one input layer/feature group'
+
+ if len(config.concat_blocks) == 0 and len(config.output_blocks) == 0:
+ leaf = self._dag.all_leaves()
+ logging.warning(
+ '%s has no `concat_blocks` or `output_blocks`, try to concat all leaf blocks: %s'
+ % (config.name, ','.join(leaf)))
+ self._config.concat_blocks.extend(leaf)
+
+ Package.__packages[self._config.name] = self
+ logging.info('%s layers: %s' %
+ (config.name, ','.join(self._name_to_layer.keys())))
+
+ def define_layers(self, layer, layer_cnf, name, reuse):
+ if layer == 'keras_layer':
+ layer_obj = self.load_keras_layer(layer_cnf.keras_layer, name, reuse)
+ self._name_to_layer[name] = layer_obj
+ elif layer == 'recurrent':
+ keras_layer = layer_cnf.recurrent.keras_layer
+ for i in range(layer_cnf.recurrent.num_steps):
+ name_i = '%s_%d' % (name, i)
+ layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
+ self._name_to_layer[name_i] = layer_obj
+ elif layer == 'repeat':
+ keras_layer = layer_cnf.repeat.keras_layer
+ for i in range(layer_cnf.repeat.num_repeat):
+ name_i = '%s_%d' % (name, i)
+ layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
+ self._name_to_layer[name_i] = layer_obj
+
+ def reset_input_config(self, config):
+ self.input_config = config
+
+ def set_package_input(self, pkg_input):
+ self._package_input = pkg_input
+
+ def has_block(self, name):
+ return name in self._name_to_blocks
+
+ def block_outputs(self, name):
+ return self._block_outputs.get(name, None)
+
+ def block_input(self, config, block_outputs, training=None, **kwargs):
+ inputs = []
+ for input_node in config.inputs:
+ input_type = input_node.WhichOneof('name')
+ input_name = getattr(input_node, input_type)
+ if input_type == 'use_package_input':
+ input_feature = self._package_input
+ input_name = 'package_input'
+ elif input_type == 'package_name':
+ if input_name not in Package.__packages:
+ raise KeyError('package name `%s` does not exists' % input_name)
+ package = Package.__packages[input_name]
+ if input_node.HasField('reset_input'):
+ package.reset_input_config(input_node.reset_input)
+ if input_node.HasField('package_input'):
+ pkg_input_name = input_node.package_input
+ if pkg_input_name in block_outputs:
+ pkg_input = block_outputs[pkg_input_name]
+ else:
+ if pkg_input_name not in Package.__packages:
+ raise KeyError('package name `%s` does not exists' %
+ pkg_input_name)
+ inner_package = Package.__packages[pkg_input_name]
+ pkg_input = inner_package(training)
+ if input_node.HasField('package_input_fn'):
+ fn = eval(input_node.package_input_fn)
+ pkg_input = fn(pkg_input)
+ package.set_package_input(pkg_input)
+ input_feature = package(training, **kwargs)
+ elif input_name in block_outputs:
+ input_feature = block_outputs[input_name]
+ else:
+ input_feature = Package.backbone_block_outputs(input_name)
+
+ if input_feature is None:
+ raise KeyError('input name `%s` does not exists' % input_name)
+
+ if input_node.ignore_input:
+ continue
+ if input_node.HasField('input_slice'):
+ fn = eval('lambda x: x' + input_node.input_slice.strip())
+ input_feature = fn(input_feature)
+ if input_node.HasField('input_fn'):
+ with tf.name_scope(config.name):
+ fn = eval(input_node.input_fn)
+ input_feature = fn(input_feature)
+ inputs.append(input_feature)
+
+ if config.merge_inputs_into_list:
+ output = inputs
+ else:
+ try:
+ output = merge_inputs(inputs, config.input_concat_axis, config.name)
+ except ValueError as e:
+ msg = getattr(e, 'message', str(e))
+ logging.error('merge inputs of block %s failed: %s', config.name, msg)
+ raise e
+
+ if config.HasField('extra_input_fn'):
+ fn = eval(config.extra_input_fn)
+ output = fn(output)
+ return output
+
+ def __call__(self, is_training, **kwargs):
+ with tf.name_scope(self._config.name):
+ return self.call(is_training, **kwargs)
+
+ def call(self, is_training, **kwargs):
+ block_outputs = {}
+ self._block_outputs = block_outputs # reset
+ blocks = self._dag.topological_sort()
+ logging.info(self._config.name + ' topological order: ' + ','.join(blocks))
+ for block in blocks:
+ if block not in self._name_to_blocks:
+ assert block in Package.__packages, 'invalid block: ' + block
+ continue
+ config = self._name_to_blocks[block]
+ if config.layers: # sequential layers
+ logging.info('call sequential %d layers' % len(config.layers))
+ output = self.block_input(config, block_outputs, is_training, **kwargs)
+ for i, layer in enumerate(config.layers):
+ name_i = '%s_l%d' % (block, i)
+ output = self.call_layer(output, layer, name_i, is_training, **kwargs)
+ block_outputs[block] = output
+ continue
+ # just one of layer
+ layer = config.WhichOneof('layer')
+ if layer is None: # identity layer
+ output = self.block_input(config, block_outputs, is_training, **kwargs)
+ block_outputs[block] = output
+ elif layer == 'raw_input':
+ block_outputs[block] = self._name_to_layer[block]
+ elif layer == 'input_layer':
+ input_fn = self._name_to_layer[block]
+ input_config = config.input_layer
+ if self.input_config is not None:
+ input_config = self.input_config
+ input_fn.reset(input_config, is_training)
+ block_outputs[block] = input_fn(input_config, is_training)
+ elif layer == 'embedding_layer':
+ input_fn = self._name_to_layer[block]
+ feature_group = config.inputs[0].feature_group_name
+ inputs, _, weights = self._feature_group_inputs[feature_group]
+ block_outputs[block] = input_fn([inputs, weights], is_training)
+ else:
+ with tf.name_scope(block + '_input'):
+ inputs = self.block_input(config, block_outputs, is_training,
+ **kwargs)
+ output = self.call_layer(inputs, config, block, is_training, **kwargs)
+ block_outputs[block] = output
+
+ outputs = []
+ for output in self._config.output_blocks:
+ if output in block_outputs:
+ temp = block_outputs[output]
+ outputs.append(temp)
+ else:
+ raise ValueError('No output `%s` of backbone to be concat' % output)
+ if outputs:
+ return outputs
+
+ for output in self._config.concat_blocks:
+ if output in block_outputs:
+ temp = block_outputs[output]
+ outputs.append(temp)
+ else:
+ raise ValueError('No output `%s` of backbone to be concat' % output)
+ try:
+ output = merge_inputs(outputs, msg='backbone')
+ except ValueError as e:
+ msg = getattr(e, 'message', str(e))
+ logging.error("merge backbone's output failed: %s", msg)
+ raise e
+ return output
+
+ def load_keras_layer(self, layer_conf, name, reuse=None):
+ layer_cls, customize = load_keras_layer(layer_conf.class_name)
+ if layer_cls is None:
+ raise ValueError('Invalid keras layer class name: ' +
+ layer_conf.class_name)
+
+ param_type = layer_conf.WhichOneof('params')
+ if customize:
+ if param_type is None or param_type == 'st_params':
+ params = Parameter(layer_conf.st_params, True, l2_reg=self._l2_reg)
+ else:
+ pb_params = getattr(layer_conf, param_type)
+ params = Parameter(pb_params, False, l2_reg=self._l2_reg)
+
+ has_reuse = True
+ try:
+ from funcsigs import signature
+ sig = signature(layer_cls.__init__)
+ has_reuse = 'reuse' in sig.parameters.keys()
+ except ImportError:
+ try:
+ from sklearn.externals.funcsigs import signature
+ sig = signature(layer_cls.__init__)
+ has_reuse = 'reuse' in sig.parameters.keys()
+ except ImportError:
+ logging.warning('import funcsigs failed')
+
+ if has_reuse:
+ layer = layer_cls(params, name=name, reuse=reuse)
+ else:
+ layer = layer_cls(params, name=name)
+ return layer, customize
+ elif param_type is None: # internal keras layer
+ layer = layer_cls(name=name)
+ return layer, customize
+ else:
+ assert param_type == 'st_params', 'internal keras layer only support st_params'
+ try:
+ kwargs = convert_to_dict(layer_conf.st_params)
+ logging.info('call %s layer with params %r' %
+ (layer_conf.class_name, kwargs))
+ layer = layer_cls(name=name, **kwargs)
+ except TypeError as e:
+ logging.warning(e)
+ args = map(format_value, layer_conf.st_params.values())
+ logging.info('try to call %s layer with params %r' %
+ (layer_conf.class_name, args))
+ layer = layer_cls(*args, name=name)
+ return layer, customize
+
+ def call_keras_layer(self, inputs, name, training, **kwargs):
+ """Call predefined Keras Layer, which can be reused."""
+ layer, customize = self._name_to_layer[name]
+ cls = layer.__class__.__name__
+ if customize:
+ try:
+ output = layer(inputs, training=training, **kwargs)
+ except Exception as e:
+ msg = getattr(e, 'message', str(e))
+ logging.error('call keras layer %s (%s) failed: %s' % (name, cls, msg))
+ raise e
+ else:
+ try:
+ output = layer(inputs, training=training)
+ if cls == 'BatchNormalization':
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
+ except TypeError:
+ output = layer(inputs)
+ return output
+
+ def call_layer(self, inputs, config, name, training, **kwargs):
+ layer_name = config.WhichOneof('layer')
+ if layer_name == 'keras_layer':
+ return self.call_keras_layer(inputs, name, training, **kwargs)
+ if layer_name == 'lambda':
+ conf = getattr(config, 'lambda')
+ fn = eval(conf.expression)
+ return fn(inputs)
+ if layer_name == 'repeat':
+ conf = config.repeat
+ n_loop = conf.num_repeat
+ outputs = []
+ for i in range(n_loop):
+ name_i = '%s_%d' % (name, i)
+ ly_inputs = inputs
+ if conf.HasField('input_slice'):
+ fn = eval('lambda x, i: x' + conf.input_slice.strip())
+ ly_inputs = fn(ly_inputs, i)
+ if conf.HasField('input_fn'):
+ with tf.name_scope(config.name):
+ fn = eval(conf.input_fn)
+ ly_inputs = fn(ly_inputs, i)
+ output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs)
+ outputs.append(output)
+ if len(outputs) == 1:
+ return outputs[0]
+ if conf.HasField('output_concat_axis'):
+ return tf.concat(outputs, conf.output_concat_axis)
+ return outputs
+ if layer_name == 'recurrent':
+ conf = config.recurrent
+ fixed_input_index = -1
+ if conf.HasField('fixed_input_index'):
+ fixed_input_index = conf.fixed_input_index
+ if fixed_input_index >= 0:
+ assert type(inputs) in (tuple, list), '%s inputs must be a list'
+ output = inputs
+ for i in range(conf.num_steps):
+ name_i = '%s_%d' % (name, i)
+ output_i = self.call_keras_layer(output, name_i, training, **kwargs)
+ if fixed_input_index >= 0:
+ j = 0
+ for idx in range(len(output)):
+ if idx == fixed_input_index:
+ continue
+ if type(output_i) in (tuple, list):
+ output[idx] = output_i[j]
+ else:
+ output[idx] = output_i
+ j += 1
+ else:
+ output = output_i
+ if fixed_input_index >= 0:
+ del output[fixed_input_index]
+ if len(output) == 1:
+ return output[0]
+ return output
+ return output
+
+ raise NotImplementedError('Unsupported backbone layer:' + layer_name)
+
+
+class Backbone(object):
+ """Configurable Backbone Network."""
+
+ def __init__(self, config, features, input_layer, l2_reg=None):
+ self._config = config
+ self._l2_reg = l2_reg
+ main_pkg = backbone_pb2.BlockPackage()
+ main_pkg.name = 'backbone'
+ main_pkg.blocks.MergeFrom(config.blocks)
+ if config.concat_blocks:
+ main_pkg.concat_blocks.extend(config.concat_blocks)
+ if config.output_blocks:
+ main_pkg.output_blocks.extend(config.output_blocks)
+ self._main_pkg = Package(main_pkg, features, input_layer, l2_reg)
+ for pkg in config.packages:
+ Package(pkg, features, input_layer, l2_reg)
+
+ def __call__(self, is_training, **kwargs):
+ output = self._main_pkg(is_training, **kwargs)
+
+ if self._config.HasField('top_mlp'):
+ params = Parameter.make_from_pb(self._config.top_mlp)
+ params.l2_regularizer = self._l2_reg
+ final_mlp = MLP(params, name='backbone_top_mlp')
+ if type(output) in (list, tuple):
+ output = tf.concat(output, axis=-1)
+ output = final_mlp(output, training=is_training, **kwargs)
+ return output
+
+ @classmethod
+ def wide_embed_dim(cls, config):
+ wide_embed_dim = None
+ for pkg in config.packages:
+ wide_embed_dim = get_wide_embed_dim(pkg.blocks, wide_embed_dim)
+ return get_wide_embed_dim(config.blocks, wide_embed_dim)
+
+
+def get_wide_embed_dim(blocks, wide_embed_dim=None):
+ for block in blocks:
+ layer = block.WhichOneof('layer')
+ if layer == 'input_layer':
+ if block.input_layer.HasField('wide_output_dim'):
+ wide_dim = block.input_layer.wide_output_dim
+ if wide_embed_dim:
+ assert wide_embed_dim == wide_dim, 'wide_output_dim must be consistent'
+ else:
+ wide_embed_dim = wide_dim
+ return wide_embed_dim
+
+
+def merge_inputs(inputs, axis=-1, msg=''):
+ if len(inputs) == 0:
+ raise ValueError('no inputs to be concat:' + msg)
+ if len(inputs) == 1:
+ return inputs[0]
+
+ from functools import reduce
+ if all(map(lambda x: type(x) == list, inputs)):
+ # merge multiple lists into a list
+ return reduce(lambda x, y: x + y, inputs)
+
+ if any(map(lambda x: type(x) == list, inputs)):
+ logging.warning('%s: try to merge inputs into list' % msg)
+ return reduce(lambda x, y: x + y,
+ [e if type(e) == list else [e] for e in inputs])
+
+ if axis != -1:
+ logging.info('concat inputs %s axis=%d' % (msg, axis))
+ return tf.concat(inputs, axis=axis)
+
+
+def format_value(value):
+ value_type = type(value)
+ if value_type == six.text_type:
+ return str(value)
+ if value_type == float:
+ int_v = int(value)
+ return int_v if int_v == value else value
+ if value_type == struct_pb2.ListValue:
+ return map(format_value, value)
+ if value_type == struct_pb2.Struct:
+ return convert_to_dict(value)
+ return value
+
+
+def convert_to_dict(struct):
+ kwargs = {}
+ for key, value in struct.items():
+ kwargs[str(key)] = format_value(value)
+ return kwargs
diff --git a/easy_rec/python/layers/capsule_layer.py b/easy_rec/python/layers/capsule_layer.py
index 4b6928402..22c8363cd 100644
--- a/easy_rec/python/layers/capsule_layer.py
+++ b/easy_rec/python/layers/capsule_layer.py
@@ -24,16 +24,41 @@ def __init__(self, capsule_config, is_training):
self._routing_logits_scale = capsule_config.routing_logits_scale
# routing_logits_stddev
self._routing_logits_stddev = capsule_config.routing_logits_stddev
+ # squash power
+ self._squash_pow = capsule_config.squash_pow
+ # scale ratio
+ self._scale_ratio = capsule_config.scale_ratio
+ self._const_caps_num = capsule_config.const_caps_num
self._is_training = is_training
def squash(self, inputs):
"""Squash inputs over the last dimension."""
input_norm = tf.reduce_sum(tf.square(inputs), keep_dims=True, axis=-1)
- scalar_factor = input_norm / (1 + input_norm) / tf.sqrt(input_norm + 1e-8)
- return scalar_factor * inputs
+ input_norm_eps = tf.maximum(input_norm, 1e-8)
+ scale_factor = tf.pow(input_norm_eps / (1 + input_norm_eps), self._squash_pow) * \
+ self._scale_ratio / tf.sqrt(input_norm_eps)
+ tf.summary.histogram('capsule/squash_scale_factor', scale_factor)
+ return scale_factor * inputs
+
+ def _build_capsule_simi(self, high_capsules, capsule_num):
+ high_capsule_mask = tf.sequence_mask(capsule_num,
+ tf.shape(high_capsules)[1])
+ high_capsules = high_capsules * tf.to_float(high_capsule_mask[:, :, None])
+ high_capsules = tf.nn.l2_normalize(high_capsules, axis=-1)
+ sum_sqr = tf.square(tf.reduce_sum(high_capsules, axis=1))
+ sqr_sum = tf.reduce_sum(tf.square(high_capsules), axis=1)
+ simi = sum_sqr - sqr_sum
+
+ div = tf.maximum(tf.to_float(capsule_num * (capsule_num - 1)), 1.0)
+ simi = tf.reduce_sum(simi, axis=1) / div
+
+ is_multi = tf.to_float(capsule_num > 1)
+ avg_simi = tf.reduce_sum((simi + 1) * is_multi) / \
+ (2.0 * tf.reduce_sum(is_multi))
+ return avg_simi
def __call__(self, seq_feas, seq_lens):
- """Capsule layer.
+ """Capsule layer implementation.
Args:
seq_feas: tensor of shape batch_size x self._max_seq_len x low_fea_dim(bsd)
@@ -77,8 +102,18 @@ def __call__(self, seq_feas, seq_lens):
seq_feas_high = tf.tensordot(seq_feas, bilinear_matrix, axes=1)
seq_feas_high_stop = tf.stop_gradient(seq_feas_high)
seq_feas_high_norm = tf.nn.l2_normalize(seq_feas_high_stop, -1)
- num_high_capsules = tf.maximum(
- 1, tf.minimum(self._max_k, tf.to_int32(tf.log(tf.to_float(seq_lens)))))
+
+ if self._const_caps_num:
+ logging.info('will use constant number of capsules: %d' % self._max_k)
+ num_high_capsules = tf.zeros_like(seq_lens, dtype=tf.int32) + self._max_k
+ else:
+ logging.info(
+ 'will use log(seq_len) number of capsules, max_capsules: %d' %
+ self._max_k)
+ num_high_capsules = tf.maximum(
+ 1, tf.minimum(self._max_k,
+ tf.to_int32(tf.log(tf.to_float(seq_lens)))))
+
# batch_size x max_seq_len(bs)
mask = tf.sequence_mask(seq_lens, self._max_seq_len)
mask = tf.cast(mask, tf.float32)
@@ -93,16 +128,37 @@ def __call__(self, seq_feas, seq_lens):
# batch_size x max_seq_len x max_k(bsh)
routing_logits = tf.minimum(routing_logits, max_cap_thresh)
routing_logits = tf.nn.softmax(routing_logits, axis=2)
+
routing_logits = routing_logits * mask[:, :, None]
+
+ logits_simi = self._build_capsule_simi(routing_logits, seq_lens)
+ tf.summary.scalar('capsule/rlogits_simi_%d' % iter_id, logits_simi)
+
+ seq_fea_simi = self._build_capsule_simi(seq_feas_high_stop, seq_lens)
+ tf.summary.scalar('capsule/seq_fea_simi_%d' % iter_id, seq_fea_simi)
+
# batch_size x max_k x high_dim(bse,bsh->bhe)
high_capsules = tf.einsum(
'bse, bsh->bhe', seq_feas_high_stop
if iter_id + 1 < self._num_iters else seq_feas_high, routing_logits)
if iter_id + 1 == self._num_iters:
+ capsule_simi = self._build_capsule_simi(high_capsules,
+ num_high_capsules)
+ tf.summary.scalar('caspule/simi_%d' % iter_id, capsule_simi)
+ tf.summary.scalar('capsule/before_squash',
+ tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
high_capsules = self.squash(high_capsules)
+ tf.summary.scalar('capsule/after_squash',
+ tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
+ capsule_simi_final = self._build_capsule_simi(high_capsules,
+ num_high_capsules)
+ tf.summary.scalar('caspule/simi_final', capsule_simi_final)
break
+
# batch_size x max_k x high_dim(bhe)
high_capsules = tf.nn.l2_normalize(high_capsules, -1)
+ capsule_simi = self._build_capsule_simi(high_capsules, num_high_capsules)
+ tf.summary.scalar('caspule/simi_%d' % iter_id, capsule_simi)
# batch_size x max_seq_len x max_k(bse, bhe->bsh)
if self._routing_logits_scale > 0:
if iter_id == 0:
@@ -115,6 +171,6 @@ def __call__(self, seq_feas, seq_lens):
high_capsules)
# zero paddings
- # high_capsule_mask = tf.sequence_mask(num_high_capsules, self._max_k)
- # high_capsules = high_capsules * tf.to_float(high_capsule_mask[:, :, None])
+ high_capsule_mask = tf.sequence_mask(num_high_capsules, self._max_k)
+ high_capsules = high_capsules * tf.to_float(high_capsule_mask[:, :, None])
return high_capsules, num_high_capsules
diff --git a/easy_rec/python/layers/cmbf.py b/easy_rec/python/layers/cmbf.py
new file mode 100644
index 000000000..e5f1caeb2
--- /dev/null
+++ b/easy_rec/python/layers/cmbf.py
@@ -0,0 +1,390 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+from easy_rec.python.layers import dnn
+from easy_rec.python.layers import multihead_cross_attention
+from easy_rec.python.utils.shape_utils import get_shape_list
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class CMBF(object):
+ """CMBF: Cross-Modal-Based Fusion Recommendation Algorithm.
+
+ This is almost an exact implementation of the original CMBF model.
+ See the original paper:
+ https://www.mdpi.com/1424-8220/21/16/5275
+ """
+
+ def __init__(self, model_config, feature_configs, features, cmbf_config,
+ input_layer):
+ self._model_config = cmbf_config
+
+ has_feature = False
+ self._img_features = None
+ if input_layer.has_group('image'):
+ self._img_features, _ = input_layer(features, 'image')
+ has_feature = True
+ self._general_features = None
+ if input_layer.has_group('general'):
+ self._general_features, _ = input_layer(features, 'general')
+ has_feature = True
+ self._txt_seq_features = None
+ if input_layer.has_group('text'):
+ self._txt_seq_features, _, _ = input_layer(
+ features, 'text', is_combine=False)
+ has_feature = True
+ self._other_features = None
+ if input_layer.has_group('other'): # e.g. statistical feature
+ self._other_features, _ = input_layer(features, 'other')
+ has_feature = True
+ assert has_feature, 'there must be one of the feature groups: [image, text, general, other]'
+
+ self._general_feature_num, self._img_feature_num = 0, 0
+ self._txt_feature_num = 0
+ general_feature_names, txt_seq_feature_names = set(), set()
+ img_feature_names = set()
+ for fea_group in model_config.feature_groups:
+ if fea_group.group_name == 'general':
+ self._general_feature_num = len(fea_group.feature_names)
+ general_feature_names = set(fea_group.feature_names)
+ assert self._general_feature_num == len(general_feature_names), (
+ 'there are duplicate features in `general` feature group')
+ elif fea_group.group_name == 'image':
+ self._img_feature_num = len(fea_group.feature_names)
+ img_feature_names = set(fea_group.feature_names)
+ assert self._img_feature_num == len(img_feature_names), (
+ 'there are duplicate features in `image` feature group')
+ elif fea_group.group_name == 'text':
+ txt_seq_feature_names = set(fea_group.feature_names)
+ self._txt_feature_num = len(fea_group.feature_names)
+ assert self._txt_feature_num == len(txt_seq_feature_names), (
+ 'there are duplicate features in `text` feature group')
+
+ max_seq_len = 0
+ txt_fea_emb_dim_list = []
+ general_emb_dim_list = []
+ img_fea_emb_dim_list = []
+ for feature_config in feature_configs:
+ fea_name = feature_config.input_names[0]
+ if feature_config.HasField('feature_name'):
+ fea_name = feature_config.feature_name
+ if fea_name in img_feature_names:
+ img_fea_emb_dim_list.append(feature_config.raw_input_dim)
+ if fea_name in general_feature_names:
+ general_emb_dim_list.append(feature_config.embedding_dim)
+ if fea_name in txt_seq_feature_names:
+ txt_fea_emb_dim_list.append(feature_config.embedding_dim)
+ if feature_config.HasField('max_seq_len'):
+ assert feature_config.max_seq_len > 0, (
+ 'feature config `max_seq_len` must be greater than 0 for feature: '
+ + fea_name)
+ if feature_config.max_seq_len > max_seq_len:
+ max_seq_len = feature_config.max_seq_len
+
+ unique_dim_num = len(set(txt_fea_emb_dim_list))
+ assert unique_dim_num <= 1 and len(
+ txt_fea_emb_dim_list
+ ) == self._txt_feature_num, (
+ 'CMBF requires that all `text` feature dimensions must be consistent.')
+ unique_dim_num = len(set(general_emb_dim_list))
+ assert unique_dim_num <= 1 and len(
+ general_emb_dim_list
+ ) == self._general_feature_num, (
+ 'CMBF requires that all `general` feature dimensions must be consistent.'
+ )
+ unique_dim_num = len(set(img_fea_emb_dim_list))
+ assert unique_dim_num <= 1 and len(
+ img_fea_emb_dim_list
+ ) == self._img_feature_num, (
+ 'CMBF requires that all `image` feature dimensions must be consistent.')
+
+ if cmbf_config.use_position_embeddings:
+ assert cmbf_config.max_position_embeddings > 0, (
+ 'model config `max_position_embeddings` must be greater than 0. '
+ 'It must be set when `use_position_embeddings` is true (default)')
+ assert cmbf_config.max_position_embeddings >= max_seq_len, (
+ 'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
+ '`max_seq_len`, which is %d' % max_seq_len)
+
+ self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
+ self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
+ self._general_emb_size = general_emb_dim_list[
+ 0] if general_emb_dim_list else 0
+ self._head_num = cmbf_config.multi_head_num
+ self._img_head_num = cmbf_config.image_multi_head_num
+ self._txt_head_num = cmbf_config.text_multi_head_num
+ self._txt_head_size = cmbf_config.text_head_size
+ self._img_head_size = cmbf_config.image_head_size
+ self._img_patch_num = cmbf_config.image_feature_patch_num
+ self._img_self_attention_layer_num = cmbf_config.image_self_attention_layer_num
+ self._txt_self_attention_layer_num = cmbf_config.text_self_attention_layer_num
+ self._cross_modal_layer_num = cmbf_config.cross_modal_layer_num
+ print('txt_feature_num: {0}, img_feature_num: {1}, txt_seq_feature_num: {2}'
+ .format(self._general_feature_num, self._img_feature_num,
+ len(self._txt_seq_features) if self._txt_seq_features else 0))
+ print('txt_embedding_size: {0}, img_embedding_size: {1}'.format(
+ self._txt_emb_size, self._img_emb_size))
+ if self._img_features is not None:
+ assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'
+
+ def image_self_attention_tower(self):
+ """The input of image self attention tower can be one of.
+
+ 1. multiple image embeddings, each corresponding to a patch, or a ROI(region of interest), or a frame of video
+ 2. one big image embedding composed by stacking multiple image embeddings
+ 3. one conventional image embedding extracted by an image model
+
+ If image embedding size is not equal to configured `image_feature_dim` argument,
+ do dimension reduce to this size before single modal learning module
+ """
+ if self._img_features is None:
+ return None
+ image_features = self._img_features
+ img_fea_num = self._img_feature_num
+ if self._img_self_attention_layer_num <= 0:
+ hidden_size = self._model_config.multi_head_num * self._model_config.image_cross_head_size
+ if self._img_emb_size != hidden_size:
+ # Run a linear projection of `hidden_size`
+ image_features = tf.reshape(
+ self._img_features, shape=[-1, self._img_emb_size])
+ image_features = tf.layers.dense(
+ image_features, hidden_size, name='img_projection')
+ image_features = tf.reshape(
+ image_features, shape=[-1, img_fea_num, hidden_size])
+ return image_features
+
+ hidden_size = self._img_head_size * self._img_head_num
+ if img_fea_num > 1: # in case of video frames or ROIs (Region Of Interest)
+ if self._img_emb_size != hidden_size:
+ # Run a linear projection of `hidden_size`
+ image_features = tf.reshape(
+ self._img_features, shape=[-1, self._img_emb_size])
+ image_features = tf.layers.dense(
+ image_features, hidden_size, name='img_projection')
+ image_features = tf.reshape(
+ image_features, shape=[-1, img_fea_num, hidden_size])
+ elif img_fea_num == 1:
+ if self._img_patch_num > 1: # image feature dimension: patch_num * emb_size
+ img_fea_num = self._img_patch_num
+ img_emb_size = self._img_emb_size // self._img_patch_num
+ assert img_emb_size * self._img_patch_num == self._img_emb_size, (
+ 'image feature dimension must equal to `image_feature_slice_num * embedding_size_per_region`'
+ )
+ self._img_emb_size = img_emb_size
+ if self._img_emb_size != hidden_size:
+ # Run a linear projection of `hidden_size`
+ image_features = tf.reshape(
+ self._img_features, shape=[-1, self._img_emb_size])
+ image_features = tf.layers.dense(
+ image_features, hidden_size, name='img_projection')
+ image_features = tf.reshape(
+ image_features, shape=[-1, img_fea_num, hidden_size])
+ else:
+ img_fea_num = self._model_config.image_feature_dim
+ if img_fea_num != self._img_emb_size:
+ image_features = tf.layers.dense(
+ image_features, img_fea_num, name='img_projection')
+ # convert each element of image feature to a feature vector
+ img_mapping_matrix = tf.get_variable(
+ 'img_map_matrix', [1, img_fea_num, hidden_size], dtype=tf.float32)
+ image_features = tf.expand_dims(image_features, -1) * img_mapping_matrix
+
+ img_attention_fea = multihead_cross_attention.transformer_encoder(
+ image_features,
+ hidden_size=hidden_size, # head_num * size_per_head
+ num_hidden_layers=self._img_self_attention_layer_num,
+ num_attention_heads=self._head_num,
+ intermediate_size=hidden_size * 4,
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
+ attention_probs_dropout_prob=self._model_config
+ .attention_probs_dropout_prob,
+ name='image_self_attention'
+ ) # shape: [batch_size, image_seq_num/image_feature_dim, hidden_size]
+ # print('img_attention_fea:', img_attention_fea.shape)
+ return img_attention_fea
+
+ def text_self_attention_tower(self):
+ hidden_size = self._txt_head_size * self._txt_head_num
+ txt_features = None
+ all_txt_features = []
+ input_masks = []
+
+ if self._general_features is not None:
+ general_features = self._general_features
+ if self._general_emb_size != hidden_size:
+ # Run a linear projection of `hidden_size`
+ general_features = tf.reshape(
+ general_features, shape=[-1, self._general_emb_size])
+ general_features = tf.layers.dense(
+ general_features, hidden_size, name='txt_projection')
+ txt_features = tf.reshape(
+ general_features, shape=[-1, self._general_feature_num, hidden_size])
+
+ all_txt_features.append(txt_features)
+ batch_size = tf.shape(txt_features)[0]
+ mask = tf.ones(
+ shape=tf.stack([batch_size, self._general_feature_num]),
+ dtype=tf.int32)
+ input_masks.append(mask)
+
+ input_mask = None
+ attention_mask = None
+ if self._txt_seq_features is not None:
+
+ def dynamic_mask(x, max_len):
+ ones = tf.ones(shape=tf.stack([x]), dtype=tf.int32)
+ zeros = tf.zeros(shape=tf.stack([max_len - x]), dtype=tf.int32)
+ return tf.concat([ones, zeros], axis=0)
+
+ token_type_vocab_size = len(self._txt_seq_features)
+ for i, (seq_fea, seq_len) in enumerate(self._txt_seq_features):
+ batch_size, max_seq_len, emb_size = get_shape_list(seq_fea, 3)
+ if emb_size != hidden_size:
+ seq_fea = tf.reshape(seq_fea, shape=[-1, emb_size])
+ seq_fea = tf.layers.dense(
+ seq_fea, hidden_size, name='txt_seq_projection_%d' % i)
+ seq_fea = tf.reshape(seq_fea, shape=[-1, max_seq_len, hidden_size])
+
+ seq_fea = multihead_cross_attention.embedding_postprocessor(
+ seq_fea,
+ use_token_type=self._model_config.use_token_type,
+ token_type_ids=tf.ones(
+ shape=tf.stack([batch_size, max_seq_len]), dtype=tf.int32) * i,
+ token_type_vocab_size=token_type_vocab_size,
+ reuse_token_type=tf.AUTO_REUSE,
+ use_position_embeddings=self._model_config.use_position_embeddings,
+ max_position_embeddings=self._model_config.max_position_embeddings,
+ position_embedding_name='position_embeddings_%d' % i,
+ dropout_prob=self._model_config.text_seq_emb_dropout_prob)
+ all_txt_features.append(seq_fea)
+
+ input_mask = tf.map_fn(
+ fn=lambda t: dynamic_mask(t, max_seq_len),
+ elems=tf.to_int32(seq_len))
+ input_masks.append(input_mask)
+
+ txt_features = tf.concat(all_txt_features, axis=1)
+ input_mask = tf.concat(input_masks, axis=1)
+ attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
+ from_tensor=txt_features, to_mask=input_mask)
+
+ if txt_features is None:
+ return None, None, None
+
+ txt_attention_fea = multihead_cross_attention.transformer_encoder(
+ txt_features,
+ hidden_size=hidden_size,
+ num_hidden_layers=self._txt_self_attention_layer_num,
+ num_attention_heads=self._head_num,
+ attention_mask=attention_mask,
+ intermediate_size=hidden_size * 4,
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
+ attention_probs_dropout_prob=self._model_config
+ .attention_probs_dropout_prob,
+ name='text_self_attention'
+ ) # shape: [batch_size, txt_seq_length, hidden_size]
+ print('txt_attention_fea:', txt_attention_fea.shape)
+ return txt_attention_fea, input_mask, input_masks
+
+ def merge_text_embedding(self, txt_embeddings, input_masks):
+ shape = get_shape_list(txt_embeddings)
+ if self._txt_seq_features is None:
+ return tf.reshape(txt_embeddings, shape=[-1, shape[1] * shape[2]])
+
+ text_seq_emb = []
+ if self._general_feature_num > 0:
+ text_emb = tf.slice(txt_embeddings, [0, 0, 0],
+ [shape[0], self._general_feature_num, shape[2]])
+ text_seq_emb.append(text_emb)
+
+ begin = self._general_feature_num
+ for i in range(len(text_seq_emb), len(input_masks)):
+ size = tf.shape(input_masks[i])[1]
+ temp_emb = tf.slice(txt_embeddings, [0, begin, 0],
+ [shape[0], size, shape[2]])
+ mask = tf.expand_dims(tf.to_float(input_masks[i]), -1)
+ temp_emb = temp_emb * mask
+ # avg pooling
+ emb_sum = tf.reduce_sum(
+ temp_emb, axis=1,
+ keepdims=True) # shape: [batch_size, 1, hidden_size]
+ count = tf.reduce_sum(
+ mask, axis=1, keepdims=True) # shape: [batch_size, 1, 1]
+ seq_emb = emb_sum / count # shape: [batch_size, 1, hidden_size]
+
+ text_seq_emb.append(seq_emb)
+ begin = begin + size
+
+ txt_emb = tf.concat(text_seq_emb, axis=1)
+ seq_num = len(text_seq_emb)
+ if self._general_feature_num > 0:
+ seq_num += self._general_feature_num - 1
+ txt_embeddings = tf.reshape(txt_emb, shape=[-1, seq_num * shape[2]])
+ return txt_embeddings
+
+ def __call__(self, is_training, *args, **kwargs):
+ if not is_training:
+ self._model_config.hidden_dropout_prob = 0.0
+ self._model_config.attention_probs_dropout_prob = 0.0
+
+ # shape: [batch_size, image_num/image_dim, hidden_size]
+ img_attention_fea = self.image_self_attention_tower()
+
+ # shape: [batch_size, txt_seq_length, hidden_size]
+ txt_attention_fea, input_mask, input_masks = self.text_self_attention_tower(
+ )
+
+ all_fea = []
+ if None not in [img_attention_fea, txt_attention_fea]:
+ img_embeddings, txt_embeddings = multihead_cross_attention.cross_attention_tower(
+ img_attention_fea,
+ txt_attention_fea,
+ num_hidden_layers=self._cross_modal_layer_num,
+ num_attention_heads=self._head_num,
+ right_input_mask=input_mask,
+ left_size_per_head=self._model_config.image_cross_head_size,
+ left_intermediate_size=4 * self._model_config.image_cross_head_size *
+ self._head_num,
+ right_size_per_head=self._model_config.text_cross_head_size,
+ right_intermediate_size=4 * self._model_config.text_cross_head_size *
+ self._head_num,
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
+ attention_probs_dropout_prob=self._model_config
+ .attention_probs_dropout_prob)
+ # img_embeddings shape: [batch_size, image_(region_)num/image_feature_dim, multi_head_num * image_cross_head_size]
+ print('img_embeddings:', img_embeddings.shape)
+ # txt_embeddings shape: [batch_size, general_feature_num + max_txt_seq_len, multi_head_num * text_cross_head_size]
+ print('txt_embeddings:', txt_embeddings.shape)
+
+ # shape: [batch_size, multi_head_num * image_cross_head_size]
+ img_embeddings = tf.reduce_mean(img_embeddings, axis=1)
+
+ # shape: [batch_size, (general_feature_num + txt_seq_num) * multi_head_num * text_cross_head_size]
+ txt_embeddings = self.merge_text_embedding(txt_embeddings, input_masks)
+ all_fea = [img_embeddings, txt_embeddings]
+
+ elif img_attention_fea is not None: # only has image tower
+ # avg pooling, shape: [batch_size, multi_head_num * image_head_size]
+ img_embeddings = tf.reduce_mean(img_attention_fea, axis=1)
+ all_fea = [img_embeddings]
+
+ elif txt_attention_fea is not None: # only has text tower
+ # shape: [batch_size, (general_feature_num + txt_seq_num) * multi_head_num * text_head_size]
+ txt_embeddings = self.merge_text_embedding(txt_attention_fea, input_masks)
+ all_fea = [txt_embeddings]
+
+ if self._other_features is not None:
+ if self._model_config.HasField('other_feature_dnn'):
+ l2_reg = kwargs['l2_reg'] if 'l2_reg' in kwargs else 0
+ other_dnn_layer = dnn.DNN(self._model_config.other_feature_dnn, l2_reg,
+ 'other_dnn', is_training)
+ other_fea = other_dnn_layer(self._other_features)
+ all_fea.append(other_fea) # e.g. statistical features
+ else:
+ all_fea.append(self._other_features) # e.g. statistical features
+
+ output = tf.concat(all_fea, axis=-1)
+ return output
diff --git a/easy_rec/python/layers/common_layers.py b/easy_rec/python/layers/common_layers.py
index 883f2a67c..68ecf37f5 100644
--- a/easy_rec/python/layers/common_layers.py
+++ b/easy_rec/python/layers/common_layers.py
@@ -1,45 +1,41 @@
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
-import numpy as np
+import six
import tensorflow as tf
+from easy_rec.python.compat.layers import layer_norm as tf_layer_norm
+from easy_rec.python.utils.activation import get_activation
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
-def gelu(x):
- """Gaussian Error Linear Unit.
-
- This is a smoother version of the RELU.
- Original paper: https://arxiv.org/abs/1606.08415
- Args:
- x: float Tensor to perform activation.
-
- Returns:
- `x` with the GELU activation applied.
- """
- cdf = 0.5 * (1.0 + tf.tanh(
- (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
- return x * cdf
-
-
def highway(x,
size=None,
activation=None,
num_layers=1,
scope='highway',
dropout=0.0,
+ init_gate_bias=-1.0,
reuse=None):
+ if isinstance(activation, six.string_types):
+ activation = get_activation(activation)
with tf.variable_scope(scope, reuse):
if size is None:
size = x.shape.as_list()[-1]
else:
x = tf.layers.dense(x, size, name='input_projection', reuse=reuse)
+ initializer = tf.constant_initializer(init_gate_bias)
for i in range(num_layers):
T = tf.layers.dense(
- x, size, activation=tf.sigmoid, name='gate_%d' % i, reuse=reuse)
+ x,
+ size,
+ activation=tf.sigmoid,
+ bias_initializer=initializer,
+ name='gate_%d' % i,
+ reuse=reuse)
H = tf.layers.dense(
x, size, activation=activation, name='activation_%d' % i, reuse=reuse)
if dropout > 0.0:
@@ -65,8 +61,8 @@ def text_cnn(x,
# conv shape: (batch_size, seq_len - filter_size + 1, num_filters)
conv = tf.layers.conv1d(
x,
- filters=num_filter,
- kernel_size=filter_size,
+ filters=int(num_filter),
+ kernel_size=int(filter_size),
activation=tf.nn.relu,
name='conv_layer',
reuse=reuse,
@@ -78,3 +74,119 @@ def text_cnn(x,
pool_flat = tf.concat(
pooled_outputs, 1) # shape: (batch_size, num_filters * len(filter_sizes))
return pool_flat
+
+
+def layer_norm(input_tensor, name=None, reuse=None):
+ """Run layer normalization on the last dimension of the tensor."""
+ return tf_layer_norm(
+ inputs=input_tensor,
+ begin_norm_axis=-1,
+ begin_params_axis=-1,
+ reuse=reuse,
+ scope=name)
+
+
+class EnhancedInputLayer(object):
+ """Enhance the raw input layer."""
+
+ def __init__(self, input_layer, feature_dict, group_name, reuse=None):
+ self._group_name = group_name
+ self.name = 'input_' + self._group_name
+ self._input_layer = input_layer
+ self._feature_dict = feature_dict
+ self._reuse = reuse
+ self.built = False
+
+ def __call__(self, config, is_training, **kwargs):
+ if not self.built:
+ self.build(config, is_training)
+
+ if config.output_seq_and_normal_feature:
+ return self.inputs
+
+ if config.do_batch_norm and config.do_layer_norm:
+ raise ValueError(
+ 'can not do batch norm and layer norm for input layer at the same time'
+ )
+ with tf.name_scope(self.name):
+ return self.call(config, is_training)
+
+ def build(self, config, training):
+ self.built = True
+ combine = not config.output_seq_and_normal_feature
+ self.inputs = self._input_layer(
+ self._feature_dict, self._group_name, is_combine=combine)
+ if config.output_seq_and_normal_feature:
+ seq_feature_and_len, _, target_features = self.inputs
+ seq_len = seq_feature_and_len[0][1]
+ seq_features = [seq_fea for seq_fea, _ in seq_feature_and_len]
+ if config.concat_seq_feature:
+ if target_features:
+ target_features = tf.concat(target_features, axis=-1)
+ else:
+ target_features = None
+ assert len(
+ seq_features) > 0, '[%s] sequence feature is empty' % self.name
+ seq_features = tf.concat(seq_features, axis=-1)
+ self.inputs = seq_features, seq_len, target_features
+ self.reset(config, training)
+
+ def reset(self, config, training):
+ if 0.0 < config.dropout_rate < 1.0:
+ self.dropout = tf.keras.layers.Dropout(rate=config.dropout_rate)
+
+ if training and 0.0 < config.feature_dropout_rate < 1.0:
+ keep_prob = 1.0 - config.feature_dropout_rate
+ self.bern = tf.distributions.Bernoulli(probs=keep_prob, dtype=tf.float32)
+
+ def call(self, config, training):
+ features, feature_list = self.inputs
+ num_features = len(feature_list)
+
+ do_ln = config.do_layer_norm
+ do_bn = config.do_batch_norm
+ do_feature_dropout = training and 0.0 < config.feature_dropout_rate < 1.0
+ if do_feature_dropout:
+ keep_prob = 1.0 - config.feature_dropout_rate
+ mask = self.bern.sample(num_features)
+ elif do_bn:
+ features = tf.layers.batch_normalization(
+ features, training=training, reuse=self._reuse)
+ elif do_ln:
+ features = layer_norm(
+ features, name=self._group_name + '_features', reuse=self._reuse)
+
+ output_feature_list = config.output_2d_tensor_and_feature_list
+ output_feature_list = output_feature_list or config.only_output_feature_list
+ output_feature_list = output_feature_list or config.only_output_3d_tensor
+ rate = config.dropout_rate
+ do_dropout = 0.0 < rate < 1.0
+ if do_feature_dropout or do_ln or do_bn or do_dropout:
+ for i in range(num_features):
+ fea = feature_list[i]
+ if do_bn:
+ fea = tf.layers.batch_normalization(
+ fea, training=training, reuse=self._reuse)
+ elif do_ln:
+ ln_name = self._group_name + 'f_%d' % i
+ fea = layer_norm(fea, name=ln_name, reuse=self._reuse)
+ if do_dropout and output_feature_list:
+ fea = self.dropout.call(fea, training=training)
+ if do_feature_dropout:
+ fea = tf.div(fea, keep_prob) * mask[i]
+ feature_list[i] = fea
+ if do_feature_dropout:
+ features = tf.concat(feature_list, axis=-1)
+
+ if do_dropout and not do_feature_dropout:
+ features = self.dropout.call(features, training=training)
+ if features.shape.ndims == 3 and int(features.shape[0]) == 1:
+ features = tf.squeeze(features, axis=0)
+
+ if config.only_output_feature_list:
+ return feature_list
+ if config.only_output_3d_tensor:
+ return tf.stack(feature_list, axis=1)
+ if config.output_2d_tensor_and_feature_list:
+ return features, feature_list
+ return features
diff --git a/easy_rec/python/layers/dnn.py b/easy_rec/python/layers/dnn.py
index f4de24455..7a57f5661 100644
--- a/easy_rec/python/layers/dnn.py
+++ b/easy_rec/python/layers/dnn.py
@@ -4,7 +4,7 @@
import tensorflow as tf
-from easy_rec.python.utils.load_class import load_by_path
+from easy_rec.python.utils.activation import get_activation
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -12,21 +12,32 @@
class DNN:
- def __init__(self, dnn_config, l2_reg, name='dnn', is_training=False):
+ def __init__(self,
+ dnn_config,
+ l2_reg,
+ name='dnn',
+ is_training=False,
+ last_layer_no_activation=False,
+ last_layer_no_batch_norm=False):
"""Initializes a `DNN` Layer.
Args:
dnn_config: instance of easy_rec.python.protos.dnn_pb2.DNN
l2_reg: l2 regularizer
name: scope of the DNN, so that the parameters could be separated from other dnns
- is_training: train phase or not, impact batchnorm and dropout
+ is_training: train phase or not, impact batch_norm and dropout
+ last_layer_no_activation: in last layer, use or not use activation
+ last_layer_no_batch_norm: in last layer, use or not use batch norm
"""
self._config = dnn_config
self._l2_reg = l2_reg
self._name = name
self._is_training = is_training
logging.info('dnn activation function = %s' % self._config.activation)
- self.activation = load_by_path(self._config.activation)
+ self.activation = get_activation(
+ self._config.activation, training=is_training)
+ self._last_layer_no_activation = last_layer_no_activation
+ self._last_layer_no_batch_norm = last_layer_no_batch_norm
@property
def hidden_units(self):
@@ -49,14 +60,16 @@ def __call__(self, deep_fea, hidden_layer_feature_output=False):
kernel_regularizer=self._l2_reg,
activation=None,
name='%s/dnn_%d' % (self._name, i))
- if self._config.use_bn:
+ if self._config.use_bn and ((i + 1 < hidden_units_len) or
+ not self._last_layer_no_batch_norm):
deep_fea = tf.layers.batch_normalization(
deep_fea,
training=self._is_training,
trainable=True,
name='%s/dnn_%d/bn' % (self._name, i))
- deep_fea = self.activation(
- deep_fea, name='%s/dnn_%d/act' % (self._name, i))
+ if (i + 1 < hidden_units_len) or not self._last_layer_no_activation:
+ deep_fea = self.activation(
+ deep_fea, name='%s/dnn_%d/act' % (self._name, i))
if len(self.dropout_ratio) > 0 and self._is_training:
assert self.dropout_ratio[
i] < 1, 'invalid dropout_ratio: %.3f' % self.dropout_ratio[i]
diff --git a/easy_rec/python/layers/fm.py b/easy_rec/python/layers/fm.py
index 4765a4512..1929e00aa 100644
--- a/easy_rec/python/layers/fm.py
+++ b/easy_rec/python/layers/fm.py
@@ -19,8 +19,7 @@ def __init__(self, name='fm'):
def __call__(self, fm_fea):
with tf.name_scope(self._name):
- fm_feas = [tf.expand_dims(x, axis=1) for x in fm_fea]
- fm_feas = tf.concat(fm_feas, axis=1)
+ fm_feas = tf.stack(fm_fea, axis=1)
sum_square = tf.square(tf.reduce_sum(fm_feas, 1))
square_sum = tf.reduce_sum(tf.square(fm_feas), 1)
y_v = 0.5 * tf.subtract(sum_square, square_sum)
diff --git a/easy_rec/python/layers/input_layer.py b/easy_rec/python/layers/input_layer.py
index 3085adc7a..27bc9bdf4 100644
--- a/easy_rec/python/layers/input_layer.py
+++ b/easy_rec/python/layers/input_layer.py
@@ -1,23 +1,27 @@
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
+import os
+from collections import OrderedDict
import tensorflow as tf
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
from easy_rec.python.compat import regularizers
from easy_rec.python.compat.feature_column import feature_column
from easy_rec.python.feature_column.feature_column import FeatureColumnParser
from easy_rec.python.feature_column.feature_group import FeatureGroup
-from easy_rec.python.layers import dnn
-from easy_rec.python.layers import seq_input_layer
+from easy_rec.python.layers import sequence_feature_layer
from easy_rec.python.layers import variational_dropout_layer
-from easy_rec.python.layers.common_layers import text_cnn
+from easy_rec.python.layers.keras import TextCNN
+from easy_rec.python.layers.utils import Parameter
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
+from easy_rec.python.utils import conditional
+from easy_rec.python.utils import shape_utils
-from easy_rec.python.compat.feature_column.feature_column import _SharedEmbeddingColumn # NOQA
-from easy_rec.python.compat.feature_column.feature_column_v2 import EmbeddingColumn # NOQA
-if tf.__version__ >= '2.0':
- tf = tf.compat.v1
+from easy_rec.python.compat.feature_column.feature_column_v2 import is_embedding_column # NOQA
class InputLayer(object):
@@ -31,101 +35,214 @@ def __init__(self,
feature_groups_config,
variational_dropout_config=None,
wide_output_dim=-1,
- use_embedding_variable=False,
+ ev_params=None,
embedding_regularizer=None,
kernel_regularizer=None,
- is_training=False):
+ is_training=False,
+ is_predicting=False):
self._feature_groups = {
x.group_name: FeatureGroup(x) for x in feature_groups_config
}
- self._seq_feature_groups_config = [
- x.sequence_features
- for x in feature_groups_config
- if x.HasField('sequence_features')
- ]
+ self.sequence_feature_layer = sequence_feature_layer.SequenceFeatureLayer(
+ feature_configs, feature_groups_config, ev_params,
+ embedding_regularizer, kernel_regularizer, is_training, is_predicting)
+ self._seq_feature_groups_config = []
+ for x in feature_groups_config:
+ for y in x.sequence_features:
+ self._seq_feature_groups_config.append(y)
self._group_name_to_seq_features = {
x.group_name: x.sequence_features
for x in feature_groups_config
- if x.HasField('sequence_features')
+ if len(x.sequence_features) > 0
}
- self._seq_input_layer = None
- if len(self._seq_feature_groups_config) > 0:
- self._seq_input_layer = seq_input_layer.SeqInputLayer(
- feature_configs, self._seq_feature_groups_config)
wide_and_deep_dict = self.get_wide_deep_dict()
self._fc_parser = FeatureColumnParser(
feature_configs,
wide_and_deep_dict,
wide_output_dim,
- use_embedding_variable=use_embedding_variable)
+ ev_params=ev_params)
self._embedding_regularizer = embedding_regularizer
self._kernel_regularizer = kernel_regularizer
self._is_training = is_training
+ self._is_predicting = is_predicting
self._variational_dropout_config = variational_dropout_config
def has_group(self, group_name):
return group_name in self._feature_groups
- def target_attention(self, dnn_config, deep_fea, name):
- cur_id, hist_id_col, seq_len = deep_fea['key'], deep_fea[
- 'hist_seq_emb'], deep_fea['hist_seq_len']
-
- seq_max_len = tf.shape(hist_id_col)[1]
- emb_dim = hist_id_col.shape[2]
-
- cur_ids = tf.tile(cur_id, [1, seq_max_len])
- cur_ids = tf.reshape(cur_ids,
- tf.shape(hist_id_col)) # (B, seq_max_len, emb_dim)
-
- din_net = tf.concat(
- [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
- axis=-1) # (B, seq_max_len, emb_dim*4)
-
- din_layer = dnn.DNN(dnn_config, None, name, self._is_training)
- din_net = din_layer(din_net)
- scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)
-
- seq_len = tf.expand_dims(seq_len, 1)
- mask = tf.sequence_mask(seq_len)
- padding = tf.ones_like(scores) * (-2**32 + 1)
- scores = tf.where(mask, scores, padding) # [B, 1, seq_max_len]
-
- # Scale
- scores = tf.nn.softmax(scores) # (B, 1, seq_max_len)
- hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, emb_dim]
- hist_din_emb = tf.reshape(hist_din_emb, [-1, emb_dim]) # [B, emb_dim]
- din_output = tf.concat([hist_din_emb, cur_id], axis=1)
- return din_output
-
- def call_seq_input_layer(self,
- features,
- seq_att_map_config,
- feature_name_to_output_tensors=None):
- group_name = seq_att_map_config.group_name
- allow_key_search = seq_att_map_config.allow_key_search
- seq_features = self._seq_input_layer(features, group_name,
- feature_name_to_output_tensors,
- allow_key_search)
- regularizers.apply_regularization(
- self._embedding_regularizer, weights_list=[seq_features['key']])
- regularizers.apply_regularization(
- self._embedding_regularizer,
- weights_list=[seq_features['hist_seq_emb']])
- seq_dnn_config = None
- if seq_att_map_config.HasField('seq_dnn'):
- seq_dnn_config = seq_att_map_config.seq_dnn
+ def get_combined_feature(self, features, group_name, is_dict=False):
+ """Get combined features by group_name.
+
+ Args:
+ features: input tensor dict
+ group_name: feature_group name
+ is_dict: whether to return group_features in dict
+
+ Return:
+ features: all features concatenate together
+ group_features: list of features
+ feature_name_to_output_tensors: dict, feature_name to feature_value, only present when is_dict is True
+ """
+ feature_name_to_output_tensors = {}
+ negative_sampler = self._feature_groups[group_name]._config.negative_sampler
+
+ place_on_cpu = os.getenv('place_embedding_on_cpu')
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
+ with conditional(self._is_predicting and place_on_cpu,
+ ops.device('/CPU:0')):
+ concat_features, group_features = self.single_call_input_layer(
+ features, group_name, feature_name_to_output_tensors)
+ if group_name in self._group_name_to_seq_features:
+ # for target attention
+ group_seq_arr = self._group_name_to_seq_features[group_name]
+ concat_features, all_seq_fea = self.sequence_feature_layer(
+ features,
+ concat_features,
+ group_seq_arr,
+ feature_name_to_output_tensors,
+ negative_sampler=negative_sampler,
+ scope_name=group_name)
+ group_features.extend(all_seq_fea)
+ for col, fea in zip(group_seq_arr, all_seq_fea):
+ feature_name_to_output_tensors['seq_fea/' + col.group_name] = fea
+ all_seq_fea = array_ops.concat(all_seq_fea, axis=-1)
+ concat_features = array_ops.concat([concat_features, all_seq_fea],
+ axis=-1)
+ if is_dict:
+ return concat_features, group_features, feature_name_to_output_tensors
else:
- logging.info(
- 'seq_dnn not set in seq_att_groups, will use default settings')
- from easy_rec.python.protos.dnn_pb2 import DNN
- seq_dnn_config = DNN()
- seq_dnn_config.hidden_units.extend([128, 64, 32, 1])
- seq_fea = self.target_attention(
- seq_dnn_config, seq_features, name='seq_dnn')
- return seq_fea
-
- def __call__(self, features, group_name, is_combine=True):
+ return concat_features, group_features
+
+ def get_plain_feature(self, features, group_name):
+ """Get plain features by group_name. Exclude sequence features.
+
+ Args:
+ features: input tensor dict
+ group_name: feature_group name
+
+ Return:
+ features: all features concatenate together
+ group_features: list of features
+ """
+ assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
+ group_name, ','.join([x for x in self._feature_groups]))
+
+ feature_group = self._feature_groups[group_name]
+ group_columns, _ = feature_group.select_columns(self._fc_parser)
+ if not group_columns:
+ return None, []
+
+ cols_to_output_tensors = OrderedDict()
+ output_features = feature_column.input_layer(
+ features,
+ group_columns,
+ cols_to_output_tensors=cols_to_output_tensors,
+ is_training=self._is_training)
+ group_features = [cols_to_output_tensors[x] for x in group_columns]
+
+ embedding_reg_lst = []
+ for col, val in cols_to_output_tensors.items():
+ if is_embedding_column(col):
+ embedding_reg_lst.append(val)
+
+ if self._embedding_regularizer is not None and len(embedding_reg_lst) > 0:
+ regularizers.apply_regularization(
+ self._embedding_regularizer, weights_list=embedding_reg_lst)
+ return output_features, group_features
+
+ def get_sequence_feature(self, features, group_name):
+ """Get sequence features by group_name. Exclude plain features.
+
+ Args:
+ features: input tensor dict
+ group_name: feature_group name
+
+ Return:
+ seq_features: list of sequence features, each element is a tuple:
+ 3d embedding tensor (batch_size, max_seq_len, embedding_dimension),
+ 1d sequence length tensor.
+ """
+ assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
+ group_name, ','.join([x for x in self._feature_groups]))
+
+ if self._variational_dropout_config is not None:
+ raise ValueError(
+ 'variational dropout is not supported in not combined mode now.')
+
+ feature_group = self._feature_groups[group_name]
+ _, group_seq_columns = feature_group.select_columns(self._fc_parser)
+
+ embedding_reg_lst = []
+ builder = feature_column._LazyBuilder(features)
+ seq_features = []
+ for fc in group_seq_columns:
+ with variable_scope.variable_scope('input_layer/' +
+ fc.categorical_column.name):
+ tmp_embedding, tmp_seq_len = fc._get_sequence_dense_tensor(builder)
+ if fc.max_seq_length > 0:
+ tmp_embedding, tmp_seq_len = shape_utils.truncate_sequence(
+ tmp_embedding, tmp_seq_len, fc.max_seq_length)
+ seq_features.append((tmp_embedding, tmp_seq_len))
+ embedding_reg_lst.append(tmp_embedding)
+
+ if self._embedding_regularizer is not None and len(embedding_reg_lst) > 0:
+ regularizers.apply_regularization(
+ self._embedding_regularizer, weights_list=embedding_reg_lst)
+ return seq_features
+
+ def get_raw_features(self, features, group_name):
+ """Get features by group_name.
+
+ Args:
+ features: input tensor dict
+ group_name: feature_group name
+
+ Return:
+ features: all raw features in list
+ """
+ assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
+ group_name, ','.join([x for x in self._feature_groups]))
+ feature_group = self._feature_groups[group_name]
+ return [features[x] for x in feature_group.feature_names]
+
+ def get_bucketized_features(self, features, group_name):
+ """Get features by group_name.
+
+ Args:
+ features: input tensor dict
+ group_name: feature_group name
+
+ Return:
+ features: all raw features in list, added feature offset
+ """
+ assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
+ group_name, ','.join([x for x in self._feature_groups]))
+ feature_group = self._feature_groups[group_name]
+ offset = 0
+ values = []
+ weights = []
+ for feature in feature_group.feature_names:
+ vocab = self._fc_parser.get_feature_vocab_size(feature)
+ logging.info('vocab size of feature %s is %d' % (feature, vocab))
+ weights.append(None)
+ if tf.is_numeric_tensor(features[feature]):
+ # suppose feature already have be bucketized
+ value = tf.to_int64(features[feature])
+ elif isinstance(features[feature], tf.SparseTensor):
+ # TagFeature
+ dense = tf.sparse.to_dense(features[feature], default_value='')
+ value = tf.string_to_hash_bucket_fast(dense, vocab)
+ if (feature + '_w') in features:
+ weights[-1] = features[feature + '_w'] # SparseTensor
+ logging.info('feature %s has weight %s', feature, feature + '_w')
+ else: # IdFeature
+ value = tf.string_to_hash_bucket_fast(features[feature], vocab)
+ values.append(value + offset)
+ offset += vocab
+ return values, offset, weights
+
+ def __call__(self, features, group_name, is_combine=True, is_dict=False):
"""Get features by group_name.
Args:
@@ -133,137 +250,130 @@ def __call__(self, features, group_name, is_combine=True):
group_name: feature_group name
is_combine: whether to combine sequence features over the
time dimension.
+ is_dict: whether to return group_features in dict
Return:
- features: all features concatenate together
- group_features: list of features
- seq_features: list of sequence features, each element is a tuple:
+ is_combine: True
+ features: all features concatenate together
+ group_features: list of features
+ feature_name_to_output_tensors: dict, feature_name to feature_value, only present when is_dict is True
+ is_combine: False
+ seq_features: list of sequence features, each element is a tuple:
3 dimension embedding tensor (batch_size, max_seq_len, embedding_dimension),
1 dimension sequence length tensor.
"""
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
group_name, ','.join([x for x in self._feature_groups]))
- feature_name_to_output_tensors = {}
- if group_name in self._group_name_to_seq_features:
- for seq_att in self._group_name_to_seq_features[group_name].seq_att_map:
- for k in seq_att.key:
- feature_name_to_output_tensors[k] = None
if is_combine:
- concat_features, group_features = self.single_call_input_layer(
- features, group_name, is_combine, feature_name_to_output_tensors)
- if group_name in self._group_name_to_seq_features:
- seq_fea = self.call_seq_input_layer(
- features, self._group_name_to_seq_features[group_name],
- feature_name_to_output_tensors)
- concat_features = tf.concat([concat_features, seq_fea], axis=1)
- return concat_features, group_features
- else:
- return self.single_call_input_layer(features, group_name, is_combine)
+ return self.get_combined_feature(features, group_name, is_dict)
+
+ # return sequence feature in raw format instead of combine them
+ place_on_cpu = os.getenv('place_embedding_on_cpu')
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
+ with conditional(self._is_predicting and place_on_cpu,
+ ops.device('/CPU:0')):
+ seq_features = self.get_sequence_feature(features, group_name)
+ plain_features, feature_list = self.get_plain_feature(
+ features, group_name)
+ return seq_features, plain_features, feature_list
def single_call_input_layer(self,
features,
group_name,
- is_combine=True,
feature_name_to_output_tensors=None):
"""Get features by group_name.
Args:
features: input tensor dict
group_name: feature_group name
- is_combine: whether to combine sequence features over the
- time dimension.
- feature_name_to_output_tensors: if set sequence_features, feature_name_to_output_tensors will
- take key tensors to reuse.
+ feature_name_to_output_tensors: if set sequence_features,
+ feature_name_to_output_tensors will take key tensors to reuse.
Return:
features: all features concatenate together
group_features: list of features
- seq_features: list of sequence features, each element is a tuple:
- 3 dimension embedding tensor (batch_size, max_seq_len, embedding_dimension),
- 1 dimension sequence length tensor.
"""
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
group_name, ','.join([x for x in self._feature_groups]))
feature_group = self._feature_groups[group_name]
group_columns, group_seq_columns = feature_group.select_columns(
self._fc_parser)
- if is_combine:
- cols_to_output_tensors = {}
- output_features = feature_column.input_layer(
- features,
- group_columns,
- cols_to_output_tensors=cols_to_output_tensors,
- feature_name_to_output_tensors=feature_name_to_output_tensors)
- embedding_reg_lst = [output_features]
- builder = feature_column._LazyBuilder(features)
- seq_features = []
- for column in sorted(group_seq_columns, key=lambda x: x.name):
- with tf.variable_scope(None, default_name=column._var_scope_name):
- seq_feature, seq_len = column._get_sequence_dense_tensor(builder)
- embedding_reg_lst.append(seq_feature)
-
- sequence_combiner = column.sequence_combiner
- if sequence_combiner is None:
- raise ValueError(
- 'sequence_combiner is none, please set sequence_combiner or use TagFeature'
- )
- if sequence_combiner.WhichOneof('combiner') == 'attention':
- attn_logits = tf.layers.dense(
- inputs=seq_feature,
- units=1,
- kernel_regularizer=self._kernel_regularizer,
- use_bias=False,
- activation=None,
- name='attention')
- attn_logits = tf.squeeze(attn_logits, axis=-1)
- attn_logits_padding = tf.ones_like(attn_logits) * (-2**32 + 1)
- seq_mask = tf.sequence_mask(seq_len)
- attn_score = tf.nn.softmax(
- tf.where(seq_mask, attn_logits, attn_logits_padding))
- seq_feature = tf.reduce_sum(
- attn_score[:, :, tf.newaxis] * seq_feature, axis=1)
- seq_features.append(seq_feature)
- cols_to_output_tensors[column] = seq_feature
- elif sequence_combiner.WhichOneof('combiner') == 'text_cnn':
- num_filters = sequence_combiner.text_cnn.num_filters
- filter_sizes = sequence_combiner.text_cnn.filter_sizes
- cnn_feature = text_cnn(seq_feature, filter_sizes, num_filters)
- seq_features.append(cnn_feature)
- cols_to_output_tensors[column] = cnn_feature
- else:
- raise NotImplementedError
- if self._variational_dropout_config is not None:
- features_dimension = [
- cols_to_output_tensors[x].get_shape()[-1] for x in group_columns
- ]
- variational_dropout = variational_dropout_layer.VariationalDropoutLayer(
- self._variational_dropout_config, features_dimension,
- self._is_training)
- noisy_features = variational_dropout(output_features)
- concat_features = tf.concat([noisy_features] + seq_features, axis=-1)
- else:
- concat_features = tf.concat([output_features] + seq_features, axis=-1)
- regularizers.apply_regularization(
- self._embedding_regularizer, weights_list=embedding_reg_lst)
+ cols_to_output_tensors = OrderedDict()
+ output_features = feature_column.input_layer(
+ features,
+ group_columns,
+ cols_to_output_tensors=cols_to_output_tensors,
+ feature_name_to_output_tensors=feature_name_to_output_tensors,
+ is_training=self._is_training)
+ embedding_reg_lst = []
+ builder = feature_column._LazyBuilder(features)
+ seq_features = []
+ for column in sorted(group_seq_columns, key=lambda x: x.name):
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name):
+ seq_feature, seq_len = column._get_sequence_dense_tensor(builder)
+ embedding_reg_lst.append(seq_feature)
+
+ sequence_combiner = column.sequence_combiner
+ if sequence_combiner is None:
+ raise ValueError(
+ 'sequence_combiner is none, please set sequence_combiner or use TagFeature'
+ )
+ if sequence_combiner.WhichOneof('combiner') == 'attention':
+ attn_logits = tf.layers.dense(
+ inputs=seq_feature,
+ units=1,
+ kernel_regularizer=self._kernel_regularizer,
+ use_bias=False,
+ activation=None,
+ name='attention')
+ attn_logits = tf.squeeze(attn_logits, axis=-1)
+ attn_logits_padding = tf.ones_like(attn_logits) * (-2**32 + 1)
+ seq_mask = tf.sequence_mask(seq_len)
+ attn_score = tf.nn.softmax(
+ tf.where(seq_mask, attn_logits, attn_logits_padding))
+ seq_feature = tf.reduce_sum(
+ attn_score[:, :, tf.newaxis] * seq_feature, axis=1)
+ seq_features.append(seq_feature)
+ cols_to_output_tensors[column] = seq_feature
+ elif sequence_combiner.WhichOneof('combiner') == 'text_cnn':
+ params = Parameter.make_from_pb(sequence_combiner.text_cnn)
+ text_cnn_layer = TextCNN(params, name=column.name + '_text_cnn')
+ cnn_feature = text_cnn_layer((seq_feature, seq_len))
+ seq_features.append(cnn_feature)
+ cols_to_output_tensors[column] = cnn_feature
+ else:
+ raise NotImplementedError
+ if self._variational_dropout_config is not None:
+ features_dimension = OrderedDict([
+ (k.raw_name, int(v.shape[-1]))
+ for k, v in cols_to_output_tensors.items()
+ ])
+ concat_features = array_ops.concat(
+ [output_features] + seq_features, axis=-1)
+ variational_dropout = variational_dropout_layer.VariationalDropoutLayer(
+ self._variational_dropout_config,
+ features_dimension,
+ self._is_training,
+ name=group_name)
+ concat_features = variational_dropout(concat_features)
+ group_features = tf.split(
+ concat_features, list(features_dimension.values()), axis=-1)
+ else:
+ concat_features = array_ops.concat(
+ [output_features] + seq_features, axis=-1)
group_features = [cols_to_output_tensors[x] for x in group_columns] + \
[cols_to_output_tensors[x] for x in group_seq_columns]
- return concat_features, group_features
- else: # return sequence feature in raw format instead of combine them
- assert len(group_columns) == 0, \
- 'there are none sequence columns: %s' % str(group_columns)
- builder = feature_column._LazyBuilder(features)
- seq_features = []
- embedding_reg_lst = []
- for fc in group_seq_columns:
- with tf.variable_scope('input_layer/' + fc.categorical_column.name):
- tmp_embedding, tmp_seq_len = fc._get_sequence_dense_tensor(builder)
- seq_features.append((tmp_embedding, tmp_seq_len))
- embedding_reg_lst.append(tmp_embedding)
- regularizers.apply_regularization(
- self._embedding_regularizer, weights_list=embedding_reg_lst)
- return seq_features
+ if self._embedding_regularizer is not None:
+ for fc, val in cols_to_output_tensors.items():
+ if is_embedding_column(fc):
+ embedding_reg_lst.append(val)
+ if embedding_reg_lst:
+ regularizers.apply_regularization(
+ self._embedding_regularizer, weights_list=embedding_reg_lst)
+ return concat_features, group_features
def get_wide_deep_dict(self):
"""Get wide or deep indicator for feature columns.
diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py
new file mode 100644
index 000000000..17e7cdb1c
--- /dev/null
+++ b/easy_rec/python/layers/keras/__init__.py
@@ -0,0 +1,34 @@
+from .attention import Attention
+from .auxiliary_loss import AuxiliaryLoss
+from .blocks import MLP
+from .blocks import Gate
+from .blocks import Highway
+from .blocks import TextCNN
+from .bst import BST
+from .custom_ops import EditDistance
+from .custom_ops import MappedDotProduct
+from .custom_ops import OverlapFeature
+from .custom_ops import SeqAugmentOps
+from .custom_ops import TextNormalize
+from .data_augment import SeqAugment
+from .din import DIN
+from .embedding import EmbeddingLayer
+from .fibinet import BiLinear
+from .fibinet import FiBiNet
+from .fibinet import SENet
+from .interaction import CIN
+from .interaction import FM
+from .interaction import Cross
+from .interaction import DotInteraction
+from .mask_net import MaskBlock
+from .mask_net import MaskNet
+from .multi_head_attention import MultiHeadAttention
+from .multi_task import AITMTower
+from .multi_task import MMoE
+from .numerical_embedding import AutoDisEmbedding
+from .numerical_embedding import NaryDisEmbedding
+from .numerical_embedding import PeriodicEmbedding
+from .ppnet import PPNet
+from .transformer import TextEncoder
+from .transformer import TransformerBlock
+from .transformer import TransformerEncoder
diff --git a/easy_rec/python/layers/keras/activation.py b/easy_rec/python/layers/keras/activation.py
new file mode 100644
index 000000000..fa6218e64
--- /dev/null
+++ b/easy_rec/python/layers/keras/activation.py
@@ -0,0 +1,114 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+from tensorflow.python.keras.layers import Activation
+from tensorflow.python.keras.layers import Layer
+
+import easy_rec.python.utils.activation
+
+try:
+ from tensorflow.python.ops.init_ops import Zeros
+except ImportError:
+ from tensorflow.python.ops.init_ops_v2 import Zeros
+
+try:
+ from tensorflow.python.keras.layers import BatchNormalization
+except ImportError:
+ BatchNormalization = tf.keras.layers.BatchNormalization
+
+try:
+ unicode
+except NameError:
+ unicode = str
+
+
+class Dice(Layer):
+ """The Data Adaptive Activation Function in DIN.
+
+ which can be viewed as a generalization of PReLu
+ and can adaptively adjust the rectified point according to distribution of input data.
+
+ Input shape
+ - Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis)
+ when using this layer as the first layer in a model.
+
+ Output shape
+ - Same shape as the input.
+
+ Arguments
+ - **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis).
+ - **epsilon** : Small float added to variance to avoid dividing by zero.
+
+ References
+ - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
+ https://arxiv.org/pdf/1706.06978.pdf
+ """
+
+ def __init__(self, axis=-1, epsilon=1e-9, **kwargs):
+ self.axis = axis
+ self.epsilon = epsilon
+ super(Dice, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ self.bn = BatchNormalization(
+ axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
+ self.alphas = self.add_weight(
+ shape=(input_shape[-1],),
+ initializer=Zeros(),
+ dtype=tf.float32,
+ name='dice_alpha') # name='alpha_'+self.name
+ super(Dice, self).build(input_shape) # Be sure to call this somewhere!
+ self.uses_learning_phase = True
+
+ def call(self, inputs, training=None, **kwargs):
+ inputs_normed = self.bn(inputs, training=training)
+ # tf.layers.batch_normalization(
+ # inputs, axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
+ x_p = tf.sigmoid(inputs_normed)
+ return self.alphas * (1.0 - x_p) * inputs + x_p * inputs
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ @property
+ def updates(self):
+ return self.bn.updates
+
+ def get_config(self,):
+ config = {'axis': self.axis, 'epsilon': self.epsilon}
+ base_config = super(Dice, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+class MaskedSoftmax(Layer):
+
+ def __init__(self, axis=-1, **kwargs):
+ super(MaskedSoftmax, self).__init__(**kwargs)
+ self.axis = axis
+
+ def call(self, inputs, mask=None):
+ if mask is not None:
+ adder = (1.0 - tf.cast(mask, inputs.dtype)) * -1e9
+ inputs += adder
+ # Calculate softmax
+ if isinstance(self.axis, (tuple, list)):
+ if len(self.axis) > 1:
+ raise ValueError('MaskedSoftmax not support multiple axis')
+ else:
+ return tf.nn.softmax(inputs, axis=self.axis[0])
+ return tf.nn.softmax(inputs, axis=self.axis)
+
+
+def activation_layer(activation, name=None):
+ if activation in ('dice', 'Dice'):
+ act_layer = Dice(name=name)
+ elif isinstance(activation, (str, unicode)):
+ act_fn = easy_rec.python.utils.activation.get_activation(activation)
+ act_layer = Activation(act_fn, name=name)
+ elif issubclass(activation, Layer):
+ act_layer = activation(name=name)
+ else:
+ raise ValueError(
+ 'Invalid activation,found %s.You should use a str or a Activation Layer Class.'
+ % (activation))
+ return act_layer
diff --git a/easy_rec/python/layers/keras/attention.py b/easy_rec/python/layers/keras/attention.py
new file mode 100644
index 000000000..4831ccae8
--- /dev/null
+++ b/easy_rec/python/layers/keras/attention.py
@@ -0,0 +1,267 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Attention layers that can be used in sequence DNN/CNN models.
+
+This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
+Attention is formed by three tensors: Query, Key and Value.
+"""
+import tensorflow as tf
+from tensorflow.python.keras.layers import Layer
+
+
+class Attention(Layer):
+ """Dot-product attention layer, a.k.a. Luong-style attention.
+
+ Inputs are a list with 2 or 3 elements:
+ 1. A `query` tensor of shape `(batch_size, Tq, dim)`.
+ 2. A `value` tensor of shape `(batch_size, Tv, dim)`.
+ 3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none
+ supplied, `value` will be used as a `key`.
+
+ The calculation follows the steps:
+ 1. Calculate attention scores using `query` and `key` with shape
+ `(batch_size, Tq, Tv)`.
+ 2. Use scores to calculate a softmax distribution with shape
+ `(batch_size, Tq, Tv)`.
+ 3. Use the softmax distribution to create a linear combination of `value`
+ with shape `(batch_size, Tq, dim)`.
+
+ Args:
+ use_scale: If `True`, will create a scalar variable to scale the
+ attention scores.
+ dropout: Float between 0 and 1. Fraction of the units to drop for the
+ attention scores. Defaults to `0.0`.
+ seed: A Python integer to use as random seed in case of `dropout`.
+ score_mode: Function to use to compute attention scores, one of
+ `{"dot", "concat"}`. `"dot"` refers to the dot product between the
+ query and key vectors. `"concat"` refers to the hyperbolic tangent
+ of the concatenation of the `query` and `key` vectors.
+
+ Call Args:
+ inputs: List of the following tensors:
+ - `query`: Query tensor of shape `(batch_size, Tq, dim)`.
+ - `value`: Value tensor of shape `(batch_size, Tv, dim)`.
+ - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If
+ not given, will use `value` for both `key` and `value`, which is
+ the most common case.
+ mask: List of the following tensors:
+ - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`.
+ If given, the output will be zero at the positions where
+ `mask==False`.
+ - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`.
+ If given, will apply the mask such that values at positions
+ where `mask==False` do not contribute to the result.
+ return_attention_scores: bool, it `True`, returns the attention scores
+ (after masking and softmax) as an additional output argument.
+ training: Python boolean indicating whether the layer should behave in
+ training mode (adding dropout) or in inference mode (no dropout).
+ use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds
+ a mask such that position `i` cannot attend to positions `j > i`.
+ This prevents the flow of information from the future towards the
+ past. Defaults to `False`.
+
+ Output:
+ Attention outputs of shape `(batch_size, Tq, dim)`.
+ (Optional) Attention scores after masking and softmax with shape
+ `(batch_size, Tq, Tv)`.
+ """
+
+ def __init__(self, params, name='attention', reuse=None, **kwargs):
+ super(Attention, self).__init__(name=name, **kwargs)
+ self.use_scale = params.get_or_default('use_scale', False)
+ self.scale_by_dim = params.get_or_default('scale_by_dim', False)
+ self.score_mode = params.get_or_default('score_mode', 'dot')
+ if self.score_mode not in ['dot', 'concat']:
+ raise ValueError('Invalid value for argument score_mode. '
+ "Expected one of {'dot', 'concat'}. "
+ 'Received: score_mode=%s' % self.score_mode)
+ self.dropout = params.get_or_default('dropout', 0.0)
+ self.seed = params.get_or_default('seed', None)
+ self.scale = None
+ self.concat_score_weight = None
+ self._return_attention_scores = params.get_or_default(
+ 'return_attention_scores', False)
+ self.use_causal_mask = params.get_or_default('use_causal_mask', False)
+
+ @property
+ def return_attention_scores(self):
+ return self._return_attention_scores
+
+ def build(self, input_shape):
+ self._validate_inputs(input_shape)
+ if self.use_scale:
+ self.scale = self.add_weight(
+ name='scale',
+ shape=(),
+ initializer='ones',
+ dtype=self.dtype,
+ trainable=True,
+ )
+ if self.score_mode == 'concat':
+ self.concat_score_weight = self.add_weight(
+ name='concat_score_weight',
+ shape=(),
+ initializer='ones',
+ dtype=self.dtype,
+ trainable=True,
+ )
+ super(Attention, self).build(input_shape) # Be sure to call this somewhere!
+
+ def _calculate_scores(self, query, key):
+ """Calculates attention scores as a query-key dot product.
+
+ Args:
+ query: Query tensor of shape `(batch_size, Tq, dim)`.
+ key: Key tensor of shape `(batch_size, Tv, dim)`.
+
+ Returns:
+ Tensor of shape `(batch_size, Tq, Tv)`.
+ """
+ if self.score_mode == 'dot':
+ scores = tf.matmul(query, tf.transpose(key, [0, 2, 1]))
+ if self.scale is not None:
+ scores *= self.scale
+ elif self.scale_by_dim:
+ dk = tf.cast(tf.shape(key)[-1], tf.float32)
+ scores /= tf.math.sqrt(dk)
+ elif self.score_mode == 'concat':
+ # Reshape tensors to enable broadcasting.
+ # Reshape into [batch_size, Tq, 1, dim].
+ q_reshaped = tf.expand_dims(query, axis=-2)
+ # Reshape into [batch_size, 1, Tv, dim].
+ k_reshaped = tf.expand_dims(key, axis=-3)
+ if self.scale is not None:
+ scores = self.concat_score_weight * tf.reduce_sum(
+ tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1)
+ else:
+ scores = self.concat_score_weight * tf.reduce_sum(
+ tf.tanh(q_reshaped + k_reshaped), axis=-1)
+ return scores
+
+ def _apply_scores(self, scores, value, scores_mask=None, training=False):
+ """Applies attention scores to the given value tensor.
+
+ To use this method in your attention layer, follow the steps:
+
+ * Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of
+ shape `(batch_size, Tv)` to calculate the attention `scores`.
+ * Pass `scores` and `value` tensors to this method. The method applies
+ `scores_mask`, calculates
+ `attention_distribution = softmax(scores)`, then returns
+ `matmul(attention_distribution, value).
+ * Apply `query_mask` and return the result.
+
+ Args:
+ scores: Scores float tensor of shape `(batch_size, Tq, Tv)`.
+ value: Value tensor of shape `(batch_size, Tv, dim)`.
+ scores_mask: A boolean mask tensor of shape `(batch_size, 1, Tv)`
+ or `(batch_size, Tq, Tv)`. If given, scores at positions where
+ `scores_mask==False` do not contribute to the result. It must
+ contain at least one `True` value in each line along the last
+ dimension.
+ training: Python boolean indicating whether the layer should behave
+ in training mode (adding dropout) or in inference mode
+ (no dropout).
+
+ Returns:
+ Tensor of shape `(batch_size, Tq, dim)`.
+ Attention scores after masking and softmax with shape
+ `(batch_size, Tq, Tv)`.
+ """
+ if scores_mask is not None:
+ padding_mask = tf.logical_not(scores_mask)
+ # Bias so padding positions do not contribute to attention
+ # distribution. Note 65504. is the max float16 value.
+ max_value = 65504.0 if scores.dtype == 'float16' else 1.0e9
+ scores -= max_value * tf.cast(padding_mask, dtype=scores.dtype)
+
+ weights = tf.nn.softmax(scores, axis=-1)
+ if training and self.dropout > 0:
+ weights = tf.nn.dropout(weights, 1.0 - self.dropout, seed=self.seed)
+ return tf.matmul(weights, value), weights
+
+ def _calculate_score_mask(self, scores, v_mask, use_causal_mask):
+ if use_causal_mask:
+ # Creates a lower triangular mask, so position i cannot attend to
+ # positions j > i. This prevents the flow of information from the
+ # future into the past.
+ score_shape = tf.shape(scores)
+ # causal_mask_shape = [1, Tq, Tv].
+ mask_shape = (1, score_shape[-2], score_shape[-1])
+ ones_mask = tf.ones(shape=mask_shape, dtype='int32')
+ row_index = tf.cumsum(ones_mask, axis=-2)
+ col_index = tf.cumsum(ones_mask, axis=-1)
+ causal_mask = tf.greater_equal(row_index, col_index)
+
+ if v_mask is not None:
+ # Mask of shape [batch_size, 1, Tv].
+ v_mask = tf.expand_dims(v_mask, axis=-2)
+ return tf.logical_and(v_mask, causal_mask)
+ return causal_mask
+ else:
+ # If not using causal mask, return the value mask as is,
+ # or None if the value mask is not provided.
+ return v_mask
+
+ def call(self, inputs, mask=None, training=False, **kwargs):
+ self._validate_inputs(inputs=inputs, mask=mask)
+ q = inputs[0]
+ v = inputs[1]
+ k = inputs[2] if len(inputs) > 2 else v
+ q_mask = mask[0] if mask else None
+ v_mask = mask[1] if mask else None
+ scores = self._calculate_scores(query=q, key=k)
+ scores_mask = self._calculate_score_mask(scores, v_mask,
+ self.use_causal_mask)
+ result, attention_scores = self._apply_scores(
+ scores=scores, value=v, scores_mask=scores_mask, training=training)
+ if q_mask is not None:
+ # Mask of shape [batch_size, Tq, 1].
+ q_mask = tf.expand_dims(q_mask, axis=-1)
+ result *= tf.cast(q_mask, dtype=result.dtype)
+ if self._return_attention_scores:
+ return result, attention_scores
+ return result
+
+ def compute_mask(self, inputs, mask=None):
+ self._validate_inputs(inputs=inputs, mask=mask)
+ if mask is None or mask[0] is None:
+ return None
+ return tf.convert_to_tensor(mask[0])
+
+ def compute_output_shape(self, input_shape):
+ """Returns shape of value tensor dim, but for query tensor length."""
+ return list(input_shape[0][:-1]), input_shape[1][-1]
+
+ def _validate_inputs(self, inputs, mask=None):
+ """Validates arguments of the call method."""
+ class_name = self.__class__.__name__
+ if not isinstance(inputs, list):
+ raise ValueError('{class_name} layer must be called on a list of inputs, '
+ 'namely [query, value] or [query, value, key]. '
+ 'Received: inputs={inputs}.'.format(
+ class_name=class_name, inputs=inputs))
+ if len(inputs) < 2 or len(inputs) > 3:
+ raise ValueError('%s layer accepts inputs list of length 2 or 3, '
+ 'namely [query, value] or [query, value, key]. '
+ 'Received length: %d.' % (class_name, len(inputs)))
+ if mask is not None:
+ if not isinstance(mask, list):
+ raise ValueError(
+ '{class_name} layer mask must be a list, '
+ 'namely [query_mask, value_mask]. Received: mask={mask}.'.format(
+ class_name=class_name, mask=mask))
+ if len(mask) < 2 or len(mask) > 3:
+ raise ValueError(
+ '{class_name} layer accepts mask list of length 2 or 3. '
+ 'Received: inputs={inputs}, mask={mask}.'.format(
+ class_name=class_name, inputs=inputs, mask=mask))
+
+ def get_config(self):
+ base_config = super(Attention, self).get_config()
+ config = {
+ 'use_scale': self.use_scale,
+ 'score_mode': self.score_mode,
+ 'dropout': self.dropout,
+ }
+ return dict(list(base_config.items()) + list(config.items()))
diff --git a/easy_rec/python/layers/keras/auxiliary_loss.py b/easy_rec/python/layers/keras/auxiliary_loss.py
new file mode 100644
index 000000000..6be248872
--- /dev/null
+++ b/easy_rec/python/layers/keras/auxiliary_loss.py
@@ -0,0 +1,47 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+
+from easy_rec.python.loss import contrastive_loss
+
+
+class AuxiliaryLoss(tf.keras.layers.Layer):
+ """Compute auxiliary loss, usually use for contrastive learning."""
+
+ def __init__(self, params, name='auxiliary_loss', reuse=None, **kwargs):
+ super(AuxiliaryLoss, self).__init__(name=name, **kwargs)
+ params.check_required('loss_type')
+ self.loss_type = params.get_or_default('loss_type', None)
+ self.loss_weight = params.get_or_default('loss_weight', 1.0)
+ logging.info('init layer `%s` with loss type: %s and weight: %f' %
+ (self.name, self.loss_type, self.loss_weight))
+ self.temperature = params.get_or_default('temperature', 0.1)
+
+ def call(self, inputs, training=None, **kwargs):
+ if self.loss_type is None:
+ logging.warning('loss_type is None in auxiliary loss layer')
+ return 0
+
+ loss_dict = kwargs['loss_dict']
+ loss_value = 0
+
+ if self.loss_type == 'l2_loss':
+ x1, x2 = inputs
+ loss = contrastive_loss.l2_loss(x1, x2)
+ loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
+ loss_dict['%s_l2_loss' % self.name] = loss_value
+ elif self.loss_type == 'info_nce':
+ query, positive = inputs
+ loss = contrastive_loss.info_nce_loss(
+ query, positive, temperature=self.temperature)
+ loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
+ loss_dict['%s_info_nce_loss' % self.name] = loss_value
+ elif self.loss_type == 'nce_loss':
+ x1, x2 = inputs
+ loss = contrastive_loss.nce_loss(x1, x2, temperature=self.temperature)
+ loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
+ loss_dict['%s_nce_loss' % self.name] = loss_value
+
+ return loss_value
diff --git a/easy_rec/python/layers/keras/blocks.py b/easy_rec/python/layers/keras/blocks.py
new file mode 100644
index 000000000..c9e722a67
--- /dev/null
+++ b/easy_rec/python/layers/keras/blocks.py
@@ -0,0 +1,262 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Convenience blocks for building models."""
+import logging
+
+import tensorflow as tf
+from tensorflow.python.keras.initializers import Constant
+from tensorflow.python.keras.layers import Dense
+from tensorflow.python.keras.layers import Dropout
+from tensorflow.python.keras.layers import Lambda
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.layers.keras.activation import activation_layer
+from easy_rec.python.layers.utils import Parameter
+from easy_rec.python.utils.shape_utils import pad_or_truncate_sequence
+from easy_rec.python.utils.tf_utils import add_elements_to_collection
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class MLP(Layer):
+ """Sequential multi-layer perceptron (MLP) block.
+
+ Attributes:
+ units: Sequential list of layer sizes.
+ use_bias: Whether to include a bias term.
+ activation: Type of activation to use on all except the last layer.
+ final_activation: Type of activation to use on last layer.
+ **kwargs: Extra args passed to the Keras Layer base class.
+ """
+
+ def __init__(self, params, name='mlp', reuse=None, **kwargs):
+ super(MLP, self).__init__(name=name, **kwargs)
+ self.layer_name = name # for add to output
+ params.check_required('hidden_units')
+ use_bn = params.get_or_default('use_bn', True)
+ use_final_bn = params.get_or_default('use_final_bn', True)
+ use_bias = params.get_or_default('use_bias', False)
+ use_final_bias = params.get_or_default('use_final_bias', False)
+ dropout_rate = list(params.get_or_default('dropout_ratio', []))
+ activation = params.get_or_default('activation', 'relu')
+ initializer = params.get_or_default('initializer', 'he_uniform')
+ final_activation = params.get_or_default('final_activation', None)
+ use_bn_after_act = params.get_or_default('use_bn_after_activation', False)
+ units = list(params.hidden_units)
+ logging.info(
+ 'MLP(%s) units: %s, dropout: %r, activate=%s, use_bn=%r, final_bn=%r,'
+ ' final_activate=%s, bias=%r, initializer=%s, bn_after_activation=%r' %
+ (name, units, dropout_rate, activation, use_bn, use_final_bn,
+ final_activation, use_bias, initializer, use_bn_after_act))
+ assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name
+ self.reuse = reuse
+ self.add_to_outputs = params.get_or_default('add_to_outputs', False)
+
+ num_dropout = len(dropout_rate)
+ self._sub_layers = []
+ for i, num_units in enumerate(units[:-1]):
+ name = 'layer_%d' % i
+ drop_rate = dropout_rate[i] if i < num_dropout else 0.0
+ self.add_rich_layer(num_units, use_bn, drop_rate, activation, initializer,
+ use_bias, use_bn_after_act, name,
+ params.l2_regularizer)
+
+ n = len(units) - 1
+ drop_rate = dropout_rate[n] if num_dropout > n else 0.0
+ name = 'layer_%d' % n
+ self.add_rich_layer(units[-1], use_final_bn, drop_rate, final_activation,
+ initializer, use_final_bias, use_bn_after_act, name,
+ params.l2_regularizer)
+
+ def add_rich_layer(self,
+ num_units,
+ use_bn,
+ dropout_rate,
+ activation,
+ initializer,
+ use_bias,
+ use_bn_after_activation,
+ name,
+ l2_reg=None):
+ act_layer = activation_layer(activation, name='%s/act' % name)
+ if use_bn and not use_bn_after_activation:
+ dense = Dense(
+ units=num_units,
+ use_bias=use_bias,
+ kernel_initializer=initializer,
+ kernel_regularizer=l2_reg,
+ name='%s/dense' % name)
+ self._sub_layers.append(dense)
+ bn = tf.keras.layers.BatchNormalization(
+ name='%s/bn' % name, trainable=True)
+ self._sub_layers.append(bn)
+ self._sub_layers.append(act_layer)
+ else:
+ dense = Dense(
+ num_units,
+ use_bias=use_bias,
+ kernel_initializer=initializer,
+ kernel_regularizer=l2_reg,
+ name='%s/dense' % name)
+ self._sub_layers.append(dense)
+ self._sub_layers.append(act_layer)
+ if use_bn and use_bn_after_activation:
+ bn = tf.keras.layers.BatchNormalization(name='%s/bn' % name)
+ self._sub_layers.append(bn)
+
+ if 0.0 < dropout_rate < 1.0:
+ dropout = Dropout(dropout_rate, name='%s/dropout' % name)
+ self._sub_layers.append(dropout)
+ elif dropout_rate >= 1.0:
+ raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
+
+ def call(self, x, training=None, **kwargs):
+ """Performs the forward computation of the block."""
+ for layer in self._sub_layers:
+ cls = layer.__class__.__name__
+ if cls in ('Dropout', 'BatchNormalization', 'Dice'):
+ x = layer(x, training=training)
+ if cls in ('BatchNormalization', 'Dice') and training:
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
+ else:
+ x = layer(x)
+ if self.add_to_outputs and 'prediction_dict' in kwargs:
+ outputs = kwargs['prediction_dict']
+ outputs[self.layer_name] = tf.squeeze(x, axis=1)
+ logging.info('add `%s` to model outputs' % self.layer_name)
+ return x
+
+
+class Highway(Layer):
+
+ def __init__(self, params, name='highway', reuse=None, **kwargs):
+ super(Highway, self).__init__(name=name, **kwargs)
+ self.emb_size = params.get_or_default('emb_size', None)
+ self.num_layers = params.get_or_default('num_layers', 1)
+ self.activation = params.get_or_default('activation', 'relu')
+ self.dropout_rate = params.get_or_default('dropout_rate', 0.0)
+ self.init_gate_bias = params.get_or_default('init_gate_bias', -3.0)
+ self.act_layer = activation_layer(self.activation)
+ self.dropout_layer = Dropout(
+ self.dropout_rate) if self.dropout_rate > 0.0 else None
+ self.project_layer = None
+ self.gate_bias_initializer = Constant(self.init_gate_bias)
+ self.gates = [] # T
+ self.transforms = [] # H
+ self.multiply_layer = tf.keras.layers.Multiply()
+ self.add_layer = tf.keras.layers.Add()
+
+ def build(self, input_shape):
+ dim = input_shape[-1]
+ if self.emb_size is not None and dim != self.emb_size:
+ self.project_layer = Dense(self.emb_size, name='input_projection')
+ dim = self.emb_size
+ self.carry_gate = Lambda(lambda x: 1.0 - x, output_shape=(dim,))
+ for i in range(self.num_layers):
+ gate = Dense(
+ units=dim,
+ bias_initializer=self.gate_bias_initializer,
+ activation='sigmoid',
+ name='gate_%d' % i)
+ self.gates.append(gate)
+ self.transforms.append(Dense(units=dim))
+
+ def call(self, inputs, training=None, **kwargs):
+ value = inputs
+ if self.project_layer is not None:
+ value = self.project_layer(inputs)
+ for i in range(self.num_layers):
+ gate = self.gates[i](value)
+ transformed = self.act_layer(self.transforms[i](value))
+ if self.dropout_layer is not None:
+ transformed = self.dropout_layer(transformed, training=training)
+ transformed_gated = self.multiply_layer([gate, transformed])
+ identity_gated = self.multiply_layer([self.carry_gate(gate), value])
+ value = self.add_layer([transformed_gated, identity_gated])
+ return value
+
+
+class Gate(Layer):
+ """Weighted sum gate."""
+
+ def __init__(self, params, name='gate', reuse=None, **kwargs):
+ super(Gate, self).__init__(name=name, **kwargs)
+ self.weight_index = params.get_or_default('weight_index', 0)
+ if params.has_field('mlp'):
+ mlp_cfg = Parameter.make_from_pb(params.mlp)
+ mlp_cfg.l2_regularizer = params.l2_regularizer
+ self.top_mlp = MLP(mlp_cfg, name='top_mlp')
+ else:
+ self.top_mlp = None
+
+ def call(self, inputs, training=None, **kwargs):
+ assert len(
+ inputs
+ ) > 1, 'input of Gate layer must be a list containing at least 2 elements'
+ weights = inputs[self.weight_index]
+ j = 0
+ for i, x in enumerate(inputs):
+ if i == self.weight_index:
+ continue
+ if j == 0:
+ output = weights[:, j, None] * x
+ else:
+ output += weights[:, j, None] * x
+ j += 1
+ if self.top_mlp is not None:
+ output = self.top_mlp(output, training=training)
+ return output
+
+
+class TextCNN(Layer):
+ """Text CNN Model.
+
+ References
+ - [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882)
+ """
+
+ def __init__(self, params, name='text_cnn', reuse=None, **kwargs):
+ super(TextCNN, self).__init__(name=name, **kwargs)
+ self.config = params.get_pb_config()
+ self.pad_seq_length = self.config.pad_sequence_length
+ if self.pad_seq_length <= 0:
+ logging.warning(
+ 'run text cnn with pad_sequence_length <= 0, the predict of model may be unstable'
+ )
+ self.conv_layers = []
+ self.pool_layer = tf.keras.layers.GlobalMaxPool1D()
+ self.concat_layer = tf.keras.layers.Concatenate(axis=-1)
+ for size, filters in zip(self.config.filter_sizes, self.config.num_filters):
+ conv = tf.keras.layers.Conv1D(
+ filters=int(filters),
+ kernel_size=int(size),
+ activation=self.config.activation)
+ self.conv_layers.append(conv)
+ if self.config.HasField('mlp'):
+ p = Parameter.make_from_pb(self.config.mlp)
+ p.l2_regularizer = params.l2_regularizer
+ self.mlp = MLP(p, name='mlp', reuse=reuse)
+ else:
+ self.mlp = None
+
+ def call(self, inputs, training=None, **kwargs):
+ """Input shape: 3D tensor with shape: `(batch_size, steps, input_dim)."""
+ assert isinstance(inputs, (list, tuple))
+ assert len(inputs) >= 2
+ seq_emb, seq_len = inputs[:2]
+
+ if self.pad_seq_length > 0:
+ seq_emb, seq_len = pad_or_truncate_sequence(seq_emb, seq_len,
+ self.pad_seq_length)
+ pooled_outputs = []
+ for layer in self.conv_layers:
+ conv = layer(seq_emb)
+ pooled = self.pool_layer(conv)
+ pooled_outputs.append(pooled)
+ net = self.concat_layer(pooled_outputs)
+ if self.mlp is not None:
+ output = self.mlp(net, training=training)
+ else:
+ output = net
+ return output
diff --git a/easy_rec/python/layers/keras/bst.py b/easy_rec/python/layers/keras/bst.py
new file mode 100644
index 000000000..dbd4882ed
--- /dev/null
+++ b/easy_rec/python/layers/keras/bst.py
@@ -0,0 +1,119 @@
+# -*- encoding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.layers import multihead_cross_attention
+from easy_rec.python.utils.activation import get_activation
+from easy_rec.python.utils.shape_utils import get_shape_list
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class BST(Layer):
+
+ def __init__(self, params, name='bst', reuse=None, **kwargs):
+ super(BST, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ self.l2_reg = params.l2_regularizer
+ self.config = params.get_pb_config()
+
+ def encode(self, seq_input, max_position):
+ seq_fea = multihead_cross_attention.embedding_postprocessor(
+ seq_input,
+ position_embedding_name=self.name,
+ max_position_embeddings=max_position,
+ reuse_position_embedding=self.reuse)
+
+ n = tf.count_nonzero(seq_input, axis=-1)
+ seq_mask = tf.cast(n > 0, tf.int32)
+
+ attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
+ from_tensor=seq_fea, to_mask=seq_mask)
+
+ hidden_act = get_activation(self.config.hidden_act)
+ attention_fea = multihead_cross_attention.transformer_encoder(
+ seq_fea,
+ hidden_size=self.config.hidden_size,
+ num_hidden_layers=self.config.num_hidden_layers,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_mask=attention_mask,
+ intermediate_size=self.config.intermediate_size,
+ intermediate_act_fn=hidden_act,
+ hidden_dropout_prob=self.config.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
+ initializer_range=self.config.initializer_range,
+ name=self.name + '/transformer',
+ reuse=self.reuse)
+ # attention_fea shape: [batch_size, seq_length, hidden_size]
+ if self.config.output_all_token_embeddings:
+ out_fea = tf.reshape(attention_fea,
+ [-1, max_position * self.config.hidden_size])
+ else:
+ out_fea = attention_fea[:, 0, :] # target feature
+ print('bst output shape:', out_fea.shape)
+ return out_fea
+
+ def call(self, inputs, training=None, **kwargs):
+ if not training:
+ self.config.hidden_dropout_prob = 0.0
+ self.config.attention_probs_dropout_prob = 0.0
+ assert isinstance(inputs, (list, tuple))
+ assert len(inputs) >= 2
+ # seq_input: [batch_size, seq_len, embed_size]
+ seq_input, seq_len = inputs[:2]
+ target = inputs[2] if len(inputs) > 2 else None
+ max_position = self.config.max_position_embeddings
+ # max_seq_len: the max sequence length in current mini-batch, all sequences are padded to this length
+ batch_size, cur_batch_max_seq_len, seq_embed_size = get_shape_list(
+ seq_input, 3)
+ valid_len = tf.assert_less_equal(
+ cur_batch_max_seq_len,
+ max_position,
+ message='sequence length is greater than `max_position_embeddings`:' +
+ str(max_position) + ' in feature group:' + self.name +
+ ', you should set `max_seq_len` in sequence feature configs')
+
+ if self.config.output_all_token_embeddings:
+ seq_input = tf.cond(
+ tf.constant(max_position) > cur_batch_max_seq_len, lambda: tf.pad(
+ seq_input, [[0, 0], [0, max_position - cur_batch_max_seq_len],
+ [0, 0]], 'CONSTANT'),
+ lambda: tf.slice(seq_input, [0, 0, 0], [-1, max_position, -1]))
+
+ if seq_embed_size != self.config.hidden_size:
+ seq_input = tf.layers.dense(
+ seq_input,
+ self.config.hidden_size,
+ activation=tf.nn.relu,
+ kernel_regularizer=self.l2_reg,
+ name=self.name + '/seq_project',
+ reuse=self.reuse)
+
+ keep_target = self.config.target_item_position in ('head', 'tail')
+ if target is not None and keep_target:
+ target_size = target.shape.as_list()[-1]
+ assert seq_embed_size == target_size, 'the embedding size of sequence and target item is not equal' \
+ ' in feature group:' + self.name
+ if target_size != self.config.hidden_size:
+ target = tf.layers.dense(
+ target,
+ self.config.hidden_size,
+ activation=tf.nn.relu,
+ kernel_regularizer=self.l2_reg,
+ name=self.name + '/target_project',
+ reuse=self.reuse)
+ # target_feature: [batch_size, 1, embed_size]
+ target = tf.expand_dims(target, 1)
+ # seq_input: [batch_size, seq_len+1, embed_size]
+ if self.config.target_item_position == 'head':
+ seq_input = tf.concat([target, seq_input], axis=1)
+ else:
+ seq_input = tf.concat([seq_input, target], axis=1)
+ max_position += 1
+ elif self.config.reserve_target_position:
+ max_position += 1
+
+ with tf.control_dependencies([valid_len]):
+ return self.encode(seq_input, max_position)
diff --git a/easy_rec/python/layers/keras/custom_ops.py b/easy_rec/python/layers/keras/custom_ops.py
new file mode 100644
index 000000000..c215ee332
--- /dev/null
+++ b/easy_rec/python/layers/keras/custom_ops.py
@@ -0,0 +1,250 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Convenience blocks for using custom ops."""
+import logging
+import os
+
+import tensorflow as tf
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.layers import Layer
+
+curr_dir, _ = os.path.split(__file__)
+parent_dir = os.path.dirname(curr_dir)
+ops_idr = os.path.dirname(parent_dir)
+ops_dir = os.path.join(ops_idr, 'ops')
+if 'PAI' in tf.__version__:
+ ops_dir = os.path.join(ops_dir, '1.12_pai')
+elif tf.__version__.startswith('1.12'):
+ ops_dir = os.path.join(ops_dir, '1.12')
+elif tf.__version__.startswith('1.15'):
+ if 'IS_ON_PAI' in os.environ:
+ ops_dir = os.path.join(ops_dir, 'DeepRec')
+ else:
+ ops_dir = os.path.join(ops_dir, '1.15')
+elif tf.__version__.startswith('2.12'):
+ ops_dir = os.path.join(ops_dir, '2.12')
+
+logging.info('ops_dir is %s' % ops_dir)
+custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so')
+try:
+ custom_ops = tf.load_op_library(custom_op_path)
+ logging.info('load custom op from %s succeed' % custom_op_path)
+except Exception as ex:
+ logging.warning('load custom op from %s failed: %s' %
+ (custom_op_path, str(ex)))
+ custom_ops = None
+
+# if tf.__version__ >= '2.0':
+# tf = tf.compat.v1
+
+
+class SeqAugmentOps(Layer):
+ """Do data augmentation for input sequence embedding."""
+
+ def __init__(self, params, name='sequence_aug', reuse=None, **kwargs):
+ super(SeqAugmentOps, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ self.seq_aug_params = params.get_pb_config()
+ self.seq_augment = custom_ops.my_seq_augment
+
+ def call(self, inputs, training=None, **kwargs):
+ assert isinstance(
+ inputs,
+ (list, tuple)), 'the inputs of SeqAugmentOps must be type of list/tuple'
+ assert len(inputs) >= 2, 'SeqAugmentOps must have at least 2 inputs'
+ seq_input, seq_len = inputs[:2]
+ embedding_dim = int(seq_input.shape[-1])
+ with tf.variable_scope(self.name, reuse=self.reuse):
+ mask_emb = tf.get_variable(
+ 'mask', (embedding_dim,), dtype=tf.float32, trainable=True)
+ seq_len = tf.to_int32(seq_len)
+ with ops.device('/CPU:0'):
+ aug_seq, aug_len = self.seq_augment(seq_input, seq_len, mask_emb,
+ self.seq_aug_params.crop_rate,
+ self.seq_aug_params.reorder_rate,
+ self.seq_aug_params.mask_rate)
+ return aug_seq, aug_len
+
+
+class TextNormalize(Layer):
+
+ def __init__(self, params, name='text_normalize', reuse=None, **kwargs):
+ super(TextNormalize, self).__init__(name=name, **kwargs)
+ self.txt_normalizer = custom_ops.text_normalize_op
+ self.norm_parameter = params.get_or_default('norm_parameter', 0)
+ self.remove_space = params.get_or_default('remove_space', False)
+
+ def call(self, inputs, training=None, **kwargs):
+ inputs = inputs if type(inputs) in (tuple, list) else [inputs]
+ with ops.device('/CPU:0'):
+ result = [
+ self.txt_normalizer(
+ txt,
+ parameter=self.norm_parameter,
+ remove_space=self.remove_space) for txt in inputs
+ ]
+ if len(result) == 1:
+ return result[0]
+ return result
+
+
+class MappedDotProduct(Layer):
+
+ def __init__(self, params, name='mapped_dot_product', reuse=None, **kwargs):
+ super(MappedDotProduct, self).__init__(name=name, **kwargs)
+ self.mapped_dot_product = custom_ops.mapped_dot_product
+ self.bucketize = custom_ops.my_bucketize
+ self.default_value = params.get_or_default('default_value', 0)
+ self.separator = params.get_or_default('separator', '\035')
+ self.norm_fn = params.get_or_default('normalize_fn', None)
+ self.boundaries = list(params.get_or_default('boundaries', []))
+ self.emb_dim = params.get_or_default('embedding_dim', 0)
+ self.print_first_n = params.get_or_default('print_first_n', 0)
+ self.summarize = params.get_or_default('summarize', None)
+ if self.emb_dim > 0:
+ vocab_size = len(self.boundaries) + 1
+ with tf.variable_scope(self.name, reuse=reuse):
+ self.embedding_table = tf.get_variable(
+ name='dot_product_emb_table',
+ shape=[vocab_size, self.emb_dim],
+ dtype=tf.float32)
+
+ def call(self, inputs, training=None, **kwargs):
+ query, doc = inputs[:2]
+ with ops.device('/CPU:0'):
+ feature = self.mapped_dot_product(
+ query=query,
+ document=doc,
+ feature_name=self.name,
+ separator=self.separator,
+ default_value=self.default_value)
+ tf.summary.scalar(self.name, tf.reduce_mean(feature))
+ if self.print_first_n:
+ encode_q = tf.regex_replace(query, self.separator, ' ')
+ encode_t = tf.regex_replace(query, self.separator, ' ')
+ feature = tf.Print(
+ feature, [encode_q, encode_t, feature],
+ message=self.name,
+ first_n=self.print_first_n,
+ summarize=self.summarize)
+ if self.norm_fn is not None:
+ fn = eval(self.norm_fn)
+ feature = fn(feature)
+ tf.summary.scalar('normalized_%s' % self.name, tf.reduce_mean(feature))
+ if self.print_first_n:
+ feature = tf.Print(
+ feature, [feature],
+ message='normalized %s' % self.name,
+ first_n=self.print_first_n,
+ summarize=self.summarize)
+ if self.boundaries:
+ feature = self.bucketize(feature, boundaries=self.boundaries)
+ tf.summary.histogram('bucketized_%s' % self.name, feature)
+ if self.emb_dim > 0 and self.boundaries:
+ vocab_size = len(self.boundaries) + 1
+ one_hot_input_ids = tf.one_hot(feature, depth=vocab_size)
+ return tf.matmul(one_hot_input_ids, self.embedding_table)
+ return tf.expand_dims(feature, axis=-1)
+
+
+class OverlapFeature(Layer):
+
+ def __init__(self, params, name='overlap_feature', reuse=None, **kwargs):
+ super(OverlapFeature, self).__init__(name=name, **kwargs)
+ self.overlap_feature = custom_ops.overlap_fg_op
+ methods = params.get_or_default('methods', [])
+ assert methods, 'overlap feature methods must be set'
+ self.methods = [str(method) for method in methods]
+ self.norm_fn = params.get_or_default('normalize_fn', None)
+ self.boundaries = list(params.get_or_default('boundaries', []))
+ self.separator = params.get_or_default('separator', '\035')
+ self.default_value = params.get_or_default('default_value', '-1')
+ self.emb_dim = params.get_or_default('embedding_dim', 0)
+ self.print_first_n = params.get_or_default('print_first_n', 0)
+ self.summarize = params.get_or_default('summarize', None)
+ if self.emb_dim > 0:
+ vocab_size = len(self.boundaries) + 1
+ vocab_size *= len(self.methods)
+ with tf.variable_scope(self.name, reuse=reuse):
+ self.embedding_table = tf.get_variable(
+ name='overlap_emb_table',
+ shape=[vocab_size, self.emb_dim],
+ dtype=tf.float32)
+
+ def call(self, inputs, training=None, **kwargs):
+ query, title = inputs[:2]
+ with ops.device('/CPU:0'):
+ feature = self.overlap_feature(
+ query=query,
+ title=title,
+ feature_name=self.name,
+ separator=self.separator,
+ default_value=self.default_value,
+ boundaries=self.boundaries,
+ methods=self.methods,
+ dtype=tf.int32 if self.boundaries else tf.float32)
+
+ for i, method in enumerate(self.methods):
+ # warning: feature[:, i] may be not the result of method
+ if self.boundaries:
+ tf.summary.histogram('bucketized_%s' % method, feature[:, i])
+ else:
+ tf.summary.scalar(method, tf.reduce_mean(feature[:, i]))
+ if self.print_first_n:
+ encode_q = tf.regex_replace(query, self.separator, ' ')
+ encode_t = tf.regex_replace(query, self.separator, ' ')
+ feature = tf.Print(
+ feature, [encode_q, encode_t, feature],
+ message=self.name,
+ first_n=self.print_first_n,
+ summarize=self.summarize)
+ if self.norm_fn is not None:
+ fn = eval(self.norm_fn)
+ feature = fn(feature)
+
+ if self.emb_dim > 0 and self.boundaries:
+ # This vocab will be small so we always do one-hot here, since it is always
+ # faster for a small vocabulary.
+ batch_size = tf.shape(feature)[0]
+ vocab_size = len(self.boundaries) + 1
+ num_indices = len(self.methods)
+ # Compute offsets, add to every column indices
+ offsets = tf.range(num_indices) * vocab_size # Shape: [3]
+ offsets = tf.reshape(offsets, [1, num_indices]) # Shape: [1, 3]
+ offsets = tf.tile(offsets,
+ [batch_size, 1]) # Shape: [batch_size, num_indices]
+ shifted_indices = feature + offsets # Shape: [batch_size, num_indices]
+ flat_feature_ids = tf.reshape(shifted_indices, [-1])
+ one_hot_ids = tf.one_hot(flat_feature_ids, depth=vocab_size * num_indices)
+ feature_embeddings = tf.matmul(one_hot_ids, self.embedding_table)
+ feature_embeddings = tf.reshape(feature_embeddings,
+ [batch_size, num_indices * self.emb_dim])
+ return feature_embeddings
+ return feature
+
+
+class EditDistance(Layer):
+
+ def __init__(self, params, name='edit_distance', reuse=None, **kwargs):
+ super(EditDistance, self).__init__(name=name, **kwargs)
+ self.edit_distance = custom_ops.my_edit_distance
+ self.txt_encoding = params.get_or_default('text_encoding', 'utf-8')
+ self.emb_size = params.get_or_default('embedding_size', 512)
+ emb_dim = params.get_or_default('embedding_dim', 4)
+ with tf.variable_scope(self.name, reuse=reuse):
+ self.embedding_table = tf.get_variable('embedding_table',
+ [self.emb_size, emb_dim],
+ tf.float32)
+
+ def call(self, inputs, training=None, **kwargs):
+ input1, input2 = inputs[:2]
+ with ops.device('/CPU:0'):
+ dist = self.edit_distance(
+ input1,
+ input2,
+ normalize=False,
+ dtype=tf.int32,
+ encoding=self.txt_encoding)
+ ids = tf.clip_by_value(dist, 0, self.emb_size - 1)
+ embed = tf.nn.embedding_lookup(self.embedding_table, ids)
+ return embed
diff --git a/easy_rec/python/layers/keras/data_augment.py b/easy_rec/python/layers/keras/data_augment.py
new file mode 100644
index 000000000..a11f08120
--- /dev/null
+++ b/easy_rec/python/layers/keras/data_augment.py
@@ -0,0 +1,133 @@
+# -*- encoding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.utils.shape_utils import get_shape_list
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def item_mask(aug_data, length, mask_emb, mask_rate):
+ length1 = tf.cast(length, dtype=tf.float32)
+ num_mask = tf.cast(tf.math.floor(length1 * mask_rate), dtype=tf.int32)
+ max_len = tf.shape(aug_data)[0]
+ seq_mask = tf.sequence_mask(num_mask, length)
+ seq_mask = tf.random.shuffle(seq_mask)
+ padding = tf.sequence_mask(0, max_len - length)
+ seq_mask = tf.concat([seq_mask, padding], axis=0)
+
+ mask_emb = tf.tile(mask_emb, [max_len, 1])
+
+ masked_item_seq = tf.where(seq_mask, mask_emb, aug_data)
+ return masked_item_seq, length
+
+
+def item_crop(aug_data, length, crop_rate):
+ length1 = tf.cast(length, dtype=tf.float32)
+ max_len, _ = get_shape_list(aug_data)
+ max_length = tf.cast(max_len, dtype=tf.int32)
+
+ num_left = tf.cast(tf.math.floor(length1 * crop_rate), dtype=tf.int32)
+ crop_begin = tf.random.uniform([],
+ minval=0,
+ maxval=length - num_left,
+ dtype=tf.int32)
+ zeros = tf.zeros_like(aug_data)
+ x = aug_data[crop_begin:crop_begin + num_left]
+ y = zeros[:max_length - num_left]
+ cropped = tf.concat([x, y], axis=0)
+ cropped_item_seq = tf.where(
+ crop_begin + num_left < max_length, cropped,
+ tf.concat([aug_data[crop_begin:], zeros[:crop_begin]], axis=0))
+ return cropped_item_seq, num_left
+
+
+def item_reorder(aug_data, length, reorder_rate):
+ length1 = tf.cast(length, dtype=tf.float32)
+ num_reorder = tf.cast(tf.math.floor(length1 * reorder_rate), dtype=tf.int32)
+ reorder_begin = tf.random.uniform([],
+ minval=0,
+ maxval=length - num_reorder,
+ dtype=tf.int32)
+ shuffle_index = tf.range(reorder_begin, reorder_begin + num_reorder)
+ shuffle_index = tf.random.shuffle(shuffle_index)
+ x = tf.range(get_shape_list(aug_data)[0])
+ left = tf.slice(x, [0], [reorder_begin])
+ right = tf.slice(x, [reorder_begin + num_reorder], [-1])
+ reordered_item_index = tf.concat([left, shuffle_index, right], axis=0)
+ reordered_item_seq = tf.scatter_nd(
+ tf.expand_dims(reordered_item_index, axis=1), aug_data,
+ tf.shape(aug_data))
+ return reordered_item_seq, length
+
+
+def augment_fn(x, aug_param, mask):
+ seq, length = x
+
+ def crop_fn():
+ return item_crop(seq, length, aug_param.crop_rate)
+
+ def mask_fn():
+ return item_mask(seq, length, mask, aug_param.mask_rate)
+
+ def reorder_fn():
+ return item_reorder(seq, length, aug_param.reorder_rate)
+
+ trans_fn = []
+ if aug_param.crop_rate < 1.0:
+ trans_fn.append(crop_fn)
+ if aug_param.mask_rate > 0:
+ trans_fn.append(mask_fn)
+ if aug_param.reorder_rate > 0:
+ trans_fn.append(reorder_fn)
+
+ num_trans = len(trans_fn)
+ if num_trans == 0:
+ return seq, length
+
+ if num_trans == 1:
+ return trans_fn[0]()
+
+ method = tf.random.uniform([], minval=0, maxval=num_trans, dtype=tf.int32)
+ if num_trans == 2:
+ return tf.cond(tf.equal(method, 0), trans_fn[0], trans_fn[1])
+
+ aug_seq, aug_len = tf.cond(
+ tf.equal(method, 0), crop_fn,
+ lambda: tf.cond(tf.equal(method, 1), mask_fn, reorder_fn))
+ return aug_seq, aug_len
+
+
+def sequence_augment(seq_input, seq_len, mask, aug_param):
+ lengths = tf.cast(seq_len, dtype=tf.int32)
+ aug_seq, aug_len = tf.map_fn(
+ lambda elems: augment_fn(elems, aug_param, mask),
+ elems=(seq_input, lengths),
+ dtype=(tf.float32, tf.int32))
+
+ aug_seq = tf.reshape(aug_seq, tf.shape(seq_input))
+ return aug_seq, aug_len
+
+
+class SeqAugment(Layer):
+ """Do data augmentation for input sequence embedding."""
+
+ def __init__(self, params, name='seq_aug', reuse=None, **kwargs):
+ super(SeqAugment, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ self.seq_aug_params = params.get_pb_config()
+
+ def call(self, inputs, training=None, **kwargs):
+ assert isinstance(inputs, (list, tuple))
+ seq_input, seq_len = inputs[:2]
+
+ embedding_size = int(seq_input.shape[-1])
+ with tf.variable_scope(self.name, reuse=self.reuse):
+ mask_emb = tf.get_variable(
+ 'mask', [1, embedding_size], dtype=tf.float32, trainable=True)
+
+ aug_seq, aug_len = sequence_augment(seq_input, seq_len, mask_emb,
+ self.seq_aug_params)
+ return aug_seq, aug_len
diff --git a/easy_rec/python/layers/keras/din.py b/easy_rec/python/layers/keras/din.py
new file mode 100644
index 000000000..082677e0b
--- /dev/null
+++ b/easy_rec/python/layers/keras/din.py
@@ -0,0 +1,67 @@
+# -*- encoding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.layers.keras import MLP
+from easy_rec.python.layers.utils import Parameter
+from easy_rec.python.utils.shape_utils import get_shape_list
+
+
+class DIN(Layer):
+
+ def __init__(self, params, name='din', reuse=None, **kwargs):
+ super(DIN, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ self.l2_reg = params.l2_regularizer
+ self.config = params.get_pb_config()
+ self.config.attention_dnn.use_final_bn = False
+ self.config.attention_dnn.use_final_bias = True
+ self.config.attention_dnn.final_activation = 'linear'
+ mlp_params = Parameter.make_from_pb(self.config.attention_dnn)
+ mlp_params.l2_regularizer = self.l2_reg
+ self.din_layer = MLP(mlp_params, 'din_attention', reuse=self.reuse)
+
+ def call(self, inputs, training=None, **kwargs):
+ keys, seq_len, query = inputs
+ assert query is not None, '[%s] target feature is empty' % self.name
+ query_emb_size = int(query.shape[-1])
+ seq_emb_size = keys.shape.as_list()[-1]
+ if query_emb_size != seq_emb_size:
+ logging.info(
+ ' the embedding size of sequence [%d] and target item [%d] is not equal'
+ ' in feature group: %s', seq_emb_size, query_emb_size, self.name)
+ if query_emb_size < seq_emb_size:
+ query = tf.pad(query, [[0, 0], [0, seq_emb_size - query_emb_size]])
+ else:
+ assert False, 'the embedding size of target item is larger than the one of sequence'
+
+ batch_size, max_seq_len, _ = get_shape_list(keys, 3)
+ queries = tf.tile(tf.expand_dims(query, 1), [1, max_seq_len, 1])
+ din_all = tf.concat([queries, keys, queries - keys, queries * keys],
+ axis=-1)
+ output = self.din_layer(din_all, training) # [B, L, 1]
+ scores = tf.transpose(output, [0, 2, 1]) # [B, 1, L]
+
+ seq_mask = tf.sequence_mask(seq_len, max_seq_len, dtype=tf.bool)
+ seq_mask = tf.expand_dims(seq_mask, 1)
+ paddings = tf.ones_like(scores) * (-2**32 + 1)
+ scores = tf.where(seq_mask, scores, paddings) # [B, 1, L]
+ if self.config.attention_normalizer == 'softmax':
+ scores = tf.nn.softmax(scores) # (B, 1, L)
+ elif self.config.attention_normalizer == 'sigmoid':
+ scores = scores / (seq_emb_size**0.5)
+ scores = tf.nn.sigmoid(scores)
+ else:
+ raise ValueError('unsupported attention normalizer: ' +
+ self.config.attention_normalizer)
+
+ if query_emb_size < seq_emb_size:
+ keys = keys[:, :, :query_emb_size] # [B, L, E]
+ output = tf.squeeze(tf.matmul(scores, keys), axis=[1])
+ if self.config.need_target_feature:
+ output = tf.concat([output, query], axis=-1)
+ print('din output shape:', output.shape)
+ return output
diff --git a/easy_rec/python/layers/keras/einsum_dense.py b/easy_rec/python/layers/keras/einsum_dense.py
new file mode 100644
index 000000000..7531644dc
--- /dev/null
+++ b/easy_rec/python/layers/keras/einsum_dense.py
@@ -0,0 +1,598 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import re
+import string
+
+import tensorflow as tf
+from tensorflow.python.keras import activations
+from tensorflow.python.keras import constraints
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import regularizers
+from tensorflow.python.keras.layers import Layer
+
+
+class EinsumDense(Layer):
+ """A layer that uses `einsum` as the backing computation.
+
+ This layer can perform einsum calculations of arbitrary dimensionality.
+
+ Args:
+ equation: An equation describing the einsum to perform.
+ This equation must be a valid einsum string of the form
+ `ab,bc->ac`, `...ab,bc->...ac`, or
+ `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum
+ axis expression sequence.
+ output_shape: The expected shape of the output tensor
+ (excluding the batch dimension and any dimensions
+ represented by ellipses). You can specify `None` for any dimension
+ that is unknown or can be inferred from the input shape.
+ activation: Activation function to use. If you don't specify anything,
+ no activation is applied
+ (that is, a "linear" activation: `a(x) = x`).
+ bias_axes: A string containing the output dimension(s)
+ to apply a bias to. Each character in the `bias_axes` string
+ should correspond to a character in the output portion
+ of the `equation` string.
+ kernel_initializer: Initializer for the `kernel` weights matrix.
+ bias_initializer: Initializer for the bias vector.
+ kernel_regularizer: Regularizer function applied to the `kernel` weights
+ matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ kernel_constraint: Constraint function applied to the `kernel` weights
+ matrix.
+ bias_constraint: Constraint function applied to the bias vector.
+ lora_rank: Optional integer. If set, the layer's forward pass
+ will implement LoRA (Low-Rank Adaptation)
+ with the provided rank. LoRA sets the layer's kernel
+ to non-trainable and replaces it with a delta over the
+ original kernel, obtained via multiplying two lower-rank
+ trainable matrices
+ (the factorization happens on the last dimension).
+ This can be useful to reduce the
+ computation cost of fine-tuning large dense layers.
+ You can also enable LoRA on an existing
+ `EinsumDense` layer by calling `layer.enable_lora(rank)`.
+ **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
+
+ Examples:
+ **Biased dense layer with einsums**
+
+ This example shows how to instantiate a standard Keras dense layer using
+ einsum operations. This example is equivalent to
+ `keras.layers.Dense(64, use_bias=True)`.
+
+ >>> layer = tf.keras.layers.EinsumDense("ab,bc->ac",
+ ... output_shape=64,
+ ... bias_axes="c")
+ >>> input_tensor = tf.keras.Input(shape=[32])
+ >>> output_tensor = layer(input_tensor)
+ >>> output_tensor.shape
+ (None, 64)
+
+ **Applying a dense layer to a sequence**
+
+ This example shows how to instantiate a layer that applies the same dense
+ operation to every element in a sequence. Here, the `output_shape` has two
+ values (since there are two non-batch dimensions in the output); the first
+ dimension in the `output_shape` is `None`, because the sequence dimension
+ `b` has an unknown shape.
+
+ >>> layer = tf.keras.layers.EinsumDense("abc,cd->abd",
+ ... output_shape=(None, 64),
+ ... bias_axes="d")
+ >>> input_tensor = tf.keras.Input(shape=[32, 128])
+ >>> output_tensor = layer(input_tensor)
+ >>> output_tensor.shape
+ (None, 32, 64)
+
+ **Applying a dense layer to a sequence using ellipses**
+
+ This example shows how to instantiate a layer that applies the same dense
+ operation to every element in a sequence, but uses the ellipsis notation
+ instead of specifying the batch and sequence dimensions.
+
+ Because we are using ellipsis notation and have specified only one axis, the
+ `output_shape` arg is a single value. When instantiated in this way, the
+ layer can handle any number of sequence dimensions - including the case
+ where no sequence dimension exists.
+
+ >>> layer = tf.keras.layers.EinsumDense("...x,xy->...y",
+ ... output_shape=64,
+ ... bias_axes="y")
+ >>> input_tensor = tf.keras.Input(shape=[32, 128])
+ >>> output_tensor = layer(input_tensor)
+ >>> output_tensor.shape
+ (None, 32, 64)
+ """
+
+ def __init__(self,
+ equation,
+ output_shape,
+ activation=None,
+ bias_axes=None,
+ kernel_initializer='glorot_uniform',
+ bias_initializer='zeros',
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ lora_rank=None,
+ **kwargs):
+ super(EinsumDense, self).__init__(**kwargs)
+ self.equation = equation
+ if isinstance(output_shape, int):
+ self.partial_output_shape = (output_shape,)
+ else:
+ self.partial_output_shape = tuple(output_shape)
+ self.bias_axes = bias_axes
+ self.activation = activations.get(activation)
+ self.kernel_initializer = initializers.get(kernel_initializer)
+ self.bias_initializer = initializers.get(bias_initializer)
+ self.kernel_regularizer = regularizers.get(kernel_regularizer)
+ self.bias_regularizer = regularizers.get(bias_regularizer)
+ self.kernel_constraint = constraints.get(kernel_constraint)
+ self.bias_constraint = constraints.get(bias_constraint)
+ self.lora_rank = lora_rank
+ self.lora_enabled = False
+
+ def build(self, input_shape):
+ shape_data = _analyze_einsum_string(
+ self.equation,
+ self.bias_axes,
+ input_shape,
+ self.partial_output_shape,
+ )
+ kernel_shape, bias_shape, full_output_shape = shape_data
+ for i in range(len(kernel_shape)):
+ dim = kernel_shape[i]
+ if isinstance(dim, tf.Dimension):
+ kernel_shape[i] = dim.value
+ for i in range(len(bias_shape)):
+ dim = bias_shape[i]
+ if isinstance(dim, tf.Dimension):
+ bias_shape[i] = dim.value
+ for i in range(len(full_output_shape)):
+ dim = full_output_shape[i]
+ if isinstance(dim, tf.Dimension):
+ full_output_shape[i] = dim.value
+ self.full_output_shape = tuple(full_output_shape)
+ self._kernel = self.add_weight(
+ name='kernel',
+ shape=tuple(kernel_shape),
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ dtype=self.dtype,
+ trainable=True,
+ )
+ if bias_shape is not None:
+ self.bias = self.add_weight(
+ name='bias',
+ shape=tuple(bias_shape),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ dtype=self.dtype,
+ trainable=True,
+ )
+ else:
+ self.bias = None
+ self.built = True
+ if self.lora_rank:
+ self.enable_lora(self.lora_rank)
+
+ @property
+ def kernel(self):
+ if not self.built:
+ raise AttributeError(
+ 'You must build the layer before accessing `kernel`.')
+ if self.lora_enabled:
+ return self._kernel + tf.matmul(self.lora_kernel_a, self.lora_kernel_b)
+ return self._kernel
+
+ def compute_output_shape(self, _):
+ return self.full_output_shape
+
+ def call(self, inputs, training=None):
+ x = tf.einsum(self.equation, inputs, self.kernel)
+ if self.bias is not None:
+ x += self.bias
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
+ def enable_lora(self,
+ rank,
+ a_initializer='he_uniform',
+ b_initializer='zeros'):
+ if self.kernel_constraint:
+ raise ValueError('Lora is incompatible with kernel constraints. '
+ 'In order to enable lora on this layer, remove the '
+ '`kernel_constraint` argument.')
+ if not self.built:
+ raise ValueError("Cannot enable lora on a layer that isn't yet built.")
+ if self.lora_enabled:
+ raise ValueError('lora is already enabled. '
+ 'This can only be done once per layer.')
+ self._tracker.unlock()
+ self.lora_kernel_a = self.add_weight(
+ name='lora_kernel_a',
+ shape=(self.kernel.shape[:-1] + (rank,)),
+ initializer=initializers.get(a_initializer),
+ regularizer=self.kernel_regularizer,
+ )
+ self.lora_kernel_b = self.add_weight(
+ name='lora_kernel_b',
+ shape=(rank, self.kernel.shape[-1]),
+ initializer=initializers.get(b_initializer),
+ regularizer=self.kernel_regularizer,
+ )
+ self._kernel.trainable = False
+ self._tracker.lock()
+ self.lora_enabled = True
+ self.lora_rank = rank
+
+ def save_own_variables(self, store):
+ # Do nothing if the layer isn't yet built
+ if not self.built:
+ return
+ # The keys of the `store` will be saved as determined because the
+ # default ordering will change after quantization
+ kernel_value, kernel_scale = self._get_kernel_with_merged_lora()
+ target_variables = [kernel_value]
+ if self.bias is not None:
+ target_variables.append(self.bias)
+ for i, variable in enumerate(target_variables):
+ store[str(i)] = variable
+
+ def load_own_variables(self, store):
+ if not self.lora_enabled:
+ self._check_load_own_variables(store)
+ # Do nothing if the layer isn't yet built
+ if not self.built:
+ return
+ # The keys of the `store` will be saved as determined because the
+ # default ordering will change after quantization
+ target_variables = [self._kernel]
+ if self.bias is not None:
+ target_variables.append(self.bias)
+ for i, variable in enumerate(target_variables):
+ variable.assign(store[str(i)])
+ if self.lora_enabled:
+ self.lora_kernel_a.assign(tf.zeros(self.lora_kernel_a.shape))
+ self.lora_kernel_b.assign(tf.zeros(self.lora_kernel_b.shape))
+
+ def get_config(self):
+ base_config = super(EinsumDense, self).get_config()
+ config = {
+ 'output_shape':
+ self.partial_output_shape,
+ 'equation':
+ self.equation,
+ 'activation':
+ activations.serialize(self.activation),
+ 'bias_axes':
+ self.bias_axes,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'activity_regularizer':
+ regularizers.serialize(self.activity_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ }
+ if self.lora_rank:
+ config['lora_rank'] = self.lora_rank
+ config.update(base_config)
+ return config
+
+ def _check_load_own_variables(self, store):
+ all_vars = self._trainable_variables + self._non_trainable_variables
+ if len(store.keys()) != len(all_vars):
+ if len(all_vars) == 0 and not self.built:
+ raise ValueError(
+ "Layer '{name}' was never built "
+ "and thus it doesn't have any variables. "
+ 'However the weights file lists {num_keys} '
+ 'variables for this layer.\n'
+ 'In most cases, this error indicates that either:\n\n'
+ '1. The layer is owned by a parent layer that '
+ 'implements a `build()` method, but calling the '
+ "parent's `build()` method did NOT create the state of "
+ "the child layer '{name}'. A `build()` method "
+ 'must create ALL state for the layer, including '
+ 'the state of any children layers.\n\n'
+ '2. You need to implement '
+ 'the `def build_from_config(self, config)` method '
+ "on layer '{name}', to specify how to rebuild "
+ 'it during loading. '
+ 'In this case, you might also want to implement the '
+ 'method that generates the build config at saving time, '
+ '`def get_build_config(self)`. '
+ 'The method `build_from_config()` is meant '
+ 'to create the state '
+ 'of the layer (i.e. its variables) upon deserialization.'.format(
+ name=self.name, num_keys=len(store.keys())))
+ raise ValueError(
+ "Layer '{name}' expected {num_var} variables, but received "
+ '{num_key} variables during loading. '
+ 'Expected: {names}'.format(
+ name=self.name,
+ num_var=len(store.keys()),
+ num_key=len(store.keys()),
+ names=[v.name for v in all_vars]))
+
+ def _get_kernel_with_merged_lora(self):
+ kernel_value = self.kernel
+ kernel_scale = None
+ return kernel_value, kernel_scale
+
+
+def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
+ """Analyzes an einsum string to determine the required weight shape."""
+ dot_replaced_string = re.sub(r'\.\.\.', '0', equation)
+
+ # This is the case where no ellipses are present in the string.
+ split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
+ dot_replaced_string)
+ if split_string:
+ return _analyze_split_string(split_string, bias_axes, input_shape,
+ output_shape)
+
+ # This is the case where ellipses are present on the left.
+ split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
+ dot_replaced_string)
+ if split_string:
+ return _analyze_split_string(
+ split_string, bias_axes, input_shape, output_shape, left_elided=True)
+
+ # This is the case where ellipses are present on the right.
+ split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
+ dot_replaced_string)
+ if split_string:
+ return _analyze_split_string(split_string, bias_axes, input_shape,
+ output_shape)
+
+ raise ValueError(
+ "Invalid einsum equation '{equation}'. Equations must be in the form "
+ '[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
+ equation=equation))
+
+
+def _analyze_split_string(split_string,
+ bias_axes,
+ input_shape,
+ output_shape,
+ left_elided=False):
+ """Analyze an pre-split einsum string to find the weight shape."""
+ input_spec = split_string.group(1)
+ weight_spec = split_string.group(2)
+ output_spec = split_string.group(3)
+ elided = len(input_shape) - len(input_spec)
+ if isinstance(output_shape, int):
+ output_shape = [output_shape]
+ else:
+ output_shape = list(output_shape)
+
+ output_shape.insert(0, input_shape[0])
+
+ if elided > 0 and left_elided:
+ for i in range(1, elided):
+ # We already inserted the 0th input dimension at dim 0, so we need
+ # to start at location 1 here.
+ output_shape.insert(1, input_shape[i])
+ elif elided > 0 and not left_elided:
+ for i in range(len(input_shape) - elided, len(input_shape)):
+ output_shape.append(input_shape[i])
+
+ if left_elided:
+ # If we have beginning dimensions elided, we need to use negative
+ # indexing to determine where in the input dimension our values are.
+ input_dim_map = {
+ dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
+ }
+ # Because we've constructed the full output shape already, we don't need
+ # to do negative indexing.
+ output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
+ else:
+ input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
+ output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
+
+ for dim in input_spec:
+ input_shape_at_dim = input_shape[input_dim_map[dim]]
+ if dim in output_dim_map:
+ output_shape_at_dim = output_shape[output_dim_map[dim]]
+ if (output_shape_at_dim is not None and
+ output_shape_at_dim != input_shape_at_dim):
+ raise ValueError(
+ 'Input shape and output shape do not match at shared '
+ "dimension '{dim}'. Input shape is {input_shape_at_dim}, "
+ 'and output shape is {output_shape}.'.format(
+ dim=dim,
+ input_shape_at_dim=input_shape_at_dim,
+ output_shape=output_shape[output_dim_map[dim]]))
+
+ for dim in output_spec:
+ if dim not in input_spec and dim not in weight_spec:
+ raise ValueError(
+ "Dimension '{dim}' was specified in the output "
+ "'{output_spec}' but has no corresponding dim in the input "
+ "spec '{input_spec}' or weight spec '{output_spec}'".format(
+ dim=dim, output_spec=output_spec, input_spec=input_spec))
+
+ weight_shape = []
+ for dim in weight_spec:
+ if dim in input_dim_map:
+ weight_shape.append(input_shape[input_dim_map[dim]])
+ elif dim in output_dim_map:
+ weight_shape.append(output_shape[output_dim_map[dim]])
+ else:
+ raise ValueError(
+ "Weight dimension '{dim}' did not have a match in either "
+ "the input spec '{input_spec}' or the output "
+ "spec '{output_spec}'. For this layer, the weight must "
+ 'be fully specified.'.format(
+ dim=dim, input_spec=input_spec, output_spec=output_spec))
+
+ if bias_axes is not None:
+ num_left_elided = elided if left_elided else 0
+ idx_map = {
+ char: output_shape[i + num_left_elided]
+ for i, char in enumerate(output_spec)
+ }
+
+ for char in bias_axes:
+ if char not in output_spec:
+ raise ValueError(
+ "Bias dimension '{char}' was requested, but is not part "
+ "of the output spec '{output_spec}'".format(
+ char=char, output_spec=output_spec))
+
+ first_bias_location = min([output_spec.find(char) for char in bias_axes])
+ bias_output_spec = output_spec[first_bias_location:]
+
+ bias_shape = [
+ idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
+ ]
+
+ if not left_elided:
+ for _ in range(elided):
+ bias_shape.append(1)
+ else:
+ bias_shape = None
+
+ return weight_shape, bias_shape, output_shape
+
+
+def _analyze_quantization_info(equation, input_shape):
+
+ def get_specs(equation, input_shape):
+ possible_labels = string.ascii_letters
+ dot_replaced_string = re.sub(r'\.\.\.', '0', equation)
+
+ # This is the case where no ellipses are present in the string.
+ split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
+ dot_replaced_string)
+ if split_string is not None:
+ input_spec = split_string.group(1)
+ weight_spec = split_string.group(2)
+ output_spec = split_string.group(3)
+ return input_spec, weight_spec, output_spec
+
+ # This is the case where ellipses are present on the left.
+ split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
+ dot_replaced_string)
+ if split_string is not None:
+ input_spec = split_string.group(1)
+ weight_spec = split_string.group(2)
+ output_spec = split_string.group(3)
+ elided = len(input_shape) - len(input_spec)
+ possible_labels = sorted(
+ set(possible_labels) - set(input_spec) - set(weight_spec) -
+ set(output_spec))
+ # Pad labels on the left to `input_spec` and `output_spec`
+ for i in range(elided):
+ input_spec = possible_labels[i] + input_spec
+ output_spec = possible_labels[i] + output_spec
+ return input_spec, weight_spec, output_spec
+
+ # This is the case where ellipses are present on the right.
+ split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
+ dot_replaced_string)
+ if split_string is not None:
+ input_spec = split_string.group(1)
+ weight_spec = split_string.group(2)
+ output_spec = split_string.group(3)
+ elided = len(input_shape) - len(input_spec)
+ possible_labels = sorted(
+ set(possible_labels) - set(input_spec) - set(weight_spec) -
+ set(output_spec))
+ # Pad labels on the right to `input_spec` and `output_spec`
+ for i in range(elided):
+ input_spec = input_spec + possible_labels[i]
+ output_spec = output_spec + possible_labels[i]
+ return input_spec, weight_spec, output_spec
+
+ raise ValueError(
+ "Invalid einsum equation '{equation}'. Equations must be in the "
+ 'form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
+ equation=equation))
+
+ input_spec, weight_spec, output_spec = get_specs(equation, input_shape)
+
+ # Determine the axes that should be reduced by the quantizer
+ input_reduced_axes = []
+ weight_reduced_axes = []
+ for i, label in enumerate(input_spec):
+ index = output_spec.find(label)
+ if index == -1:
+ input_reduced_axes.append(i)
+ for i, label in enumerate(weight_spec):
+ index = output_spec.find(label)
+ if index == -1:
+ weight_reduced_axes.append(i)
+
+ # Determine the axes of `ops.expand_dims`
+ input_expand_axes = []
+ weight_expand_axes = []
+ for i, label in enumerate(output_spec):
+ index_input = input_spec.find(label)
+ index_weight = weight_spec.find(label)
+ if index_input == -1:
+ input_expand_axes.append(i)
+ if index_weight == -1:
+ weight_expand_axes.append(i)
+
+ # Determine the axes of `ops.transpose`
+ input_transpose_axes = []
+ weight_transpose_axes = []
+ for i, label in enumerate(output_spec):
+ index_input = input_spec.find(label)
+ index_weight = weight_spec.find(label)
+ if index_input != -1:
+ input_transpose_axes.append(index_input)
+ if index_weight != -1:
+ weight_transpose_axes.append(index_weight)
+ # Postprocess the information:
+ # 1. Add dummy axes (1) to transpose_axes
+ # 2. Add axis to squeeze_axes if 1. failed
+ input_squeeze_axes = []
+ weight_squeeze_axes = []
+ for ori_index in input_reduced_axes:
+ try:
+ index = input_expand_axes.pop(0)
+ except IndexError:
+ input_squeeze_axes.append(ori_index)
+ input_transpose_axes.insert(index, ori_index)
+ for ori_index in weight_reduced_axes:
+ try:
+ index = weight_expand_axes.pop(0)
+ except IndexError:
+ weight_squeeze_axes.append(ori_index)
+ weight_transpose_axes.insert(index, ori_index)
+ # Prepare equation for `einsum_with_inputs_gradient`
+ custom_gradient_equation = '{output_spec},{weight_spec}->{input_spec}'.format(
+ output_spec=output_spec, input_spec=input_spec, weight_spec=weight_spec)
+ weight_reverse_transpose_axes = [
+ i for (_, i) in sorted((v, i)
+ for (i, v) in enumerate(weight_transpose_axes))
+ ]
+ return (
+ input_reduced_axes,
+ weight_reduced_axes,
+ input_transpose_axes,
+ weight_transpose_axes,
+ input_expand_axes,
+ weight_expand_axes,
+ input_squeeze_axes,
+ weight_squeeze_axes,
+ custom_gradient_equation,
+ weight_reverse_transpose_axes,
+ )
diff --git a/easy_rec/python/layers/keras/embedding.py b/easy_rec/python/layers/keras/embedding.py
new file mode 100644
index 000000000..77b513951
--- /dev/null
+++ b/easy_rec/python/layers/keras/embedding.py
@@ -0,0 +1,81 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Fused embedding layer."""
+import tensorflow as tf
+from tensorflow.python.keras.layers import Embedding
+from tensorflow.python.keras.layers import Layer
+
+
+def _combine(embeddings, weights, comb_fn):
+ # embeddings shape: [B, N, D]
+ if callable(comb_fn):
+ return comb_fn(embeddings, axis=1)
+ if weights is None:
+ return tf.reduce_mean(embeddings, axis=1)
+ if isinstance(weights, tf.SparseTensor):
+ if weights.dtype == tf.string:
+ weights = tf.sparse.to_dense(weights, default_value='0')
+ weights = tf.string_to_number(weights)
+ else:
+ weights = tf.sparse.to_dense(weights, default_value=0.0)
+ sum_weights = tf.reduce_sum(weights, axis=1, keepdims=True)
+ weights = tf.expand_dims(weights / sum_weights, axis=-1)
+ return tf.reduce_sum(embeddings * weights, axis=1)
+
+
+class EmbeddingLayer(Layer):
+
+ def __init__(self, params, name='embedding_layer', reuse=None, **kwargs):
+ super(EmbeddingLayer, self).__init__(name=name, **kwargs)
+ params.check_required(['vocab_size', 'embedding_dim'])
+ vocab_size = int(params.vocab_size)
+ combiner = params.get_or_default('combiner', 'weight')
+ if combiner == 'mean':
+ self.combine_fn = tf.reduce_mean
+ elif combiner == 'sum':
+ self.combine_fn = tf.reduce_sum
+ elif combiner == 'max':
+ self.combine_fn = tf.reduce_max
+ elif combiner == 'min':
+ self.combine_fn = tf.reduce_min
+ elif combiner == 'weight':
+ self.combine_fn = 'weight'
+ else:
+ raise ValueError('unsupported embedding combiner: ' + combiner)
+ self.embed_dim = int(params.embedding_dim)
+ self.embedding = Embedding(vocab_size, self.embed_dim)
+ self.do_concat = params.get_or_default('concat', True)
+
+ def call(self, inputs, training=None, **kwargs):
+ inputs, weights = inputs
+ # 将多个特征的输入合并为一个索引 tensor
+ flat_inputs = [tf.reshape(input_field, [-1]) for input_field in inputs]
+ all_indices = tf.concat(flat_inputs, axis=0)
+ # 从共享的嵌入表中进行一次 embedding lookup
+ all_embeddings = self.embedding(all_indices)
+ is_multi = []
+ # 计算每个特征的嵌入
+ split_sizes = []
+ for input_field in inputs:
+ assert input_field.shape.ndims <= 2, 'dims of embedding layer input must be <= 2'
+ input_shape = tf.shape(input_field)
+ size = input_shape[0]
+ if input_field.shape.ndims > 1:
+ size *= input_shape[-1]
+ is_multi.append(True)
+ else:
+ is_multi.append(False)
+ split_sizes.append(size)
+ embeddings = tf.split(all_embeddings, split_sizes, axis=0)
+ for i in range(len(embeddings)):
+ if is_multi[i]:
+ batch_size = tf.shape(inputs[i])[0]
+ embeddings[i] = tf.cond(
+ tf.equal(tf.size(embeddings[i]), 0),
+ lambda: tf.zeros([batch_size, self.embed_dim]), lambda: _combine(
+ tf.reshape(embeddings[i], [batch_size, -1, self.embed_dim]),
+ weights[i], self.combine_fn))
+ if self.do_concat:
+ embeddings = tf.concat(embeddings, axis=-1)
+ print('Embedding layer:', self.name, embeddings)
+ return embeddings
diff --git a/easy_rec/python/layers/keras/fibinet.py b/easy_rec/python/layers/keras/fibinet.py
new file mode 100644
index 000000000..220c57cb5
--- /dev/null
+++ b/easy_rec/python/layers/keras/fibinet.py
@@ -0,0 +1,251 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import itertools
+import logging
+
+import tensorflow as tf
+from tensorflow.python.keras.layers import Dense
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.layers.keras.blocks import MLP
+from easy_rec.python.layers.keras.layer_norm import LayerNormalization
+from easy_rec.python.layers.utils import Parameter
+
+
+class SENet(Layer):
+ """SENET Layer used in FiBiNET.
+
+ Input shape
+ - A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
+ The ``embedding_size`` of each field can have different value.
+
+ Output shape
+ - A 2D tensor with shape: ``(batch_size,sum_of_embedding_size)``.
+
+ References:
+ 1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
+ Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
+ 2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
+ Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
+ """
+
+ def __init__(self, params, name='SENet', reuse=None, **kwargs):
+ super(SENet, self).__init__(name=name, **kwargs)
+ self.config = params.get_pb_config()
+ self.reuse = reuse
+ if tf.__version__ >= '2.0':
+ self.layer_norm = tf.keras.layers.LayerNormalization(name='output_ln')
+ else:
+ self.layer_norm = LayerNormalization(name='output_ln')
+
+ def build(self, input_shape):
+ g = self.config.num_squeeze_group
+ emb_size = 0
+ for shape in input_shape:
+ assert shape.ndims == 2, 'field embeddings must be rank 2 tensors'
+ dim = int(shape[-1])
+ assert dim >= g and dim % g == 0, 'field embedding dimension %d must be divisible by %d' % (
+ dim, g)
+ emb_size += dim
+
+ r = self.config.reduction_ratio
+ field_size = len(input_shape)
+ reduction_size = max(1, field_size * g * 2 // r)
+ self.reduce_layer = Dense(
+ units=reduction_size,
+ activation='relu',
+ kernel_initializer='he_normal',
+ name='W1')
+ self.excite_layer = Dense(
+ units=emb_size, kernel_initializer='glorot_normal', name='W2')
+ super(SENet, self).build(input_shape) # Be sure to call this somewhere!
+
+ def call(self, inputs, **kwargs):
+ g = self.config.num_squeeze_group
+
+ # Squeeze
+ # embedding dimension 必须能被 g 整除
+ group_embs = [
+ tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs
+ ]
+
+ squeezed = []
+ for emb in group_embs:
+ squeezed.append(tf.reduce_max(emb, axis=-1)) # [B, g]
+ squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g]
+ z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]
+
+ # Excitation
+ a1 = self.reduce_layer(z)
+ weights = self.excite_layer(a1)
+
+ # Re-weight
+ inputs = tf.concat(inputs, axis=-1)
+ output = inputs * weights
+
+ # Fuse, add skip-connection
+ if self.config.use_skip_connection:
+ output += inputs
+
+ # Layer Normalization
+ if self.config.use_output_layer_norm:
+ output = self.layer_norm(output)
+ return output
+
+
+def _full_interaction(v_i, v_j):
+ # [bs, 1, dim] x [bs, dim, 1] = [bs, 1]
+ interaction = tf.matmul(
+ tf.expand_dims(v_i, axis=1), tf.expand_dims(v_j, axis=-1))
+ return tf.squeeze(interaction, axis=1)
+
+
+class BiLinear(Layer):
+ """BilinearInteraction Layer used in FiBiNET.
+
+ Input shape
+ - A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
+ Its length is ``filed_size``.
+ The ``embedding_size`` of each field can have different value.
+
+ Output shape
+ - 2D tensor with shape: ``(batch_size,output_size)``.
+
+ Attributes:
+ num_output_units: the number of output units
+ type: ['all', 'each', 'interaction'], types of bilinear functions used in this layer
+ use_plus: whether to use bi-linear+
+
+ References:
+ 1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
+ Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
+ 2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
+ Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
+ """
+
+ def __init__(self, params, name='bilinear', reuse=None, **kwargs):
+ super(BiLinear, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ params.check_required(['num_output_units'])
+ bilinear_plus = params.get_or_default('use_plus', True)
+ self.output_size = params.num_output_units
+ self.bilinear_type = params.get_or_default('type', 'interaction').lower()
+ if self.bilinear_type not in ['all', 'each', 'interaction']:
+ raise NotImplementedError(
+ "bilinear_type only support: ['all', 'each', 'interaction']")
+ if bilinear_plus:
+ self.func = _full_interaction
+ else:
+ self.func = tf.multiply
+ self.output_layer = Dense(self.output_size, name='output')
+
+ def build(self, input_shape):
+ if type(input_shape) not in (tuple, list):
+ raise TypeError('input of BiLinear layer must be a list')
+ field_num = len(input_shape)
+ logging.info('Bilinear Layer with %d inputs' % field_num)
+ if field_num > 200:
+ logging.warning('Too many inputs for bilinear layer: %d' % field_num)
+ equal_dim = True
+ _dim = input_shape[0][-1]
+ for shape in input_shape:
+ assert shape.ndims == 2, 'field embeddings must be rank 2 tensors'
+ if shape[-1] != _dim:
+ equal_dim = False
+ if not equal_dim and self.bilinear_type != 'interaction':
+ raise ValueError(
+ 'all embedding dimensions must be same when not use bilinear type: interaction'
+ )
+ dim = int(_dim)
+
+ if self.bilinear_type == 'all':
+ self.dot_layer = Dense(dim, name='all')
+ elif self.bilinear_type == 'each':
+ self.dot_layers = [
+ Dense(dim, name='each_%d' % i) for i in range(field_num - 1)
+ ]
+ else: # interaction
+ self.dot_layers = [
+ Dense(
+ units=int(input_shape[j][-1]), name='interaction_%d_%d' % (i, j))
+ for i, j in itertools.combinations(range(field_num), 2)
+ ]
+ super(BiLinear, self).build(input_shape) # Be sure to call this somewhere!
+
+ def call(self, inputs, **kwargs):
+ embeddings = inputs
+ field_num = len(embeddings)
+
+ # bi-linear+: dimension of `p` is [bs, f*(f-1)/2]
+ # bi-linear:
+ # - when equal_dim=True, dimension of `p` is [bs, f*(f-1)/2*k], k is embedding size
+ # - when equal_dim=False, dimension of `p` is [bs, (k_2+k_3+...+k_f)+...+(k_i+k_{i+1}+...+k_f)+...+k_f],
+ # - where k_i is the embedding size of the ith field
+ if self.bilinear_type == 'all':
+ v_dot = [self.dot_layer(v_i) for v_i in embeddings[:-1]]
+ p = [
+ self.func(v_dot[i], embeddings[j])
+ for i, j in itertools.combinations(range(field_num), 2)
+ ]
+ elif self.bilinear_type == 'each':
+ v_dot = [self.dot_layers[i](v_i) for i, v_i in enumerate(embeddings[:-1])]
+ p = [
+ self.func(v_dot[i], embeddings[j])
+ for i, j in itertools.combinations(range(field_num), 2)
+ ]
+ else: # interaction
+ p = [
+ self.func(self.dot_layers[i * field_num + j](embeddings[i]),
+ embeddings[j])
+ for i, j in itertools.combinations(range(field_num), 2)
+ ]
+
+ return self.output_layer(tf.concat(p, axis=-1))
+
+
+class FiBiNet(Layer):
+ """FiBiNet++:Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction.
+
+ References:
+ - [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
+ Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
+ """
+
+ def __init__(self, params, name='fibinet', reuse=None, **kwargs):
+ super(FiBiNet, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ self._config = params.get_pb_config()
+
+ se_params = Parameter.make_from_pb(self._config.senet)
+ self.senet_layer = SENet(
+ se_params, name=self.name + '/senet', reuse=self.reuse)
+
+ if self._config.HasField('bilinear'):
+ bi_params = Parameter.make_from_pb(self._config.bilinear)
+ self.bilinear_layer = BiLinear(
+ bi_params, name=self.name + '/bilinear', reuse=self.reuse)
+
+ if self._config.HasField('mlp'):
+ p = Parameter.make_from_pb(self._config.mlp)
+ p.l2_regularizer = params.l2_regularizer
+ self.final_mlp = MLP(p, name=self.name + '/mlp', reuse=reuse)
+ else:
+ self.final_mlp = None
+
+ def call(self, inputs, training=None, **kwargs):
+ feature_list = []
+
+ senet_output = self.senet_layer(inputs)
+ feature_list.append(senet_output)
+
+ if self._config.HasField('bilinear'):
+ feature_list.append(self.bilinear_layer(inputs))
+
+ if len(feature_list) > 1:
+ feature = tf.concat(feature_list, axis=-1)
+ else:
+ feature = feature_list[0]
+
+ if self.final_mlp is not None:
+ feature = self.final_mlp(feature, training=training)
+ return feature
diff --git a/easy_rec/python/layers/keras/interaction.py b/easy_rec/python/layers/keras/interaction.py
new file mode 100644
index 000000000..9b14f254a
--- /dev/null
+++ b/easy_rec/python/layers/keras/interaction.py
@@ -0,0 +1,416 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+from easy_rec.python.utils.activation import get_activation
+
+
+class FM(tf.keras.layers.Layer):
+ """Factorization Machine models pairwise (order-2) feature interactions without linear term and bias.
+
+ References
+ - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
+ Input shape.
+ - List of 2D tensor with shape: ``(batch_size,embedding_size)``.
+ - Or a 3D tensor with shape: ``(batch_size,field_size,embedding_size)``
+ Output shape
+ - 2D tensor with shape: ``(batch_size, 1)``.
+ """
+
+ def __init__(self, params, name='fm', reuse=None, **kwargs):
+ super(FM, self).__init__(name=name, **kwargs)
+ self.use_variant = params.get_or_default('use_variant', False)
+
+ def call(self, inputs, **kwargs):
+ if type(inputs) == list:
+ emb_dims = set(map(lambda x: int(x.shape[-1]), inputs))
+ if len(emb_dims) != 1:
+ dims = ','.join([str(d) for d in emb_dims])
+ raise ValueError('all embedding dim must be equal in FM layer:' + dims)
+ with tf.name_scope(self.name):
+ fea = tf.stack(inputs, axis=1)
+ else:
+ assert inputs.shape.ndims == 3, 'input of FM layer must be a 3D tensor or a list of 2D tensors'
+ fea = inputs
+
+ with tf.name_scope(self.name):
+ square_of_sum = tf.square(tf.reduce_sum(fea, axis=1))
+ sum_of_square = tf.reduce_sum(tf.square(fea), axis=1)
+ cross_term = tf.subtract(square_of_sum, sum_of_square)
+ if self.use_variant:
+ cross_term = 0.5 * cross_term
+ else:
+ cross_term = 0.5 * tf.reduce_sum(cross_term, axis=-1, keepdims=True)
+ return cross_term
+
+
+class DotInteraction(tf.keras.layers.Layer):
+ """Dot interaction layer of DLRM model..
+
+ See theory in the DLRM paper: https://arxiv.org/pdf/1906.00091.pdf,
+ section 2.1.3. Sparse activations and dense activations are combined.
+ Dot interaction is applied to a batch of input Tensors [e1,...,e_k] of the
+ same dimension and the output is a batch of Tensors with all distinct pairwise
+ dot products of the form dot(e_i, e_j) for i <= j if self self_interaction is
+ True, otherwise dot(e_i, e_j) i < j.
+
+ Attributes:
+ self_interaction: Boolean indicating if features should self-interact.
+ If it is True, then the diagonal entries of the interaction metric are
+ also taken.
+ skip_gather: An optimization flag. If it's set then the upper triangle part
+ of the dot interaction matrix dot(e_i, e_j) is set to 0. The resulting
+ activations will be of dimension [num_features * num_features] from which
+ half will be zeros. Otherwise activations will be only lower triangle part
+ of the interaction matrix. The later saves space but is much slower.
+ name: String name of the layer.
+ """
+
+ def __init__(self, params, name=None, reuse=None, **kwargs):
+ super(DotInteraction, self).__init__(name=name, **kwargs)
+ self._self_interaction = params.get_or_default('self_interaction', False)
+ self._skip_gather = params.get_or_default('skip_gather', False)
+
+ def call(self, inputs, **kwargs):
+ """Performs the interaction operation on the tensors in the list.
+
+ The tensors represent as transformed dense features and embedded categorical
+ features.
+ Pre-condition: The tensors should all have the same shape.
+
+ Args:
+ inputs: List of features with shapes [batch_size, feature_dim].
+
+ Returns:
+ activations: Tensor representing interacted features. It has a dimension
+ `num_features * num_features` if skip_gather is True, otherside
+ `num_features * (num_features + 1) / 2` if self_interaction is True and
+ `num_features * (num_features - 1) / 2` if self_interaction is False.
+ """
+ if isinstance(inputs, (list, tuple)):
+ # concat_features shape: batch_size, num_features, feature_dim
+ try:
+ concat_features = tf.stack(inputs, axis=1)
+ except (ValueError, tf.errors.InvalidArgumentError) as e:
+ raise ValueError('Input tensors` dimensions must be equal, original'
+ 'error message: {}'.format(e))
+ else:
+ assert inputs.shape.ndims == 3, 'input of dot func must be a 3D tensor or a list of 2D tensors'
+ concat_features = inputs
+
+ batch_size = tf.shape(concat_features)[0]
+
+ # Interact features, select lower-triangular portion, and re-shape.
+ xactions = tf.matmul(concat_features, concat_features, transpose_b=True)
+ num_features = xactions.shape[-1]
+ ones = tf.ones_like(xactions)
+ if self._self_interaction:
+ # Selecting lower-triangular portion including the diagonal.
+ lower_tri_mask = tf.linalg.band_part(ones, -1, 0)
+ upper_tri_mask = ones - lower_tri_mask
+ out_dim = num_features * (num_features + 1) // 2
+ else:
+ # Selecting lower-triangular portion not included the diagonal.
+ upper_tri_mask = tf.linalg.band_part(ones, 0, -1)
+ lower_tri_mask = ones - upper_tri_mask
+ out_dim = num_features * (num_features - 1) // 2
+
+ if self._skip_gather:
+ # Setting upper triangle part of the interaction matrix to zeros.
+ activations = tf.where(
+ condition=tf.cast(upper_tri_mask, tf.bool),
+ x=tf.zeros_like(xactions),
+ y=xactions)
+ out_dim = num_features * num_features
+ else:
+ activations = tf.boolean_mask(xactions, lower_tri_mask)
+ activations = tf.reshape(activations, (batch_size, out_dim))
+ return activations
+
+
+class Cross(tf.keras.layers.Layer):
+ """Cross Layer in Deep & Cross Network to learn explicit feature interactions.
+
+ A layer that creates explicit and bounded-degree feature interactions
+ efficiently. The `call` method accepts `inputs` as a tuple of size 2
+ tensors. The first input `x0` is the base layer that contains the original
+ features (usually the embedding layer); the second input `xi` is the output
+ of the previous `Cross` layer in the stack, i.e., the i-th `Cross`
+ layer. For the first `Cross` layer in the stack, x0 = xi.
+
+ The output is x_{i+1} = x0 .* (W * xi + bias + diag_scale * xi) + xi,
+ where .* designates elementwise multiplication, W could be a full-rank
+ matrix, or a low-rank matrix U*V to reduce the computational cost, and
+ diag_scale increases the diagonal of W to improve training stability (
+ especially for the low-rank case).
+
+ References:
+ 1. [R. Wang et al.](https://arxiv.org/pdf/2008.13535.pdf)
+ See Eq. (1) for full-rank and Eq. (2) for low-rank version.
+ 2. [R. Wang et al.](https://arxiv.org/pdf/1708.05123.pdf)
+
+ Example:
+
+ ```python
+ # after embedding layer in a functional model:
+ input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64)
+ x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)
+ x1 = Cross()(x0, x0)
+ x2 = Cross()(x0, x1)
+ logits = tf.keras.layers.Dense(units=10)(x2)
+ model = tf.keras.Model(input, logits)
+ ```
+
+ Args:
+ projection_dim: project dimension to reduce the computational cost.
+ Default is `None` such that a full (`input_dim` by `input_dim`) matrix
+ W is used. If enabled, a low-rank matrix W = U*V will be used, where U
+ is of size `input_dim` by `projection_dim` and V is of size
+ `projection_dim` by `input_dim`. `projection_dim` need to be smaller
+ than `input_dim`/2 to improve the model efficiency. In practice, we've
+ observed that `projection_dim` = d/4 consistently preserved the
+ accuracy of a full-rank version.
+ diag_scale: a non-negative float used to increase the diagonal of the
+ kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an
+ identity matrix.
+ use_bias: whether to add a bias term for this layer. If set to False,
+ no bias term will be used.
+ preactivation: Activation applied to output matrix of the layer, before
+ multiplication with the input. Can be used to control the scale of the
+ layer's outputs and improve stability.
+ kernel_initializer: Initializer to use on the kernel matrix.
+ bias_initializer: Initializer to use on the bias vector.
+ kernel_regularizer: Regularizer to use on the kernel matrix.
+ bias_regularizer: Regularizer to use on bias vector.
+
+ Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs.
+ Output shape: A single (batch_size, `input_dim`) dimensional output.
+ """
+
+ def __init__(self, params, name='cross', reuse=None, **kwargs):
+ super(Cross, self).__init__(name=name, **kwargs)
+ self._projection_dim = params.get_or_default('projection_dim', None)
+ self._diag_scale = params.get_or_default('diag_scale', 0.0)
+ self._use_bias = params.get_or_default('use_bias', True)
+ preactivation = params.get_or_default('preactivation', None)
+ preact = get_activation(preactivation)
+ self._preactivation = tf.keras.activations.get(preact)
+ kernel_initializer = params.get_or_default('kernel_initializer',
+ 'truncated_normal')
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ bias_initializer = params.get_or_default('bias_initializer', 'zeros')
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ kernel_regularizer = params.get_or_default('kernel_regularizer', None)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ bias_regularizer = params.get_or_default('bias_regularizer', None)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._input_dim = None
+ self._supports_masking = True
+
+ if self._diag_scale < 0: # pytype: disable=unsupported-operands
+ raise ValueError(
+ '`diag_scale` should be non-negative. Got `diag_scale` = {}'.format(
+ self._diag_scale))
+
+ def build(self, input_shape):
+ last_dim = input_shape[0][-1]
+
+ if self._projection_dim is None:
+ self._dense = tf.keras.layers.Dense(
+ last_dim,
+ kernel_initializer=_clone_initializer(self._kernel_initializer),
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ use_bias=self._use_bias,
+ dtype=self.dtype,
+ activation=self._preactivation,
+ )
+ else:
+ self._dense_u = tf.keras.layers.Dense(
+ self._projection_dim,
+ kernel_initializer=_clone_initializer(self._kernel_initializer),
+ kernel_regularizer=self._kernel_regularizer,
+ use_bias=False,
+ dtype=self.dtype,
+ )
+ self._dense_v = tf.keras.layers.Dense(
+ last_dim,
+ kernel_initializer=_clone_initializer(self._kernel_initializer),
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ use_bias=self._use_bias,
+ dtype=self.dtype,
+ activation=self._preactivation,
+ )
+ super(Cross, self).build(input_shape) # Be sure to call this somewhere!
+
+ def call(self, inputs, **kwargs):
+ """Computes the feature cross.
+
+ Args:
+ inputs: The input tensor(x0, x)
+ - x0: The input tensor
+ - x: Optional second input tensor. If provided, the layer will compute
+ crosses between x0 and x; if not provided, the layer will compute
+ crosses between x0 and itself.
+
+ Returns:
+ Tensor of crosses.
+ """
+ if isinstance(inputs, (list, tuple)):
+ x0, x = inputs
+ else:
+ x0, x = inputs, inputs
+
+ if not self.built:
+ self.build(x0.shape)
+
+ if x0.shape[-1] != x.shape[-1]:
+ raise ValueError(
+ '`x0` and `x` dimension mismatch! Got `x0` dimension {}, and x '
+ 'dimension {}. This case is not supported yet.'.format(
+ x0.shape[-1], x.shape[-1]))
+
+ if self._projection_dim is None:
+ prod_output = self._dense(x)
+ else:
+ prod_output = self._dense_v(self._dense_u(x))
+
+ # prod_output = tf.cast(prod_output, self.compute_dtype)
+
+ if self._diag_scale:
+ prod_output = prod_output + self._diag_scale * x
+
+ return x0 * prod_output + x
+
+ def get_config(self):
+ config = {
+ 'projection_dim':
+ self._projection_dim,
+ 'diag_scale':
+ self._diag_scale,
+ 'use_bias':
+ self._use_bias,
+ 'preactivation':
+ tf.keras.activations.serialize(self._preactivation),
+ 'kernel_initializer':
+ tf.keras.initializers.serialize(self._kernel_initializer),
+ 'bias_initializer':
+ tf.keras.initializers.serialize(self._bias_initializer),
+ 'kernel_regularizer':
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
+ 'bias_regularizer':
+ tf.keras.regularizers.serialize(self._bias_regularizer),
+ }
+ base_config = super(Cross, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+class CIN(tf.keras.layers.Layer):
+ """Compressed Interaction Network(CIN) module in xDeepFM model.
+
+ CIN layer is aimed at achieving high-order feature interactions at
+ vector-wise level rather than bit-wise level.
+
+
+ Reference:
+ [xDeepFM](https://arxiv.org/pdf/1803.05170)
+ xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems
+ """
+
+ def __init__(self, params, name='cin', reuse=None, **kwargs):
+ super(CIN, self).__init__(name=name, **kwargs)
+ self._name = name
+ self._hidden_feature_sizes = list(
+ params.get_or_default('hidden_feature_sizes', []))
+
+ assert isinstance(self._hidden_feature_sizes, list) and len(
+ self._hidden_feature_sizes
+ ) > 0, 'parameter hidden_feature_sizes must be a list of int with length greater than 0'
+
+ kernel_regularizer = params.get_or_default('kernel_regularizer', None)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ bias_regularizer = params.get_or_default('bias_regularizer', None)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+
+ def build(self, input_shape):
+ if len(input_shape) != 3:
+ raise ValueError(
+ 'Unexpected inputs dimensions %d, expect to be 3 dimensions' %
+ (len(input_shape)))
+
+ hidden_feature_sizes = [input_shape[1]
+ ] + [h for h in self._hidden_feature_sizes]
+ tfv1 = tf.compat.v1 if tf.__version__ >= '2.0' else tf
+ with tfv1.variable_scope(self._name):
+ self.kernel_list = [
+ tfv1.get_variable(
+ name='cin_kernel_%d' % i,
+ shape=[
+ hidden_feature_sizes[i + 1], hidden_feature_sizes[i],
+ hidden_feature_sizes[0]
+ ],
+ initializer=tf.initializers.he_normal(),
+ regularizer=self._kernel_regularizer,
+ trainable=True) for i in range(len(self._hidden_feature_sizes))
+ ]
+ self.bias_list = [
+ tfv1.get_variable(
+ name='cin_bias_%d' % i,
+ shape=[hidden_feature_sizes[i + 1]],
+ initializer=tf.keras.initializers.Zeros,
+ regularizer=self._bias_regularizer,
+ trainable=True) for i in range(len(self._hidden_feature_sizes))
+ ]
+
+ super(CIN, self).build(input_shape)
+
+ def call(self, input, **kwargs):
+ """Computes the compressed feature maps.
+
+ Args:
+ input: The 3D input tensor with shape (b, h0, d), where b is batch_size,
+ h0 is the number of features, d is the feature embedding dimension.
+
+ Returns:
+ 2D tensor of compressed feature map with shape (b, featuremap_num),
+ where b is the batch_size, featuremap_num is sum of the hidden layer sizes
+ """
+ x_0 = input
+ x_i = input
+ x_0_expanded = tf.expand_dims(x_0, 1)
+ pooled_feature_map_list = []
+ for i in range(len(self._hidden_feature_sizes)):
+ hk = self._hidden_feature_sizes[i]
+
+ x_i_expanded = tf.expand_dims(x_i, 2)
+ intermediate_tensor = tf.multiply(x_0_expanded, x_i_expanded)
+
+ intermediate_tensor_expanded = tf.expand_dims(intermediate_tensor, 1)
+ intermediate_tensor_expanded = tf.tile(intermediate_tensor_expanded,
+ [1, hk, 1, 1, 1])
+
+ feature_map_elementwise = tf.multiply(
+ intermediate_tensor_expanded,
+ tf.expand_dims(tf.expand_dims(self.kernel_list[i], -1), 0))
+ feature_map = tf.reduce_sum(
+ tf.reduce_sum(feature_map_elementwise, axis=3), axis=2)
+
+ feature_map = tf.add(
+ feature_map,
+ tf.expand_dims(tf.expand_dims(self.bias_list[i], axis=-1), axis=0))
+ feature_map = tf.nn.relu(feature_map)
+
+ x_i = feature_map
+ pooled_feature_map_list.append(tf.reduce_sum(feature_map, axis=-1))
+ return tf.concat(
+ pooled_feature_map_list, axis=-1) # shape = (b, h1 + ... + hk)
+
+ def get_config(self):
+ pass
+
+
+def _clone_initializer(initializer):
+ return initializer.__class__.from_config(initializer.get_config())
diff --git a/easy_rec/python/layers/keras/layer_norm.py b/easy_rec/python/layers/keras/layer_norm.py
new file mode 100644
index 000000000..7d6c81d5f
--- /dev/null
+++ b/easy_rec/python/layers/keras/layer_norm.py
@@ -0,0 +1,364 @@
+"""Layer Normalization layer."""
+import tensorflow as tf
+from tensorflow.python.keras import constraints
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import regularizers
+from tensorflow.python.keras.layers import Layer
+
+
+def validate_axis(axis, input_shape):
+ """Validate an axis value and returns its standardized form.
+
+ Args:
+ axis: Value to validate. Can be an integer or a list/tuple of integers.
+ Integers may be negative.
+ input_shape: Reference input shape that the axis/axes refer to.
+
+ Returns:
+ Normalized form of `axis`, i.e. a list with all-positive values.
+ """
+ input_shape = tf.TensorShape(input_shape)
+ rank = input_shape.ndims
+ if not rank:
+ raise ValueError(
+ 'Input has undefined rank. Received: input_shape={input_shape}'.format(
+ input_shape=input_shape))
+
+ # Convert axis to list and resolve negatives
+ if isinstance(axis, int):
+ axis = [axis]
+ else:
+ axis = list(axis)
+ for idx, x in enumerate(axis):
+ if x < 0:
+ axis[idx] = rank + x
+
+ # Validate axes
+ for x in axis:
+ if x < 0 or x >= rank:
+ raise ValueError('Invalid value for `axis` argument. '
+ 'Expected 0 <= axis < inputs.rank (with '
+ 'inputs.rank={rank}). Received: axis={axis}'.format(
+ rank=rank, axis=tuple(axis)))
+ if len(axis) != len(set(axis)):
+ raise ValueError('Duplicate axis: {axis}'.format(axis=tuple(axis)))
+ return axis
+
+
+class LayerNormalization(Layer):
+ """Layer normalization layer (Ba et al., 2016).
+
+ Normalize the activations of the previous layer for each given example in a
+ batch independently, rather than across a batch like Batch Normalization.
+ i.e. applies a transformation that maintains the mean activation within each
+ example close to 0 and the activation standard deviation close to 1.
+
+ Given a tensor `inputs`, moments are calculated and normalization
+ is performed across the axes specified in `axis`.
+
+ Example:
+ >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32)
+ >>> print(data)
+ tf.Tensor(
+ [[ 0. 10.]
+ [20. 30.]
+ [40. 50.]
+ [60. 70.]
+ [80. 90.]], shape=(5, 2), dtype=float32)
+
+ >>> layer = tf.keras.layers.LayerNormalization(axis=1)
+ >>> output = layer(data)
+ >>> print(output)
+ tf.Tensor(
+ [[-1. 1.]
+ [-1. 1.]
+ [-1. 1.]
+ [-1. 1.]
+ [-1. 1.]], shape=(5, 2), dtype=float32)
+
+ Notice that with Layer Normalization the normalization happens across the
+ axes *within* each example, rather than across different examples in the
+ batch.
+
+ If `scale` or `center` are enabled, the layer will scale the normalized
+ outputs by broadcasting them with a trainable variable `gamma`, and center
+ the outputs by broadcasting with a trainable variable `beta`. `gamma` will
+ default to a ones tensor and `beta` will default to a zeros tensor, so that
+ centering and scaling are no-ops before training has begun.
+
+ So, with scaling and centering enabled the normalization equations
+ are as follows:
+
+ Let the intermediate activations for a mini-batch to be the `inputs`.
+
+ For each sample `x_i` in `inputs` with `k` features, we compute the mean and
+ variance of the sample:
+
+ ```python
+ mean_i = sum(x_i[j] for j in range(k)) / k
+ var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
+ ```
+
+ and then compute a normalized `x_i_normalized`, including a small factor
+ `epsilon` for numerical stability.
+
+ ```python
+ x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
+ ```
+
+ And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,
+ which are learned parameters:
+
+ ```python
+ output_i = x_i_normalized * gamma + beta
+ ```
+
+ `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and
+ this part of the inputs' shape must be fully defined.
+
+ For example:
+ >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
+ >>> layer.build([5, 20, 30, 40])
+ >>> print(layer.beta.shape)
+ (20, 30, 40)
+ >>> print(layer.gamma.shape)
+ (20, 30, 40)
+
+ Note that other implementations of layer normalization may choose to define
+ `gamma` and `beta` over a separate set of axes from the axes being
+ normalized across. For example, Group Normalization
+ ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
+ corresponds to a Layer Normalization that normalizes across height, width,
+ and channel and has `gamma` and `beta` span only the channel dimension.
+ So, this Layer Normalization implementation will not match a Group
+ Normalization layer with group size set to 1.
+
+ Args:
+ axis: Integer or List/Tuple. The axis or axes to normalize across.
+ Typically, this is the features axis/axes. The left-out axes are
+ typically the batch axis/axes. `-1` is the last dimension in the
+ input. Defaults to `-1`.
+ epsilon: Small float added to variance to avoid dividing by zero. Defaults
+ to 1e-3
+ center: If True, add offset of `beta` to normalized tensor. If False,
+ `beta` is ignored. Defaults to `True`.
+ scale: If True, multiply by `gamma`. If False, `gamma` is not used.
+ When the next layer is linear (also e.g. `nn.relu`), this can be
+ disabled since the scaling will be done by the next layer.
+ Defaults to `True`.
+ beta_initializer: Initializer for the beta weight. Defaults to zeros.
+ gamma_initializer: Initializer for the gamma weight. Defaults to ones.
+ beta_regularizer: Optional regularizer for the beta weight. None by
+ default.
+ gamma_regularizer: Optional regularizer for the gamma weight. None by
+ default.
+ beta_constraint: Optional constraint for the beta weight. None by default.
+ gamma_constraint: Optional constraint for the gamma weight. None by
+ default.
+
+ Input shape:
+ Arbitrary. Use the keyword argument `input_shape` (tuple of
+ integers, does not include the samples axis) when using this layer as the
+ first layer in a model.
+
+ Output shape:
+ Same shape as input.
+
+ Reference:
+ - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
+ """
+
+ def __init__(self,
+ axis=-1,
+ epsilon=1e-3,
+ center=True,
+ scale=True,
+ beta_initializer='zeros',
+ gamma_initializer='ones',
+ beta_regularizer=None,
+ gamma_regularizer=None,
+ beta_constraint=None,
+ gamma_constraint=None,
+ **kwargs):
+ super(LayerNormalization, self).__init__(**kwargs)
+ if isinstance(axis, (list, tuple)):
+ self.axis = list(axis)
+ elif isinstance(axis, int):
+ self.axis = axis
+ else:
+ raise TypeError('Expected an int or a list/tuple of ints for the '
+ "argument 'axis', but received: %r" % axis)
+
+ self.epsilon = epsilon
+ self.center = center
+ self.scale = scale
+ self.beta_initializer = initializers.get(beta_initializer)
+ self.gamma_initializer = initializers.get(gamma_initializer)
+ self.beta_regularizer = regularizers.get(beta_regularizer)
+ self.gamma_regularizer = regularizers.get(gamma_regularizer)
+ self.beta_constraint = constraints.get(beta_constraint)
+ self.gamma_constraint = constraints.get(gamma_constraint)
+
+ self.supports_masking = True
+
+ # Indicates whether a faster fused implementation can be used. This will
+ # be set to True or False in build()"
+ self._fused = None
+
+ def _fused_can_be_used(self, ndims):
+ """Returns false if fused implementation cannot be used.
+
+ Check if the axis is contiguous and can be collapsed into the last axis.
+ The self.axis is assumed to have no duplicates.
+ """
+ if not tf.test.is_gpu_available():
+ return False
+ axis = sorted(self.axis)
+ can_use_fused = False
+
+ if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1:
+ can_use_fused = True
+
+ # fused_batch_norm will silently raise epsilon to be at least 1.001e-5,
+ # so we cannot used the fused version if epsilon is below that value.
+ # Also, the variable dtype must be float32, as fused_batch_norm only
+ # supports float32 variables.
+ if self.epsilon < 1.001e-5 or self.dtype != 'float32':
+ can_use_fused = False
+
+ return can_use_fused
+
+ def build(self, input_shape):
+ self.axis = validate_axis(self.axis, input_shape)
+ input_shape = tf.TensorShape(input_shape)
+ rank = input_shape.ndims
+
+ param_shape = [input_shape[dim] for dim in self.axis]
+ if self.scale:
+ self.gamma = self.add_weight(
+ name='gamma',
+ shape=param_shape,
+ initializer=self.gamma_initializer,
+ regularizer=self.gamma_regularizer,
+ constraint=self.gamma_constraint,
+ trainable=True,
+ )
+ else:
+ self.gamma = None
+
+ if self.center:
+ self.beta = self.add_weight(
+ name='beta',
+ shape=param_shape,
+ initializer=self.beta_initializer,
+ regularizer=self.beta_regularizer,
+ constraint=self.beta_constraint,
+ trainable=True,
+ )
+ else:
+ self.beta = None
+
+ self._fused = self._fused_can_be_used(rank)
+ super(LayerNormalization,
+ self).build(input_shape) # Be sure to call this somewhere!
+
+ def call(self, inputs):
+ # Compute the axes along which to reduce the mean / variance
+ input_shape = inputs.shape
+ ndims = len(input_shape)
+
+ # Broadcasting only necessary for norm when the axis is not just
+ # the last dimension
+ broadcast_shape = [1] * ndims
+ for dim in self.axis:
+ broadcast_shape[dim] = input_shape.dims[dim].value
+
+ def _broadcast(v):
+ if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
+ return tf.reshape(v, broadcast_shape)
+ return v
+
+ if not self._fused:
+ input_dtype = inputs.dtype
+ if (input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32'):
+ # If mixed precision is used, cast inputs to float32 so that
+ # this is at least as numerically stable as the fused version.
+ inputs = tf.cast(inputs, 'float32')
+
+ # Calculate the moments on the last axis (layer activations).
+ mean, variance = tf.nn.moments(inputs, self.axis, keep_dims=True)
+
+ scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
+
+ # Compute layer normalization using the batch_normalization
+ # function.
+ outputs = tf.nn.batch_normalization(
+ inputs,
+ mean,
+ variance,
+ offset=offset,
+ scale=scale,
+ variance_epsilon=self.epsilon,
+ )
+ outputs = tf.cast(outputs, input_dtype)
+ else:
+ # Collapse dims before self.axis, and dims in self.axis
+
+ axis = sorted(self.axis)
+ tensor_shape = tf.shape(inputs)
+ pre_dim = tf.reduce_prod(tensor_shape[:axis[0]])
+ in_dim = tf.reduce_prod(tensor_shape[axis[0]:])
+ squeezed_shape = [1, pre_dim, in_dim, 1]
+ # This fused operation requires reshaped inputs to be NCHW.
+ data_format = 'NCHW'
+
+ inputs = tf.reshape(inputs, squeezed_shape)
+
+ # self.gamma and self.beta have the wrong shape for
+ # fused_batch_norm, so we cannot pass them as the scale and offset
+ # parameters. Therefore, we create two constant tensors in correct
+ # shapes for fused_batch_norm and later construct a separate
+ # calculation on the scale and offset.
+ scale = tf.ones(tf.convert_to_tensor([pre_dim]), dtype=self.dtype)
+ offset = tf.zeros(tf.convert_to_tensor([pre_dim]), dtype=self.dtype)
+
+ # Compute layer normalization using the fused_batch_norm function.
+ outputs, _, _ = tf.compat.v1.nn.fused_batch_norm(
+ inputs,
+ scale=scale,
+ offset=offset,
+ epsilon=self.epsilon,
+ data_format=data_format,
+ )
+
+ outputs = tf.reshape(outputs, tensor_shape)
+
+ scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
+
+ if scale is not None:
+ outputs = outputs * tf.cast(scale, outputs.dtype)
+ if offset is not None:
+ outputs = outputs + tf.cast(offset, outputs.dtype)
+
+ # If some components of the shape got lost due to adjustments, fix that.
+ outputs.set_shape(input_shape)
+ return outputs
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ def get_config(self):
+ config = {
+ 'axis': self.axis,
+ 'epsilon': self.epsilon,
+ 'center': self.center,
+ 'scale': self.scale,
+ 'beta_initializer': initializers.serialize(self.beta_initializer),
+ 'gamma_initializer': initializers.serialize(self.gamma_initializer),
+ 'beta_regularizer': regularizers.serialize(self.beta_regularizer),
+ 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
+ 'beta_constraint': constraints.serialize(self.beta_constraint),
+ 'gamma_constraint': constraints.serialize(self.gamma_constraint),
+ }
+ base_config = super(LayerNormalization, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
diff --git a/easy_rec/python/layers/keras/mask_net.py b/easy_rec/python/layers/keras/mask_net.py
new file mode 100644
index 000000000..bf687154e
--- /dev/null
+++ b/easy_rec/python/layers/keras/mask_net.py
@@ -0,0 +1,166 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+from tensorflow.python.keras.layers import Activation
+from tensorflow.python.keras.layers import Dense
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.layers.keras.blocks import MLP
+from easy_rec.python.layers.keras.layer_norm import LayerNormalization
+from easy_rec.python.layers.utils import Parameter
+
+
+class MaskBlock(Layer):
+ """MaskBlock use in MaskNet.
+
+ Args:
+ projection_dim: project dimension to reduce the computational cost.
+ Default is `None` such that a full (`input_dim` by `aggregation_size`) matrix
+ W is used. If enabled, a low-rank matrix W = U*V will be used, where U
+ is of size `input_dim` by `projection_dim` and V is of size
+ `projection_dim` by `aggregation_size`. `projection_dim` need to be smaller
+ than `aggregation_size`/2 to improve the model efficiency. In practice, we've
+ observed that `projection_dim` = d/4 consistently preserved the
+ accuracy of a full-rank version.
+ """
+
+ def __init__(self, params, name='mask_block', reuse=None, **kwargs):
+ super(MaskBlock, self).__init__(name=name, **kwargs)
+ self.config = params.get_pb_config()
+ self.l2_reg = params.l2_regularizer
+ self._projection_dim = params.get_or_default('projection_dim', None)
+ self.reuse = reuse
+ self.final_relu = Activation('relu', name='relu')
+
+ def build(self, input_shape):
+ if type(input_shape) in (tuple, list):
+ assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs'
+ input_dim = int(input_shape[0][-1])
+ mask_input_dim = int(input_shape[1][-1])
+ else:
+ input_dim, mask_input_dim = input_shape[-1], input_shape[-1]
+ if self.config.HasField('reduction_factor'):
+ aggregation_size = int(mask_input_dim * self.config.reduction_factor)
+ elif self.config.HasField('aggregation_size') is not None:
+ aggregation_size = self.config.aggregation_size
+ else:
+ raise ValueError(
+ 'Need one of reduction factor or aggregation size for MaskBlock.')
+
+ self.aggr_layer = Dense(
+ aggregation_size,
+ activation='relu',
+ kernel_initializer='he_uniform',
+ kernel_regularizer=self.l2_reg,
+ name='aggregation')
+ self.weight_layer = Dense(input_dim, name='weights')
+ if self._projection_dim is not None:
+ logging.info('%s project dim is %d', self.name, self._projection_dim)
+ self.project_layer = Dense(
+ self._projection_dim,
+ kernel_regularizer=self.l2_reg,
+ use_bias=False,
+ name='project')
+ if self.config.input_layer_norm:
+ # 推荐在调用MaskBlock之前做好 layer norm,否则每一次调用都需要对input做ln
+ if tf.__version__ >= '2.0':
+ self.input_layer_norm = tf.keras.layers.LayerNormalization(
+ name='input_ln')
+ else:
+ self.input_layer_norm = LayerNormalization(name='input_ln')
+
+ if self.config.HasField('output_size'):
+ self.output_layer = Dense(
+ self.config.output_size, use_bias=False, name='output')
+ if tf.__version__ >= '2.0':
+ self.output_layer_norm = tf.keras.layers.LayerNormalization(
+ name='output_ln')
+ else:
+ self.output_layer_norm = LayerNormalization(name='output_ln')
+ super(MaskBlock, self).build(input_shape)
+
+ def call(self, inputs, training=None, **kwargs):
+ if type(inputs) in (tuple, list):
+ net, mask_input = inputs[:2]
+ else:
+ net, mask_input = inputs, inputs
+
+ if self.config.input_layer_norm:
+ net = self.input_layer_norm(net)
+
+ if self._projection_dim is None:
+ aggr = self.aggr_layer(mask_input)
+ else:
+ u = self.project_layer(mask_input)
+ aggr = self.aggr_layer(u)
+
+ weights = self.weight_layer(aggr)
+ masked_net = net * weights
+
+ if not self.config.HasField('output_size'):
+ return masked_net
+
+ hidden = self.output_layer(masked_net)
+ ln_hidden = self.output_layer_norm(hidden)
+ return self.final_relu(ln_hidden)
+
+
+class MaskNet(Layer):
+ """MaskNet: Introducing Feature-Wise Multiplication to CTR Ranking Models by Instance-Guided Mask.
+
+ Refer: https://arxiv.org/pdf/2102.07619.pdf
+ """
+
+ def __init__(self, params, name='mask_net', reuse=None, **kwargs):
+ super(MaskNet, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ self.params = params
+ self.config = params.get_pb_config()
+ if self.config.HasField('mlp'):
+ p = Parameter.make_from_pb(self.config.mlp)
+ p.l2_regularizer = params.l2_regularizer
+ self.mlp = MLP(p, name='mlp', reuse=reuse)
+ else:
+ self.mlp = None
+
+ self.mask_layers = []
+ for i, block_conf in enumerate(self.config.mask_blocks):
+ params = Parameter.make_from_pb(block_conf)
+ params.l2_regularizer = self.params.l2_regularizer
+ mask_layer = MaskBlock(params, name='block_%d' % i, reuse=self.reuse)
+ self.mask_layers.append(mask_layer)
+
+ if self.config.input_layer_norm:
+ if tf.__version__ >= '2.0':
+ self.input_layer_norm = tf.keras.layers.LayerNormalization(
+ name='input_ln')
+ else:
+ self.input_layer_norm = LayerNormalization(name='input_ln')
+
+ def call(self, inputs, training=None, **kwargs):
+ if self.config.input_layer_norm:
+ inputs = self.input_layer_norm(inputs)
+
+ if self.config.use_parallel:
+ mask_outputs = [
+ mask_layer((inputs, inputs)) for mask_layer in self.mask_layers
+ ]
+ all_mask_outputs = tf.concat(mask_outputs, axis=1)
+ if self.mlp is not None:
+ output = self.mlp(all_mask_outputs, training=training)
+ else:
+ output = all_mask_outputs
+ return output
+ else:
+ net = inputs
+ for i, _ in enumerate(self.config.mask_blocks):
+ mask_layer = self.mask_layers[i]
+ net = mask_layer((net, inputs))
+
+ if self.mlp is not None:
+ output = self.mlp(net, training=training)
+ else:
+ output = net
+ return output
diff --git a/easy_rec/python/layers/keras/multi_head_attention.py b/easy_rec/python/layers/keras/multi_head_attention.py
new file mode 100644
index 000000000..a5ca0b40d
--- /dev/null
+++ b/easy_rec/python/layers/keras/multi_head_attention.py
@@ -0,0 +1,717 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+import string
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.keras import constraints
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import regularizers
+from tensorflow.python.keras.layers import Dropout
+from tensorflow.python.keras.layers import Layer
+from tensorflow.python.keras.layers import Softmax
+
+from easy_rec.python.layers.keras.activation import MaskedSoftmax
+from easy_rec.python.layers.keras.einsum_dense import EinsumDense
+
+
+class MultiHeadAttention(Layer):
+ """MultiHeadAttention layer.
+
+ This is an implementation of multi-headed attention as described in the
+ paper "Attention is all you Need"
+ [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762).
+ If `query`, `key,` `value` are the same, then
+ this is self-attention. Each time step in `query` attends to the
+ corresponding sequence in `key`, and returns a fixed-width vector.
+
+ This layer first projects `query`, `key` and `value`. These are
+ (effectively) a list of tensors of length `num_attention_heads`, where the
+ corresponding shapes are `(batch_size, , key_dim)`,
+ `(batch_size, , key_dim)`,
+ `(batch_size, , value_dim)`.
+
+ Then, the query and key tensors are dot-producted and scaled. These are
+ softmaxed to obtain attention probabilities. The value tensors are then
+ interpolated by these probabilities, then concatenated back to a single
+ tensor.
+
+ Finally, the result tensor with the last dimension as `value_dim` can take
+ a linear projection and return.
+
+ Args:
+ num_heads: Number of attention heads.
+ key_dim: Size of each attention head for query and key.
+ value_dim: Size of each attention head for value.
+ dropout: Dropout probability.
+ use_bias: Boolean, whether the dense layers use bias vectors/matrices.
+ output_shape: The expected shape of an output tensor, besides the batch
+ and sequence dims. If not specified, projects back to the query
+ feature dim (the query input's last dimension).
+ attention_axes: axes over which the attention is applied. `None` means
+ attention over all axes, but batch, heads, and features.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ use_causal_mask: A boolean to indicate whether to apply a causal mask to
+ prevent tokens from attending to future tokens (e.g., used in a
+ decoder Transformer).
+ return_attention_scores: A boolean to indicate whether the output should
+ be `(attention_output, attention_scores)` if `True`, or
+ `attention_output` if `False`. Defaults to `False`.
+
+ Call arguments:
+ query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size,
+ `T` is the target sequence length, and dim is the feature dimension.
+ value: Value tensor of shape `(B, S, dim)`, where `B` is the batch size,
+ `S` is the source sequence length, and dim is the feature dimension.
+ key: Optional key tensor of shape `(B, S, dim)`. If not given, will
+ use `value` for both `key` and `value`, which is the most common
+ case.
+ attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
+ attention to certain positions. The boolean mask specifies which
+ query elements can attend to which key elements, 1 indicates
+ attention and 0 indicates no attention. Broadcasting can happen for
+ the missing batch dimensions and the head dimension.
+ training: Python boolean indicating whether the layer should behave in
+ training mode (adding dropout) or in inference mode (no dropout).
+ Will go with either using the training mode of the parent
+ layer/model, or `False` (inference) if there is no parent layer.
+
+ Returns:
+ attention_output: The result of the computation, of shape `(B, T, E)`,
+ where `T` is for target sequence shapes and `E` is the query input
+ last dimension if `output_shape` is `None`. Otherwise, the
+ multi-head outputs are projected to the shape specified by
+ `output_shape`.
+ attention_scores: (Optional) multi-head attention coefficients over
+ attention axes.
+ """
+
+ def __init__(self, params, name='multi_head_attention', reuse=None, **kwargs):
+ super(MultiHeadAttention, self).__init__(name=name, **kwargs)
+ self.supports_masking = True
+ self._num_heads = params.num_heads
+ self._key_dim = params.key_dim
+ # Cache 1.0 / math.sqrt(self._key_dim).
+ self._inverse_sqrt_key_dim = None
+ value_dim = params.get_or_default('value_dim', None)
+ self._value_dim = value_dim if value_dim else self._key_dim
+ self._dropout = params.get_or_default('dropout', 0.0)
+ self._use_bias = params.get_or_default('use_bias', True)
+ self._output_shape = params.get_or_default('output_shape', None)
+ self._kernel_initializer = initializers.get(
+ params.get_or_default('kernel_initializer', 'glorot_uniform'))
+ self._bias_initializer = initializers.get(
+ params.get_or_default('bias_initializer', 'zeros'))
+ self._kernel_regularizer = regularizers.get(
+ params.get_or_default('kernel_regularizer', None))
+ self._bias_regularizer = regularizers.get(
+ params.get_or_default('bias_regularizer', None))
+ self._activity_regularizer = regularizers.get(
+ params.get_or_default('activity_regularizer', None))
+ self._kernel_constraint = constraints.get(
+ params.get_or_default('kernel_constraint', None))
+ self._bias_constraint = constraints.get(
+ params.get_or_default('bias_constraint', None))
+ self._attention_axes = params.get_or_default('attention_axes', None)
+ self._use_causal_mask = params.get_or_default('use_causal_mask', False)
+ self._return_attention_scores = params.get_or_default(
+ 'return_attention_scores', False)
+
+ @property
+ def num_heads(self):
+ return self._num_heads
+
+ @property
+ def key_dim(self):
+ return self._key_dim
+
+ @property
+ def value_dim(self):
+ return self._value_dim
+
+ @property
+ def dropout(self):
+ return self._dropout
+
+ @property
+ def use_bias(self):
+ return self._use_bias
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ @property
+ def attention_axes(self):
+ return self._attention_axes
+
+ def get_config(self):
+ base_config = super(MultiHeadAttention, self).get_config()
+ config = {
+ 'num_heads':
+ self._num_heads,
+ 'key_dim':
+ self._key_dim,
+ 'value_dim':
+ self._value_dim,
+ 'dropout':
+ self._dropout,
+ 'use_bias':
+ self._use_bias,
+ 'output_shape':
+ self._output_shape,
+ 'attention_axes':
+ self._attention_axes,
+ 'kernel_initializer':
+ initializers.serialize(self._kernel_initializer),
+ 'bias_initializer':
+ initializers.serialize(self._bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self._kernel_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self._bias_regularizer),
+ 'activity_regularizer':
+ regularizers.serialize(self._activity_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self._kernel_constraint),
+ 'bias_constraint':
+ constraints.serialize(self._bias_constraint),
+ }
+ config.update(base_config)
+ return config
+
+ def build(self, input_shape):
+ """Builds layers and variables."""
+ if len(input_shape) == 3:
+ query_shape, value_shape, key_shape = input_shape
+ elif len(input_shape) == 2:
+ query_shape, value_shape = input_shape
+ key_shape = None
+ else:
+ raise ValueError('invalid input shape of MultiHeadAttention')
+
+ key_shape = value_shape if key_shape is None else key_shape
+ query_rank = len(query_shape)
+ value_rank = len(value_shape)
+ key_rank = len(key_shape)
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ query_rank - 1, bound_dims=1, output_dims=2)
+ self._query_dense = EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1,
+ [self._num_heads, self._key_dim]),
+ bias_axes=bias_axes if self._use_bias else None,
+ name='query',
+ **self._get_common_kwargs_for_sublayer())
+ self._query_dense.build(query_shape)
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ key_rank - 1, bound_dims=1, output_dims=2)
+ self._key_dense = EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1,
+ [self._num_heads, self._key_dim]),
+ bias_axes=bias_axes if self._use_bias else None,
+ name='key',
+ **self._get_common_kwargs_for_sublayer())
+ self._key_dense.build(key_shape)
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ value_rank - 1, bound_dims=1, output_dims=2)
+ self._value_dense = EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1,
+ [self._num_heads, self._value_dim]),
+ bias_axes=bias_axes if self._use_bias else None,
+ name='value',
+ **self._get_common_kwargs_for_sublayer())
+ self._value_dense.build(value_shape)
+ # Builds the attention computations for multi-head dot product
+ # attention. These computations could be wrapped into the keras
+ # attention layer once it supports multi-head einsum computations.
+ self._build_attention(output_rank)
+ self._output_dense = self._make_output_dense(
+ query_shape,
+ self._get_common_kwargs_for_sublayer(),
+ 'attention_output',
+ )
+ output_dense_input_shape = list(
+ self._query_dense.compute_output_shape(query_shape))
+ output_dense_input_shape[-1] = self._value_dim
+ self._output_dense.build(tuple(output_dense_input_shape))
+ self.built = True
+ print('MultiHeadAttention (%s) built' % self.name)
+
+ @property
+ def query_dense(self):
+ return self._query_dense
+
+ @property
+ def key_dense(self):
+ return self._key_dense
+
+ @property
+ def value_dense(self):
+ return self._value_dense
+
+ @property
+ def output_dense(self):
+ return self._output_dense
+
+ def _get_common_kwargs_for_sublayer(self):
+ common_kwargs = dict(
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ dtype=tf.float32,
+ )
+ # Create new clone of kernel/bias initializer, so that we don't reuse
+ # the initializer instance, which could lead to same init value since
+ # initializer is stateless.
+ kernel_initializer = self._kernel_initializer.__class__.from_config(
+ self._kernel_initializer.get_config())
+ bias_initializer = self._bias_initializer.__class__.from_config(
+ self._bias_initializer.get_config())
+ common_kwargs['kernel_initializer'] = kernel_initializer
+ common_kwargs['bias_initializer'] = bias_initializer
+ return common_kwargs
+
+ def _make_output_dense(self, query_shape, common_kwargs, name=None):
+ """Builds the output projection matrix.
+
+ Args:
+ query_shape: query tensor shape
+ common_kwargs: Common keyword arguments for einsum layer.
+ name: Name for the projection layer.
+
+ Returns:
+ Projection layer.
+ """
+ query_rank = len(query_shape)
+ if self._output_shape:
+ if hasattr(self._output_shape, '__len__'):
+ output_shape = self._output_shape
+ else:
+ output_shape = [self._output_shape]
+ else:
+ output_shape = [query_shape[-1]]
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ query_rank - 1, bound_dims=2, output_dims=len(output_shape))
+ return EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1, output_shape),
+ bias_axes=bias_axes if self._use_bias else None,
+ name=name,
+ **common_kwargs)
+
+ def _build_attention(self, rank):
+ """Builds multi-head dot-product attention computations.
+
+ This function builds attributes necessary for `_compute_attention` to
+ customize attention computation to replace the default dot-product
+ attention.
+
+ Args:
+ rank: the rank of query, key, value tensors.
+ """
+ if self._attention_axes is None:
+ self._attention_axes = tuple(range(1, rank - 2))
+ else:
+ self._attention_axes = tuple(self._attention_axes)
+ (
+ self._dot_product_equation,
+ self._combine_equation,
+ attn_scores_rank,
+ ) = _build_attention_equation(
+ rank, attn_axes=self._attention_axes)
+ norm_axes = tuple(
+ range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
+ self._softmax = Softmax(
+ axis=norm_axes) if tf.__version__ >= '2.0' else MaskedSoftmax(
+ axis=norm_axes)
+ self._dropout_layer = Dropout(rate=self._dropout)
+ self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim))
+
+ def _masked_softmax(self, attention_scores, attention_mask=None):
+ # Normalize the attention scores to probabilities.
+ # attention_scores = [B, N, T, S]
+ if attention_mask is not None:
+ # The expand dim happens starting from the `num_heads` dimension,
+ # (, num_heads, )
+ mask_expansion_axis = -len(self._attention_axes) * 2 - 1
+ for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
+ attention_mask = tf.expand_dims(
+ attention_mask, axis=mask_expansion_axis)
+ return self._softmax(attention_scores, mask=attention_mask)
+
+ def _compute_attention(self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ training=None):
+ """Applies Dot-product attention with query, key, value tensors.
+
+ This function defines the computation inside `call` with projected
+ multi-head Q, K, V inputs. Users can override this function for
+ customized attention implementation.
+
+ Args:
+ query: Projected query tensor of shape `(B, T, N, key_dim)`.
+ key: Projected key tensor of shape `(B, S, N, key_dim)`.
+ value: Projected value tensor of shape `(B, S, N, value_dim)`.
+ attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
+ attention to certain positions. It is generally not needed if
+ the `query` and `value` (and/or `key`) are masked.
+ training: Python boolean indicating whether the layer should behave
+ in training mode (adding dropout) or in inference mode (doing
+ nothing).
+
+ Returns:
+ attention_output: Multi-headed outputs of attention computation.
+ attention_scores: Multi-headed attention weights.
+ """
+ # Note: Applying scalar multiply at the smaller end of einsum improves
+ # XLA performance, but may introduce slight numeric differences in
+ # the Transformer attention head.
+ query = tf.multiply(query, tf.cast(self._inverse_sqrt_key_dim, query.dtype))
+
+ # Take the dot product between "query" and "key" to get the raw
+ # attention scores.
+ attention_scores = tf.einsum(self._dot_product_equation, key, query)
+
+ attention_scores = self._masked_softmax(attention_scores, attention_mask)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ if self.dropout:
+ final_attn_scores = self._dropout_layer(
+ attention_scores, training=training)
+ else:
+ final_attn_scores = attention_scores
+
+ # `context_layer` = [B, T, N, H]
+ attention_output = tf.einsum(self._combine_equation, final_attn_scores,
+ value)
+ return attention_output, attention_scores
+
+ def call(self, inputs, mask=None, training=None, **kwargs):
+ assert isinstance(
+ inputs, (tuple, list)), 'inputs of MultiHeadAttention must be a list'
+ query, value, key = (list(inputs) + [None] * 2)[:3]
+ if key is None:
+ key = value
+ if mask is None:
+ masks = [None] * 4
+ elif type(mask) in (list, tuple):
+ masks = (list(mask) + [None] * 4)[:4]
+ else:
+ masks = ([mask] + [None] * 3)[:4]
+ query_mask, value_mask, key_mask, attention_mask = masks
+ if attention_mask is None and value_mask is None:
+ value_mask = query_mask
+ attention_mask = self._compute_attention_mask(
+ query,
+ value,
+ query_mask=query_mask,
+ value_mask=value_mask,
+ key_mask=key_mask,
+ attention_mask=attention_mask,
+ use_causal_mask=self._use_causal_mask,
+ )
+
+ # N = `num_attention_heads`
+ # H = `size_per_head`
+ # `query` = [B, T, N ,H]
+ query = self._query_dense(query)
+
+ # `key` = [B, S, N, H]
+ key = self._key_dense(key)
+
+ # `value` = [B, S, N, H]
+ value = self._value_dense(value)
+ attention_output, attention_scores = self._compute_attention(
+ query, key, value, attention_mask, training)
+ attention_output = self._output_dense(attention_output)
+ if self._return_attention_scores:
+ return attention_output, attention_scores
+ return attention_output
+
+ def _compute_attention_mask(
+ self,
+ query,
+ value,
+ query_mask=None,
+ value_mask=None,
+ key_mask=None,
+ attention_mask=None,
+ use_causal_mask=False,
+ ):
+ """Computes the attention mask, using the Keras masks of the inputs.
+
+ * The `query`'s mask is reshaped from [B, T] to [B, T, 1].
+ * The `value`'s mask is reshaped from [B, S] to [B, 1, S].
+ * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
+ mask is ignored if `key` is `None` or if `key is value`.
+ * If `use_causal_mask=True`, then the causal mask is computed. Its shape
+ is [1, T, S].
+
+ All defined masks are merged using a logical AND operation (`&`).
+
+ In general, if the `query` and `value` are masked, then there is no need
+ to define the `attention_mask`.
+
+ Args:
+ query: Projected query tensor of shape `(B, T, N, key_dim)`.
+ value: Projected value tensor of shape `(B, T, N, value_dim)`.
+ attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
+ attention to certain positions.
+ use_causal_mask: A boolean to indicate whether to apply a causal
+ mask to prevent tokens from attending to future tokens (e.g.,
+ used in a decoder Transformer).
+
+ Returns:
+ attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
+ attention to certain positions, based on the Keras masks of the
+ `query`, `key`, `value`, and `attention_mask` tensors, and the
+ causal mask if `use_causal_mask=True`.
+ """
+ auto_mask = None
+ if query_mask is not None:
+ query_mask = tf.cast(query_mask, tf.bool) # defensive casting
+ # B = batch size, T = max query length
+ auto_mask = tf.expand_dims(query_mask, -1) # shape is [B, T, 1]
+ if value_mask is not None:
+ value_mask = tf.cast(value_mask, tf.bool) # defensive casting
+ # B = batch size, S == max value length
+ mask = tf.expand_dims(value_mask, -2) # shape is [B, 1, S]
+ auto_mask = mask if auto_mask is None else auto_mask & mask
+ if key_mask is not None:
+ key_mask = tf.cast(key_mask, tf.bool) # defensive casting
+ # B == batch size, S == max key length == max value length
+ mask = tf.expand_dims(key_mask, -2) # shape is [B, 1, S]
+ auto_mask = mask if auto_mask is None else auto_mask & mask
+ if use_causal_mask:
+ # the shape of the causal mask is [1, T, S]
+ mask = self._compute_causal_mask(query, value)
+ auto_mask = mask if auto_mask is None else auto_mask & mask
+ if auto_mask is not None:
+ # merge attention_mask & automatic mask, to shape [B, T, S]
+ attention_mask = (
+ auto_mask if attention_mask is None else
+ tf.cast(attention_mask, tf.bool) & auto_mask)
+ return attention_mask
+
+ def _compute_causal_mask(self, query, value=None):
+ """Computes a causal mask (e.g., for masked self-attention layers).
+
+ For example, if query and value both contain sequences of length 4,
+ this function returns a boolean tensor equal to:
+
+ ```
+ [[[True, False, False, False],
+ [True, True, False, False],
+ [True, True, True, False],
+ [True, True, True, True]]]
+ ```
+
+ Args:
+ query: query tensor of shape `(B, T, ...)`.
+ value: value tensor of shape `(B, S, ...)` (optional, defaults to
+ query).
+
+ Returns:
+ mask: a boolean tensor of shape `(1, T, S)` containing a lower
+ triangular matrix of shape `(T, S)`.
+ """
+ q_seq_length = tf.shape(query)[1]
+ v_seq_length = q_seq_length if value is None else tf.shape(value)[1]
+ ones_mask = tf.ones((1, q_seq_length, v_seq_length), dtype='int32')
+ row_index = tf.cumsum(ones_mask, axis=-2)
+ col_index = tf.cumsum(ones_mask, axis=-1)
+ return tf.greater_equal(row_index, col_index)
+
+ def compute_output_shape(self, input_shape):
+ if len(input_shape) == 3:
+ query_shape, value_shape, key_shape = input_shape
+ elif len(input_shape) == 2:
+ query_shape, value_shape = input_shape
+ key_shape = None
+ else:
+ raise ValueError('invalid input shape of MultiHeadAttention')
+ if key_shape is None:
+ key_shape = value_shape
+
+ if query_shape[-1] != value_shape[-1]:
+ raise ValueError(
+ 'The last dimension of `query_shape` and `value_shape` '
+ 'must be equal, but are {query_last_dim}, {value_last_dim}. '
+ 'Received: query_shape={query_shape}, value_shape={value_shape}'
+ .format(
+ query_shape=query_shape,
+ value_shape=value_shape,
+ query_last_dim=query_shape[-1],
+ value_last_dim=value_shape[-1]))
+
+ if value_shape[1:-1] != key_shape[1:-1]:
+ raise ValueError(
+ 'All dimensions of `value` and `key`, except the last one, '
+ 'must be equal. Received: value_shape={value_shape} and '
+ 'key_shape={key_shape}'.format(
+ key_shape=key_shape, value_shape=value_shape))
+
+ if self._output_shape:
+ if hasattr(self._output_dense, '__len__'):
+ return query_shape[:-1] + self._output_shape
+ else:
+ return query_shape[:-1] + [self._output_shape]
+
+ return query_shape
+
+
+def _index_to_einsum_variable(i):
+ """Coverts an index to a einsum variable name.
+
+ We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'.
+ """
+ return string.ascii_lowercase[i]
+
+
+def _build_attention_equation(rank, attn_axes):
+ """Builds einsum equations for the attention computation.
+
+ Query, key, value inputs after projection are expected to have the shape as:
+ `(bs, , , num_heads, channels)`.
+ `bs` and `` are treated as ``.
+
+ The attention operations can be generalized:
+ 1. Query-key dot product:
+ (, , num_heads, channels),
+ (, , num_heads, channels) ->
+ (, num_heads, , )
+ 2. Combination:
+ (, num_heads, , ),
+ (, , num_heads, channels) -> (, , num_heads, channels)
+
+ Args:
+ rank: Rank of query, key, value tensors.
+ attn_axes: List/tuple of axes, `[-1, rank)`,
+ that attention will be applied to.
+
+ Returns:
+ Einsum equations.
+ """
+ target_notation = ''
+ for i in range(rank):
+ target_notation += _index_to_einsum_variable(i)
+ # `batch_dims` includes the head dim.
+ batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
+ letter_offset = rank
+ source_notation = ''
+ for i in range(rank):
+ if i in batch_dims or i == rank - 1:
+ source_notation += target_notation[i]
+ else:
+ source_notation += _index_to_einsum_variable(letter_offset)
+ letter_offset += 1
+
+ product_notation = ''.join([target_notation[i] for i in batch_dims] +
+ [target_notation[i] for i in attn_axes] +
+ [source_notation[i] for i in attn_axes])
+ dot_product_equation = '%s,%s->%s' % (
+ source_notation,
+ target_notation,
+ product_notation,
+ )
+ attn_scores_rank = len(product_notation)
+ combine_equation = '%s,%s->%s' % (
+ product_notation,
+ source_notation,
+ target_notation,
+ )
+ return dot_product_equation, combine_equation, attn_scores_rank
+
+
+def _build_proj_equation(free_dims, bound_dims, output_dims):
+ """Builds an einsum equation for projections inside multi-head attention."""
+ input_str = ''
+ kernel_str = ''
+ output_str = ''
+ bias_axes = ''
+ letter_offset = 0
+ for i in range(free_dims):
+ char = _index_to_einsum_variable(i + letter_offset)
+ input_str += char
+ output_str += char
+
+ letter_offset += free_dims
+ for i in range(bound_dims):
+ char = _index_to_einsum_variable(i + letter_offset)
+ input_str += char
+ kernel_str += char
+
+ letter_offset += bound_dims
+ for i in range(output_dims):
+ char = _index_to_einsum_variable(i + letter_offset)
+ kernel_str += char
+ output_str += char
+ bias_axes += char
+ equation = '{input_str},{kernel_str}->{output_str}'.format(
+ input_str=input_str, kernel_str=kernel_str, output_str=output_str)
+ return equation, bias_axes, len(output_str)
+
+
+def _get_output_shape(output_rank, known_last_dims):
+ return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
+
+ # def __init__(self, params, name='multi_head_attention', reuse=None, **kwargs):
+ # super(MultiHeadAttention, self).__init__(name=name, **kwargs)
+ # self.num_heads = params.num_attention_heads
+ # self.d_model = params.hidden_size
+ # if self.d_model % self.num_heads != 0:
+ # raise ValueError(
+ # 'The hidden size (%d) is not a multiple of the number of attention '
+ # 'heads (%d)' % (self.d_model, self.num_heads))
+ # self.depth = self.d_model // self.num_heads
+ # self.wq = Dense(self.d_model)
+ # self.wk = Dense(self.d_model)
+ # self.wv = Dense(self.d_model)
+ # self.dense = Dense(self.d_model)
+ # att_params = Parameter.make_from_pb(params.attention)
+ # self.attention = Attention(att_params, 'scaled_dot_product_attention')
+ #
+ # # def split_heads(self, x, batch_size):
+ # # x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
+ # # return tf.transpose(x, perm=[0, 2, 1, 3])
+ #
+ # def call(self, inputs, training=None, **kwargs):
+ # q, v, k, mask = inputs
+ # batch_size = tf.shape(q)[0]
+ #
+ # q = self.wq(q)
+ # k = self.wk(k)
+ # v = self.wv(v)
+ #
+ # # q = self.split_heads(q, batch_size)
+ # # k = self.split_heads(k, batch_size)
+ # # v = self.split_heads(v, batch_size)
+ #
+ # attn = self.attention([q, v, k], mask=[mask, mask], training=training)
+ # return_attn_score = self.attention.return_attention_scores
+ # attention, attention_scores = attn if return_attn_score else attn, None
+ #
+ # # attention = tf.transpose(attention, perm=[0, 2, 1, 3])
+ # # attention = tf.reshape(attention, (batch_size, -1, self.d_model))
+ # output = self.dense(attention)
+ # if return_attn_score:
+ # return output, attention_scores
+ # return output
diff --git a/easy_rec/python/layers/keras/multi_task.py b/easy_rec/python/layers/keras/multi_task.py
new file mode 100644
index 000000000..dbb26ee86
--- /dev/null
+++ b/easy_rec/python/layers/keras/multi_task.py
@@ -0,0 +1,125 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+from tensorflow.python.keras.layers import Dense
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.layers.keras.attention import Attention
+from easy_rec.python.layers.keras.blocks import MLP
+from easy_rec.python.layers.utils import Parameter
+from easy_rec.python.protos import seq_encoder_pb2
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class MMoE(Layer):
+ """Multi-gate Mixture-of-Experts model."""
+
+ def __init__(self, params, name='MMoE', reuse=None, **kwargs):
+ super(MMoE, self).__init__(name=name, **kwargs)
+ params.check_required(['num_expert', 'num_task'])
+ self._reuse = reuse
+ self._num_expert = params.num_expert
+ self._num_task = params.num_task
+ if params.has_field('expert_mlp'):
+ expert_params = Parameter.make_from_pb(params.expert_mlp)
+ expert_params.l2_regularizer = params.l2_regularizer
+ self._has_experts = True
+ self._experts = [
+ MLP(expert_params, 'expert_%d' % i, reuse=reuse)
+ for i in range(self._num_expert)
+ ]
+ else:
+ self._has_experts = False
+
+ self._gates = []
+ for task_id in range(self._num_task):
+ dense = Dense(
+ self._num_expert,
+ activation='softmax',
+ name='gate_%d' % task_id,
+ kernel_regularizer=params.l2_regularizer)
+ self._gates.append(dense)
+
+ def call(self, inputs, training=None, **kwargs):
+ if self._num_expert == 0:
+ logging.warning('num_expert of MMoE layer `%s` is 0' % self.name)
+ return inputs
+ if self._has_experts:
+ expert_fea_list = [
+ expert(inputs, training=training) for expert in self._experts
+ ]
+ else:
+ expert_fea_list = inputs
+ experts_fea = tf.stack(expert_fea_list, axis=1)
+ # 不使用内置MLP作为expert时,gate的input使用最后一个额外的输入
+ gate_input = inputs if self._has_experts else inputs[self._num_expert]
+ task_input_list = []
+ for task_id in range(self._num_task):
+ gate = self._gates[task_id](gate_input)
+ gate = tf.expand_dims(gate, -1)
+ task_input = tf.multiply(experts_fea, gate)
+ task_input = tf.reduce_sum(task_input, axis=1)
+ task_input_list.append(task_input)
+ return task_input_list
+
+
+class AITMTower(Layer):
+ """Adaptive Information Transfer Multi-task (AITM) Tower."""
+
+ def __init__(self, params, name='AITMTower', reuse=None, **kwargs):
+ super(AITMTower, self).__init__(name=name, **kwargs)
+ self.project_dim = params.get_or_default('project_dim', None)
+ self.stop_gradient = params.get_or_default('stop_gradient', True)
+ self.transfer = None
+ if params.has_field('transfer_mlp'):
+ mlp_cfg = Parameter.make_from_pb(params.transfer_mlp)
+ mlp_cfg.l2_regularizer = params.l2_regularizer
+ self.transfer = MLP(mlp_cfg, name='transfer')
+ self.queries = []
+ self.keys = []
+ self.values = []
+ self.attention = None
+
+ def build(self, input_shape):
+ if not isinstance(input_shape, (tuple, list)):
+ super(AITMTower, self).build(input_shape)
+ return
+ dim = self.project_dim if self.project_dim else int(input_shape[0][-1])
+ for i in range(len(input_shape)):
+ self.queries.append(Dense(dim, name='query_%d' % i))
+ self.keys.append(Dense(dim, name='key_%d' % i))
+ self.values.append(Dense(dim, name='value_%d' % i))
+ attn_cfg = seq_encoder_pb2.Attention()
+ attn_cfg.scale_by_dim = True
+ attn_params = Parameter.make_from_pb(attn_cfg)
+ self.attention = Attention(attn_params)
+ super(AITMTower, self).build(input_shape)
+
+ def call(self, inputs, training=None, **kwargs):
+ if not isinstance(inputs, (tuple, list)):
+ return inputs
+
+ queries = []
+ keys = []
+ values = []
+ for i, tower in enumerate(inputs):
+ if i == 0: # current tower
+ queries.append(self.queries[i](tower))
+ keys.append(self.keys[i](tower))
+ values.append(self.values[i](tower))
+ else:
+ dep = tf.stop_gradient(tower) if self.stop_gradient else tower
+ if self.transfer is not None:
+ dep = self.transfer(dep, training=training)
+ queries.append(self.queries[i](dep))
+ keys.append(self.keys[i](dep))
+ values.append(self.values[i](dep))
+ query = tf.stack(queries, axis=1)
+ key = tf.stack(keys, axis=1)
+ value = tf.stack(values, axis=1)
+ attn = self.attention([query, value, key])
+ return attn[:, 0, :]
diff --git a/easy_rec/python/layers/keras/numerical_embedding.py b/easy_rec/python/layers/keras/numerical_embedding.py
new file mode 100644
index 000000000..65cc77d52
--- /dev/null
+++ b/easy_rec/python/layers/keras/numerical_embedding.py
@@ -0,0 +1,376 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+import math
+import os
+
+import tensorflow as tf
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.compat.array_ops import repeat
+from easy_rec.python.utils.activation import get_activation
+from easy_rec.python.utils.tf_utils import get_ps_num_from_tf_config
+
+curr_dir, _ = os.path.split(__file__)
+parent_dir = os.path.dirname(curr_dir)
+ops_idr = os.path.dirname(parent_dir)
+ops_dir = os.path.join(ops_idr, 'ops')
+if 'PAI' in tf.__version__:
+ ops_dir = os.path.join(ops_dir, '1.12_pai')
+elif tf.__version__.startswith('1.12'):
+ ops_dir = os.path.join(ops_dir, '1.12')
+elif tf.__version__.startswith('1.15'):
+ if 'IS_ON_PAI' in os.environ:
+ ops_dir = os.path.join(ops_dir, 'DeepRec')
+ else:
+ ops_dir = os.path.join(ops_dir, '1.15')
+elif tf.__version__.startswith('2.12'):
+ ops_dir = os.path.join(ops_dir, '2.12')
+
+logging.info('ops_dir is %s' % ops_dir)
+custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so')
+try:
+ custom_ops = tf.load_op_library(custom_op_path)
+ logging.info('load custom op from %s succeed' % custom_op_path)
+except Exception as ex:
+ logging.warning('load custom op from %s failed: %s' %
+ (custom_op_path, str(ex)))
+ custom_ops = None
+
+
+class NLinear(Layer):
+ """N linear layers for N token (feature) embeddings.
+
+ To understand this module, let's revise `tf.layers.dense`. When `tf.layers.dense` is
+ applied to three-dimensional inputs of the shape
+ ``(batch_size, n_tokens, d_embedding)``, then the same linear transformation is
+ applied to each of ``n_tokens`` token (feature) embeddings.
+
+ By contrast, `NLinear` allocates one linear layer per token (``n_tokens`` layers in total).
+ One such layer can be represented as ``tf.layers.dense(d_in, d_out)``.
+ So, the i-th linear transformation is applied to the i-th token embedding, as
+ illustrated in the following pseudocode::
+
+ layers = [tf.layers.dense(d_in, d_out) for _ in range(n_tokens)]
+ x = tf.random.normal(batch_size, n_tokens, d_in)
+ result = tf.stack([layers[i](x[:, i]) for i in range(n_tokens)], 1)
+
+ Examples:
+ .. testcode::
+
+ batch_size = 2
+ n_features = 3
+ d_embedding_in = 4
+ d_embedding_out = 5
+ x = tf.random.normal(batch_size, n_features, d_embedding_in)
+ m = NLinear(n_features, d_embedding_in, d_embedding_out)
+ assert m(x).shape == (batch_size, n_features, d_embedding_out)
+ """
+
+ def __init__(self,
+ n_tokens,
+ d_in,
+ d_out,
+ bias=True,
+ name='nd_linear',
+ **kwargs):
+ """Init with input shapes.
+
+ Args:
+ n_tokens: the number of tokens (features)
+ d_in: the input dimension
+ d_out: the output dimension
+ bias: indicates if the underlying linear layers have biases
+ name: layer name
+ """
+ super(NLinear, self).__init__(name=name, **kwargs)
+ self.weight = self.add_weight(
+ 'weights', [1, n_tokens, d_in, d_out], dtype=tf.float32)
+ if bias:
+ initializer = tf.constant_initializer(0.0)
+ self.bias = self.add_weight(
+ 'bias', [1, n_tokens, d_out],
+ dtype=tf.float32,
+ initializer=initializer)
+ else:
+ self.bias = None
+
+ def call(self, x, **kwargs):
+ if x.shape.ndims != 3:
+ raise ValueError(
+ 'The input must have three dimensions (batch_size, n_tokens, d_embedding)'
+ )
+ if x.shape[2] != self.weight.shape[2]:
+ raise ValueError('invalid input embedding dimension %d, expect %d' %
+ (int(x.shape[2]), int(self.weight.shape[2])))
+
+ x = x[..., None] * self.weight # [B, N, D, D_out]
+ x = tf.reduce_sum(x, axis=-2) # [B, N, D_out]
+ if self.bias is not None:
+ x = x + self.bias
+ return x
+
+
+class PeriodicEmbedding(Layer):
+ """Periodic embeddings for numerical features described in [1].
+
+ References:
+ * [1] Yury Gorishniy, Ivan Rubachev, Artem Babenko,
+ "On Embeddings for Numerical Features in Tabular Deep Learning", 2022
+ https://arxiv.org/pdf/2203.05556.pdf
+
+ Attributes:
+ embedding_dim: the embedding size, must be an even positive integer.
+ sigma: the scale of the weight initialization.
+ **This is a super important parameter which significantly affects performance**.
+ Its optimal value can be dramatically different for different datasets, so
+ no "default value" can exist for this parameter, and it must be tuned for
+ each dataset. In the original paper, during hyperparameter tuning, this
+ parameter was sampled from the distribution ``LogUniform[1e-2, 1e2]``.
+ A similar grid would be ``[1e-2, 1e-1, 1e0, 1e1, 1e2]``.
+ If possible, add more intermediate values to this grid.
+ output_3d_tensor: whether to output a 3d tensor
+ output_tensor_list: whether to output the list of embedding
+ """
+
+ def __init__(self, params, name='periodic_embedding', reuse=None, **kwargs):
+ super(PeriodicEmbedding, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ params.check_required(['embedding_dim', 'sigma'])
+ self.embedding_dim = int(params.embedding_dim)
+ if self.embedding_dim % 2:
+ raise ValueError('embedding_dim must be even')
+ sigma = params.sigma
+ self.initializer = tf.random_normal_initializer(stddev=sigma)
+ self.add_linear_layer = params.get_or_default('add_linear_layer', True)
+ self.linear_activation = params.get_or_default('linear_activation', 'relu')
+ self.output_tensor_list = params.get_or_default('output_tensor_list', False)
+ self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
+
+ def build(self, input_shape):
+ if input_shape.ndims != 2:
+ raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.')
+ self.num_features = int(input_shape[-1])
+ num_ps = get_ps_num_from_tf_config()
+ partitioner = None
+ if num_ps > 0:
+ partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
+ emb_dim = self.embedding_dim // 2
+ self.coef = self.add_weight(
+ 'coefficients',
+ shape=[1, self.num_features, emb_dim],
+ partitioner=partitioner,
+ initializer=self.initializer)
+ if self.add_linear_layer:
+ self.linear = NLinear(
+ self.num_features,
+ self.embedding_dim,
+ self.embedding_dim,
+ name='nd_linear')
+ super(PeriodicEmbedding, self).build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ features = inputs[..., None] # [B, N, 1]
+ v = 2 * math.pi * self.coef * features # [B, N, E]
+ emb = tf.concat([tf.sin(v), tf.cos(v)], axis=-1) # [B, N, 2E]
+
+ dim = self.embedding_dim
+ if self.add_linear_layer:
+ emb = self.linear(emb)
+ act = get_activation(self.linear_activation)
+ if callable(act):
+ emb = act(emb)
+ output = tf.reshape(emb, [-1, self.num_features * dim])
+
+ if self.output_tensor_list:
+ return output, tf.unstack(emb, axis=1)
+ if self.output_3d_tensor:
+ return output, emb
+ return output
+
+
+class AutoDisEmbedding(Layer):
+ """An Embedding Learning Framework for Numerical Features in CTR Prediction.
+
+ Refer: https://arxiv.org/pdf/2012.08986v2.pdf
+ """
+
+ def __init__(self, params, name='auto_dis_embedding', reuse=None, **kwargs):
+ super(AutoDisEmbedding, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ params.check_required(['embedding_dim', 'num_bins', 'temperature'])
+ self.emb_dim = int(params.embedding_dim)
+ self.num_bins = int(params.num_bins)
+ self.temperature = params.temperature
+ self.keep_prob = params.get_or_default('keep_prob', 0.8)
+ self.output_tensor_list = params.get_or_default('output_tensor_list', False)
+ self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
+
+ def build(self, input_shape):
+ if input_shape.ndims != 2:
+ raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.')
+ self.num_features = int(input_shape[-1])
+ num_ps = get_ps_num_from_tf_config()
+ partitioner = None
+ if num_ps > 0:
+ partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
+ self.meta_emb = self.add_weight(
+ 'meta_embedding',
+ shape=[self.num_features, self.num_bins, self.emb_dim],
+ partitioner=partitioner)
+ self.proj_w = self.add_weight(
+ 'project_w',
+ shape=[1, self.num_features, self.num_bins],
+ partitioner=partitioner)
+ self.proj_mat = self.add_weight(
+ 'project_mat',
+ shape=[self.num_features, self.num_bins, self.num_bins],
+ partitioner=partitioner)
+ super(AutoDisEmbedding, self).build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ x = tf.expand_dims(inputs, axis=-1) # [B, N, 1]
+ hidden = tf.nn.leaky_relu(self.proj_w * x) # [B, N, num_bin]
+ # 低版本的tf(1.12) matmul 不支持广播,所以改成 einsum
+ # y = tf.matmul(mat, hidden[..., None]) # [B, N, num_bin, 1]
+ # y = tf.squeeze(y, axis=3) # [B, N, num_bin]
+ y = tf.einsum('nik,bnk->bni', self.proj_mat, hidden) # [B, N, num_bin]
+
+ # keep_prob(float): if dropout_flag is True, keep_prob rate to keep connect
+ alpha = self.keep_prob
+ x_bar = y + alpha * hidden # [B, N, num_bin]
+ x_hat = tf.nn.softmax(x_bar / self.temperature) # [B, N, num_bin]
+
+ # emb = tf.matmul(x_hat[:, :, None, :], meta_emb) # [B, N, 1, D]
+ # emb = tf.squeeze(emb, axis=2) # [B, N, D]
+ emb = tf.einsum('bnk,nkd->bnd', x_hat, self.meta_emb)
+ output = tf.reshape(emb, [-1, self.emb_dim * self.num_features]) # [B, N*D]
+
+ if self.output_tensor_list:
+ return output, tf.unstack(emb, axis=1)
+ if self.output_3d_tensor:
+ return output, emb
+ return output
+
+
+class NaryDisEmbedding(Layer):
+ """Numerical Feature Representation with Hybrid 𝑁 -ary Encoding, CIKM 2022..
+
+ Refer: https://dl.acm.org/doi/pdf/10.1145/3511808.3557090
+ """
+
+ def __init__(self, params, name='nary_dis_embedding', reuse=None, **kwargs):
+ super(NaryDisEmbedding, self).__init__(name=name, **kwargs)
+ self.reuse = reuse
+ self.nary_carry = custom_ops.nary_carry
+ params.check_required(['embedding_dim', 'carries'])
+ self.emb_dim = int(params.embedding_dim)
+ self.carries = params.get_or_default('carries', [2, 9])
+ self.num_replicas = params.get_or_default('num_replicas', 1)
+ assert self.num_replicas >= 1, 'num replicas must be >= 1'
+ self.lengths = list(map(self.max_length, self.carries))
+ self.vocab_size = int(sum(self.lengths))
+ self.multiplier = params.get_or_default('multiplier', 1.0)
+ self.intra_ary_pooling = params.get_or_default('intra_ary_pooling', 'sum')
+ self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
+ self.output_tensor_list = params.get_or_default('output_tensor_list', False)
+ logging.info(
+ '{} carries: {}, lengths: {}, vocab_size: {}, intra_ary: {}, replicas: {}, multiplier: {}'
+ .format(self.name, ','.join(map(str, self.carries)),
+ ','.join(map(str, self.lengths)), self.vocab_size,
+ self.intra_ary_pooling, self.num_replicas, self.multiplier))
+
+ @staticmethod
+ def max_length(carry):
+ bits = math.log(4294967295, carry)
+ return (math.floor(bits) + 1) * carry
+
+ def build(self, input_shape):
+ assert isinstance(input_shape,
+ tf.TensorShape), 'NaryDisEmbedding only takes 1 input'
+ self.num_features = int(input_shape[-1])
+ logging.info('%s has %d input features', self.name, self.num_features)
+ vocab_size = self.num_features * self.vocab_size
+ emb_dim = self.emb_dim * self.num_replicas
+ num_ps = get_ps_num_from_tf_config()
+ partitioner = None
+ if num_ps > 0:
+ partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
+ self.embedding_table = self.add_weight(
+ 'embed_table', shape=[vocab_size, emb_dim], partitioner=partitioner)
+ super(NaryDisEmbedding, self).build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ if inputs.shape.ndims != 2:
+ raise ValueError('inputs of NaryDisEmbedding must have 2 dimensions.')
+ if self.multiplier != 1.0:
+ inputs *= self.multiplier
+ inputs = tf.to_int32(inputs)
+ offset, emb_indices, emb_splits = 0, [], []
+ with ops.device('/CPU:0'):
+ for carry, length in zip(self.carries, self.lengths):
+ values, splits = self.nary_carry(inputs, carry=carry, offset=offset)
+ offset += length
+ emb_indices.append(values)
+ emb_splits.append(splits)
+ indices = tf.concat(emb_indices, axis=0)
+ splits = tf.concat(emb_splits, axis=0)
+ # embedding shape: [B*N*C, D]
+ embedding = tf.nn.embedding_lookup(self.embedding_table, indices)
+
+ total_length = tf.size(splits)
+ if self.intra_ary_pooling == 'sum':
+ if tf.__version__ >= '2.0':
+ segment_ids = tf.repeat(tf.range(total_length), repeats=splits)
+ else:
+ segment_ids = repeat(tf.range(total_length), repeats=splits)
+ embedding = tf.math.segment_sum(embedding, segment_ids)
+ elif self.intra_ary_pooling == 'mean':
+ if tf.__version__ >= '2.0':
+ segment_ids = tf.repeat(tf.range(total_length), repeats=splits)
+ else:
+ segment_ids = repeat(tf.range(total_length), repeats=splits)
+ embedding = tf.math.segment_mean(embedding, segment_ids)
+ else:
+ raise ValueError('Unsupported intra ary pooling method %s' %
+ self.intra_ary_pooling)
+ # B: batch size
+ # N: num features
+ # C: num carries
+ # D: embedding dimension
+ # R: num replicas
+ # shape of embedding: [B*N*C, R*D]
+ N = self.num_features
+ C = len(self.carries)
+ D = self.emb_dim
+ if self.num_replicas == 1:
+ embedding = tf.reshape(embedding, [C, -1, D]) # [C, B*N, D]
+ embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D]
+ embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D]
+ output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D]
+ if self.output_tensor_list:
+ return output, tf.split(embedding, N) # [B, C*D] * N
+ if self.output_3d_tensor:
+ embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D]
+ return output, embedding
+ return output
+
+ # self.num_replicas > 1:
+ replicas = tf.split(embedding, self.num_replicas, axis=1)
+ outputs = []
+ outputs2 = []
+ for replica in replicas:
+ # shape of replica: [B*N*C, D]
+ embedding = tf.reshape(replica, [C, -1, D]) # [C, B*N, D]
+ embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D]
+ embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D]
+ output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D]
+ outputs.append(output)
+ if self.output_tensor_list:
+ embedding = tf.split(embedding, N) # [B, C*D] * N
+ outputs2.append(embedding)
+ elif self.output_3d_tensor:
+ embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D]
+ outputs2.append(embedding)
+ return outputs + outputs2
diff --git a/easy_rec/python/layers/keras/ppnet.py b/easy_rec/python/layers/keras/ppnet.py
new file mode 100644
index 000000000..431034924
--- /dev/null
+++ b/easy_rec/python/layers/keras/ppnet.py
@@ -0,0 +1,194 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Convenience blocks for building models."""
+import logging
+
+import tensorflow as tf
+
+from easy_rec.python.layers.keras.activation import activation_layer
+from easy_rec.python.utils.tf_utils import add_elements_to_collection
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class GateNN(tf.keras.layers.Layer):
+
+ def __init__(self,
+ params,
+ output_units=None,
+ name='gate_nn',
+ reuse=None,
+ **kwargs):
+ super(GateNN, self).__init__(name=name, **kwargs)
+ output_dim = output_units if output_units is not None else params.output_dim
+ hidden_dim = params.get_or_default('hidden_dim', output_dim)
+ initializer = params.get_or_default('initializer', 'he_uniform')
+ do_batch_norm = params.get_or_default('use_bn', False)
+ activation = params.get_or_default('activation', 'relu')
+ dropout_rate = params.get_or_default('dropout_rate', 0.0)
+
+ self._sub_layers = []
+ dense = tf.keras.layers.Dense(
+ units=hidden_dim,
+ use_bias=not do_batch_norm,
+ kernel_initializer=initializer)
+ self._sub_layers.append(dense)
+
+ if do_batch_norm:
+ bn = tf.keras.layers.BatchNormalization(trainable=True)
+ self._sub_layers.append(bn)
+
+ act_layer = activation_layer(activation)
+ self._sub_layers.append(act_layer)
+
+ if 0.0 < dropout_rate < 1.0:
+ dropout = tf.keras.layers.Dropout(dropout_rate)
+ self._sub_layers.append(dropout)
+ elif dropout_rate >= 1.0:
+ raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
+
+ dense = tf.keras.layers.Dense(
+ units=output_dim,
+ activation='sigmoid',
+ use_bias=not do_batch_norm,
+ kernel_initializer=initializer,
+ name='weight')
+ self._sub_layers.append(dense)
+ self._sub_layers.append(lambda x: x * 2)
+
+ def call(self, x, training=None, **kwargs):
+ """Performs the forward computation of the block."""
+ for layer in self._sub_layers:
+ cls = layer.__class__.__name__
+ if cls in ('Dropout', 'BatchNormalization', 'Dice'):
+ x = layer(x, training=training)
+ if cls in ('BatchNormalization', 'Dice') and training:
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
+ else:
+ x = layer(x)
+ return x
+
+
+class PPNet(tf.keras.layers.Layer):
+ """PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information.
+
+ Attributes:
+ units: Sequential list of layer sizes.
+ use_bias: Whether to include a bias term.
+ activation: Type of activation to use on all except the last layer.
+ final_activation: Type of activation to use on last layer.
+ **kwargs: Extra args passed to the Keras Layer base class.
+ """
+
+ def __init__(self, params, name='ppnet', reuse=None, **kwargs):
+ super(PPNet, self).__init__(name=name, **kwargs)
+ params.check_required('mlp')
+ self.full_gate_input = params.get_or_default('full_gate_input', True)
+ mode = params.get_or_default('mode', 'lazy')
+ gate_params = params.gate_params
+ params = params.mlp
+ params.check_required('hidden_units')
+ use_bn = params.get_or_default('use_bn', True)
+ use_final_bn = params.get_or_default('use_final_bn', True)
+ use_bias = params.get_or_default('use_bias', False)
+ use_final_bias = params.get_or_default('use_final_bias', False)
+ dropout_rate = list(params.get_or_default('dropout_ratio', []))
+ activation = params.get_or_default('activation', 'relu')
+ initializer = params.get_or_default('initializer', 'he_uniform')
+ final_activation = params.get_or_default('final_activation', None)
+ use_bn_after_act = params.get_or_default('use_bn_after_activation', False)
+ units = list(params.hidden_units)
+ logging.info(
+ 'MLP(%s) units: %s, dropout: %r, activate=%s, use_bn=%r, final_bn=%r,'
+ ' final_activate=%s, bias=%r, initializer=%s, bn_after_activation=%r' %
+ (name, units, dropout_rate, activation, use_bn, use_final_bn,
+ final_activation, use_bias, initializer, use_bn_after_act))
+ assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name
+ self.reuse = reuse
+
+ num_dropout = len(dropout_rate)
+ self._sub_layers = []
+
+ if mode != 'lazy':
+ self._sub_layers.append(GateNN(gate_params, None, 'gate_0'))
+ for i, num_units in enumerate(units[:-1]):
+ name = 'layer_%d' % i
+ drop_rate = dropout_rate[i] if i < num_dropout else 0.0
+ self.add_rich_layer(num_units, use_bn, drop_rate, activation, initializer,
+ use_bias, use_bn_after_act, name,
+ params.l2_regularizer)
+ self._sub_layers.append(
+ GateNN(gate_params, num_units, 'gate_%d' % (i + 1)))
+
+ n = len(units) - 1
+ drop_rate = dropout_rate[n] if num_dropout > n else 0.0
+ name = 'layer_%d' % n
+ self.add_rich_layer(units[-1], use_final_bn, drop_rate, final_activation,
+ initializer, use_final_bias, use_bn_after_act, name,
+ params.l2_regularizer)
+ if mode == 'lazy':
+ self._sub_layers.append(
+ GateNN(gate_params, units[-1], 'gate_%d' % (n + 1)))
+
+ def add_rich_layer(self,
+ num_units,
+ use_bn,
+ dropout_rate,
+ activation,
+ initializer,
+ use_bias,
+ use_bn_after_activation,
+ name,
+ l2_reg=None):
+ act_layer = activation_layer(activation, name='%s/act' % name)
+ if use_bn and not use_bn_after_activation:
+ dense = tf.keras.layers.Dense(
+ units=num_units,
+ use_bias=use_bias,
+ kernel_initializer=initializer,
+ kernel_regularizer=l2_reg,
+ name='%s/dense' % name)
+ self._sub_layers.append(dense)
+ bn = tf.keras.layers.BatchNormalization(
+ name='%s/bn' % name, trainable=True)
+ self._sub_layers.append(bn)
+ self._sub_layers.append(act_layer)
+ else:
+ dense = tf.keras.layers.Dense(
+ num_units,
+ use_bias=use_bias,
+ kernel_initializer=initializer,
+ kernel_regularizer=l2_reg,
+ name='%s/dense' % name)
+ self._sub_layers.append(dense)
+ self._sub_layers.append(act_layer)
+ if use_bn and use_bn_after_activation:
+ bn = tf.keras.layers.BatchNormalization(name='%s/bn' % name)
+ self._sub_layers.append(bn)
+
+ if 0.0 < dropout_rate < 1.0:
+ dropout = tf.keras.layers.Dropout(dropout_rate, name='%s/dropout' % name)
+ self._sub_layers.append(dropout)
+ elif dropout_rate >= 1.0:
+ raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
+
+ def call(self, inputs, training=None, **kwargs):
+ """Performs the forward computation of the block."""
+ x, gate_input = inputs
+ if self.full_gate_input:
+ with tf.name_scope(self.name):
+ gate_input = tf.concat([tf.stop_gradient(x), gate_input], axis=-1)
+
+ for layer in self._sub_layers:
+ cls = layer.__class__.__name__
+ if cls == 'GateNN':
+ gate = layer(gate_input)
+ x *= gate
+ elif cls in ('Dropout', 'BatchNormalization', 'Dice'):
+ x = layer(x, training=training)
+ if cls in ('BatchNormalization', 'Dice') and training:
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
+ else:
+ x = layer(x)
+ return x
diff --git a/easy_rec/python/layers/keras/transformer.py b/easy_rec/python/layers/keras/transformer.py
new file mode 100644
index 000000000..d71a02831
--- /dev/null
+++ b/easy_rec/python/layers/keras/transformer.py
@@ -0,0 +1,192 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.keras.layers import Dense
+from tensorflow.python.keras.layers import Dropout
+from tensorflow.python.keras.layers import Embedding
+from tensorflow.python.keras.layers import Layer
+
+from easy_rec.python.layers.keras import MultiHeadAttention
+from easy_rec.python.layers.keras.layer_norm import LayerNormalization
+from easy_rec.python.layers.utils import Parameter
+from easy_rec.python.protos import seq_encoder_pb2
+
+
+class TransformerBlock(Layer):
+ """A transformer block combines multi-head attention and feed-forward networks with layer normalization and dropout.
+
+ Purpose: Combines attention and feed-forward layers with residual connections and normalization.
+ Components: Multi-head attention, feed-forward network, dropout, and layer normalization.
+ Output: Enhanced representation after applying attention and feed-forward layers.
+ """
+
+ def __init__(self, params, name='transformer_block', reuse=None, **kwargs):
+ super(TransformerBlock, self).__init__(name=name, **kwargs)
+ d_model = params.hidden_size
+ num_heads = params.num_attention_heads
+ mha_cfg = seq_encoder_pb2.MultiHeadAttention()
+ mha_cfg.num_heads = num_heads
+ mha_cfg.key_dim = d_model // num_heads
+ mha_cfg.dropout = params.get_or_default('attention_probs_dropout_prob', 0.0)
+ mha_cfg.return_attention_scores = False
+ args = Parameter.make_from_pb(mha_cfg)
+ self.mha = MultiHeadAttention(args, 'multi_head_attn')
+ dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1)
+ ffn_units = params.get_or_default('intermediate_size', d_model)
+ ffn_act = params.get_or_default('hidden_act', 'relu')
+ self.ffn_dense1 = Dense(ffn_units, activation=ffn_act)
+ self.ffn_dense2 = Dense(d_model)
+ if tf.__version__ >= '2.0':
+ self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
+ self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
+ else:
+ self.layer_norm1 = LayerNormalization(epsilon=1e-6)
+ self.layer_norm2 = LayerNormalization(epsilon=1e-6)
+ self.dropout1 = Dropout(dropout_rate)
+ self.dropout2 = Dropout(dropout_rate)
+
+ def call(self, inputs, training=None, **kwargs):
+ x, mask = inputs
+ attn_output = self.mha([x, x, x], mask=mask, training=training)
+ attn_output = self.dropout1(attn_output, training=training)
+ out1 = self.layer_norm1(x + attn_output)
+ ffn_mid = self.ffn_dense1(out1)
+ ffn_output = self.ffn_dense2(ffn_mid)
+ ffn_output = self.dropout2(ffn_output, training=training)
+ out2 = self.layer_norm2(out1 + ffn_output)
+ return out2
+
+
+# Positional Encoding, https://www.tensorflow.org/text/tutorials/transformer
+def positional_encoding(length, depth):
+ depth = depth / 2
+ positions = np.arange(length)[:, np.newaxis] # (seq, 1)
+ depths = np.arange(depth)[np.newaxis, :] / depth # (1, depth)
+ angle_rates = 1 / (10000**depths) # (1, depth)
+ angle_rads = positions * angle_rates # (pos, depth)
+ pos_encoding = np.concatenate(
+ [np.sin(angle_rads), np.cos(angle_rads)], axis=-1)
+ return tf.cast(pos_encoding, dtype=tf.float32)
+
+
+class PositionalEmbedding(Layer):
+
+ def __init__(self, vocab_size, d_model, max_position, name='pos_embedding'):
+ super(PositionalEmbedding, self).__init__(name=name)
+ self.d_model = d_model
+ self.embedding = Embedding(vocab_size, d_model)
+ self.pos_encoding = positional_encoding(length=max_position, depth=d_model)
+
+ def call(self, x, training=None):
+ length = tf.shape(x)[1]
+ x = self.embedding(x)
+ # This factor sets the relative scale of the embedding and positional_encoding.
+ x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
+ x = x + self.pos_encoding[tf.newaxis, :length, :]
+ return x
+
+
+class TransformerEncoder(Layer):
+ """The encoder consists of a stack of encoder layers.
+
+ It converts the input sequence into a set of embeddings enriched with positional information.
+ Purpose: Encodes the input sequence into a set of embeddings.
+ Components: Embedding layer, positional encoding, and a stack of transformer blocks.
+ Output: Encoded representation of the input sequence.
+ """
+
+ def __init__(self, params, name='transformer_encoder', reuse=None, **kwargs):
+ super(TransformerEncoder, self).__init__(name=name, **kwargs)
+ d_model = params.hidden_size
+ dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1)
+ max_position = params.get_or_default('max_position_embeddings', 512)
+ num_layers = params.get_or_default('num_hidden_layers', 1)
+ vocab_size = params.vocab_size
+ logging.info('vocab size of TransformerEncoder(%s) is %d', name, vocab_size)
+ self.output_all = params.get_or_default('output_all_token_embeddings', True)
+ self.pos_encoding = PositionalEmbedding(vocab_size, d_model, max_position)
+ self.dropout = Dropout(dropout_rate)
+ self.enc_layers = [
+ TransformerBlock(params, 'layer_%d' % i) for i in range(num_layers)
+ ]
+ self._vocab_size = vocab_size
+ self._max_position = max_position
+
+ @property
+ def vocab_size(self):
+ return self._vocab_size
+
+ @property
+ def max_position(self):
+ return self._max_position
+
+ def call(self, inputs, training=None, **kwargs):
+ x, mask = inputs
+ # `x` is token-IDs shape: (batch, seq_len)
+ x = self.pos_encoding(x) # Shape `(batch_size, seq_len, d_model)`.
+ x = self.dropout(x, training=training)
+ for block in self.enc_layers:
+ x = block([x, mask], training)
+ # x Shape `(batch_size, seq_len, d_model)`.
+ return x if self.output_all else x[:, 0, :]
+
+
+class TextEncoder(Layer):
+
+ def __init__(self, params, name='text_encoder', reuse=None, **kwargs):
+ super(TextEncoder, self).__init__(name=name, **kwargs)
+ self.separator = params.get_or_default('separator', ' ')
+ self.cls_token = '[CLS]' + self.separator
+ self.sep_token = self.separator + '[SEP]' + self.separator
+ params.transformer.output_all_token_embeddings = False
+ trans_params = Parameter.make_from_pb(params.transformer)
+ vocab_file = params.get_or_default('vocab_file', None)
+ self.vocab = None
+ self.default_token_id = params.get_or_default('default_token_id', 0)
+ if vocab_file is not None:
+ self.vocab = tf.feature_column.categorical_column_with_vocabulary_file(
+ 'tokens',
+ vocabulary_file=vocab_file,
+ default_value=self.default_token_id)
+ logging.info('vocab file of TextEncoder(%s) is %s', name, vocab_file)
+ trans_params.vocab_size = self.vocab.vocabulary_size
+ self.encoder = TransformerEncoder(trans_params, name='transformer')
+
+ def call(self, inputs, training=None, **kwargs):
+ if type(inputs) not in (tuple, list):
+ inputs = [inputs]
+ inputs = [tf.squeeze(text) for text in inputs]
+ batch_size = tf.shape(inputs[0])
+ cls = tf.fill(batch_size, self.cls_token)
+ sep = tf.fill(batch_size, self.sep_token)
+ sentences = [cls]
+ for sentence in inputs:
+ sentences.append(sentence)
+ sentences.append(sep)
+ text = tf.strings.join(sentences)
+ tokens = tf.strings.split(text, self.separator)
+ if self.vocab is not None:
+ features = {'tokens': tokens}
+ token_ids = self.vocab._transform_feature(features)
+ token_ids = tf.sparse.to_dense(
+ token_ids, default_value=self.default_token_id, name='token_ids')
+ length = tf.shape(token_ids)[-1]
+ token_ids = tf.cond(
+ tf.less_equal(length, self.encoder.max_position), lambda: token_ids,
+ lambda: tf.slice(token_ids, [0, 0], [-1, self.encoder.max_position]))
+ mask = tf.not_equal(token_ids, self.default_token_id, name='mask')
+ else:
+ tokens = tf.sparse.to_dense(tokens, default_value='')
+ length = tf.shape(tokens)[-1]
+ tokens = tf.cond(
+ tf.less_equal(length, self.encoder.max_position), lambda: tokens,
+ lambda: tf.slice(tokens, [0, 0], [-1, self.encoder.max_position]))
+ token_ids = tf.string_to_hash_bucket_fast(
+ tokens, self.encoder.vocab_size, name='token_ids')
+ mask = tf.not_equal(tokens, '', name='mask')
+
+ encoding = self.encoder([token_ids, mask], training=training)
+ return encoding
diff --git a/easy_rec/python/layers/multihead_cross_attention.py b/easy_rec/python/layers/multihead_cross_attention.py
new file mode 100644
index 000000000..f230ac974
--- /dev/null
+++ b/easy_rec/python/layers/multihead_cross_attention.py
@@ -0,0 +1,749 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+
+from easy_rec.python.compat.layers import layer_norm as tf_layer_norm
+from easy_rec.python.utils.activation import gelu
+from easy_rec.python.utils.shape_utils import get_shape_list
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def create_initializer(initializer_range=0.02):
+ """Creates a `truncated_normal_initializer` with the given range."""
+ return tf.truncated_normal_initializer(stddev=initializer_range)
+
+
+def dropout(input_tensor, dropout_prob):
+ """Perform dropout.
+
+ Args:
+ input_tensor: float Tensor.
+ dropout_prob: Python float. The probability of dropping out a value (NOT of
+ *keeping* a dimension as in `tf.nn.dropout`).
+
+ Returns:
+ A version of `input_tensor` with dropout applied.
+ """
+ if dropout_prob is None or dropout_prob == 0.0:
+ return input_tensor
+
+ output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
+ return output
+
+
+def attention_layer(from_tensor,
+ to_tensor,
+ size_per_head,
+ num_attention_heads=1,
+ attention_mask=None,
+ query_act=None,
+ key_act=None,
+ value_act=None,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ do_return_2d_tensor=False,
+ batch_size=None,
+ from_seq_length=None,
+ to_seq_length=None,
+ reuse=None):
+ """Performs multi-headed attention from `from_tensor` to `to_tensor`.
+
+ This is an implementation of multi-headed attention based on "Attention is all you Need".
+ If `from_tensor` and `to_tensor` are the same, then this is self-attention.
+ Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`,
+ and returns a fixed-width vector.
+ This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors.
+ These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape:
+ [batch_size, seq_length, size_per_head].
+ Then, the query and key tensors are dot-producted and scaled. These are
+ softmaxed to obtain attention probabilities. The value tensors are then
+ interpolated by these probabilities, then concatenated back to a single
+ tensor and returned.
+ In practice, the multi-headed attention are done with transposes and
+ reshapes rather than actual separate tensors.
+
+ Args:
+ from_tensor: float Tensor of shape [batch_size, from_seq_length,
+ from_width].
+ to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
+ size_per_head: int. Size of each attention head.
+ num_attention_heads: int. Number of attention heads.
+ attention_mask: (optional) int32 Tensor of shape [batch_size,
+ from_seq_length, to_seq_length]. The values should be 1 or 0. The
+ attention scores will effectively be set to -infinity for any positions in
+ the mask that are 0, and will be unchanged for positions that are 1.
+ query_act: (optional) Activation function for the query transform.
+ key_act: (optional) Activation function for the key transform.
+ value_act: (optional) Activation function for the value transform.
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
+ attention probabilities.
+ initializer_range: float. Range of the weight initializer.
+ do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
+ * from_seq_length, num_attention_heads * size_per_head]. If False, the
+ output will be of shape [batch_size, from_seq_length, num_attention_heads
+ * size_per_head].
+ batch_size: (Optional) int. If the input is 2D, this might be the batch size
+ of the 3D version of the `from_tensor` and `to_tensor`.
+ from_seq_length: (Optional) If the input is 2D, this might be the seq length
+ of the 3D version of the `from_tensor`.
+ to_seq_length: (Optional) If the input is 2D, this might be the seq length
+ of the 3D version of the `to_tensor`.
+ reuse: whether to reuse this layer
+
+ Returns:
+ float Tensor of shape [batch_size, from_seq_length,
+ num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
+ true, this will be of shape [batch_size * from_seq_length,
+ num_attention_heads * size_per_head]).
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ """
+
+ def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
+ seq_length, width):
+ output_tensor = tf.reshape(
+ input_tensor, [batch_size, seq_length, num_attention_heads, width])
+
+ output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
+ return output_tensor
+
+ from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
+ to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
+
+ if len(from_shape) != len(to_shape):
+ raise ValueError(
+ 'The rank of `from_tensor` must match the rank of `to_tensor`.')
+
+ if len(from_shape) == 3:
+ batch_size = from_shape[0]
+ from_seq_length = from_shape[1]
+ to_seq_length = to_shape[1]
+ elif len(from_shape) == 2:
+ if (batch_size is None or from_seq_length is None or to_seq_length is None):
+ raise ValueError(
+ 'When passing in rank 2 tensors to attention_layer, the values '
+ 'for `batch_size`, `from_seq_length`, and `to_seq_length` '
+ 'must all be specified.')
+
+ # Scalar dimensions referenced here:
+ # B = batch size (number of sequences)
+ # F = `from_tensor` sequence length
+ # T = `to_tensor` sequence length
+ # N = `num_attention_heads`
+ # H = `size_per_head`
+
+ from_tensor_2d = reshape_to_matrix(from_tensor)
+ to_tensor_2d = reshape_to_matrix(to_tensor)
+
+ # `query_layer` = [B*F, N*H]
+ query_layer = tf.layers.dense(
+ from_tensor_2d,
+ num_attention_heads * size_per_head,
+ activation=query_act,
+ name='query',
+ kernel_initializer=create_initializer(initializer_range),
+ reuse=reuse)
+
+ # `key_layer` = [B*T, N*H]
+ key_layer = tf.layers.dense(
+ to_tensor_2d,
+ num_attention_heads * size_per_head,
+ activation=key_act,
+ name='key',
+ kernel_initializer=create_initializer(initializer_range),
+ reuse=reuse)
+
+ # `value_layer` = [B*T, N*H]
+ value_layer = tf.layers.dense(
+ to_tensor_2d,
+ num_attention_heads * size_per_head,
+ activation=value_act,
+ name='value',
+ kernel_initializer=create_initializer(initializer_range),
+ reuse=reuse)
+
+ # `query_layer` = [B, N, F, H]
+ query_layer = transpose_for_scores(query_layer, batch_size,
+ num_attention_heads, from_seq_length,
+ size_per_head)
+
+ # `key_layer` = [B, N, T, H]
+ key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
+ to_seq_length, size_per_head)
+
+ # Take the dot product between "query" and "key" to get the raw
+ # attention scores.
+ # `attention_scores` = [B, N, F, T]
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ attention_scores = tf.multiply(attention_scores,
+ 1.0 / math.sqrt(float(size_per_head)))
+
+ if attention_mask is not None:
+ # `attention_mask` = [B, 1, F, T]
+ attention_mask = tf.expand_dims(attention_mask, axis=[1])
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
+
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_scores += adder
+
+ # Normalize the attention scores to probabilities.
+ # `attention_probs` = [B, N, F, T]
+ attention_probs = tf.nn.softmax(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
+
+ # `value_layer` = [B, T, N, H]
+ value_layer = tf.reshape(
+ value_layer,
+ [batch_size, to_seq_length, num_attention_heads, size_per_head])
+
+ # `value_layer` = [B, N, T, H]
+ value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
+
+ # `context_layer` = [B, N, F, H]
+ context_layer = tf.matmul(attention_probs, value_layer)
+
+ # `context_layer` = [B, F, N, H]
+ context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
+
+ if do_return_2d_tensor:
+ # `context_layer` = [B*F, N*H]
+ context_layer = tf.reshape(
+ context_layer,
+ [batch_size * from_seq_length, num_attention_heads * size_per_head])
+ else:
+ # `context_layer` = [B, F, N*H]
+ context_layer = tf.reshape(
+ context_layer,
+ [batch_size, from_seq_length, num_attention_heads * size_per_head])
+
+ return context_layer
+
+
+def transformer_encoder(input_tensor,
+ attention_mask=None,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ intermediate_act_fn=gelu,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ initializer_range=0.02,
+ reuse=None,
+ name='transformer'):
+ """Multi-headed, multi-layer Transformer from "Attention is All You Need".
+
+ This is almost an exact implementation of the original Transformer encoder.
+ See the original paper:
+ https://arxiv.org/abs/1706.03762
+ Args:
+ input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
+ attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
+ seq_length], with 1 for positions that can be attended to and 0 in
+ positions that should not be.
+ hidden_size: int. Hidden size of the Transformer.
+ num_hidden_layers: int. Number of layers (blocks) in the Transformer.
+ num_attention_heads: int. Number of attention heads in the Transformer.
+ intermediate_size: int. The size of the "intermediate" (a.k.a., feed
+ forward) layer.
+ intermediate_act_fn: function. The non-linear activation function to apply
+ to the output of the intermediate/feed-forward layer.
+ hidden_dropout_prob: float. Dropout probability for the hidden layers.
+ attention_probs_dropout_prob: float. Dropout probability of the attention
+ probabilities.
+ initializer_range: float. Range of the initializer (stddev of truncated
+ normal).
+ reuse: whether to reuse this encoder
+ name: scope name prefix
+
+ Returns:
+ float Tensor of shape [batch_size, seq_length, hidden_size], the final
+ hidden layer of the Transformer.
+
+ Raises:
+ ValueError: A Tensor shape or parameter is invalid.
+ """
+ if hidden_size % num_attention_heads != 0:
+ raise ValueError(
+ 'The hidden size (%d) is not a multiple of the number of attention '
+ 'heads (%d)' % (hidden_size, num_attention_heads))
+
+ attention_head_size = int(hidden_size / num_attention_heads)
+ input_shape = get_shape_list(input_tensor, expected_rank=3)
+ batch_size = input_shape[0]
+ seq_length = input_shape[1]
+ input_width = input_shape[2]
+
+ # The Transformer performs sum residuals on all layers so the input needs
+ # to be the same as the hidden size.
+ if input_width != hidden_size:
+ raise ValueError('The width of the input tensor (%d) != hidden size (%d)' %
+ (input_width, hidden_size))
+
+ # We keep the representation as a 2D tensor to avoid re-shaping it back and
+ # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
+ # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
+ # help the optimizer.
+ prev_output = reshape_to_matrix(input_tensor)
+
+ for layer_idx in range(num_hidden_layers):
+ with tf.variable_scope('%s_layer_%d' % (name, layer_idx)):
+ layer_input = prev_output
+
+ with tf.variable_scope('attention'):
+ with tf.variable_scope('self'):
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
+ attention_output = attention_layer(
+ from_tensor=layer_input,
+ to_tensor=layer_input,
+ size_per_head=attention_head_size,
+ num_attention_heads=num_attention_heads,
+ attention_mask=attention_mask,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ initializer_range=initializer_range,
+ do_return_2d_tensor=True,
+ batch_size=batch_size,
+ from_seq_length=seq_length,
+ to_seq_length=seq_length,
+ reuse=reuse)
+
+ # Run a linear projection of `hidden_size` then add a residual
+ # with `layer_input`.
+ with tf.variable_scope('output', reuse=reuse):
+ attention_output = tf.layers.dense(
+ attention_output,
+ hidden_size,
+ kernel_initializer=create_initializer(initializer_range))
+ attention_output = dropout(attention_output, hidden_dropout_prob)
+ attention_output = layer_norm(attention_output + layer_input)
+
+ # The activation is only applied to the "intermediate" hidden layer.
+ with tf.variable_scope('intermediate', reuse=reuse):
+ intermediate_output = tf.layers.dense(
+ attention_output,
+ intermediate_size,
+ activation=intermediate_act_fn,
+ kernel_initializer=create_initializer(initializer_range))
+
+ # Down-project back to `hidden_size` then add the residual.
+ with tf.variable_scope('output', reuse=reuse):
+ layer_output = tf.layers.dense(
+ intermediate_output,
+ hidden_size,
+ kernel_initializer=create_initializer(initializer_range))
+ layer_output = dropout(layer_output, hidden_dropout_prob)
+ layer_output = layer_norm(layer_output + attention_output)
+ prev_output = layer_output
+
+ final_output = reshape_from_matrix(prev_output, input_shape)
+ return final_output
+
+
+def cross_attention_block(from_tensor,
+ to_tensor,
+ layer_idx,
+ size_per_head,
+ cross_attention_mask=None,
+ self_attention_mask=None,
+ num_attention_heads=1,
+ intermediate_size=512,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ initializer_range=0.02,
+ name=''):
+ """Multi-headed cross attention block.
+
+ Args:
+ from_tensor: float Tensor of shape [batch_size, from_seq_length,
+ from_width].
+ to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
+ layer_idx: int. layer id in the Transformer.
+ size_per_head: int. Size of each attention head.
+ cross_attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length,
+ to_seq_length], with 1 for positions that can be attended to and 0 in
+ positions that should not be.
+ self_attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length,
+ from_seq_length], with 1 for positions that can be attended to and 0 in
+ positions that should not be.
+ num_attention_heads: int. Number of attention heads in the Transformer.
+ intermediate_size: int. The size of the "intermediate" (a.k.a., feed
+ forward) layer.
+ hidden_dropout_prob: float. Dropout probability for the hidden layers.
+ attention_probs_dropout_prob: float. Dropout probability of the attention
+ probabilities.
+ initializer_range: float. Range of the initializer (stddev of truncated
+ normal).
+ name: scope name prefix
+
+ Returns:
+ float Tensor of shape [batch_size, seq_length, hidden_size], the final
+ hidden layer of the Transformer.
+
+ Raises:
+ ValueError: A Tensor shape or parameter is invalid.
+ """
+ input_shape = get_shape_list(from_tensor, expected_rank=3)
+ batch_size = input_shape[0]
+ from_seq_length = input_shape[1]
+
+ input_shape = get_shape_list(to_tensor, expected_rank=3)
+ to_seq_length = input_shape[1]
+
+ with tf.variable_scope('%scross_layer_%d' % (name, layer_idx)):
+ with tf.variable_scope('attention'):
+ with tf.variable_scope('cross'):
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
+ cross_attention_output = attention_layer(
+ from_tensor=from_tensor,
+ to_tensor=to_tensor,
+ size_per_head=size_per_head,
+ num_attention_heads=num_attention_heads,
+ attention_mask=cross_attention_mask,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ initializer_range=initializer_range,
+ do_return_2d_tensor=True,
+ batch_size=batch_size,
+ from_seq_length=from_seq_length,
+ to_seq_length=to_seq_length)
+
+ with tf.variable_scope('self'):
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
+ self_attention_output = attention_layer(
+ from_tensor=cross_attention_output,
+ to_tensor=cross_attention_output,
+ size_per_head=size_per_head,
+ num_attention_heads=num_attention_heads,
+ attention_mask=self_attention_mask,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ initializer_range=initializer_range,
+ do_return_2d_tensor=True,
+ batch_size=batch_size,
+ from_seq_length=from_seq_length,
+ to_seq_length=from_seq_length)
+
+ with tf.variable_scope('output'):
+ attention_output = dropout(self_attention_output, hidden_dropout_prob)
+ attention_output = layer_norm(attention_output + cross_attention_output)
+
+ # The activation is only applied to the "intermediate" hidden layer.
+ with tf.variable_scope('intermediate'):
+ intermediate_output = tf.layers.dense(
+ attention_output,
+ intermediate_size,
+ activation=tf.nn.relu,
+ kernel_initializer=create_initializer(initializer_range))
+
+ # Down-project back to `hidden_size` then add the residual.
+ with tf.variable_scope('output'):
+ layer_output = tf.layers.dense(
+ intermediate_output,
+ num_attention_heads * size_per_head,
+ kernel_initializer=create_initializer(initializer_range))
+ layer_output = dropout(layer_output, hidden_dropout_prob)
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
+ layer_output = layer_norm(layer_output + attention_output)
+
+ final_output = reshape_from_matrix(
+ layer_output,
+ [batch_size, from_seq_length, num_attention_heads * size_per_head])
+ return final_output # [batch_size, from_seq_length, num_attention_heads * size_per_head]
+
+
+def cross_attention_tower(left_tensor,
+ right_tensor,
+ num_hidden_layers=1,
+ num_attention_heads=12,
+ left_size_per_head=64,
+ right_size_per_head=64,
+ left_intermediate_size=0,
+ right_intermediate_size=0,
+ left_input_mask=None,
+ right_input_mask=None,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ initializer_range=0.02,
+ name=''):
+ """Multi-headed, multi layer cross attention block.
+
+ Args:
+ left_tensor: float Tensor of shape [batch_size, left_seq_length,
+ from_width].
+ right_tensor: float Tensor of shape [batch_size, right_seq_length, to_width].
+ num_hidden_layers: int. Number of layers (blocks) in the Transformer.
+ num_attention_heads: int. Number of attention heads in the Transformer.
+ left_size_per_head: int. Size of each attention head of left tower.
+ right_size_per_head: int. Size of each attention head of right tower.
+ left intermediate_size: int. The size of the "intermediate" (a.k.a., feed
+ forward) layer of left tower. Less or equal to 0 means `num_attention_heads
+ * left_size_per_head`
+ right intermediate_size: int. The size of the "intermediate" (a.k.a., feed
+ forward) layer of right tower. Less or equal to 0 means `num_attention_heads
+ * right_size_per_head`
+ left_input_mask: the mask for `left_tensor`
+ right_input_mask: the mask for `right_tensor`
+ hidden_dropout_prob: float. Dropout probability for the hidden layers.
+ attention_probs_dropout_prob: float. Dropout probability of the attention
+ probabilities.
+ initializer_range: float. Range of the initializer (stddev of truncated
+ normal).
+ name: scope name prefix
+
+ Returns:
+ tuple of float Tensors of shape ([batch_size, left_seq_length, hidden_size],
+ [batch_size, right_seq_length, hidden_size]),
+ where hidden_size = num_attention_heads * size_per_head
+
+ Raises:
+ ValueError: A Tensor shape or parameter is invalid.
+ """
+ if left_intermediate_size <= 0:
+ left_intermediate_size = num_attention_heads * left_size_per_head
+ if right_intermediate_size <= 0:
+ right_intermediate_size = num_attention_heads * right_size_per_head
+
+ left_attention_mask = None
+ if left_input_mask is not None:
+ left_attention_mask = create_attention_mask_from_input_mask(
+ left_tensor, left_attention_mask)
+
+ left_2_right_attention_mask = None
+ if right_input_mask is not None:
+ left_2_right_attention_mask = create_attention_mask_from_input_mask(
+ left_tensor, right_input_mask)
+
+ right_attention_mask = None
+ if right_input_mask is not None:
+ right_attention_mask = create_attention_mask_from_input_mask(
+ right_tensor, right_input_mask)
+
+ right_2_left_attention_mask = None
+ if left_input_mask is not None:
+ right_2_left_attention_mask = create_attention_mask_from_input_mask(
+ right_tensor, left_input_mask)
+
+ prev_left_output = left_tensor
+ prev_right_output = right_tensor
+ for layer_idx in range(num_hidden_layers):
+ left_output = cross_attention_block(
+ prev_left_output,
+ prev_right_output,
+ layer_idx,
+ num_attention_heads=num_attention_heads,
+ size_per_head=left_size_per_head,
+ intermediate_size=left_intermediate_size,
+ hidden_dropout_prob=hidden_dropout_prob,
+ cross_attention_mask=left_2_right_attention_mask,
+ self_attention_mask=left_attention_mask,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ initializer_range=initializer_range,
+ name='%sleft_to_right_' % name)
+ right_output = cross_attention_block(
+ prev_right_output,
+ prev_left_output,
+ layer_idx,
+ num_attention_heads=num_attention_heads,
+ size_per_head=right_size_per_head,
+ intermediate_size=right_intermediate_size,
+ hidden_dropout_prob=hidden_dropout_prob,
+ cross_attention_mask=right_2_left_attention_mask,
+ self_attention_mask=right_attention_mask,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ initializer_range=initializer_range,
+ name='%sright_to_left_' % name)
+ prev_left_output = left_output
+ prev_right_output = right_output
+ return prev_left_output, prev_right_output
+
+
+def layer_norm(input_tensor, name=None):
+ """Run layer normalization on the last dimension of the tensor."""
+ return tf_layer_norm(
+ inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
+
+
+def reshape_to_matrix(input_tensor):
+ """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
+ ndims = input_tensor.shape.ndims
+ if ndims < 2:
+ raise ValueError('Input tensor must have at least rank 2. Shape = %s' %
+ (input_tensor.shape))
+ if ndims == 2:
+ return input_tensor
+
+ width = input_tensor.shape[-1]
+ output_tensor = tf.reshape(input_tensor, [-1, width])
+ return output_tensor
+
+
+def reshape_from_matrix(output_tensor, orig_shape_list):
+ """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
+ if len(orig_shape_list) == 2:
+ return output_tensor
+
+ output_shape = get_shape_list(output_tensor)
+
+ orig_dims = orig_shape_list[0:-1]
+ width = output_shape[-1]
+
+ return tf.reshape(output_tensor, orig_dims + [width])
+
+
+def create_attention_mask_from_input_mask(from_tensor, to_mask):
+ """Create 3D attention mask from a 2D tensor mask.
+
+ Args:
+ from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
+ to_mask: int32 Tensor of shape [batch_size, to_seq_length].
+
+ Returns:
+ float Tensor of shape [batch_size, from_seq_length, to_seq_length].
+ """
+ from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
+ batch_size = from_shape[0]
+ from_seq_length = from_shape[1]
+
+ to_shape = get_shape_list(to_mask, expected_rank=2)
+ to_seq_length = to_shape[1]
+
+ to_mask = tf.cast(
+ tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
+
+ # We don't assume that `from_tensor` is a mask (although it could be). We
+ # don't actually care if we attend *from* padding tokens (only *to* padding)
+ # tokens so we create a tensor of all ones.
+ #
+ # `broadcast_ones` = [batch_size, from_seq_length, 1]
+ broadcast_ones = tf.ones(
+ shape=tf.stack([batch_size, from_seq_length, 1]), dtype=tf.float32)
+
+ # Here we broadcast along two dimensions to create the mask.
+ mask = broadcast_ones * to_mask
+
+ return mask
+
+
+def embedding_postprocessor(input_tensor,
+ use_token_type=False,
+ token_type_ids=None,
+ token_type_vocab_size=16,
+ token_type_embedding_name='token_type_embeddings',
+ reuse_token_type=None,
+ use_position_embeddings=True,
+ position_embedding_name='position_embeddings',
+ reuse_position_embedding=None,
+ initializer_range=0.02,
+ max_position_embeddings=512,
+ dropout_prob=0.1):
+ """Performs various post-processing on a word embedding tensor.
+
+ Args:
+ input_tensor: float Tensor of shape [batch_size, seq_length,
+ embedding_size].
+ use_token_type: bool. Whether to add embeddings for `token_type_ids`.
+ token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
+ Must be specified if `use_token_type` is True.
+ token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
+ token_type_embedding_name: string. The name of the embedding table variable
+ for token type ids.
+ reuse_token_type: bool. Whether to reuse token type embedding variable.
+ use_position_embeddings: bool. Whether to add position embeddings for the
+ position of each token in the sequence.
+ position_embedding_name: string. The name of the embedding table variable
+ for positional embeddings.
+ reuse_position_embedding: bool. Whether to reuse position embedding variable.
+ initializer_range: float. Range of the weight initialization.
+ max_position_embeddings: int. Maximum sequence length that might ever be
+ used with this model. This can be longer than the sequence length of
+ input_tensor, but cannot be shorter.
+ dropout_prob: float. Dropout probability applied to the final output tensor.
+
+ Returns:
+ float tensor with same shape as `input_tensor`.
+
+ Raises:
+ ValueError: One of the tensor shapes or input values is invalid.
+ """
+ input_shape = get_shape_list(input_tensor, expected_rank=3)
+ batch_size = input_shape[0]
+ seq_length = input_shape[1]
+ width = input_shape[2]
+
+ output = input_tensor
+
+ if use_token_type:
+ if token_type_ids is None:
+ raise ValueError('`token_type_ids` must be specified if'
+ '`use_token_type` is True.')
+ with tf.variable_scope('token_type', reuse=reuse_token_type):
+ token_type_table = tf.get_variable(
+ name=token_type_embedding_name,
+ shape=[token_type_vocab_size, width],
+ initializer=create_initializer(initializer_range))
+ # This vocab will be small so we always do one-hot here, since it is always
+ # faster for a small vocabulary.
+ flat_token_type_ids = tf.reshape(token_type_ids, [-1])
+ one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
+ token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
+ token_type_embeddings = tf.reshape(token_type_embeddings,
+ [batch_size, seq_length, width])
+ output += token_type_embeddings
+
+ if use_position_embeddings:
+ assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
+ with tf.control_dependencies([assert_op]):
+ with tf.variable_scope(
+ 'position_embedding', reuse=reuse_position_embedding):
+ full_position_embeddings = tf.get_variable(
+ name=position_embedding_name,
+ shape=[max_position_embeddings, width],
+ initializer=create_initializer(initializer_range))
+ # Since the position embedding table is a learned variable, we create it
+ # using a (long) sequence length `max_position_embeddings`. The actual
+ # sequence length might be shorter than this, for faster training of
+ # tasks that do not have long sequences.
+ #
+ # So `full_position_embeddings` is effectively an embedding table
+ # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
+ # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
+ # perform a slice.
+ position_embeddings = tf.slice(full_position_embeddings, [0, 0],
+ [seq_length, -1])
+ num_dims = len(output.shape.as_list())
+
+ # Only the last two dimensions are relevant (`seq_length` and `width`), so
+ # we broadcast among the first dimensions, which is typically just
+ # the batch size.
+ position_broadcast_shape = []
+ for _ in range(num_dims - 2):
+ position_broadcast_shape.append(1)
+ position_broadcast_shape.extend([seq_length, width])
+ position_embeddings = tf.reshape(position_embeddings,
+ position_broadcast_shape)
+ output += position_embeddings
+
+ output = layer_norm_and_dropout(output, dropout_prob)
+ return output
+
+
+def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
+ """Runs layer normalization followed by dropout."""
+ output_tensor = layer_norm(input_tensor, name)
+ output_tensor = dropout(output_tensor, dropout_prob)
+ return output_tensor
diff --git a/easy_rec/python/layers/senet.py b/easy_rec/python/layers/senet.py
new file mode 100644
index 000000000..777079341
--- /dev/null
+++ b/easy_rec/python/layers/senet.py
@@ -0,0 +1,73 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class SENet:
+ """Squeeze and Excite Network.
+
+ Input shape
+ - A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
+ The ``embedding_size`` of each field can have different value.
+
+ Args:
+ num_fields: int, number of fields.
+ num_squeeze_group: int, number of groups for squeeze.
+ reduction_ratio: int, reduction ratio for squeeze.
+ l2_reg: float, l2 regularizer for embedding.
+ name: str, name of the layer.
+ """
+
+ def __init__(self,
+ num_fields,
+ num_squeeze_group,
+ reduction_ratio,
+ l2_reg,
+ name='SENet'):
+ self.num_fields = num_fields
+ self.num_squeeze_group = num_squeeze_group
+ self.reduction_ratio = reduction_ratio
+ self._l2_reg = l2_reg
+ self._name = name
+
+ def __call__(self, inputs):
+ g = self.num_squeeze_group
+ f = self.num_fields
+ r = self.reduction_ratio
+ reduction_size = max(1, f * g * 2 // r)
+
+ emb_size = 0
+ for input in inputs:
+ emb_size += int(input.shape[-1])
+
+ group_embs = [
+ tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs
+ ]
+
+ squeezed = []
+ for emb in group_embs:
+ squeezed.append(tf.reduce_max(emb, axis=-1)) # [B, g]
+ squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g]
+ z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]
+
+ reduced = tf.layers.dense(
+ inputs=z,
+ units=reduction_size,
+ kernel_regularizer=self._l2_reg,
+ activation='relu',
+ name='%s/reduce' % self._name)
+
+ excited_weights = tf.layers.dense(
+ inputs=reduced,
+ units=emb_size,
+ kernel_initializer='glorot_normal',
+ name='%s/excite' % self._name)
+
+ # Re-weight
+ inputs = tf.concat(inputs, axis=-1)
+ output = inputs * excited_weights
+
+ return output
diff --git a/easy_rec/python/layers/seq_input_layer.py b/easy_rec/python/layers/seq_input_layer.py
index ee27f8039..a52904dd1 100644
--- a/easy_rec/python/layers/seq_input_layer.py
+++ b/easy_rec/python/layers/seq_input_layer.py
@@ -4,7 +4,10 @@
import logging
import tensorflow as tf
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
+from easy_rec.python.compat import regularizers
from easy_rec.python.compat.feature_column import feature_column
from easy_rec.python.feature_column.feature_column import FeatureColumnParser
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
@@ -15,18 +18,25 @@
class SeqInputLayer(object):
- def __init__(self, feature_configs, feature_groups_config):
+ def __init__(self,
+ feature_configs,
+ feature_groups_config,
+ embedding_regularizer=None,
+ ev_params=None):
self._feature_groups_config = {
x.group_name: x for x in feature_groups_config
}
wide_and_deep_dict = self.get_wide_deep_dict()
- self._fc_parser = FeatureColumnParser(feature_configs, wide_and_deep_dict)
+ self._fc_parser = FeatureColumnParser(
+ feature_configs, wide_and_deep_dict, ev_params=ev_params)
+ self._embedding_regularizer = embedding_regularizer
def __call__(self,
features,
group_name,
feature_name_to_output_tensors={},
- allow_key_search=True):
+ allow_key_search=True,
+ scope_name=None):
feature_column_dict = self._fc_parser.deep_columns
feature_column_dict.update(self._fc_parser.sequence_columns)
@@ -42,17 +52,26 @@ def _seq_embed_summary_name(input_name):
input_name = input_name.split('/')[:2]
return 'sequence_feature/' + '/'.join(input_name)
- with tf.variable_scope(group_name, reuse=tf.AUTO_REUSE):
+ if scope_name is None:
+ scope_name = group_name
+ # name_scope is specified to avoid adding _1 _2 after scope_name
+ with variable_scope.variable_scope(
+ scope_name,
+ reuse=variable_scope.AUTO_REUSE), ops.name_scope(scope_name + '/'):
key_tensors = []
hist_tensors = []
+ check_op_list = []
for x in feature_dict.seq_att_map:
for key in x.key:
if key not in feature_name_to_output_tensors or (
feature_name_to_output_tensors[key] is None and allow_key_search):
qfc = feature_column_dict[key]
- with tf.variable_scope(qfc._var_scope_name):
- key_tensors.append(
- feature_column_dict[key]._get_dense_tensor(builder))
+ with variable_scope.variable_scope(qfc._var_scope_name):
+ tmp_key_tensor = feature_column_dict[key]._get_dense_tensor(
+ builder)
+ regularizers.apply_regularization(
+ self._embedding_regularizer, weights_list=[tmp_key_tensor])
+ key_tensors.append(tmp_key_tensor)
elif feature_name_to_output_tensors[key] is None:
assert feature_name_to_output_tensors[
key] is not None, 'When allow_key_search is False, key: %s should defined in same feature group.' % key
@@ -63,13 +82,22 @@ def _seq_embed_summary_name(input_name):
for key_tensor in key_tensors:
tf.summary.histogram(
_seq_embed_summary_name(key_tensor.name), key_tensor)
-
+ cur_hist_seqs = []
for hist_seq in x.hist_seq:
seq_fc = feature_column_dict[hist_seq]
- with tf.variable_scope(seq_fc._var_scope_name):
- hist_tensors.append(
+ with variable_scope.variable_scope(seq_fc._var_scope_name):
+ cur_hist_seqs.append(
feature_column_dict[hist_seq]._get_sequence_dense_tensor(
builder))
+ hist_tensors.extend(cur_hist_seqs)
+
+ aux_hist_emb_list = []
+ for aux_hist_seq in x.aux_hist_seq:
+ seq_fc = feature_column_dict[aux_hist_seq]
+ with variable_scope.variable_scope(seq_fc._var_scope_name):
+ aux_hist_embedding, _ = feature_column_dict[
+ aux_hist_seq]._get_sequence_dense_tensor(builder)
+ aux_hist_emb_list.append(aux_hist_embedding)
if tf_summary:
for hist_embed, hist_seq_len in hist_tensors:
@@ -78,11 +106,21 @@ def _seq_embed_summary_name(input_name):
tf.summary.histogram(
_seq_embed_summary_name(hist_seq_len.name), hist_seq_len)
- features = {
- 'key': tf.concat(key_tensors, axis=-1),
- 'hist_seq_emb': tf.concat([x[0] for x in hist_tensors], axis=-1),
- 'hist_seq_len': hist_tensors[0][1]
- }
+ for idx in range(1, len(cur_hist_seqs)):
+ check_op = tf.assert_equal(
+ cur_hist_seqs[0][1],
+ cur_hist_seqs[idx][1],
+ message='SequenceFeature Error: The size of %s not equal to the size of %s.'
+ % (x.hist_seq[idx], x.hist_seq[0]))
+ check_op_list.append(check_op)
+
+ with tf.control_dependencies(check_op_list):
+ features = {
+ 'key': tf.concat(key_tensors, axis=-1),
+ 'hist_seq_emb': tf.concat([x[0] for x in hist_tensors], axis=-1),
+ 'hist_seq_len': hist_tensors[0][1],
+ 'aux_hist_seq_emb_list': aux_hist_emb_list
+ }
return features
def get_wide_deep_dict(self):
diff --git a/easy_rec/python/layers/sequence_feature_layer.py b/easy_rec/python/layers/sequence_feature_layer.py
new file mode 100644
index 000000000..fd01b5b2c
--- /dev/null
+++ b/easy_rec/python/layers/sequence_feature_layer.py
@@ -0,0 +1,249 @@
+import logging
+import os
+
+import tensorflow as tf
+from tensorflow.python.framework import ops
+
+from easy_rec.python.compat import regularizers
+from easy_rec.python.layers import dnn
+from easy_rec.python.layers import seq_input_layer
+from easy_rec.python.utils import conditional
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class SequenceFeatureLayer(object):
+
+ def __init__(self,
+ feature_configs,
+ feature_groups_config,
+ ev_params=None,
+ embedding_regularizer=None,
+ kernel_regularizer=None,
+ is_training=False,
+ is_predicting=False):
+ self._seq_feature_groups_config = []
+ for x in feature_groups_config:
+ for y in x.sequence_features:
+ self._seq_feature_groups_config.append(y)
+ self._seq_input_layer = None
+ if len(self._seq_feature_groups_config) > 0:
+ self._seq_input_layer = seq_input_layer.SeqInputLayer(
+ feature_configs,
+ self._seq_feature_groups_config,
+ embedding_regularizer=embedding_regularizer,
+ ev_params=ev_params)
+ self._embedding_regularizer = embedding_regularizer
+ self._kernel_regularizer = kernel_regularizer
+ self._is_training = is_training
+ self._is_predicting = is_predicting
+
+ def negative_sampler_target_attention(self,
+ dnn_config,
+ deep_fea,
+ concat_features,
+ name,
+ need_key_feature=True,
+ allow_key_transform=False):
+ cur_id, hist_id_col, seq_len, aux_hist_emb_list = deep_fea['key'], deep_fea[
+ 'hist_seq_emb'], deep_fea['hist_seq_len'], deep_fea[
+ 'aux_hist_seq_emb_list']
+
+ seq_max_len = tf.shape(hist_id_col)[1]
+ seq_emb_dim = hist_id_col.shape[2]
+ cur_id_dim = tf.shape(cur_id)[-1]
+ batch_size = tf.shape(hist_id_col)[0]
+
+ pos_feature = cur_id[:batch_size]
+ neg_feature = cur_id[batch_size:]
+ cur_id = tf.concat([
+ pos_feature[:, tf.newaxis, :],
+ tf.tile(neg_feature[tf.newaxis, :, :], multiples=[batch_size, 1, 1])
+ ],
+ axis=1) # noqa: E126
+ neg_num_add_1 = tf.shape(cur_id)[1]
+ hist_id_col_tmp = tf.tile(
+ hist_id_col[:, :, :], multiples=[1, neg_num_add_1, 1])
+ hist_id_col = tf.reshape(
+ hist_id_col_tmp, [batch_size * neg_num_add_1, seq_max_len, seq_emb_dim])
+
+ concat_features = tf.tile(
+ concat_features[:, tf.newaxis, :], multiples=[1, neg_num_add_1, 1])
+ seq_len = tf.tile(seq_len, multiples=[neg_num_add_1])
+
+ if allow_key_transform and (cur_id_dim != seq_emb_dim):
+ cur_id = tf.layers.dense(
+ cur_id, seq_emb_dim, name='sequence_key_transform_layer')
+
+ cur_ids = tf.tile(cur_id, [1, 1, seq_max_len])
+ cur_ids = tf.reshape(
+ cur_ids,
+ tf.shape(hist_id_col)) # (B * neg_num_add_1, seq_max_len, seq_emb_dim)
+
+ din_net = tf.concat(
+ [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
+ axis=-1) # (B * neg_num_add_1, seq_max_len, seq_emb_dim*4)
+
+ din_layer = dnn.DNN(
+ dnn_config,
+ self._kernel_regularizer,
+ name,
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
+ din_net = din_layer(din_net)
+ scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)
+
+ seq_len = tf.expand_dims(seq_len, 1)
+ mask = tf.sequence_mask(seq_len)
+ padding = tf.ones_like(scores) * (-2**32 + 1)
+ scores = tf.where(mask, scores,
+ padding) # [B*neg_num_add_1, 1, seq_max_len]
+
+ # Scale
+ scores = tf.nn.softmax(scores) # (B * neg_num_add_1, 1, seq_max_len)
+ hist_din_emb = tf.matmul(scores,
+ hist_id_col) # [B * neg_num_add_1, 1, seq_emb_dim]
+ hist_din_emb = tf.reshape(hist_din_emb,
+ [batch_size, neg_num_add_1, seq_emb_dim
+ ]) # [B * neg_num_add_1, seq_emb_dim]
+ if len(aux_hist_emb_list) > 0:
+ all_hist_dim_emb = [hist_din_emb]
+ for hist_col in aux_hist_emb_list:
+ cur_aux_hist = tf.matmul(scores, hist_col)
+ outputs = tf.reshape(cur_aux_hist, [-1, seq_emb_dim])
+ all_hist_dim_emb.append(outputs)
+ hist_din_emb = tf.concat(all_hist_dim_emb, axis=1)
+ if not need_key_feature:
+ return hist_din_emb, concat_features
+ din_output = tf.concat([hist_din_emb, cur_id], axis=2)
+ return din_output, concat_features
+
+ def target_attention(self,
+ dnn_config,
+ deep_fea,
+ name,
+ need_key_feature=True,
+ allow_key_transform=False,
+ transform_dnn=False):
+ cur_id, hist_id_col, seq_len, aux_hist_emb_list = deep_fea['key'], deep_fea[
+ 'hist_seq_emb'], deep_fea['hist_seq_len'], deep_fea[
+ 'aux_hist_seq_emb_list']
+
+ seq_max_len = tf.shape(hist_id_col)[1]
+ seq_emb_dim = hist_id_col.shape[2]
+ cur_id_dim = cur_id.shape[-1]
+
+ if allow_key_transform and (cur_id_dim != seq_emb_dim):
+ if seq_emb_dim > cur_id_dim and not transform_dnn:
+ cur_id = tf.pad(cur_id, [[0, 0], [0, seq_emb_dim - cur_id_dim]])
+ else:
+ cur_key_layer_name = 'sequence_key_transform_layer_' + name
+ cur_id = tf.layers.dense(cur_id, seq_emb_dim, name=cur_key_layer_name)
+ cur_fea_layer_name = 'sequence_fea_transform_layer_' + name
+ hist_id_col = tf.layers.dense(
+ hist_id_col, seq_emb_dim, name=cur_fea_layer_name)
+ else:
+ cur_id = cur_id[:tf.shape(hist_id_col)[0], ...] # for negative sampler
+
+ cur_ids = tf.tile(cur_id, [1, seq_max_len])
+ cur_ids = tf.reshape(cur_ids,
+ tf.shape(hist_id_col)) # (B, seq_max_len, seq_emb_dim)
+
+ din_net = tf.concat(
+ [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
+ axis=-1) # (B, seq_max_len, seq_emb_dim*4)
+
+ din_layer = dnn.DNN(
+ dnn_config,
+ self._kernel_regularizer,
+ name,
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
+ din_net = din_layer(din_net)
+ scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)
+
+ seq_len = tf.expand_dims(seq_len, 1)
+ mask = tf.sequence_mask(seq_len)
+ padding = tf.ones_like(scores) * (-2**32 + 1)
+ scores = tf.where(mask, scores, padding) # [B, 1, seq_max_len]
+
+ # Scale
+ scores = tf.nn.softmax(scores) # (B, 1, seq_max_len)
+ hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, seq_emb_dim]
+ hist_din_emb = tf.reshape(hist_din_emb,
+ [-1, seq_emb_dim]) # [B, seq_emb_dim]
+ if len(aux_hist_emb_list) > 0:
+ all_hist_dim_emb = [hist_din_emb]
+ for hist_col in aux_hist_emb_list:
+ aux_hist_dim = hist_col.shape[-1]
+ cur_aux_hist = tf.matmul(scores, hist_col)
+ outputs = tf.reshape(cur_aux_hist, [-1, aux_hist_dim])
+ all_hist_dim_emb.append(outputs)
+ hist_din_emb = tf.concat(all_hist_dim_emb, axis=1)
+ if not need_key_feature:
+ return hist_din_emb
+ din_output = tf.concat([hist_din_emb, cur_id], axis=1)
+ return din_output
+
+ def __call__(self,
+ features,
+ concat_features,
+ all_seq_att_map_config,
+ feature_name_to_output_tensors=None,
+ negative_sampler=False,
+ scope_name=None):
+ logging.info('use sequence feature layer.')
+ all_seq_fea = []
+ # process all sequence features
+ for seq_att_map_config in all_seq_att_map_config:
+ group_name = seq_att_map_config.group_name
+ allow_key_search = seq_att_map_config.allow_key_search
+ need_key_feature = seq_att_map_config.need_key_feature
+ allow_key_transform = seq_att_map_config.allow_key_transform
+ transform_dnn = seq_att_map_config.transform_dnn
+
+ place_on_cpu = os.getenv('place_embedding_on_cpu')
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
+ with conditional(self._is_predicting and place_on_cpu,
+ ops.device('/CPU:0')):
+ seq_features = self._seq_input_layer(features, group_name,
+ feature_name_to_output_tensors,
+ allow_key_search, scope_name)
+
+ # apply regularization for sequence feature key in seq_input_layer.
+
+ regularizers.apply_regularization(
+ self._embedding_regularizer,
+ weights_list=[seq_features['hist_seq_emb']])
+ seq_dnn_config = None
+ if seq_att_map_config.HasField('seq_dnn'):
+ seq_dnn_config = seq_att_map_config.seq_dnn
+ else:
+ logging.info(
+ 'seq_dnn not set in seq_att_groups, will use default settings')
+ # If not set seq_dnn, will use default settings
+ from easy_rec.python.protos.dnn_pb2 import DNN
+ seq_dnn_config = DNN()
+ seq_dnn_config.hidden_units.extend([128, 64, 32, 1])
+ cur_target_attention_name = 'seq_dnn' + group_name
+ if negative_sampler:
+ seq_fea, concat_features = self.negative_sampler_target_attention(
+ seq_dnn_config,
+ seq_features,
+ concat_features,
+ name=cur_target_attention_name,
+ need_key_feature=need_key_feature,
+ allow_key_transform=allow_key_transform)
+ else:
+ seq_fea = self.target_attention(
+ seq_dnn_config,
+ seq_features,
+ name=cur_target_attention_name,
+ need_key_feature=need_key_feature,
+ allow_key_transform=allow_key_transform,
+ transform_dnn=transform_dnn)
+ all_seq_fea.append(seq_fea)
+ return concat_features, all_seq_fea
diff --git a/easy_rec/python/layers/uniter.py b/easy_rec/python/layers/uniter.py
new file mode 100644
index 000000000..3018bad61
--- /dev/null
+++ b/easy_rec/python/layers/uniter.py
@@ -0,0 +1,301 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+from easy_rec.python.layers import dnn
+from easy_rec.python.layers import multihead_cross_attention
+from easy_rec.python.utils.activation import get_activation
+from easy_rec.python.utils.shape_utils import get_shape_list
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class Uniter(object):
+ """UNITER: UNiversal Image-TExt Representation Learning.
+
+ See the original paper:
+ https://arxiv.org/abs/1909.11740
+ """
+
+ def __init__(self, model_config, feature_configs, features, uniter_config,
+ input_layer):
+ self._model_config = uniter_config
+ tower_num = 0
+ self._img_features = None
+ if input_layer.has_group('image'):
+ self._img_features, _ = input_layer(features, 'image')
+ tower_num += 1
+ self._general_features = None
+ if input_layer.has_group('general'):
+ self._general_features, _ = input_layer(features, 'general')
+ tower_num += 1
+ self._txt_seq_features = None
+ if input_layer.has_group('text'):
+ self._txt_seq_features, _, _ = input_layer(
+ features, 'text', is_combine=False)
+ tower_num += 1
+ self._use_token_type = True if tower_num > 1 else False
+ self._other_features = None
+ if input_layer.has_group('other'): # e.g. statistical feature
+ self._other_features, _ = input_layer(features, 'other')
+ tower_num += 1
+ assert tower_num > 0, 'there must be one of the feature groups: [image, text, general, other]'
+
+ self._general_feature_num = 0
+ self._txt_feature_num, self._img_feature_num = 0, 0
+ general_feature_names = set()
+ img_feature_names, txt_feature_names = set(), set()
+ for fea_group in model_config.feature_groups:
+ if fea_group.group_name == 'general':
+ self._general_feature_num = len(fea_group.feature_names)
+ general_feature_names = set(fea_group.feature_names)
+ assert self._general_feature_num == len(general_feature_names), (
+ 'there are duplicate features in `general` feature group')
+ elif fea_group.group_name == 'image':
+ self._img_feature_num = len(fea_group.feature_names)
+ img_feature_names = set(fea_group.feature_names)
+ assert self._img_feature_num == len(img_feature_names), (
+ 'there are duplicate features in `image` feature group')
+ elif fea_group.group_name == 'text':
+ self._txt_feature_num = len(fea_group.feature_names)
+ txt_feature_names = set(fea_group.feature_names)
+ assert self._txt_feature_num == len(txt_feature_names), (
+ 'there are duplicate features in `text` feature group')
+
+ if self._txt_feature_num > 1 or self._img_feature_num > 1:
+ self._use_token_type = True
+ self._token_type_vocab_size = self._txt_feature_num
+ if self._img_feature_num > 0:
+ self._token_type_vocab_size += 1
+ if self._general_feature_num > 0:
+ self._token_type_vocab_size += 1
+
+ max_seq_len = 0
+ txt_fea_emb_dim_list = []
+ general_emb_dim_list = []
+ img_fea_emb_dim_list = []
+ for feature_config in feature_configs:
+ fea_name = feature_config.input_names[0]
+ if feature_config.HasField('feature_name'):
+ fea_name = feature_config.feature_name
+ if fea_name in img_feature_names:
+ img_fea_emb_dim_list.append(feature_config.raw_input_dim)
+ if fea_name in general_feature_names:
+ general_emb_dim_list.append(feature_config.embedding_dim)
+ if fea_name in txt_feature_names:
+ txt_fea_emb_dim_list.append(feature_config.embedding_dim)
+ if feature_config.HasField('max_seq_len'):
+ assert feature_config.max_seq_len > 0, (
+ 'feature config `max_seq_len` must be greater than 0 for feature: '
+ + fea_name)
+ if feature_config.max_seq_len > max_seq_len:
+ max_seq_len = feature_config.max_seq_len
+
+ unique_dim_num = len(set(txt_fea_emb_dim_list))
+ assert unique_dim_num <= 1 and len(
+ txt_fea_emb_dim_list
+ ) == self._txt_feature_num, (
+ 'Uniter requires that all `text` feature dimensions must be consistent.'
+ )
+ unique_dim_num = len(set(img_fea_emb_dim_list))
+ assert unique_dim_num <= 1 and len(
+ img_fea_emb_dim_list
+ ) == self._img_feature_num, (
+ 'Uniter requires that all `image` feature dimensions must be consistent.'
+ )
+ unique_dim_num = len(set(general_emb_dim_list))
+ assert unique_dim_num <= 1 and len(
+ general_emb_dim_list
+ ) == self._general_feature_num, (
+ 'Uniter requires that all `general` feature dimensions must be consistent.'
+ )
+
+ if self._txt_feature_num > 0 and uniter_config.use_position_embeddings:
+ assert uniter_config.max_position_embeddings > 0, (
+ 'model config `max_position_embeddings` must be greater than 0. ')
+ assert uniter_config.max_position_embeddings >= max_seq_len, (
+ 'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
+ '`max_seq_len`, which is %d' % max_seq_len)
+
+ self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
+ self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
+ self._general_emb_size = general_emb_dim_list[
+ 0] if general_emb_dim_list else 0
+ if self._img_features is not None:
+ assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'
+
+ def text_embeddings(self, token_type_id):
+ all_txt_features = []
+ input_masks = []
+ hidden_size = self._model_config.hidden_size
+ if self._general_features is not None:
+ general_features = self._general_features
+ if self._general_emb_size != hidden_size:
+ # Run a linear projection of `hidden_size`
+ general_features = tf.reshape(
+ general_features, shape=[-1, self._general_emb_size])
+ general_features = tf.layers.dense(
+ general_features, hidden_size, name='txt_projection')
+ general_features = tf.reshape(
+ general_features, shape=[-1, self._general_feature_num, hidden_size])
+
+ batch_size = tf.shape(general_features)[0]
+ general_features = multihead_cross_attention.embedding_postprocessor(
+ general_features,
+ use_token_type=self._use_token_type,
+ token_type_ids=tf.ones(
+ shape=tf.stack([batch_size, self._general_feature_num]),
+ dtype=tf.int32) * token_type_id,
+ token_type_vocab_size=self._token_type_vocab_size,
+ reuse_token_type=tf.AUTO_REUSE,
+ use_position_embeddings=False,
+ dropout_prob=self._model_config.hidden_dropout_prob)
+
+ all_txt_features.append(general_features)
+ mask = tf.ones(
+ shape=tf.stack([batch_size, self._general_feature_num]),
+ dtype=tf.int32)
+ input_masks.append(mask)
+
+ if self._txt_seq_features is not None:
+
+ def dynamic_mask(x, max_len):
+ ones = tf.ones(shape=tf.stack([x]), dtype=tf.int32)
+ zeros = tf.zeros(shape=tf.stack([max_len - x]), dtype=tf.int32)
+ return tf.concat([ones, zeros], axis=0)
+
+ token_type_id += len(all_txt_features)
+ for i, (seq_fea, seq_len) in enumerate(self._txt_seq_features):
+ batch_size, max_seq_len, emb_size = get_shape_list(seq_fea, 3)
+ if emb_size != hidden_size:
+ seq_fea = tf.reshape(seq_fea, shape=[-1, emb_size])
+ seq_fea = tf.layers.dense(
+ seq_fea, hidden_size, name='txt_seq_projection_%d' % i)
+ seq_fea = tf.reshape(seq_fea, shape=[-1, max_seq_len, hidden_size])
+
+ seq_fea = multihead_cross_attention.embedding_postprocessor(
+ seq_fea,
+ use_token_type=self._use_token_type,
+ token_type_ids=tf.ones(
+ shape=tf.stack([batch_size, max_seq_len]), dtype=tf.int32) *
+ (i + token_type_id),
+ token_type_vocab_size=self._token_type_vocab_size,
+ reuse_token_type=tf.AUTO_REUSE,
+ use_position_embeddings=self._model_config.use_position_embeddings,
+ max_position_embeddings=self._model_config.max_position_embeddings,
+ position_embedding_name='txt_position_embeddings_%d' % i,
+ dropout_prob=self._model_config.hidden_dropout_prob)
+ all_txt_features.append(seq_fea)
+
+ input_mask = tf.map_fn(
+ fn=lambda t: dynamic_mask(t, max_seq_len),
+ elems=tf.to_int32(seq_len))
+ input_masks.append(input_mask)
+
+ return all_txt_features, input_masks
+
+ def image_embeddings(self):
+ if self._img_features is None:
+ return None
+ hidden_size = self._model_config.hidden_size
+ image_features = self._img_features
+ if self._img_emb_size != hidden_size:
+ # Run a linear projection of `hidden_size`
+ image_features = tf.reshape(
+ image_features, shape=[-1, self._img_emb_size])
+ image_features = tf.layers.dense(
+ image_features, hidden_size, name='img_projection')
+ image_features = tf.reshape(
+ image_features, shape=[-1, self._img_feature_num, hidden_size])
+
+ batch_size = tf.shape(image_features)[0]
+ img_fea = multihead_cross_attention.embedding_postprocessor(
+ image_features,
+ use_token_type=self._use_token_type,
+ token_type_ids=tf.zeros(
+ shape=tf.stack([batch_size, self._img_feature_num]),
+ dtype=tf.int32),
+ token_type_vocab_size=self._token_type_vocab_size,
+ reuse_token_type=tf.AUTO_REUSE,
+ use_position_embeddings=self._model_config.use_position_embeddings,
+ max_position_embeddings=self._model_config.max_position_embeddings,
+ position_embedding_name='img_position_embeddings',
+ dropout_prob=self._model_config.hidden_dropout_prob)
+ return img_fea
+
+ def __call__(self, is_training, *args, **kwargs):
+ if not is_training:
+ self._model_config.hidden_dropout_prob = 0.0
+ self._model_config.attention_probs_dropout_prob = 0.0
+
+ sub_modules = []
+
+ img_fea = self.image_embeddings()
+ start_token_id = 1 if self._img_feature_num > 0 else 0
+ txt_features, txt_masks = self.text_embeddings(start_token_id)
+
+ if img_fea is not None:
+ batch_size = tf.shape(img_fea)[0]
+ elif txt_features:
+ batch_size = tf.shape(txt_features[0])[0]
+ else:
+ batch_size = None
+
+ hidden_size = self._model_config.hidden_size
+ if batch_size is not None:
+ all_features = []
+ masks = []
+ cls_emb = tf.get_variable(name='cls_emb', shape=[1, 1, hidden_size])
+ cls_emb = tf.tile(cls_emb, [batch_size, 1, 1])
+ all_features.append(cls_emb)
+
+ mask = tf.ones(shape=tf.stack([batch_size, 1]), dtype=tf.int32)
+ masks.append(mask)
+
+ if img_fea is not None:
+ all_features.append(img_fea)
+ mask = tf.ones(
+ shape=tf.stack([batch_size, self._img_feature_num]), dtype=tf.int32)
+ masks.append(mask)
+
+ if txt_features:
+ all_features.extend(txt_features)
+ masks.extend(txt_masks)
+
+ all_fea = tf.concat(all_features, axis=1)
+ input_mask = tf.concat(masks, axis=1)
+ attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
+ from_tensor=all_fea, to_mask=input_mask)
+ hidden_act = get_activation(self._model_config.hidden_act)
+ attention_fea = multihead_cross_attention.transformer_encoder(
+ all_fea,
+ hidden_size=hidden_size,
+ num_hidden_layers=self._model_config.num_hidden_layers,
+ num_attention_heads=self._model_config.num_attention_heads,
+ attention_mask=attention_mask,
+ intermediate_size=self._model_config.intermediate_size,
+ intermediate_act_fn=hidden_act,
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
+ attention_probs_dropout_prob=self._model_config
+ .attention_probs_dropout_prob,
+ initializer_range=self._model_config.initializer_range,
+ name='uniter') # shape: [batch_size, seq_length, hidden_size]
+ print('attention_fea:', attention_fea.shape)
+ mm_fea = attention_fea[:, 0, :] # [CLS] feature
+ sub_modules.append(mm_fea)
+
+ if self._other_features is not None:
+ if self._model_config.HasField('other_feature_dnn'):
+ l2_reg = kwargs['l2_reg'] if 'l2_reg' in kwargs else 0
+ other_dnn_layer = dnn.DNN(self._model_config.other_feature_dnn, l2_reg,
+ 'other_dnn', is_training)
+ other_fea = other_dnn_layer(self._other_features)
+ else:
+ other_fea = self._other_features
+ sub_modules.append(other_fea)
+
+ if len(sub_modules) == 1:
+ return sub_modules[0]
+ output = tf.concat(sub_modules, axis=-1)
+ return output
diff --git a/easy_rec/python/layers/utils.py b/easy_rec/python/layers/utils.py
new file mode 100644
index 000000000..7eb86b791
--- /dev/null
+++ b/easy_rec/python/layers/utils.py
@@ -0,0 +1,248 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common util functions used by layers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from google.protobuf import struct_pb2
+from google.protobuf.descriptor import FieldDescriptor
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import variables
+
+try:
+ from tensorflow.python.ops import kv_variable_ops
+except ImportError:
+ kv_variable_ops = None
+
+ColumnNameInCollection = {}
+
+
+def _tensor_to_map(tensor):
+ return {
+ 'node_path': tensor.name,
+ 'shape': tensor.shape.as_list() if tensor.shape else None,
+ 'dtype': tensor.dtype.name
+ }
+
+
+def _tensor_to_tensorinfo(tensor):
+ tensor_info = {}
+ if isinstance(tensor, sparse_tensor.SparseTensor):
+ tensor_info['is_dense'] = False
+ tensor_info['values'] = _tensor_to_map(tensor.values)
+ tensor_info['indices'] = _tensor_to_map(tensor.indices)
+ tensor_info['dense_shape'] = _tensor_to_map(tensor.dense_shape)
+ else:
+ tensor_info['is_dense'] = True
+ tensor_info.update(_tensor_to_map(tensor))
+ return tensor_info
+
+
+def add_tensor_to_collection(collection_name, name, tensor):
+ tensor_info = _tensor_to_tensorinfo(tensor)
+ tensor_info['name'] = name
+ update_attr_to_collection(collection_name, tensor_info)
+
+
+def append_tensor_to_collection(collection_name, name, key, tensor):
+ tensor_info = _tensor_to_tensorinfo(tensor)
+ append_attr_to_collection(collection_name, name, key, tensor_info)
+
+
+def _collection_item_key(col, name):
+ return '%d#%s' % (id(col), name)
+
+
+def _process_item(collection_name, name, func):
+ col = ops.get_collection_ref(collection_name)
+ item_found = {}
+ idx_found = -1
+
+ # add id(col) because col may re-new sometimes
+ key = _collection_item_key(col, name)
+ if key in ColumnNameInCollection:
+ idx_found = ColumnNameInCollection[key]
+ if idx_found >= len(col):
+ raise Exception(
+ 'Find column name in collection failed: index out of range')
+
+ item_found = json.loads(col[idx_found])
+ if item_found['name'] != name:
+ raise Exception(
+ 'Find column name in collection failed: item name not match')
+ func(item_found)
+ col[idx_found] = json.dumps(item_found)
+ else:
+ func(item_found)
+ col.append(json.dumps(item_found))
+ ColumnNameInCollection[key] = len(col) - 1
+
+
+def append_attr_to_collection(collection_name, name, key, value):
+
+ def append(item_found):
+ if key not in item_found:
+ item_found[key] = []
+ item_found[key].append(value)
+
+ _process_item(collection_name, name, append)
+
+
+def update_attr_to_collection(collection_name, attrs):
+
+ def update(item_found):
+ item_found.update(attrs)
+
+ _process_item(collection_name, attrs['name'], update)
+
+
+def unique_name_in_collection(collection_name, name):
+ col = ops.get_collection_ref(collection_name)
+ unique_name = name
+ index = 0
+ while True:
+ key = _collection_item_key(col, unique_name)
+ if key not in ColumnNameInCollection:
+ break
+ index += 1
+ unique_name = '%s_%d' % (name, index)
+ return unique_name
+
+
+def gen_embedding_attrs(column=None,
+ variable=None,
+ bucket_size=None,
+ combiner=None,
+ is_embedding_var=None):
+ attrs = dict()
+ attrs['name'] = column.name
+ attrs['bucket_size'] = bucket_size
+ attrs['combiner'] = combiner
+ attrs['is_embedding_var'] = is_embedding_var
+ attrs['weights_op_path'] = variable.name
+ if kv_variable_ops:
+ if isinstance(variable, kv_variable_ops.EmbeddingVariable):
+ attrs['is_embedding_var'] = True
+ attrs['embedding_var_keys'] = variable._shared_name + '-keys'
+ attrs['embedding_var_values'] = variable._shared_name + '-values'
+ elif (isinstance(variable, variables.PartitionedVariable)) and \
+ (isinstance(variable._get_variable_list()[0], kv_variable_ops.EmbeddingVariable)):
+ attrs['embedding_var_keys'] = [v._shared_name + '-keys' for v in variable]
+ attrs['embedding_var_values'] = [
+ v._shared_name + '-values' for v in variable
+ ]
+ else:
+ attrs['is_embedding_var'] = False
+ else:
+ attrs['is_embedding_var'] = False
+ return attrs
+
+
+def mark_input_src(name, src_desc):
+ ops.add_to_collection(ops.GraphKeys.RANK_SERVICE_INPUT_SRC,
+ json.dumps({
+ 'name': name,
+ 'src': src_desc
+ }))
+
+
+def is_proto_message(pb_obj, field):
+ if not hasattr(pb_obj, 'DESCRIPTOR'):
+ return False
+ if field not in pb_obj.DESCRIPTOR.fields_by_name:
+ return False
+ field_type = pb_obj.DESCRIPTOR.fields_by_name[field].type
+ return field_type == FieldDescriptor.TYPE_MESSAGE
+
+
+class Parameter(object):
+
+ def __init__(self, params, is_struct, l2_reg=None):
+ self.params = params
+ self.is_struct = is_struct
+ self._l2_reg = l2_reg
+
+ @staticmethod
+ def make_from_pb(config):
+ return Parameter(config, False)
+
+ def get_pb_config(self):
+ assert not self.is_struct, 'Struct parameter can not convert to pb config'
+ return self.params
+
+ @property
+ def l2_regularizer(self):
+ return self._l2_reg
+
+ @l2_regularizer.setter
+ def l2_regularizer(self, value):
+ self._l2_reg = value
+
+ def __getattr__(self, key):
+ if self.is_struct:
+ if key not in self.params:
+ return None
+ value = self.params[key]
+ if type(value) == struct_pb2.Struct:
+ return Parameter(value, True, self._l2_reg)
+ else:
+ return value
+ value = getattr(self.params, key)
+ if is_proto_message(self.params, key):
+ return Parameter(value, False, self._l2_reg)
+ return value
+
+ def __getitem__(self, key):
+ return self.__getattr__(key)
+
+ def get_or_default(self, key, def_val):
+ if self.is_struct:
+ if key in self.params:
+ if def_val is None:
+ return self.params[key]
+ value = self.params[key]
+ if type(value) == float:
+ return type(def_val)(value)
+ return value
+ return def_val
+ else: # pb message
+ value = getattr(self.params, key, def_val)
+ if hasattr(value, '__len__'): # repeated
+ return value if len(value) > 0 else def_val
+ try:
+ if self.params.HasField(key):
+ return value
+ except ValueError:
+ pass
+ return def_val # maybe not equal to the default value of msg field
+
+ def check_required(self, keys):
+ if not self.is_struct:
+ return
+ if not isinstance(keys, (list, tuple)):
+ keys = [keys]
+ for key in keys:
+ if key not in self.params:
+ raise KeyError('%s must be set in params' % key)
+
+ def has_field(self, key):
+ if self.is_struct:
+ return key in self.params
+ else:
+ return self.params.HasField(key)
diff --git a/easy_rec/python/layers/variational_dropout_layer.py b/easy_rec/python/layers/variational_dropout_layer.py
index 56fe32501..0eeddcf7b 100644
--- a/easy_rec/python/layers/variational_dropout_layer.py
+++ b/easy_rec/python/layers/variational_dropout_layer.py
@@ -1,5 +1,7 @@
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
+
import numpy as np
import tensorflow as tf
@@ -21,10 +23,11 @@ class VariationalDropoutLayer(object):
def __init__(self,
variational_dropout_config,
features_dimension,
- is_training=False):
+ is_training=False,
+ name=''):
self._config = variational_dropout_config
self.features_dimension = features_dimension
- self.features_total_dimension = sum(self.features_dimension)
+ self.features_total_dimension = sum(self.features_dimension.values())
if self.variational_dropout_wise():
self._dropout_param_size = self.features_total_dimension
@@ -34,11 +37,15 @@ def __init__(self,
self.drop_param_shape = [self._dropout_param_size]
self.evaluate = not is_training
+ logit_p_name = 'logit_p' if name == 'all' else 'logit_p_%s' % name
self.logit_p = tf.get_variable(
- name='logit_p',
+ name=logit_p_name,
shape=self.drop_param_shape,
dtype=tf.float32,
initializer=None)
+ tf.add_to_collection(
+ 'variational_dropout',
+ json.dumps([name, list(self.features_dimension.items())]))
def get_lambda(self):
return self._config.regularization_lambda
@@ -49,8 +56,7 @@ def variational_dropout_wise(self):
def build_expand_index(self, batch_size):
# Build index_list--->[[0,0],[0,0],[0,0],[0,0],[0,1]......]
expanded_index = []
- for i in range(len(self.features_dimension)):
- index_loop_count = self.features_dimension[i]
+ for i, index_loop_count in enumerate(self.features_dimension.values()):
for m in range(index_loop_count):
expanded_index.append([i])
expanded_index = tf.tile(expanded_index, [batch_size, 1])
diff --git a/easy_rec/python/loss/contrastive_loss.py b/easy_rec/python/loss/contrastive_loss.py
new file mode 100644
index 000000000..3fd2be645
--- /dev/null
+++ b/easy_rec/python/loss/contrastive_loss.py
@@ -0,0 +1,79 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def l2_loss(x1, x2):
+ """Compute euclidean distance of two embeddings."""
+ distance = tf.reduce_sum(tf.square(x1 - x2), axis=-1)
+ return tf.reduce_mean(distance)
+
+
+def info_nce_loss(query, positive, temperature=0.1):
+ """Calculates the InfoNCE loss for self-supervised learning.
+
+ This contrastive loss enforces the embeddings of similar (positive) samples to be close
+ and those of different (negative) samples to be distant.
+ A query embedding is compared with one positive key and with one or more negative keys.
+
+ References:
+ https://arxiv.org/abs/1807.03748v2
+ https://arxiv.org/abs/2010.05113
+ """
+ # Check input dimensionality.
+ if query.shape.ndims != 2:
+ raise ValueError(' must have 2 dimensions.')
+ if positive.shape.ndims != 2:
+ raise ValueError(' must have 2 dimensions.')
+ # Embedding vectors should have same number of components.
+ if query.shape[-1] != positive.shape[-1]:
+ raise ValueError(
+ 'Vectors of and should have the same number of components.'
+ )
+
+ # Negative keys are implicitly off-diagonal positive keys.
+
+ # Cosine between all combinations
+ logits = tf.matmul(query, positive, transpose_b=True)
+ logits /= temperature
+
+ # Positive keys are the entries on the diagonal
+ batch_size = tf.shape(query)[0]
+ labels = tf.range(batch_size)
+
+ return tf.losses.sparse_softmax_cross_entropy(labels, logits)
+
+
+def get_mask_matrix(batch_size):
+ mat = tf.ones((batch_size, batch_size), dtype=tf.bool)
+ diag = tf.zeros([batch_size], dtype=tf.bool)
+ mask = tf.linalg.set_diag(mat, diag)
+ mask = tf.tile(mask, [2, 2])
+ return mask
+
+
+def nce_loss(z_i, z_j, temperature=1.0):
+ """Contrastive nce loss for homogeneous embeddings.
+
+ Refer paper: Contrastive Learning for Sequential Recommendation
+ """
+ batch_size = tf.shape(z_i)[0]
+ N = 2 * batch_size
+ z = tf.concat((z_i, z_j), axis=0)
+ sim = tf.matmul(z, tf.transpose(z)) / temperature
+ sim_i_j = tf.matrix_diag_part(
+ tf.slice(sim, [batch_size, 0], [batch_size, batch_size]))
+ sim_j_i = tf.matrix_diag_part(
+ tf.slice(sim, [0, batch_size], [batch_size, batch_size]))
+ positive_samples = tf.reshape(tf.concat((sim_i_j, sim_j_i), axis=0), (N, 1))
+ mask = get_mask_matrix(batch_size)
+ negative_samples = tf.reshape(tf.boolean_mask(sim, mask), (N, -1))
+
+ labels = tf.zeros(N, dtype=tf.int32)
+ logits = tf.concat((positive_samples, negative_samples), axis=1)
+
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
+ return loss
diff --git a/easy_rec/python/loss/f1_reweight_loss.py b/easy_rec/python/loss/f1_reweight_loss.py
new file mode 100644
index 000000000..3f9689f4d
--- /dev/null
+++ b/easy_rec/python/loss/f1_reweight_loss.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import tensorflow as tf
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def f1_reweight_sigmoid_cross_entropy(labels,
+ logits,
+ beta_square,
+ label_smoothing=0,
+ weights=None):
+ """Refer paper: Adaptive Scaling for Sparse Detection in Information Extraction."""
+ probs = tf.nn.sigmoid(logits)
+ if len(logits.shape.as_list()) == 1:
+ logits = tf.expand_dims(logits, -1)
+ if len(labels.shape.as_list()) == 1:
+ labels = tf.expand_dims(labels, -1)
+ labels = tf.to_float(labels)
+ batch_size = tf.shape(labels)[0]
+ batch_size_float = tf.to_float(batch_size)
+ num_pos = tf.reduce_sum(labels, axis=0)
+ num_neg = batch_size_float - num_pos
+ tp = tf.reduce_sum(probs, axis=0)
+ tn = batch_size_float - tp
+ neg_weight = tp / (beta_square * num_pos + num_neg - tn + 1e-8)
+ neg_weight_tile = tf.tile(tf.expand_dims(neg_weight, 0), [batch_size, 1])
+ final_weights = tf.where(
+ tf.equal(labels, 1.0), tf.ones_like(labels), neg_weight_tile)
+ if weights is not None:
+ weights = tf.cast(weights, tf.float32)
+ if len(weights.shape.as_list()) == 1:
+ weights = tf.expand_dims(weights, -1)
+ final_weights *= weights
+ return tf.losses.sigmoid_cross_entropy(
+ labels, logits, final_weights, label_smoothing=label_smoothing)
diff --git a/easy_rec/python/loss/focal_loss.py b/easy_rec/python/loss/focal_loss.py
new file mode 100644
index 000000000..9ef6a94a7
--- /dev/null
+++ b/easy_rec/python/loss/focal_loss.py
@@ -0,0 +1,93 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def sigmoid_focal_loss_with_logits(labels,
+ logits,
+ gamma=2.0,
+ alpha=None,
+ ohem_ratio=1.0,
+ sample_weights=None,
+ label_smoothing=0,
+ name=''):
+ """Implements the focal loss function.
+
+ Focal loss was first introduced in the RetinaNet paper
+ (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
+ classification when you have highly imbalanced classes. It down-weights
+ well-classified examples and focuses on hard examples. The loss value is
+ much high for a sample which is misclassified by the classifier as compared
+ to the loss value corresponding to a well-classified example. One of the
+ best use-cases of focal loss is its usage in object detection where the
+ imbalance between the background class and other classes is extremely high.
+
+ Args
+ labels: `[batch_size]` target integer labels in `{0, 1}`.
+ logits: Float `[batch_size]` logits outputs of the network.
+ alpha: balancing factor.
+ gamma: modulating factor.
+ ohem_ratio: the percent of hard examples to be mined
+ sample_weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `losses` dimension).
+ label_smoothing: If greater than `0` then smooth the labels.
+ name: the name of loss
+
+ Returns:
+ Weighted loss float `Tensor`. If `reduction` is `NONE`,this has the
+ same shape as `y_true`; otherwise, it is scalar.
+
+ Raises:
+ ValueError: If the shape of `sample_weight` is invalid or value of
+ `gamma` is less than zero
+ """
+ loss_name = name if name else 'focal_loss'
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
+ if gamma and gamma < 0:
+ raise ValueError('Value of gamma should be greater than or equal to zero')
+ logging.info(
+ '[{}] gamma: {}, alpha: {}, ohem_ratho: {}, label smoothing: {}'.format(
+ loss_name, gamma, alpha, ohem_ratio, label_smoothing))
+
+ y_true = tf.cast(labels, logits.dtype)
+
+ # convert the predictions into probabilities
+ y_pred = tf.nn.sigmoid(logits)
+ epsilon = 1e-7
+ y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)
+ p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
+ weights = tf.pow((1 - p_t), gamma)
+
+ if alpha is not None:
+ alpha_factor = y_true * alpha + ((1 - alpha) * (1 - y_true))
+ weights *= alpha_factor
+
+ if sample_weights is not None:
+ if tf.is_numeric_tensor(sample_weights):
+ logging.info('[%s] use sample weight' % loss_name)
+ weights *= tf.cast(sample_weights, tf.float32)
+ elif sample_weights != 1.0:
+ logging.info('[%s] use sample weight: %f' % (loss_name, sample_weights))
+ weights *= sample_weights
+
+ if ohem_ratio == 1.0:
+ return tf.losses.sigmoid_cross_entropy(
+ y_true, logits, weights=weights, label_smoothing=label_smoothing)
+
+ losses = tf.losses.sigmoid_cross_entropy(
+ y_true,
+ logits,
+ weights=weights,
+ label_smoothing=label_smoothing,
+ reduction=tf.losses.Reduction.NONE)
+ k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
+ k = tf.to_int32(tf.math.rint(k))
+ topk = tf.nn.top_k(losses, k)
+ losses = tf.boolean_mask(topk.values, topk.values > 0)
+ return tf.reduce_mean(losses)
diff --git a/easy_rec/python/loss/jrc_loss.py b/easy_rec/python/loss/jrc_loss.py
new file mode 100644
index 000000000..b5165d3c2
--- /dev/null
+++ b/easy_rec/python/loss/jrc_loss.py
@@ -0,0 +1,128 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import numpy as np
+import tensorflow as tf
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def jrc_loss(labels,
+ logits,
+ session_ids,
+ alpha=0.5,
+ loss_weight_strategy='fixed',
+ sample_weights=1.0,
+ same_label_loss=True,
+ name=''):
+ """Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model.
+
+ https://arxiv.org/abs/2208.06164
+
+ Args:
+ labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
+ logits: a `Tensor` with shape [batch_size, 2]. e.g. the value of last neuron before activation.
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
+ alpha: the weight to balance ranking loss and calibration loss
+ loss_weight_strategy: str, the loss weight strategy to balancing between ce_loss and ge_loss
+ sample_weights: Coefficients for the loss. This must be scalar or broadcastable to
+ `labels` (i.e. same rank and each dimension is either 1 or the same).
+ same_label_loss: enable ge_loss for sample with same label in a session or not.
+ name: the name of loss
+ """
+ loss_name = name if name else 'jrc_loss'
+ logging.info('[{}] alpha: {}, loss_weight_strategy: {}'.format(
+ loss_name, alpha, loss_weight_strategy))
+
+ ce_loss = tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights=sample_weights)
+
+ labels = tf.expand_dims(labels, 1) # [B, 1]
+ labels = tf.concat([1 - labels, labels], axis=1) # [B, 2]
+
+ batch_size = tf.shape(logits)[0]
+
+ # Mask: shape [B, B], mask[i,j]=1 indicates the i-th sample
+ # and j-th sample are in the same context
+ mask = tf.equal(
+ tf.expand_dims(session_ids, 1), tf.expand_dims(session_ids, 0))
+ mask = tf.to_float(mask)
+
+ # Tile logits and label: [B, 2]->[B, B, 2]
+ logits = tf.tile(tf.expand_dims(logits, 1), [1, batch_size, 1])
+ y = tf.tile(tf.expand_dims(labels, 1), [1, batch_size, 1])
+
+ # Set logits that are not in the same context to -inf
+ mask3d = tf.expand_dims(mask, 2)
+ y = tf.to_float(y) * mask3d
+ logits = logits + (1 - mask3d) * -1e9
+ y_neg, y_pos = y[:, :, 0], y[:, :, 1]
+ l_neg, l_pos = logits[:, :, 0], logits[:, :, 1]
+
+ if tf.is_numeric_tensor(sample_weights):
+ logging.info('[%s] use sample weight' % loss_name)
+ weights = tf.expand_dims(tf.cast(sample_weights, tf.float32), 0)
+ pairwise_weights = tf.tile(weights, tf.stack([batch_size, 1]))
+ y_pos *= pairwise_weights
+ y_neg *= pairwise_weights
+
+ # Compute list-wise generative loss -log p(x|y, z)
+ if same_label_loss:
+ logging.info('[%s] enable same_label_loss' % loss_name)
+ loss_pos = -tf.reduce_sum(y_pos * tf.nn.log_softmax(l_pos, axis=0), axis=0)
+ loss_neg = -tf.reduce_sum(y_neg * tf.nn.log_softmax(l_neg, axis=0), axis=0)
+ ge_loss = tf.reduce_mean(
+ (loss_pos + loss_neg) / tf.reduce_sum(mask, axis=0))
+ else:
+ logging.info('[%s] disable same_label_loss' % loss_name)
+ diag = tf.one_hot(tf.range(batch_size), batch_size)
+ l_pos = l_pos + (1 - diag) * y_pos * -1e9
+ l_neg = l_neg + (1 - diag) * y_neg * -1e9
+ loss_pos = -tf.linalg.diag_part(y_pos * tf.nn.log_softmax(l_pos, axis=0))
+ loss_neg = -tf.linalg.diag_part(y_neg * tf.nn.log_softmax(l_neg, axis=0))
+ ge_loss = tf.reduce_mean(loss_pos + loss_neg)
+
+ tf.summary.scalar('loss/%s_ce' % loss_name, ce_loss)
+ tf.summary.scalar('loss/%s_ge' % loss_name, ge_loss)
+
+ # The final JRC model
+ if loss_weight_strategy == 'fixed':
+ loss = alpha * ce_loss + (1 - alpha) * ge_loss
+ elif loss_weight_strategy == 'random_uniform':
+ weight = tf.random_uniform([])
+ loss = weight * ce_loss + (1 - weight) * ge_loss
+ tf.summary.scalar('loss/%s_ce_weight' % loss_name, weight)
+ tf.summary.scalar('loss/%s_ge_weight' % loss_name, 1 - weight)
+ elif loss_weight_strategy == 'random_normal':
+ weights = tf.random_normal([2])
+ loss_weight = tf.nn.softmax(weights)
+ loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
+ tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
+ tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
+ elif loss_weight_strategy == 'random_bernoulli':
+ bern = tf.distributions.Bernoulli(probs=0.5, dtype=tf.float32)
+ weights = bern.sample(2)
+ loss_weight = tf.cond(
+ tf.equal(tf.reduce_sum(weights), 1), lambda: weights,
+ lambda: tf.convert_to_tensor([0.5, 0.5]))
+ loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
+ tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
+ tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
+ elif loss_weight_strategy == 'uncertainty':
+ uncertainty1 = tf.Variable(
+ 0, name='%s_ranking_loss_weight' % loss_name, dtype=tf.float32)
+ tf.summary.scalar('loss/%s_ranking_uncertainty' % loss_name, uncertainty1)
+ uncertainty2 = tf.Variable(
+ 0, name='%s_calibration_loss_weight' % loss_name, dtype=tf.float32)
+ tf.summary.scalar('loss/%s_calibration_uncertainty' % loss_name,
+ uncertainty2)
+ loss = tf.exp(-uncertainty1) * ce_loss + 0.5 * uncertainty1
+ loss += tf.exp(-uncertainty2) * ge_loss + 0.5 * uncertainty2
+ else:
+ raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' %
+ loss_weight_strategy)
+ if np.isscalar(sample_weights) and sample_weights != 1.0:
+ return loss * sample_weights
+ return loss
diff --git a/easy_rec/python/loss/listwise_loss.py b/easy_rec/python/loss/listwise_loss.py
new file mode 100644
index 000000000..24bd5864f
--- /dev/null
+++ b/easy_rec/python/loss/listwise_loss.py
@@ -0,0 +1,161 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
+import tensorflow as tf
+
+from easy_rec.python.utils.load_class import load_by_path
+
+
+def _list_wise_loss(x, labels, logits, session_ids, label_is_logits):
+ mask = tf.equal(x, session_ids)
+ logits = tf.boolean_mask(logits, mask)
+ labels = tf.boolean_mask(labels, mask)
+ y = tf.nn.softmax(labels) if label_is_logits else labels
+ y_hat = tf.nn.log_softmax(logits)
+ return -tf.reduce_sum(y * y_hat)
+
+
+def _list_prob_loss(x, labels, logits, session_ids):
+ mask = tf.equal(x, session_ids)
+ logits = tf.boolean_mask(logits, mask)
+ labels = tf.boolean_mask(labels, mask)
+ y = labels / tf.reduce_sum(labels)
+ y_hat = tf.nn.log_softmax(logits)
+ return -tf.reduce_sum(y * y_hat)
+
+
+def listwise_rank_loss(labels,
+ logits,
+ session_ids,
+ transform_fn=None,
+ temperature=1.0,
+ label_is_logits=False,
+ scale_logits=False,
+ weights=1.0,
+ name='listwise_loss'):
+ r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
+
+ Definition:
+ $$
+ \mathcal{L}(\{y\}, \{s\}) =
+ \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} )
+ $$
+
+ Args:
+ labels: A `Tensor` of the same shape as `logits` representing graded
+ relevance.
+ logits: A `Tensor` with shape [batch_size].
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
+ transform_fn: an affine transformation function of labels
+ temperature: (Optional) The temperature to use for scaling the logits.
+ label_is_logits: Whether `labels` is expected to be a logits tensor.
+ By default, we consider that `labels` encodes a probability distribution.
+ scale_logits: Whether to scale the logits.
+ weights: sample weights
+ name: the name of loss
+ """
+ loss_name = name if name else 'listwise_rank_loss'
+ logging.info('[{}] temperature: {}, scale logits: {}'.format(
+ loss_name, temperature, scale_logits))
+ labels = tf.to_float(labels)
+ if scale_logits:
+ with tf.variable_scope(loss_name):
+ w = tf.get_variable(
+ 'scale_w',
+ dtype=tf.float32,
+ shape=(1,),
+ initializer=tf.ones_initializer())
+ b = tf.get_variable(
+ 'scale_b',
+ dtype=tf.float32,
+ shape=(1,),
+ initializer=tf.zeros_initializer())
+ logits = logits * tf.abs(w) + b
+ if temperature != 1.0:
+ logits /= temperature
+ if label_is_logits:
+ labels /= temperature
+ if transform_fn is not None:
+ trans_fn = load_by_path(transform_fn)
+ labels = trans_fn(labels)
+
+ sessions, _ = tf.unique(tf.squeeze(session_ids))
+ tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
+ losses = tf.map_fn(
+ lambda x: _list_wise_loss(x, labels, logits, session_ids, label_is_logits
+ ),
+ sessions,
+ dtype=tf.float32)
+ if tf.is_numeric_tensor(weights):
+ logging.error('[%s] use unsupported sample weight' % loss_name)
+ return tf.reduce_mean(losses)
+ else:
+ return tf.reduce_mean(losses) * weights
+
+
+def listwise_distill_loss(labels,
+ logits,
+ session_ids,
+ transform_fn=None,
+ temperature=1.0,
+ label_clip_max_value=512,
+ scale_logits=False,
+ weights=1.0,
+ name='listwise_distill_loss'):
+ r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
+
+ Definition:
+ $$
+ \mathcal{L}(\{y\}, \{s\}) =
+ \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} )
+ $$
+
+ Args:
+ labels: A `Tensor` of the same shape as `logits` representing the rank position of a base model.
+ logits: A `Tensor` with shape [batch_size].
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
+ transform_fn: an transformation function of labels.
+ temperature: (Optional) The temperature to use for scaling the logits.
+ label_clip_max_value: clip the labels to this value.
+ scale_logits: Whether to scale the logits.
+ weights: sample weights
+ name: the name of loss
+ """
+ loss_name = name if name else 'listwise_rank_loss'
+ logging.info('[{}] temperature: {}'.format(loss_name, temperature))
+ labels = tf.to_float(labels) # supposed to be positions of a teacher model
+ labels = tf.clip_by_value(labels, 1, label_clip_max_value)
+ if transform_fn is not None:
+ trans_fn = load_by_path(transform_fn)
+ labels = trans_fn(labels)
+ else:
+ labels = tf.log1p(label_clip_max_value) - tf.log(labels)
+
+ if scale_logits:
+ with tf.variable_scope(loss_name):
+ w = tf.get_variable(
+ 'scale_w',
+ dtype=tf.float32,
+ shape=(1,),
+ initializer=tf.ones_initializer())
+ b = tf.get_variable(
+ 'scale_b',
+ dtype=tf.float32,
+ shape=(1,),
+ initializer=tf.zeros_initializer())
+ logits = logits * tf.abs(w) + b
+ if temperature != 1.0:
+ logits /= temperature
+
+ sessions, _ = tf.unique(tf.squeeze(session_ids))
+ tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
+ losses = tf.map_fn(
+ lambda x: _list_prob_loss(x, labels, logits, session_ids),
+ sessions,
+ dtype=tf.float32)
+ if tf.is_numeric_tensor(weights):
+ logging.error('[%s] use unsupported sample weight' % loss_name)
+ return tf.reduce_mean(losses)
+ else:
+ return tf.reduce_mean(losses) * weights
diff --git a/easy_rec/python/loss/pairwise_loss.py b/easy_rec/python/loss/pairwise_loss.py
index 9e16e3bdb..604f1ce2e 100644
--- a/easy_rec/python/loss/pairwise_loss.py
+++ b/easy_rec/python/loss/pairwise_loss.py
@@ -1,27 +1,307 @@
-# coding=utf-8
+# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import tensorflow as tf
+from tensorflow.python.ops.losses.losses_impl import compute_weighted_loss
+
+from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
+from easy_rec.python.utils.shape_utils import get_shape_list
if tf.__version__ >= '2.0':
tf = tf.compat.v1
-def pairwise_loss(labels, logits):
- pairwise_logits = tf.expand_dims(logits, -1) - tf.expand_dims(logits, 0)
- logging.info('[pairwise_loss] pairwise logits: {}'.format(pairwise_logits))
+def pairwise_loss(labels,
+ logits,
+ session_ids=None,
+ margin=0,
+ temperature=1.0,
+ weights=1.0,
+ name=''):
+ """Deprecated Pairwise loss. Also see `pairwise_logistic_loss` below.
+
+ Args:
+ labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
+ logits: a `Tensor` with shape [batch_size]. e.g. the value of last neuron before activation.
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
+ margin: the margin between positive and negative sample pair
+ temperature: (Optional) The temperature to use for scaling the logits.
+ weights: sample weights
+ name: the name of loss
+ """
+ logging.warning(
+ 'The old `pairwise_loss` is being deprecated. '
+ 'Please use the new `pairwise_logistic_loss` or `pairwise_focal_loss`')
+ loss_name = name if name else 'pairwise_loss'
+ logging.info('[{}] margin: {}, temperature: {}'.format(
+ loss_name, margin, temperature))
+ if temperature != 1.0:
+ logits /= temperature
+ pairwise_logits = tf.math.subtract(
+ tf.expand_dims(logits, -1), tf.expand_dims(logits, 0)) - margin
pairwise_mask = tf.greater(
- tf.expand_dims(labels, -1) - tf.expand_dims(labels, 0), 0)
- logging.info('[pairwise_loss] mask: {}'.format(pairwise_mask))
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
+ if session_ids is not None:
+ logging.info('[%s] use session ids' % loss_name)
+ group_equal = tf.equal(
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
- logging.info('[pairwise_loss] after masking: {}'.format(pairwise_logits))
+ num_pair = tf.size(pairwise_logits)
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
+
+ if tf.is_numeric_tensor(weights):
+ logging.info('[%s] use sample weight' % loss_name)
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
+ batch_size, _ = get_shape_list(weights, 2)
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
+ else:
+ pairwise_weights = weights
pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
- loss = tf.losses.sigmoid_cross_entropy(pairwise_pseudo_labels,
- pairwise_logits)
+ loss = tf.losses.sigmoid_cross_entropy(
+ pairwise_pseudo_labels, pairwise_logits, weights=pairwise_weights)
# set rank loss to zero if a batch has no positive sample.
- loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
+ # loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
+ return loss
+
+
+def pairwise_focal_loss(labels,
+ logits,
+ session_ids=None,
+ hinge_margin=None,
+ gamma=2,
+ alpha=None,
+ ohem_ratio=1.0,
+ temperature=1.0,
+ weights=1.0,
+ name=''):
+ loss_name = name if name else 'pairwise_focal_loss'
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
+ logging.info(
+ '[{}] hinge margin: {}, gamma: {}, alpha: {}, ohem_ratio: {}, temperature: {}'
+ .format(loss_name, hinge_margin, gamma, alpha, ohem_ratio, temperature))
+
+ if temperature != 1.0:
+ logits /= temperature
+ pairwise_logits = tf.expand_dims(logits, -1) - tf.expand_dims(logits, 0)
+
+ pairwise_mask = tf.greater(
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
+ if hinge_margin is not None:
+ hinge_mask = tf.less(pairwise_logits, hinge_margin)
+ pairwise_mask = tf.logical_and(pairwise_mask, hinge_mask)
+ if session_ids is not None:
+ logging.info('[%s] use session ids' % loss_name)
+ group_equal = tf.equal(
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
+
+ pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
+ num_pair = tf.size(pairwise_logits)
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
+
+ if tf.is_numeric_tensor(weights):
+ logging.info('[%s] use sample weight' % loss_name)
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
+ batch_size, _ = get_shape_list(weights, 2)
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
+ else:
+ pairwise_weights = weights
+
+ pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
+ loss = sigmoid_focal_loss_with_logits(
+ pairwise_pseudo_labels,
+ pairwise_logits,
+ gamma=gamma,
+ alpha=alpha,
+ ohem_ratio=ohem_ratio,
+ sample_weights=pairwise_weights)
return loss
+
+
+def pairwise_logistic_loss(labels,
+ logits,
+ session_ids=None,
+ temperature=1.0,
+ hinge_margin=None,
+ weights=1.0,
+ ohem_ratio=1.0,
+ use_label_margin=False,
+ name=''):
+ r"""Computes pairwise logistic loss between `labels` and `logits`, equivalent to RankNet loss.
+
+ Definition:
+ $$
+ \mathcal{L}(\{y\}, \{s\}) =
+ \sum_i \sum_j I[y_i > y_j] \log(1 + \exp(-(s_i - s_j)))
+ $$
+
+ Args:
+ labels: A `Tensor` of the same shape as `logits` representing graded
+ relevance.
+ logits: A `Tensor` with shape [batch_size].
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each
+ sample, used to max GAUC metric. e.g. user_id
+ temperature: (Optional) The temperature to use for scaling the logits.
+ hinge_margin: the margin between positive and negative logits
+ weights: A scalar, a `Tensor` with shape [batch_size] for each sample
+ ohem_ratio: the percent of hard examples to be mined
+ use_label_margin: whether to use the diff `label[i]-label[j]` as margin
+ name: the name of loss
+ """
+ loss_name = name if name else 'pairwise_logistic_loss'
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
+ logging.info('[{}] hinge margin: {}, ohem_ratio: {}, temperature: {}'.format(
+ loss_name, hinge_margin, ohem_ratio, temperature))
+
+ if temperature != 1.0:
+ logits /= temperature
+ if use_label_margin:
+ labels /= temperature
+
+ pairwise_logits = tf.math.subtract(
+ tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
+ if use_label_margin:
+ pairwise_logits -= tf.math.subtract(
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
+ elif hinge_margin is not None:
+ pairwise_logits -= hinge_margin
+
+ pairwise_mask = tf.greater(
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
+ if session_ids is not None:
+ logging.info('[%s] use session ids' % loss_name)
+ group_equal = tf.equal(
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
+
+ pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
+ num_pair = tf.size(pairwise_logits)
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
+
+ # The following is the same as log(1 + exp(-pairwise_logits)).
+ losses = tf.nn.relu(-pairwise_logits) + tf.math.log1p(
+ tf.exp(-tf.abs(pairwise_logits)))
+
+ if tf.is_numeric_tensor(weights):
+ logging.info('[%s] use sample weight' % loss_name)
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
+ batch_size, _ = get_shape_list(weights, 2)
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
+ else:
+ pairwise_weights = weights
+
+ if ohem_ratio == 1.0:
+ return compute_weighted_loss(losses, pairwise_weights)
+
+ losses = compute_weighted_loss(
+ losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
+ k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
+ k = tf.to_int32(tf.math.rint(k))
+ topk = tf.nn.top_k(losses, k)
+ losses = tf.boolean_mask(topk.values, topk.values > 0)
+ return tf.reduce_mean(losses)
+
+
+def pairwise_hinge_loss(labels,
+ logits,
+ session_ids=None,
+ temperature=1.0,
+ margin=1.0,
+ weights=1.0,
+ ohem_ratio=1.0,
+ label_is_logits=True,
+ use_label_margin=True,
+ use_exponent=False,
+ name=''):
+ r"""Computes pairwise hinge loss between `labels` and `logits`.
+
+ Definition:
+ $$
+ \mathcal{L}(\{y\}, \{s\}) =
+ \sum_i \sum_j I[y_i > y_j] \max(0, 1 - (s_i - s_j))
+ $$
+
+ Args:
+ labels: A `Tensor` of the same shape as `logits` representing graded
+ relevance.
+ logits: A `Tensor` with shape [batch_size].
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
+ temperature: (Optional) The temperature to use for scaling the logits.
+ margin: the margin between positive and negative logits
+ weights: A scalar, a `Tensor` with shape [batch_size] for each sample
+ ohem_ratio: the percent of hard examples to be mined
+ label_is_logits: Whether `labels` is expected to be a logits tensor.
+ use_label_margin: whether to use the diff `label[i]-label[j]` as margin
+ use_exponent: whether to use exponential difference
+ name: the name of loss
+ """
+ loss_name = name if name else 'pairwise_hinge_loss'
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
+ logging.info(
+ '[{}] margin: {}, ohem_ratio: {}, temperature: {}, use_exponent: {}, label_is_logits: {}, use_label_margin: {}'
+ .format(loss_name, margin, ohem_ratio, temperature, use_exponent,
+ label_is_logits, use_label_margin))
+
+ if temperature != 1.0:
+ logits /= temperature
+ if label_is_logits:
+ labels /= temperature
+ if use_exponent:
+ labels = tf.nn.sigmoid(labels)
+ logits = tf.nn.sigmoid(labels)
+
+ pairwise_logits = tf.math.subtract(
+ tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
+ pairwise_labels = tf.math.subtract(
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
+
+ pairwise_mask = tf.greater(pairwise_labels, 0)
+ if session_ids is not None:
+ logging.info('[%s] use session ids' % loss_name)
+ group_equal = tf.equal(
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
+
+ pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
+ pairwise_labels = tf.boolean_mask(pairwise_labels, pairwise_mask)
+ num_pair = tf.size(pairwise_logits)
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
+
+ if use_label_margin:
+ diff = pairwise_labels - pairwise_logits
+ else:
+ diff = margin - pairwise_logits
+ if use_exponent:
+ threshold = 88.0 # the max value of float32 is 3.4028235e+38
+ safe_diff = tf.clip_by_value(diff, -threshold, threshold)
+ losses = tf.nn.relu(tf.exp(safe_diff) - 1.0)
+ else:
+ losses = tf.nn.relu(diff)
+
+ if tf.is_numeric_tensor(weights):
+ logging.info('[%s] use sample weight' % loss_name)
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
+ batch_size, _ = get_shape_list(weights, 2)
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
+ else:
+ pairwise_weights = weights
+
+ if ohem_ratio == 1.0:
+ return compute_weighted_loss(losses, pairwise_weights)
+
+ losses = compute_weighted_loss(
+ losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
+ k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
+ k = tf.to_int32(tf.math.rint(k))
+ topk = tf.nn.top_k(losses, k)
+ losses = tf.boolean_mask(topk.values, topk.values > 0)
+ return tf.reduce_mean(losses)
diff --git a/easy_rec/python/loss/softmax_loss_with_negative_mining.py b/easy_rec/python/loss/softmax_loss_with_negative_mining.py
index 417aad527..99f92d4af 100644
--- a/easy_rec/python/loss/softmax_loss_with_negative_mining.py
+++ b/easy_rec/python/loss/softmax_loss_with_negative_mining.py
@@ -2,8 +2,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
-from easy_rec.python.utils.shape_utils import get_shape_list
-
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -40,7 +38,8 @@ def softmax_loss_with_negative_mining(user_emb,
weights=1.0,
gamma=1.0,
margin=0,
- t=1):
+ t=1,
+ seed=None):
"""Compute the softmax loss based on the cosine distance explained below.
Given mini batches for `user_emb` and `item_emb`, this function computes for each element in `user_emb`
@@ -62,36 +61,50 @@ def softmax_loss_with_negative_mining(user_emb,
gamma: smooth coefficient of softmax
margin: the margin between positive pair and negative pair
t: coefficient of support vector guided softmax loss
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See `tf.set_random_seed`
+ for behavior.
+
Return:
support vector guided softmax loss of positive labels
"""
- batch_size = get_shape_list(item_emb)[0]
- assert 0 < num_negative_samples < batch_size, '`num_negative_samples` should be in range [1, batch_size)'
+ assert 0 < num_negative_samples, '`num_negative_samples` should be greater than 0'
- if not embed_normed:
- user_emb = tf.nn.l2_normalize(user_emb, axis=-1)
- item_emb = tf.nn.l2_normalize(item_emb, axis=-1)
+ batch_size = tf.shape(item_emb)[0]
+ is_valid = tf.assert_less(
+ num_negative_samples,
+ batch_size,
+ message='`num_negative_samples` should be less than batch_size')
+ with tf.control_dependencies([is_valid]):
+ if not embed_normed:
+ user_emb = tf.nn.l2_normalize(user_emb, axis=-1)
+ item_emb = tf.nn.l2_normalize(item_emb, axis=-1)
- vectors = [item_emb]
- for i in range(num_negative_samples):
- shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32)
- neg_item_emb = tf.roll(item_emb, shift, axis=0)
- vectors.append(neg_item_emb)
- # all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim)
- all_embeddings = tf.stack(vectors, axis=1)
+ vectors = [item_emb]
+ for i in range(num_negative_samples):
+ shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32, seed=seed)
+ neg_item_emb = tf.roll(item_emb, shift, axis=0)
+ vectors.append(neg_item_emb)
+ # all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim)
+ all_embeddings = tf.stack(vectors, axis=1)
- mask = tf.greater(labels, 0)
- mask_user_emb = tf.boolean_mask(user_emb, mask)
- mask_item_emb = tf.boolean_mask(all_embeddings, mask)
- if isinstance(weights, tf.Tensor):
- weights = tf.boolean_mask(weights, mask)
+ mask = tf.greater(labels, 0)
+ mask_user_emb = tf.boolean_mask(user_emb, mask)
+ mask_item_emb = tf.boolean_mask(all_embeddings, mask)
+ if isinstance(weights, tf.Tensor):
+ weights = tf.boolean_mask(weights, mask)
- # sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1)
- sim_scores = tf.keras.backend.batch_dot(
- mask_user_emb, mask_item_emb, axes=(1, 2))
- pos_score = tf.slice(sim_scores, [0, 0], [-1, 1])
- neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1])
+ # sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1)
+ sim_scores = tf.keras.backend.batch_dot(
+ mask_user_emb, mask_item_emb, axes=(1, 2))
+ pos_score = tf.slice(sim_scores, [0, 0], [-1, 1])
+ neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1])
- loss = support_vector_guided_softmax_loss(
- pos_score, neg_scores, margin=margin, t=t, smooth=gamma, weights=weights)
+ loss = support_vector_guided_softmax_loss(
+ pos_score,
+ neg_scores,
+ margin=margin,
+ t=t,
+ smooth=gamma,
+ weights=weights)
return loss
diff --git a/easy_rec/python/loss/zero_inflated_lognormal.py b/easy_rec/python/loss/zero_inflated_lognormal.py
new file mode 100644
index 000000000..e3ae3110e
--- /dev/null
+++ b/easy_rec/python/loss/zero_inflated_lognormal.py
@@ -0,0 +1,76 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Zero-inflated lognormal loss for lifetime value prediction."""
+import tensorflow as tf
+import tensorflow_probability as tfp
+
+tfd = tfp.distributions
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def zero_inflated_lognormal_pred(logits):
+ """Calculates predicted mean of zero inflated lognormal logits.
+
+ Arguments:
+ logits: [batch_size, 3] tensor of logits.
+
+ Returns:
+ positive_probs: [batch_size, 1] tensor of positive probability.
+ preds: [batch_size, 1] tensor of predicted mean.
+ """
+ logits = tf.convert_to_tensor(logits, dtype=tf.float32)
+ positive_probs = tf.keras.backend.sigmoid(logits[..., :1])
+ loc = logits[..., 1:2]
+ scale = tf.keras.backend.softplus(logits[..., 2:])
+ preds = (
+ positive_probs *
+ tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))
+ return positive_probs, preds
+
+
+def zero_inflated_lognormal_loss(labels, logits, name=''):
+ """Computes the zero inflated lognormal loss.
+
+ Usage with tf.keras API:
+
+ ```python
+ model = tf.keras.Model(inputs, outputs)
+ model.compile('sgd', loss=zero_inflated_lognormal)
+ ```
+
+ Arguments:
+ labels: True targets, tensor of shape [batch_size, 1].
+ logits: Logits of output layer, tensor of shape [batch_size, 3].
+ name: the name of loss
+
+ Returns:
+ Zero inflated lognormal loss value.
+ """
+ loss_name = name if name else 'ziln_loss'
+ labels = tf.cast(labels, dtype=tf.float32)
+ if labels.shape.ndims == 1:
+ labels = tf.expand_dims(labels, 1) # [B, 1]
+ positive = tf.cast(labels > 0, tf.float32)
+
+ logits = tf.convert_to_tensor(logits, dtype=tf.float32)
+ logits.shape.assert_is_compatible_with(
+ tf.TensorShape(labels.shape[:-1].as_list() + [3]))
+
+ positive_logits = logits[..., :1]
+ classification_loss = tf.keras.backend.binary_crossentropy(
+ positive, positive_logits, from_logits=True)
+ classification_loss = tf.keras.backend.mean(classification_loss)
+ tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss)
+
+ loc = logits[..., 1:2]
+ scale = tf.math.maximum(
+ tf.keras.backend.softplus(logits[..., 2:]),
+ tf.math.sqrt(tf.keras.backend.epsilon()))
+ safe_labels = positive * labels + (
+ 1 - positive) * tf.keras.backend.ones_like(labels)
+ regression_loss = -tf.keras.backend.mean(
+ positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels))
+ tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss)
+ return classification_loss + regression_loss
diff --git a/easy_rec/python/main.py b/easy_rec/python/main.py
index 7afd453b8..1747155fe 100644
--- a/easy_rec/python/main.py
+++ b/easy_rec/python/main.py
@@ -9,6 +9,7 @@
import logging
import math
import os
+import time
import six
import tensorflow as tf
@@ -16,26 +17,37 @@
import easy_rec
from easy_rec.python.builders import strategy_builder
+from easy_rec.python.compat import estimator_train
from easy_rec.python.compat import exporter
from easy_rec.python.input.input import Input
from easy_rec.python.model.easy_rec_estimator import EasyRecEstimator
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.train_pb2 import DistributionStrategy
from easy_rec.python.utils import config_util
+from easy_rec.python.utils import constant
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import load_class
+from easy_rec.python.utils.config_util import get_eval_input_path
+from easy_rec.python.utils.config_util import get_model_dir_path
+from easy_rec.python.utils.config_util import get_train_input_path
+from easy_rec.python.utils.config_util import set_eval_input_path
from easy_rec.python.utils.export_big_model import export_big_model
from easy_rec.python.utils.export_big_model import export_big_model_to_oss
+try:
+ import horovod.tensorflow as hvd
+except Exception:
+ hvd = None
+
if tf.__version__ >= '2.0':
- gfile = tf.compat.v1.gfile
from tensorflow.core.protobuf import config_pb2
ConfigProto = config_pb2.ConfigProto
GPUOptions = config_pb2.GPUOptions
+
+ tf = tf.compat.v1
else:
- gfile = tf.gfile
GPUOptions = tf.GPUOptions
ConfigProto = tf.ConfigProto
@@ -55,7 +67,9 @@
def _get_input_fn(data_config,
feature_configs,
data_path=None,
- export_config=None):
+ export_config=None,
+ check_mode=False,
+ **kwargs):
"""Build estimator input function.
Args:
@@ -78,7 +92,9 @@ def _get_input_fn(data_config,
feature_configs,
data_path,
task_index=task_id,
- task_num=task_num)
+ task_num=task_num,
+ check_mode=check_mode,
+ **kwargs)
input_fn = input_obj.create_input(export_config)
return input_fn
@@ -86,13 +102,37 @@ def _get_input_fn(data_config,
def _create_estimator(pipeline_config, distribution=None, params={}):
model_config = pipeline_config.model_config
train_config = pipeline_config.train_config
- gpu_options = GPUOptions(allow_growth=False)
+ gpu_options = GPUOptions(allow_growth=True) # False)
+
+ logging.info(
+ 'train_config.train_distribute=%s[value=%d]' %
+ (DistributionStrategy.Name(pipeline_config.train_config.train_distribute),
+ pipeline_config.train_config.train_distribute))
+
+ # set gpu options only under hvd scenes
+ if hvd is not None and pipeline_config.train_config.train_distribute in [
+ DistributionStrategy.EmbeddingParallelStrategy,
+ DistributionStrategy.SokStrategy, DistributionStrategy.HorovodStrategy
+ ]:
+ local_rnk = hvd.local_rank()
+ gpus = tf.config.experimental.list_physical_devices('GPU')
+ logging.info('local_rnk=%d num_gpus=%d' % (local_rnk, len(gpus)))
+ if len(gpus) > 0:
+ tf.config.experimental.set_visible_devices(gpus[local_rnk], 'GPU')
+ gpu_options.visible_device_list = str(local_rnk)
+
session_config = ConfigProto(
gpu_options=gpu_options,
allow_soft_placement=True,
log_device_placement=params.get('log_device_placement', False),
inter_op_parallelism_threads=train_config.inter_op_parallelism_threads,
intra_op_parallelism_threads=train_config.intra_op_parallelism_threads)
+
+ if constant.NO_ARITHMETRIC_OPTI in os.environ:
+ logging.info('arithmetic_optimization is closed to improve performance')
+ session_config.graph_options.rewrite_options.arithmetic_optimization = \
+ session_config.graph_options.rewrite_options.OFF
+
session_config.device_filters.append('/job:ps')
model_cls = EasyRecModel.create_class(model_config.model_class)
@@ -109,7 +149,7 @@ def _create_estimator(pipeline_config, distribution=None, params={}):
run_config = tf.estimator.RunConfig(
model_dir=pipeline_config.model_dir,
- log_step_count_steps=train_config.log_step_count_steps,
+ log_step_count_steps=None, # train_config.log_step_count_steps,
save_summary_steps=train_config.save_summary_steps,
save_checkpoints_steps=save_checkpoints_steps,
save_checkpoints_secs=save_checkpoints_secs,
@@ -123,7 +163,7 @@ def _create_estimator(pipeline_config, distribution=None, params={}):
return estimator, run_config
-def _create_eval_export_spec(pipeline_config, eval_data):
+def _create_eval_export_spec(pipeline_config, eval_data, check_mode=False):
data_config = pipeline_config.data_config
# feature_configs = pipeline_config.feature_configs
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
@@ -135,9 +175,17 @@ def _create_eval_export_spec(pipeline_config, eval_data):
logging.info('eval_steps = %d' % eval_steps)
else:
eval_steps = None
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
# create eval input
- export_input_fn = _get_input_fn(data_config, feature_configs, None,
- export_config)
+ export_input_fn = _get_input_fn(
+ data_config,
+ feature_configs,
+ None,
+ export_config,
+ check_mode=check_mode,
+ **input_fn_kwargs)
if export_config.exporter_type == 'final':
exporters = [
FinalExporter(name='final', serving_input_receiver_fn=export_input_fn)
@@ -178,7 +226,8 @@ def _metric_cmp_fn(best_eval_result, current_eval_result):
# set throttle_secs to a small number, so that we can control evaluation
# interval steps by checkpoint saving steps
- eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data)
+ eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
+ **input_fn_kwargs)
eval_spec = tf.estimator.EvalSpec(
name='val',
input_fn=eval_input_fn,
@@ -190,25 +239,28 @@ def _metric_cmp_fn(best_eval_result, current_eval_result):
def _check_model_dir(model_dir, continue_train):
if not continue_train:
- if not gfile.IsDirectory(model_dir):
- gfile.MakeDirs(model_dir)
+ if not tf.gfile.IsDirectory(model_dir):
+ tf.gfile.MakeDirs(model_dir)
else:
- assert len(gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
+ assert len(tf.gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
'model_dir[=%s] already exists and not empty(if you ' \
'want to continue train on current model_dir please ' \
'delete dir %s or specify --continue_train[internal use only])' % (
model_dir, model_dir)
else:
- if not gfile.IsDirectory(model_dir):
+ if not tf.gfile.IsDirectory(model_dir):
logging.info('%s does not exists, create it automatically' % model_dir)
- gfile.MakeDirs(model_dir)
+ tf.gfile.MakeDirs(model_dir)
def _get_ckpt_path(pipeline_config, checkpoint_path):
if checkpoint_path != '' and checkpoint_path is not None:
- ckpt_path = checkpoint_path
- elif gfile.IsDirectory(pipeline_config.model_dir):
- ckpt_path = tf.train.latest_checkpoint(pipeline_config.model_dir)
+ if tf.gfile.IsDirectory(checkpoint_path):
+ ckpt_path = estimator_utils.latest_checkpoint(checkpoint_path)
+ else:
+ ckpt_path = checkpoint_path
+ elif tf.gfile.IsDirectory(pipeline_config.model_dir):
+ ckpt_path = estimator_utils.latest_checkpoint(pipeline_config.model_dir)
logging.info('checkpoint_path is not specified, '
'will use latest checkpoint %s from %s' %
(ckpt_path, pipeline_config.model_dir))
@@ -231,7 +283,8 @@ def train_and_evaluate(pipeline_config_path, continue_train=False):
Returns:
None, the model will be saved into pipeline_config.model_dir
"""
- assert gfile.Exists(pipeline_config_path), 'pipeline_config_path not exists'
+ assert tf.gfile.Exists(
+ pipeline_config_path), 'pipeline_config_path not exists'
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path)
@@ -240,10 +293,13 @@ def train_and_evaluate(pipeline_config_path, continue_train=False):
return pipeline_config
-def _train_and_evaluate_impl(pipeline_config, continue_train=False):
+def _train_and_evaluate_impl(pipeline_config,
+ continue_train=False,
+ check_mode=False,
+ fit_on_eval=False,
+ fit_on_eval_steps=None):
train_config = pipeline_config.train_config
data_config = pipeline_config.data_config
- # feature_configs = pipeline_config.feature_configs
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
if train_config.train_distribute != DistributionStrategy.NoStrategy\
@@ -253,50 +309,95 @@ def _train_and_evaluate_impl(pipeline_config, continue_train=False):
% pipeline_config.train_config.train_distribute)
pipeline_config.train_config.sync_replicas = False
- if pipeline_config.WhichOneof('train_path') == 'kafka_train_input':
- train_data = pipeline_config.kafka_train_input
- elif pipeline_config.WhichOneof('train_path') == 'datahub_train_input':
- train_data = pipeline_config.datahub_train_input
- else:
- train_data = pipeline_config.train_input_path
-
- if pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input':
- eval_data = pipeline_config.kafka_eval_input
- elif pipeline_config.WhichOneof('eval_path') == 'datahub_eval_input':
- eval_data = pipeline_config.datahub_eval_input
- else:
- eval_data = pipeline_config.eval_input_path
+ train_data = get_train_input_path(pipeline_config)
+ eval_data = get_eval_input_path(pipeline_config)
distribution = strategy_builder.build(train_config)
+ params = {}
+ if train_config.is_profiling:
+ params['log_device_placement'] = True
estimator, run_config = _create_estimator(
- pipeline_config, distribution=distribution)
+ pipeline_config, distribution=distribution, params=params)
- master_stat_file = os.path.join(pipeline_config.model_dir, 'master.stat')
version_file = os.path.join(pipeline_config.model_dir, 'version')
if estimator_utils.is_chief():
_check_model_dir(pipeline_config.model_dir, continue_train)
config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
- with gfile.GFile(version_file, 'w') as f:
+ with tf.gfile.GFile(version_file, 'w') as f:
f.write(easy_rec.__version__ + '\n')
- if gfile.Exists(master_stat_file):
- gfile.Remove(master_stat_file)
- train_steps = pipeline_config.train_config.num_steps
- if train_steps <= 0:
- train_steps = None
- logging.warn('will train INFINITE number of steps')
- else:
- logging.info('train_steps = %d' % train_steps)
+ train_steps = None
+ if train_config.HasField('num_steps') and train_config.num_steps > 0:
+ train_steps = train_config.num_steps
+ assert train_steps is not None or data_config.num_epochs > 0, (
+ 'either num_steps and num_epochs must be set to an integer > 0.')
+
+ if train_steps and data_config.num_epochs:
+ logging.info('Both num_steps and num_epochs are set.')
+ is_sync = train_config.sync_replicas
+ batch_size = data_config.batch_size
+ epoch_str = 'sample_num * %d / %d' % (data_config.num_epochs, batch_size)
+ if is_sync:
+ _, worker_num = estimator_utils.get_task_index_and_num()
+ epoch_str += ' / ' + str(worker_num)
+ logging.info('Will train min(%d, %s) steps...' % (train_steps, epoch_str))
+
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
+
# create train input
- train_input_fn = _get_input_fn(data_config, feature_configs, train_data)
+ train_input_fn = _get_input_fn(
+ data_config,
+ feature_configs,
+ train_data,
+ check_mode=check_mode,
+ **input_fn_kwargs)
# Currently only a single Eval Spec is allowed.
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps)
- # create eval spec
- eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
- from easy_rec.python.compat import estimator_train
- estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
+
+ embedding_parallel = train_config.train_distribute in (
+ DistributionStrategy.SokStrategy,
+ DistributionStrategy.EmbeddingParallelStrategy)
+
+ if embedding_parallel:
+ estimator.train(
+ input_fn=train_input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks),
+ saving_listeners=train_spec.saving_listeners)
+ train_input_fn.input_creator.stop()
+ else:
+ # create eval spec
+ eval_spec = _create_eval_export_spec(
+ pipeline_config, eval_data, check_mode=check_mode)
+ estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
logging.info('Train and evaluate finish')
+ if fit_on_eval and (not estimator_utils.is_evaluator()):
+ tf.reset_default_graph()
+ logging.info('Start continue training on eval data')
+ eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
+ **input_fn_kwargs)
+ if fit_on_eval_steps is not None:
+ # wait estimator train done to get the correct train_steps
+ while not estimator_train.estimator_train_done(estimator):
+ time.sleep(1)
+ train_steps = estimator_utils.get_trained_steps(estimator.model_dir)
+ logging.info('\ttrain_steps=%d fit_on_eval_steps=%d' %
+ (train_steps, fit_on_eval_steps))
+ fit_on_eval_steps += train_steps
+ # Do not use estimator_train.train_and_evaluate as it starts tf.Server,
+ # which is redundant and reports port not available error.
+ estimator.train(
+ input_fn=eval_input_fn,
+ max_steps=fit_on_eval_steps,
+ hooks=list(train_spec.hooks),
+ saving_listeners=train_spec.saving_listeners if hasattr(
+ train_spec, 'saving_listeners') else None)
+ logging.info('Finished training on eval data')
+ # return estimator for custom training using estimator.train
+ return estimator
def evaluate(pipeline_config,
@@ -330,16 +431,10 @@ def evaluate(pipeline_config,
fg_util.load_fg_json_to_config(pipeline_config)
if eval_data_path is not None:
logging.info('Evaluating on data: %s' % eval_data_path)
- if isinstance(eval_data_path, list):
- pipeline_config.eval_input_path = ','.join(eval_data_path)
- else:
- pipeline_config.eval_input_path = eval_data_path
- train_config = pipeline_config.train_config
+ set_eval_input_path(pipeline_config, eval_data_path)
- if pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input':
- eval_data = pipeline_config.kafka_eval_input
- else:
- eval_data = pipeline_config.eval_input_path
+ train_config = pipeline_config.train_config
+ eval_data = get_eval_input_path(pipeline_config)
server_target = None
if 'TF_CONFIG' in os.environ:
@@ -405,6 +500,7 @@ def evaluate(pipeline_config,
# worker_device='/job:master/task:0', cluster=cluster)):
eval_result = estimator.evaluate(
eval_spec.input_fn, eval_spec.steps, checkpoint_path=ckpt_path)
+ eval_spec.input_fn.input_creator.stop()
logging.info('Evaluate finish')
print('eval_result = ', eval_result)
@@ -413,7 +509,7 @@ def evaluate(pipeline_config,
model_dir = pipeline_config.model_dir
eval_result_file = os.path.join(model_dir, eval_result_filename)
logging.info('save eval result to file %s' % eval_result_file)
- with gfile.GFile(eval_result_file, 'w') as ofile:
+ with tf.gfile.GFile(eval_result_file, 'w') as ofile:
result_to_write = {}
for key in sorted(eval_result):
# skip logging binary data
@@ -428,7 +524,7 @@ def evaluate(pipeline_config,
def distribute_evaluate(pipeline_config,
eval_checkpoint_path='',
eval_data_path=None,
- eval_result_filename='eval_result.txt'):
+ eval_result_filename='distribute_eval_result.txt'):
"""Evaluate a EasyRec model defined in pipeline_config_path.
Evaluate the model defined in pipeline_config_path on the eval data,
@@ -454,16 +550,24 @@ def distribute_evaluate(pipeline_config,
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
if eval_data_path is not None:
logging.info('Evaluating on data: %s' % eval_data_path)
- if isinstance(eval_data_path, list):
- pipeline_config.eval_input_path = ','.join(eval_data_path)
- else:
- pipeline_config.eval_input_path = eval_data_path
+ set_eval_input_path(pipeline_config, eval_data_path)
train_config = pipeline_config.train_config
-
- if pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input':
- eval_data = pipeline_config.kafka_eval_input
- else:
- eval_data = pipeline_config.eval_input_path
+ eval_data = get_eval_input_path(pipeline_config)
+ data_config = pipeline_config.data_config
+ if data_config.HasField('sampler'):
+ logging.warning(
+ 'It is not accuracy to use eval with negative sampler, recommand to use hitrate.py!'
+ )
+ eval_result = {}
+ return eval_result
+ model_dir = get_model_dir_path(pipeline_config)
+ eval_tmp_results_dir = os.path.join(model_dir, 'distribute_eval_tmp_results')
+ if not tf.gfile.IsDirectory(eval_tmp_results_dir):
+ logging.info('create eval tmp results dir {}'.format(eval_tmp_results_dir))
+ tf.gfile.MakeDirs(eval_tmp_results_dir)
+ assert tf.gfile.IsDirectory(
+ eval_tmp_results_dir), 'tmp results dir not create success.'
+ os.environ['eval_tmp_results_dir'] = eval_tmp_results_dir
server_target = None
cur_job_name = None
@@ -495,16 +599,7 @@ def distribute_evaluate(pipeline_config,
server_target = server.target
print('server_target = %s' % server_target)
- distribution = strategy_builder.build(train_config)
- estimator, run_config = _create_estimator(pipeline_config, distribution)
- eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
- ckpt_path = _get_ckpt_path(pipeline_config, eval_checkpoint_path)
-
if server_target:
- # evaluate with parameter server
- input_iter = eval_spec.input_fn(
- mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
- input_feas, input_lbls = input_iter.get_next()
from tensorflow.python.training.device_setter import replica_device_setter
from tensorflow.python.framework.ops import device
from tensorflow.python.training.monitored_session import MonitoredSession
@@ -512,20 +607,34 @@ def distribute_evaluate(pipeline_config,
from tensorflow.python.training.monitored_session import WorkerSessionCreator
from easy_rec.python.utils.estimator_utils import EvaluateExitBarrierHook
cur_work_device = '/job:' + cur_job_name + '/task:' + str(cur_task_index)
+ cur_ps_num = len(tf_config['cluster']['ps'])
with device(
- replica_device_setter(worker_device=cur_work_device, cluster=cluster)):
+ replica_device_setter(
+ ps_tasks=cur_ps_num, worker_device=cur_work_device,
+ cluster=cluster)):
+ distribution = strategy_builder.build(train_config)
+ estimator, run_config = _create_estimator(pipeline_config, distribution)
+ eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
+ ckpt_path = _get_ckpt_path(pipeline_config, eval_checkpoint_path)
+ ckpt_dir = os.path.dirname(ckpt_path)
+ input_iter = eval_spec.input_fn(
+ mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
+ input_feas, input_lbls = input_iter.get_next()
estimator_spec = estimator._distribute_eval_model_fn(
input_feas, input_lbls, run_config)
session_config = ConfigProto(
- allow_soft_placement=True, log_device_placement=True)
+ allow_soft_placement=True,
+ log_device_placement=True,
+ device_filters=['/job:ps',
+ '/job:worker/task:%d' % cur_task_index])
if cur_job_name == 'master':
metric_variables = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
global_variables = tf.global_variables()
remain_variables = list(
set(global_variables).difference(set(metric_variables)))
- cur_saver = tf.train.Saver(var_list=remain_variables)
+ cur_saver = tf.train.Saver(var_list=remain_variables, sharded=True)
cur_scaffold = tf.train.Scaffold(
saver=cur_saver,
ready_for_local_init_op=model_ready_for_local_init_op)
@@ -541,15 +650,14 @@ def distribute_evaluate(pipeline_config,
update_ops = [eval_metric_ops[x][1] for x in eval_metric_ops.keys()]
metric_ops = {x: eval_metric_ops[x][0] for x in eval_metric_ops.keys()}
update_op = tf.group(update_ops)
- count = 0
cur_worker_num = len(tf_config['cluster']['worker']) + 1
if cur_job_name == 'master':
cur_stop_grace_period_sesc = 120
- cur_hooks = EvaluateExitBarrierHook(cur_worker_num, True, ckpt_path,
+ cur_hooks = EvaluateExitBarrierHook(cur_worker_num, True, ckpt_dir,
metric_ops)
else:
cur_stop_grace_period_sesc = 10
- cur_hooks = EvaluateExitBarrierHook(cur_worker_num, False, ckpt_path,
+ cur_hooks = EvaluateExitBarrierHook(cur_worker_num, False, ckpt_dir,
metric_ops)
with MonitoredSession(
session_creator=cur_sess_creator,
@@ -557,19 +665,11 @@ def distribute_evaluate(pipeline_config,
stop_grace_period_secs=cur_stop_grace_period_sesc) as sess:
while True:
try:
- count += 1
sess.run(update_op)
except tf.errors.OutOfRangeError:
break
eval_result = cur_hooks.eval_result
- else:
- # this way does not work, wait to be debugged
- # the variables are not placed to parameter server
- # with tf.device(
- # replica_device_setter(
- # worker_device='/job:master/task:0', cluster=cluster)):
- eval_result = estimator.evaluate(
- eval_spec.input_fn, eval_spec.steps, checkpoint_path=ckpt_path)
+
logging.info('Evaluate finish')
# write eval result to file
@@ -579,8 +679,8 @@ def distribute_evaluate(pipeline_config,
if cur_job_name == 'master':
print('eval_result = ', eval_result)
logging.info('eval_result = {0}'.format(eval_result))
- with gfile.GFile(eval_result_file, 'w') as ofile:
- result_to_write = {}
+ with tf.gfile.GFile(eval_result_file, 'w') as ofile:
+ result_to_write = {'eval_method': 'distribute'}
for key in sorted(eval_result):
# skip logging binary data
if isinstance(eval_result[key], six.binary_type):
@@ -616,12 +716,9 @@ def predict(pipeline_config, checkpoint_path='', data_path=None):
fg_util.load_fg_json_to_config(pipeline_config)
if data_path is not None:
logging.info('Predict on data: %s' % data_path)
- pipeline_config.eval_input_path = data_path
+ set_eval_input_path(pipeline_config, data_path)
train_config = pipeline_config.train_config
- if pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input':
- eval_data = pipeline_config.kafka_eval_input
- else:
- eval_data = pipeline_config.eval_input_path
+ eval_data = get_eval_input_path(pipeline_config)
distribution = strategy_builder.build(train_config)
estimator, _ = _create_estimator(pipeline_config, distribution)
@@ -669,8 +766,8 @@ def export(export_dir,
AssertionError, if:
* pipeline_config_path does not exist
"""
- if not gfile.Exists(export_dir):
- gfile.MakeDirs(export_dir)
+ if not tf.gfile.Exists(export_dir):
+ tf.gfile.MakeDirs(export_dir)
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
if pipeline_config.fg_json_path:
@@ -684,7 +781,8 @@ def export(export_dir,
asset_file_dict = {}
for asset_file in asset_files.split(','):
asset_file = asset_file.strip()
- if ':' not in asset_file or asset_file.startswith('oss:'):
+ if ':' not in asset_file or asset_file.startswith(
+ 'oss:') or asset_file.startswith('hdfs:'):
_, asset_name = os.path.split(asset_file)
else:
asset_name, asset_file = asset_file.split(':', 1)
@@ -694,26 +792,33 @@ def export(export_dir,
# construct serving input fn
export_config = pipeline_config.export_config
data_config = pipeline_config.data_config
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
serving_input_fn = _get_input_fn(data_config, feature_configs, None,
- export_config)
+ export_config, **input_fn_kwargs)
+ ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
if 'oss_path' in extra_params:
+ if pipeline_config.train_config.HasField('incr_save_config'):
+ incr_save_config = pipeline_config.train_config.incr_save_config
+ extra_params['incr_update'] = {}
+ incr_save_type = incr_save_config.WhichOneof('incr_update')
+ logging.info('incr_save_type=%s' % incr_save_type)
+ if incr_save_type:
+ extra_params['incr_update'][incr_save_type] = getattr(
+ incr_save_config, incr_save_type)
return export_big_model_to_oss(export_dir, pipeline_config, extra_params,
- serving_input_fn, estimator, checkpoint_path,
+ serving_input_fn, estimator, ckpt_path,
verbose)
if 'redis_url' in extra_params:
return export_big_model(export_dir, pipeline_config, extra_params,
- serving_input_fn, estimator, checkpoint_path,
- verbose)
-
- if not checkpoint_path:
- checkpoint_path = estimator_utils.latest_checkpoint(
- pipeline_config.model_dir)
+ serving_input_fn, estimator, ckpt_path, verbose)
final_export_dir = estimator.export_savedmodel(
export_dir_base=export_dir,
serving_input_receiver_fn=serving_input_fn,
- checkpoint_path=checkpoint_path,
+ checkpoint_path=ckpt_path,
strip_default_attrs=True)
# add export ts as version info
@@ -725,11 +830,49 @@ def export(export_dir,
]
export_ts = export_ts[-1]
saved_pb_path = os.path.join(final_export_dir, 'saved_model.pb')
- with gfile.GFile(saved_pb_path, 'rb') as fin:
+ with tf.gfile.GFile(saved_pb_path, 'rb') as fin:
saved_model.ParseFromString(fin.read())
saved_model.meta_graphs[0].meta_info_def.meta_graph_version = export_ts
- with gfile.GFile(saved_pb_path, 'wb') as fout:
+ with tf.gfile.GFile(saved_pb_path, 'wb') as fout:
fout.write(saved_model.SerializeToString())
logging.info('model has been exported to %s successfully' % final_export_dir)
return final_export_dir
+
+
+def export_checkpoint(pipeline_config=None,
+ export_path='',
+ checkpoint_path='',
+ asset_files=None,
+ verbose=False,
+ mode=tf.estimator.ModeKeys.PREDICT):
+ """Export the EasyRec model as checkpoint."""
+ pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
+ if pipeline_config.fg_json_path:
+ fg_util.load_fg_json_to_config(pipeline_config)
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
+ data_config = pipeline_config.data_config
+
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
+
+ # create estimator
+ params = {'log_device_placement': verbose}
+ if asset_files:
+ logging.info('will add asset files: %s' % asset_files)
+ params['asset_files'] = asset_files
+ estimator, _ = _create_estimator(pipeline_config, params=params)
+
+ # construct serving input fn
+ export_config = pipeline_config.export_config
+ serving_input_fn = _get_input_fn(data_config, feature_configs, None,
+ export_config, **input_fn_kwargs)
+ ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
+ estimator.export_checkpoint(
+ export_path=export_path,
+ serving_input_receiver_fn=serving_input_fn,
+ checkpoint_path=ckpt_path,
+ mode=mode)
+
+ logging.info('model checkpoint has been exported successfully')
diff --git a/easy_rec/python/model/autoint.py b/easy_rec/python/model/autoint.py
index fc9c05ca5..b7013486e 100644
--- a/easy_rec/python/model/autoint.py
+++ b/easy_rec/python/model/autoint.py
@@ -28,11 +28,11 @@ def __init__(self,
self._features, _ = self._input_layer(self._feature_dict, 'all')
self._feature_num = len(self._model_config.feature_groups[0].feature_names)
self._seq_key_num = 0
- if self._model_config.feature_groups[0].HasField('sequence_features'):
- self._feature_num += len(self._model_config.feature_groups[0]
- .sequence_features.seq_att_map[0].hist_seq)
- self._seq_key_num = len(self._model_config.feature_groups[0]
- .sequence_features.seq_att_map[0].key)
+ if len(self._model_config.feature_groups[0].sequence_features) > 0:
+ for seq_fea in self._model_config.feature_groups[0].sequence_features:
+ for seq_att in seq_fea.seq_att_map:
+ self._feature_num += len(seq_att.hist_seq)
+ self._seq_key_num += len(seq_att.key)
self._model_config = self._model_config.autoint
assert isinstance(self._model_config, AutoIntConfig)
diff --git a/easy_rec/python/model/cmbf.py b/easy_rec/python/model/cmbf.py
new file mode 100644
index 000000000..0f0a8f3aa
--- /dev/null
+++ b/easy_rec/python/model/cmbf.py
@@ -0,0 +1,47 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+from easy_rec.python.layers import cmbf
+from easy_rec.python.layers import dnn
+from easy_rec.python.model.rank_model import RankModel
+
+from easy_rec.python.protos.cmbf_pb2 import CMBF as CMBFConfig # NOQA
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class CMBF(RankModel):
+ """CMBF: Cross-Modal-Based Fusion Recommendation Algorithm.
+
+ This is almost an exact implementation of the original CMBF model.
+ See the original paper:
+ https://www.mdpi.com/1424-8220/21/16/5275
+ """
+
+ def __init__(self,
+ model_config,
+ feature_configs,
+ features,
+ labels=None,
+ is_training=False):
+ super(CMBF, self).__init__(model_config, feature_configs, features, labels,
+ is_training)
+ assert self._model_config.WhichOneof('model') == 'cmbf', (
+ 'invalid model config: %s' % self._model_config.WhichOneof('model'))
+
+ self._cmbf_layer = cmbf.CMBF(model_config, feature_configs, features,
+ self._model_config.cmbf.config,
+ self._input_layer)
+ self._model_config = self._model_config.cmbf
+
+ def build_predict_graph(self):
+ hidden = self._cmbf_layer(self._is_training, l2_reg=self._l2_reg)
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
+ 'final_dnn', self._is_training)
+ all_fea = final_dnn_layer(hidden)
+
+ final = tf.layers.dense(all_fea, self._num_class, name='output')
+ self._add_to_prediction_dict(final)
+ return self._prediction_dict
diff --git a/easy_rec/python/model/collaborative_metric_learning.py b/easy_rec/python/model/collaborative_metric_learning.py
index 84c87ccaa..b19537239 100644
--- a/easy_rec/python/model/collaborative_metric_learning.py
+++ b/easy_rec/python/model/collaborative_metric_learning.py
@@ -3,12 +3,12 @@
from easy_rec.python.core.metrics import metric_learning_average_precision_at_k
from easy_rec.python.core.metrics import metric_learning_recall_at_k
from easy_rec.python.layers import dnn
-from easy_rec.python.layers.common_layers import gelu
from easy_rec.python.layers.common_layers import highway
from easy_rec.python.loss.circle_loss import circle_loss
from easy_rec.python.loss.multi_similarity import ms_loss
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType
+from easy_rec.python.utils.activation import gelu
from easy_rec.python.utils.proto_util import copy_obj
from easy_rec.python.protos.collaborative_metric_learning_pb2 import CoMetricLearningI2I as MetricLearningI2IConfig # NOQA
@@ -48,21 +48,22 @@ def __init__(
raise ValueError('unsupported loss type: %s' %
LossType.Name(self._loss_type))
- self._highway_features = {}
- self._highway_num = len(self._model_config.highway)
- for _id in range(self._highway_num):
- highway_cfg = self._model_config.highway[_id]
- highway_feature, _ = self._input_layer(self._feature_dict,
- highway_cfg.input)
- self._highway_features[highway_cfg.input] = highway_feature
+ if not self.has_backbone:
+ self._highway_features = {}
+ self._highway_num = len(self._model_config.highway)
+ for _id in range(self._highway_num):
+ highway_cfg = self._model_config.highway[_id]
+ highway_feature, _ = self._input_layer(self._feature_dict,
+ highway_cfg.input)
+ self._highway_features[highway_cfg.input] = highway_feature
- self.input_features = []
- if self._model_config.HasField('input'):
- input_feature, _ = self._input_layer(self._feature_dict,
- self._model_config.input)
- self.input_features.append(input_feature)
+ self.input_features = []
+ if self._model_config.HasField('input'):
+ input_feature, _ = self._input_layer(self._feature_dict,
+ self._model_config.input)
+ self.input_features.append(input_feature)
- self.dnn = copy_obj(self._model_config.dnn)
+ self.dnn = copy_obj(self._model_config.dnn)
if self._labels is not None:
if self._model_config.HasField('session_id'):
@@ -79,32 +80,35 @@ def __init__(
self.sample_id = None
def build_predict_graph(self):
- for _id in range(self._highway_num):
- highway_cfg = self._model_config.highway[_id]
- highway_fea = tf.layers.batch_normalization(
- self._highway_features[highway_cfg.input],
- training=self._is_training,
- trainable=True,
- name='highway_%s_bn' % highway_cfg.input)
- highway_fea = highway(
- highway_fea,
- highway_cfg.emb_size,
- activation=gelu,
- scope='highway_%s' % _id)
- print('highway_fea: ', highway_fea)
- self.input_features.append(highway_fea)
-
- feature = tf.concat(self.input_features, axis=1)
-
- num_dnn_layer = len(self.dnn.hidden_units)
- last_hidden = self.dnn.hidden_units.pop()
- dnn_net = dnn.DNN(self.dnn, self._l2_reg, 'dnn', self._is_training)
- net_output = dnn_net(feature)
- tower_emb = tf.layers.dense(
- inputs=net_output,
- units=last_hidden,
- kernel_regularizer=self._l2_reg,
- name='dnn/dnn_%d' % (num_dnn_layer - 1))
+ if self.has_backbone:
+ tower_emb = self.backbone
+ else:
+ for _id in range(self._highway_num):
+ highway_cfg = self._model_config.highway[_id]
+ highway_fea = tf.layers.batch_normalization(
+ self._highway_features[highway_cfg.input],
+ training=self._is_training,
+ trainable=True,
+ name='highway_%s_bn' % highway_cfg.input)
+ highway_fea = highway(
+ highway_fea,
+ highway_cfg.emb_size,
+ activation=gelu,
+ scope='highway_%s' % _id)
+ print('highway_fea: ', highway_fea)
+ self.input_features.append(highway_fea)
+
+ feature = tf.concat(self.input_features, axis=1)
+
+ num_dnn_layer = len(self.dnn.hidden_units)
+ last_hidden = self.dnn.hidden_units.pop()
+ dnn_net = dnn.DNN(self.dnn, self._l2_reg, 'dnn', self._is_training)
+ net_output = dnn_net(feature)
+ tower_emb = tf.layers.dense(
+ inputs=net_output,
+ units=last_hidden,
+ kernel_regularizer=self._l2_reg,
+ name='dnn/dnn_%d' % (num_dnn_layer - 1))
if self._model_config.output_l2_normalized_emb:
norm_emb = tf.nn.l2_normalize(tower_emb, axis=-1)
diff --git a/easy_rec/python/model/dat.py b/easy_rec/python/model/dat.py
new file mode 100644
index 000000000..5c312299c
--- /dev/null
+++ b/easy_rec/python/model/dat.py
@@ -0,0 +1,138 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+from easy_rec.python.layers import dnn
+from easy_rec.python.model.match_model import MatchModel
+from easy_rec.python.protos.dat_pb2 import DAT as DATConfig
+from easy_rec.python.protos.loss_pb2 import LossType
+from easy_rec.python.utils.proto_util import copy_obj
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class DAT(MatchModel):
+ """Dual Augmented Two-tower Model."""
+
+ def __init__(self,
+ model_config,
+ feature_configs,
+ features,
+ labels=None,
+ is_training=False):
+ super(DAT, self).__init__(model_config, feature_configs, features, labels,
+ is_training)
+ assert self._model_config.WhichOneof('model') == 'dat', \
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
+
+ feature_group_names = [
+ fg.group_name for fg in self._model_config.feature_groups
+ ]
+ assert 'user' in feature_group_names, 'user feature group not found'
+ assert 'item' in feature_group_names, 'item feature group not found'
+ assert 'user_id_augment' in feature_group_names, 'user_id_augment feature group not found'
+ assert 'item_id_augment' in feature_group_names, 'item_id_augment feature group not found'
+
+ self._model_config = self._model_config.dat
+ assert isinstance(self._model_config, DATConfig)
+
+ self.user_tower = copy_obj(self._model_config.user_tower)
+ self.user_deep_feature, _ = self._input_layer(self._feature_dict, 'user')
+ self.user_augmented_vec, _ = self._input_layer(self._feature_dict,
+ 'user_id_augment')
+
+ self.item_tower = copy_obj(self._model_config.item_tower)
+ self.item_deep_feature, _ = self._input_layer(self._feature_dict, 'item')
+ self.item_augmented_vec, _ = self._input_layer(self._feature_dict,
+ 'item_id_augment')
+
+ self._user_tower_emb = None
+ self._item_tower_emb = None
+
+ def build_predict_graph(self):
+ num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
+ last_user_hidden = self.user_tower.dnn.hidden_units.pop()
+ user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
+ self._is_training)
+
+ user_tower_feature = tf.concat(
+ [self.user_deep_feature, self.user_augmented_vec], axis=-1)
+ user_tower_emb = user_dnn(user_tower_feature)
+ user_tower_emb = tf.layers.dense(
+ inputs=user_tower_emb,
+ units=last_user_hidden,
+ kernel_regularizer=self._l2_reg,
+ name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
+
+ num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
+ last_item_hidden = self.item_tower.dnn.hidden_units.pop()
+ item_dnn = dnn.DNN(self.item_tower.dnn, self._l2_reg, 'item_dnn',
+ self._is_training)
+
+ item_tower_feature = tf.concat(
+ [self.item_deep_feature, self.item_augmented_vec], axis=-1)
+ item_tower_emb = item_dnn(item_tower_feature)
+ item_tower_emb = tf.layers.dense(
+ inputs=item_tower_emb,
+ units=last_item_hidden,
+ kernel_regularizer=self._l2_reg,
+ name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
+
+ user_tower_emb = self.norm(user_tower_emb)
+ item_tower_emb = self.norm(item_tower_emb)
+ temperature = self._model_config.temperature
+
+ y_pred = self.sim(user_tower_emb, item_tower_emb) / temperature
+
+ if self._is_point_wise:
+ raise ValueError('Currently DAT model only supports list wise mode.')
+
+ if self._loss_type == LossType.CLASSIFICATION:
+ raise ValueError(
+ 'Currently DAT model only supports SOFTMAX_CROSS_ENTROPY loss.')
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ y_pred = self._mask_in_batch(y_pred)
+ self._prediction_dict['logits'] = y_pred
+ self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
+ else:
+ self._prediction_dict['y'] = y_pred
+
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
+ self._prediction_dict['user_emb'] = tf.reduce_join(
+ tf.as_string(user_tower_emb), axis=-1, separator=',')
+ self._prediction_dict['item_emb'] = tf.reduce_join(
+ tf.as_string(item_tower_emb), axis=-1, separator=',')
+
+ augmented_p_u = tf.stop_gradient(user_tower_emb)
+ augmented_p_i = tf.stop_gradient(item_tower_emb)
+
+ self._prediction_dict['augmented_p_u'] = augmented_p_u
+ self._prediction_dict['augmented_p_i'] = augmented_p_i
+
+ self._prediction_dict['augmented_a_u'] = self.user_augmented_vec
+ self._prediction_dict['augmented_a_i'] = self.item_augmented_vec
+
+ return self._prediction_dict
+
+ def get_outputs(self):
+ if self._loss_type == LossType.CLASSIFICATION:
+ raise ValueError(
+ 'Currently DAT model only supports SOFTMAX_CROSS_ENTROPY loss.')
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ self._prediction_dict['logits'] = tf.squeeze(
+ self._prediction_dict['logits'], axis=-1)
+ self._prediction_dict['probs'] = tf.nn.sigmoid(
+ self._prediction_dict['logits'])
+ return [
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
+ 'item_tower_emb', 'augmented_p_u', 'augmented_p_i', 'augmented_a_u',
+ 'augmented_a_i'
+ ]
+ else:
+ raise ValueError('invalid loss type: %s' % str(self._loss_type))
+
+ def build_output_dict(self):
+ output_dict = super(DAT, self).build_output_dict()
+ return output_dict
diff --git a/easy_rec/python/model/dbmtl.py b/easy_rec/python/model/dbmtl.py
index e0e2db607..6c69d33ca 100644
--- a/easy_rec/python/model/dbmtl.py
+++ b/easy_rec/python/model/dbmtl.py
@@ -2,8 +2,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
+from easy_rec.python.layers import cmbf
from easy_rec.python.layers import dnn
from easy_rec.python.layers import mmoe
+from easy_rec.python.layers import uniter
from easy_rec.python.model.multi_task_model import MultiTaskModel
from easy_rec.python.protos.dbmtl_pb2 import DBMTL as DBMTLConfig
@@ -26,19 +28,38 @@ def __init__(self,
self._model_config = self._model_config.dbmtl
assert isinstance(self._model_config, DBMTLConfig)
- self._features, _ = self._input_layer(self._feature_dict, 'all')
+ if self._model_config.HasField('bottom_cmbf'):
+ self._cmbf_layer = cmbf.CMBF(model_config, feature_configs, features,
+ self._model_config.bottom_cmbf,
+ self._input_layer)
+ elif self._model_config.HasField('bottom_uniter'):
+ self._uniter_layer = uniter.Uniter(model_config, feature_configs,
+ features,
+ self._model_config.bottom_uniter,
+ self._input_layer)
+ elif not self.has_backbone:
+ self._features, self._feature_list = self._input_layer(
+ self._feature_dict, 'all')
+ else:
+ assert False, 'invalid code branch'
self._init_towers(self._model_config.task_towers)
def build_predict_graph(self):
- if self._model_config.HasField('bottom_dnn'):
- bottom_dnn = dnn.DNN(
- self._model_config.bottom_dnn,
- self._l2_reg,
- name='bottom_dnn',
- is_training=self._is_training)
- bottom_fea = bottom_dnn(self._features)
- else:
- bottom_fea = self._features
+ bottom_fea = self.backbone
+ if bottom_fea is None:
+ if self._model_config.HasField('bottom_cmbf'):
+ bottom_fea = self._cmbf_layer(self._is_training, l2_reg=self._l2_reg)
+ elif self._model_config.HasField('bottom_uniter'):
+ bottom_fea = self._uniter_layer(self._is_training, l2_reg=self._l2_reg)
+ elif self._model_config.HasField('bottom_dnn'):
+ bottom_dnn = dnn.DNN(
+ self._model_config.bottom_dnn,
+ self._l2_reg,
+ name='bottom_dnn',
+ is_training=self._is_training)
+ bottom_fea = bottom_dnn(self._features)
+ else:
+ bottom_fea = self._features
# MMOE block
if self._model_config.HasField('expert_dnn'):
diff --git a/easy_rec/python/model/deepfm.py b/easy_rec/python/model/deepfm.py
index 8734de8d5..d1414c050 100644
--- a/easy_rec/python/model/deepfm.py
+++ b/easy_rec/python/model/deepfm.py
@@ -4,7 +4,6 @@
from easy_rec.python.layers import dnn
from easy_rec.python.layers import fm
-from easy_rec.python.layers import input_layer
from easy_rec.python.model.rank_model import RankModel
from easy_rec.python.protos.deepfm_pb2 import DeepFM as DeepFMConfig
@@ -43,13 +42,8 @@ def build_input_layer(self, model_config, feature_configs):
has_final = len(model_config.deepfm.final_dnn.hidden_units) > 0
if not has_final:
assert model_config.deepfm.wide_output_dim == model_config.num_class
- self._input_layer = input_layer.InputLayer(
- feature_configs,
- model_config.feature_groups,
- wide_output_dim=model_config.deepfm.wide_output_dim,
- use_embedding_variable=model_config.use_embedding_variable,
- embedding_regularizer=self._emb_reg,
- kernel_regularizer=self._l2_reg)
+ self._wide_output_dim = model_config.deepfm.wide_output_dim
+ super(DeepFM, self).build_input_layer(model_config, feature_configs)
def build_predict_graph(self):
# Wide
@@ -58,6 +52,7 @@ def build_predict_graph(self):
# FM
fm_fea = fm.FM(name='fm_feature')(self._fm_features)
+ self._fm_outputs = fm_fea
# Deep
deep_layer = dnn.DNN(self._model_config.dnn, self._l2_reg, 'deep_feature',
@@ -94,3 +89,18 @@ def build_predict_graph(self):
self._add_to_prediction_dict(output)
return self._prediction_dict
+
+ def build_feature_output_dict(self):
+ outputs = super(DeepFM, self).build_feature_output_dict()
+ outputs.update({
+ 'wide_features':
+ tf.reduce_join(
+ tf.as_string(self._wide_features), axis=-1, separator=','),
+ 'deep_features':
+ tf.reduce_join(
+ tf.as_string(self._deep_features), axis=-1, separator=','),
+ 'fm_outputs':
+ tf.reduce_join(
+ tf.as_string(self._fm_outputs), axis=-1, separator=',')
+ })
+ return outputs
diff --git a/easy_rec/python/model/dropoutnet.py b/easy_rec/python/model/dropoutnet.py
index de12776e5..683677531 100644
--- a/easy_rec/python/model/dropoutnet.py
+++ b/easy_rec/python/model/dropoutnet.py
@@ -7,16 +7,13 @@
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.utils.proto_util import copy_obj
-from easy_rec.python.utils.shape_utils import get_shape_list
from easy_rec.python.protos.dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA
from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
from easy_rec.python.protos.dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA
-
if tf.__version__ >= '2.0':
tf = tf.compat.v1
losses = tf.losses
-metrics = tf.metrics
def cosine_similarity(user_emb, item_emb):
@@ -25,6 +22,15 @@ def cosine_similarity(user_emb, item_emb):
return user_item_sim
+def bernoulli_dropout(x, rate, training=False):
+ if rate == 0.0 or not training:
+ return x
+ keep_rate = 1.0 - rate
+ dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype)
+ mask = dist.sample(sample_shape=tf.stack([tf.shape(x)[0], 1]))
+ return x * mask / keep_rate
+
+
class DropoutNet(EasyRecModel):
def __init__(self,
@@ -70,8 +76,6 @@ def __init__(self,
assert self.item_content_feature is not None or self.item_preference_feature is not None, 'no item feature'
def build_predict_graph(self):
- batch_size = get_shape_list(self.item_content_feature)[0]
-
num_user_dnn_layer = len(self.user_tower_layers.hidden_units)
last_user_hidden = self.user_tower_layers.hidden_units.pop()
num_item_dnn_layer = len(self.item_tower_layers.hidden_units)
@@ -87,15 +91,9 @@ def build_predict_graph(self):
content_feature = user_content_dnn(self.user_content_feature)
user_features.append(content_feature)
if self.user_preference_feature is not None:
- if self._is_training:
- prob = tf.random.uniform([batch_size])
- user_prefer_feature = tf.where(
- tf.less(prob, self._model_config.user_dropout_rate),
- tf.zeros_like(self.user_preference_feature),
- self.user_preference_feature)
- else:
- user_prefer_feature = self.user_preference_feature
-
+ user_prefer_feature = bernoulli_dropout(
+ self.user_preference_feature, self._model_config.user_dropout_rate,
+ self._is_training)
user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg,
'user_preference', self._is_training)
prefer_feature = user_prefer_dnn(user_prefer_feature)
@@ -121,15 +119,9 @@ def build_predict_graph(self):
content_feature = item_content_dnn(self.item_content_feature)
item_features.append(content_feature)
if self.item_preference_feature is not None:
- if self._is_training:
- prob = tf.random.uniform([batch_size])
- item_prefer_feature = tf.where(
- tf.less(prob, self._model_config.item_dropout_rate),
- tf.zeros_like(self.item_preference_feature),
- self.item_preference_feature)
- else:
- item_prefer_feature = self.item_preference_feature
-
+ item_prefer_feature = bernoulli_dropout(
+ self.item_preference_feature, self._model_config.item_dropout_rate,
+ self._is_training)
item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg,
'item_preference', self._is_training)
prefer_feature = item_prefer_dnn(item_prefer_feature)
@@ -188,6 +180,7 @@ def build_loss_graph(self):
return self._loss_dict
def build_metric_graph(self, eval_config):
+ from easy_rec.python.core.easyrec_metrics import metrics_tf as metrics
metric_dict = {}
labels = list(self._labels.values())[0]
sim_score = self._prediction_dict['similarity']
diff --git a/easy_rec/python/model/dssm.py b/easy_rec/python/model/dssm.py
index 20f873677..e35d69030 100644
--- a/easy_rec/python/model/dssm.py
+++ b/easy_rec/python/model/dssm.py
@@ -1,12 +1,9 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
-import logging
-
import tensorflow as tf
-from easy_rec.python.builders import loss_builder
from easy_rec.python.layers import dnn
-from easy_rec.python.model.easy_rec_model import EasyRecModel
+from easy_rec.python.model.match_model import MatchModel
from easy_rec.python.protos.dssm_pb2 import DSSM as DSSMConfig
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.protos.simi_pb2 import Similarity
@@ -15,10 +12,9 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
losses = tf.losses
-metrics = tf.metrics
-class DSSM(EasyRecModel):
+class DSSM(MatchModel):
def __init__(self,
model_config,
@@ -28,81 +24,19 @@ def __init__(self,
is_training=False):
super(DSSM, self).__init__(model_config, feature_configs, features, labels,
is_training)
- self._loss_type = self._model_config.loss_type
- self._num_class = self._model_config.num_class
assert self._model_config.WhichOneof('model') == 'dssm', \
'invalid model config: %s' % self._model_config.WhichOneof('model')
self._model_config = self._model_config.dssm
assert isinstance(self._model_config, DSSMConfig)
- if self._loss_type == LossType.CLASSIFICATION:
- assert self._num_class == 1
-
# copy_obj so that any modification will not affect original config
self.user_tower = copy_obj(self._model_config.user_tower)
self.user_tower_feature, _ = self._input_layer(self._feature_dict, 'user')
- self.user_id = self.user_tower.id
# copy_obj so that any modification will not affect original config
self.item_tower = copy_obj(self._model_config.item_tower)
self.item_tower_feature, _ = self._input_layer(self._feature_dict, 'item')
- self.item_id = self.item_tower.id
-
- if self._loss_type in [LossType.CLASSIFICATION, LossType.L2_LOSS]:
- self._is_point_wise = True
- logging.info('Use point wise dssm.')
- else:
- self._is_point_wise = False
- logging.info('Use list wise dssm.')
-
- def _list_wise_sim(self, user_emb, item_emb):
- batch_size = tf.shape(user_emb)[0]
- hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)
-
- if hard_neg_indices is not None:
- tf.logging.info('With hard negative examples')
- noclk_size = tf.shape(hard_neg_indices)[0]
- pos_item_emb, neg_item_emb, hard_neg_item_emb = tf.split(
- item_emb, [batch_size, -1, noclk_size], axis=0)
- else:
- pos_item_emb = item_emb[:batch_size]
- neg_item_emb = item_emb[batch_size:]
-
- pos_user_item_sim = tf.reduce_sum(
- tf.multiply(user_emb, pos_item_emb), axis=1, keep_dims=True)
- neg_user_item_sim = tf.matmul(user_emb, tf.transpose(neg_item_emb))
-
- if hard_neg_indices is not None:
- user_emb_expand = tf.gather(user_emb, hard_neg_indices[:, 0])
- hard_neg_user_item_sim = tf.reduce_sum(
- tf.multiply(user_emb_expand, hard_neg_item_emb), axis=1)
- # scatter hard negatives sim update neg_user_item_sim
- neg_sim_shape = tf.shape(neg_user_item_sim, out_type=tf.int64)
- hard_neg_mask = tf.scatter_nd(
- hard_neg_indices,
- tf.ones_like(hard_neg_user_item_sim, dtype=tf.bool),
- shape=neg_sim_shape)
- hard_neg_user_item_sim = tf.scatter_nd(
- hard_neg_indices, hard_neg_user_item_sim, shape=neg_sim_shape)
- neg_user_item_sim = tf.where(
- hard_neg_mask, x=hard_neg_user_item_sim, y=neg_user_item_sim)
-
- user_item_sim = tf.concat([pos_user_item_sim, neg_user_item_sim], axis=1)
- return user_item_sim
-
- def _point_wise_sim(self, user_emb, item_emb):
- user_item_sim = tf.reduce_sum(
- tf.multiply(user_emb, item_emb), axis=1, keep_dims=True)
- return user_item_sim
-
- def sim(self, user_emb, item_emb):
- if self._is_point_wise:
- return self._point_wise_sim(user_emb, item_emb)
- else:
- return self._list_wise_sim(user_emb, item_emb)
-
- def norm(self, fea):
- fea_norm = tf.nn.l2_normalize(fea, axis=1)
- return fea_norm
+ self._user_tower_emb = None
+ self._item_tower_emb = None
def build_predict_graph(self):
num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
@@ -127,130 +61,94 @@ def build_predict_graph(self):
kernel_regularizer=self._l2_reg,
name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
- if self._loss_type == LossType.CLASSIFICATION:
- if self._model_config.simi_func == Similarity.COSINE:
- user_tower_emb = self.norm(user_tower_emb)
- item_tower_emb = self.norm(item_tower_emb)
+ if self._model_config.simi_func == Similarity.COSINE:
+ user_tower_emb = self.norm(user_tower_emb)
+ item_tower_emb = self.norm(item_tower_emb)
+ temperature = self._model_config.temperature
+ else:
+ temperature = 1.0
- user_item_sim = self.sim(user_tower_emb, item_tower_emb)
- y_pred = user_item_sim
+ user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
if self._model_config.scale_simi:
sim_w = tf.get_variable(
'sim_w',
dtype=tf.float32,
- shape=(1, 1),
+ shape=(1),
initializer=tf.ones_initializer())
sim_b = tf.get_variable(
'sim_b',
dtype=tf.float32,
shape=(1),
initializer=tf.zeros_initializer())
- y_pred = tf.matmul(user_item_sim, tf.abs(sim_w)) + sim_b
+ y_pred = user_item_sim * tf.abs(sim_w) + sim_b
+ else:
+ y_pred = user_item_sim
+
+ if self._is_point_wise:
y_pred = tf.reshape(y_pred, [-1])
if self._loss_type == LossType.CLASSIFICATION:
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ y_pred = self._mask_in_batch(y_pred)
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
else:
self._prediction_dict['y'] = y_pred
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
self._prediction_dict['user_emb'] = tf.reduce_join(
tf.as_string(user_tower_emb), axis=-1, separator=',')
self._prediction_dict['item_emb'] = tf.reduce_join(
tf.as_string(item_tower_emb), axis=-1, separator=',')
return self._prediction_dict
- def build_loss_graph(self):
- if self._is_point_wise:
- return self._build_point_wise_loss_graph()
- else:
- return self._build_list_wise_loss_graph()
-
- def _build_list_wise_loss_graph(self):
- if self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
- hit_prob = self._prediction_dict['probs'][:, :1]
- self._loss_dict['cross_entropy_loss'] = -tf.reduce_mean(
- tf.log(hit_prob + 1e-12))
- logging.info('softmax cross entropy loss is used')
- else:
- raise ValueError('invalid loss type: %s' % str(self._loss_type))
- return self._loss_dict
-
- def _build_point_wise_loss_graph(self):
- label = list(self._labels.values())[0]
+ def get_outputs(self):
if self._loss_type == LossType.CLASSIFICATION:
- pred = self._prediction_dict['logits']
- loss_name = 'cross_entropy_loss'
+ return [
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
+ 'item_tower_emb'
+ ]
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ self._prediction_dict['logits'] = tf.squeeze(
+ self._prediction_dict['logits'], axis=-1)
+ self._prediction_dict['probs'] = tf.nn.sigmoid(
+ self._prediction_dict['logits'])
+ return [
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
+ 'item_tower_emb'
+ ]
elif self._loss_type == LossType.L2_LOSS:
- pred = self._prediction_dict['y']
- loss_name = 'l2_loss'
+ return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb']
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
- self._loss_dict[loss_name] = loss_builder.build(
- self._loss_type,
- label=label,
- pred=pred,
- loss_weight=self._sample_weight)
-
- # build kd loss
- kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
- self._labels)
- self._loss_dict.update(kd_loss_dict)
- return self._loss_dict
-
- def build_metric_graph(self, eval_config):
- if self._is_point_wise:
- return self._build_point_wise_metric_graph(eval_config)
- else:
- return self._build_list_wise_metric_graph(eval_config)
-
- def _build_list_wise_metric_graph(self, eval_config):
- metric_dict = {}
- for metric in eval_config.metrics_set:
- if metric.WhichOneof('metric') == 'recall_at_topk':
- logits = self._prediction_dict['logits']
- label = tf.zeros_like(logits[:, :1], dtype=tf.int64)
- metric_dict['recall_at_top%d' %
- metric.recall_at_topk.topk] = metrics.recall_at_k(
- label, logits, metric.recall_at_topk.topk)
- else:
- ValueError('invalid metric type: %s' % str(metric))
- return metric_dict
-
- def _build_point_wise_metric_graph(self, eval_config):
- metric_dict = {}
- label = list(self._labels.values())[0]
- for metric in eval_config.metrics_set:
- if metric.WhichOneof('metric') == 'auc':
- assert self._loss_type == LossType.CLASSIFICATION
- metric_dict['auc'] = metrics.auc(label, self._prediction_dict['probs'])
- elif metric.WhichOneof('metric') == 'recall_at_topk':
- assert self._loss_type == LossType.CLASSIFICATION
- metric_dict['recall_at_topk%d' %
- metric.recall_at_topk.topk] = metrics.recall_at_k(
- label, self._prediction_dict['probs'],
- metric.recall_at_topk.topk)
- elif metric.WhichOneof('metric') == 'mean_absolute_error':
- assert self._loss_type == LossType.L2_LOSS
- metric_dict['mean_absolute_error'] = metrics.mean_absolute_error(
- label, self._prediction_dict['y'])
- elif metric.WhichOneof('metric') == 'accuracy':
- assert self._loss_type == LossType.CLASSIFICATION
- metric_dict['accuracy'] = metrics.accuracy(
- label, self._prediction_dict['probs'])
- else:
- ValueError('invalid metric type: %s' % str(metric))
- return metric_dict
-
- def get_outputs(self):
- if self._loss_type in (LossType.CLASSIFICATION,
- LossType.SOFTMAX_CROSS_ENTROPY):
- return ['logits', 'probs', 'user_emb', 'item_emb']
- elif self._loss_type == LossType.L2_LOSS:
- return ['y', 'user_emb', 'item_emb']
- else:
- raise ValueError('invalid loss type: %s' % str(self._loss_type))
+ def build_output_dict(self):
+ output_dict = super(DSSM, self).build_output_dict()
+ output_dict['user_tower_feature'] = tf.reduce_join(
+ tf.as_string(self.user_tower_feature), axis=-1, separator=',')
+ output_dict['item_tower_feature'] = tf.reduce_join(
+ tf.as_string(self.item_tower_feature), axis=-1, separator=',')
+ return output_dict
+
+ def build_rtp_output_dict(self):
+ output_dict = super(DSSM, self).build_rtp_output_dict()
+ if 'user_tower_emb' not in self._prediction_dict:
+ raise ValueError(
+ 'User tower embedding does not exist. Please checking predict graph.')
+ output_dict['user_embedding_output'] = tf.identity(
+ self._prediction_dict['user_tower_emb'], name='user_embedding_output')
+ if 'item_tower_emb' not in self._prediction_dict:
+ raise ValueError(
+ 'Item tower embedding does not exist. Please checking predict graph.')
+ output_dict['item_embedding_output'] = tf.identity(
+ self._prediction_dict['item_tower_emb'], name='item_embedding_output')
+ if self._loss_type == LossType.CLASSIFICATION:
+ if 'probs' not in self._prediction_dict:
+ raise ValueError(
+ 'Probs output does not exist. Please checking predict graph.')
+ output_dict['rank_predict'] = tf.identity(
+ self._prediction_dict['probs'], name='rank_predict')
+ return output_dict
diff --git a/easy_rec/python/model/dssm_senet.py b/easy_rec/python/model/dssm_senet.py
new file mode 100644
index 000000000..c84d52161
--- /dev/null
+++ b/easy_rec/python/model/dssm_senet.py
@@ -0,0 +1,143 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+from easy_rec.python.layers import dnn
+from easy_rec.python.layers import senet
+from easy_rec.python.model.dssm import DSSM
+from easy_rec.python.model.match_model import MatchModel
+from easy_rec.python.protos.loss_pb2 import LossType
+from easy_rec.python.protos.simi_pb2 import Similarity
+from easy_rec.python.utils.proto_util import copy_obj
+
+from easy_rec.python.protos.dssm_senet_pb2 import DSSM_SENet as DSSM_SENet_Config # NOQA
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+losses = tf.losses
+
+
+class DSSM_SENet(DSSM):
+
+ def __init__(self,
+ model_config,
+ feature_configs,
+ features,
+ labels=None,
+ is_training=False):
+
+ MatchModel.__init__(self, model_config, feature_configs, features, labels,
+ is_training)
+
+ assert self._model_config.WhichOneof('model') == 'dssm_senet', \
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
+ self._model_config = self._model_config.dssm_senet
+ assert isinstance(self._model_config, DSSM_SENet_Config)
+
+ # copy_obj so that any modification will not affect original config
+ self.user_tower = copy_obj(self._model_config.user_tower)
+
+ self.user_seq_features, self.user_plain_features, self.user_feature_list = self._input_layer(
+ self._feature_dict, 'user', is_combine=False)
+ self.user_num_fields = len(self.user_feature_list)
+
+ # copy_obj so that any modification will not affect original config
+ self.item_tower = copy_obj(self._model_config.item_tower)
+
+ self.item_seq_features, self.item_plain_features, self.item_feature_list = self._input_layer(
+ self._feature_dict, 'item', is_combine=False)
+ self.item_num_fields = len(self.item_feature_list)
+
+ self._user_tower_emb = None
+ self._item_tower_emb = None
+
+ def build_predict_graph(self):
+ user_senet = senet.SENet(
+ num_fields=self.user_num_fields,
+ num_squeeze_group=self.user_tower.senet.num_squeeze_group,
+ reduction_ratio=self.user_tower.senet.reduction_ratio,
+ l2_reg=self._l2_reg,
+ name='user_senet')
+ user_senet_output_list = user_senet(self.user_feature_list)
+ user_senet_output = tf.concat(user_senet_output_list, axis=-1)
+
+ num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
+ last_user_hidden = self.user_tower.dnn.hidden_units.pop()
+ user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
+ self._is_training)
+ user_tower_emb = user_dnn(user_senet_output)
+ user_tower_emb = tf.layers.dense(
+ inputs=user_tower_emb,
+ units=last_user_hidden,
+ kernel_regularizer=self._l2_reg,
+ name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
+
+ item_senet = senet.SENet(
+ num_fields=self.item_num_fields,
+ num_squeeze_group=self.item_tower.senet.num_squeeze_group,
+ reduction_ratio=self.item_tower.senet.reduction_ratio,
+ l2_reg=self._l2_reg,
+ name='item_senet')
+
+ item_senet_output_list = item_senet(self.item_feature_list)
+ item_senet_output = tf.concat(item_senet_output_list, axis=-1)
+
+ num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
+ last_item_hidden = self.item_tower.dnn.hidden_units.pop()
+ item_dnn = dnn.DNN(self.item_tower.dnn, self._l2_reg, 'item_dnn',
+ self._is_training)
+ item_tower_emb = item_dnn(item_senet_output)
+ item_tower_emb = tf.layers.dense(
+ inputs=item_tower_emb,
+ units=last_item_hidden,
+ kernel_regularizer=self._l2_reg,
+ name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
+
+ if self._model_config.simi_func == Similarity.COSINE:
+ user_tower_emb = self.norm(user_tower_emb)
+ item_tower_emb = self.norm(item_tower_emb)
+ temperature = self._model_config.temperature
+ else:
+ temperature = 1.0
+
+ user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
+ if self._model_config.scale_simi:
+ sim_w = tf.get_variable(
+ 'sim_w',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.ones_initializer())
+ sim_b = tf.get_variable(
+ 'sim_b',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.zeros_initializer())
+ y_pred = user_item_sim * tf.abs(sim_w) + sim_b
+ else:
+ y_pred = user_item_sim
+
+ if self._is_point_wise:
+ y_pred = tf.reshape(y_pred, [-1])
+
+ if self._loss_type == LossType.CLASSIFICATION:
+ self._prediction_dict['logits'] = y_pred
+ self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ y_pred = self._mask_in_batch(y_pred)
+ self._prediction_dict['logits'] = y_pred
+ self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
+ else:
+ self._prediction_dict['y'] = y_pred
+
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
+ self._prediction_dict['user_emb'] = tf.reduce_join(
+ tf.as_string(user_tower_emb), axis=-1, separator=',')
+ self._prediction_dict['item_emb'] = tf.reduce_join(
+ tf.as_string(item_tower_emb), axis=-1, separator=',')
+ return self._prediction_dict
+
+ def build_output_dict(self):
+ output_dict = MatchModel.build_output_dict(self)
+
+ return output_dict
diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py
index b64c04f06..95385936c 100644
--- a/easy_rec/python/model/easy_rec_estimator.py
+++ b/easy_rec/python/model/easy_rec_estimator.py
@@ -2,6 +2,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import print_function
+import json
import logging
import os
import re
@@ -9,23 +10,54 @@
from collections import OrderedDict
import tensorflow as tf
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import saver
from easy_rec.python.builders import optimizer_builder
from easy_rec.python.compat import optimizers
+from easy_rec.python.compat import sync_replicas_optimizer
from easy_rec.python.compat.early_stopping import custom_early_stop_hook
+from easy_rec.python.compat.early_stopping import deadline_stop_hook
from easy_rec.python.compat.early_stopping import find_early_stop_var
+from easy_rec.python.compat.early_stopping import oss_stop_hook
from easy_rec.python.compat.early_stopping import stop_if_no_decrease_hook
from easy_rec.python.compat.early_stopping import stop_if_no_increase_hook
+from easy_rec.python.compat.ops import GraphKeys
+from easy_rec.python.input.input import Input
+from easy_rec.python.layers.utils import _tensor_to_tensorinfo
from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
from easy_rec.python.protos.train_pb2 import DistributionStrategy
+from easy_rec.python.utils import constant
+from easy_rec.python.utils import embedding_utils
from easy_rec.python.utils import estimator_utils
+from easy_rec.python.utils import hvd_utils
+from easy_rec.python.utils import pai_util
from easy_rec.python.utils.multi_optimizer import MultiOptimizer
+from easy_rec.python.compat.embedding_parallel_saver import EmbeddingParallelSaver # NOQA
+
+try:
+ import horovod.tensorflow as hvd
+except Exception:
+ hvd = None
+
+try:
+ from sparse_operation_kit import experiment as sok
+ from easy_rec.python.compat import sok_optimizer
+except Exception:
+ sok = None
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
+tf.estimator.Estimator._assert_members_are_not_overridden = lambda x: x
+
class EasyRecEstimator(tf.estimator.Estimator):
@@ -40,6 +72,40 @@ def __init__(self, pipeline_config, model_cls, run_config, params):
config=run_config,
params=params)
+ def evaluate(self,
+ input_fn,
+ steps=None,
+ hooks=None,
+ checkpoint_path=None,
+ name=None):
+ # support for datahub/kafka offset restore
+ input_fn.input_creator.restore(checkpoint_path)
+ return super(EasyRecEstimator, self).evaluate(input_fn, steps, hooks,
+ checkpoint_path, name)
+
+ def train(self,
+ input_fn,
+ hooks=None,
+ steps=None,
+ max_steps=None,
+ saving_listeners=None):
+ # support for datahub/kafka offset restore
+ checkpoint_path = estimator_utils.latest_checkpoint(self.model_dir)
+ if checkpoint_path is not None:
+ input_fn.input_creator.restore(checkpoint_path)
+ elif self.train_config.HasField('fine_tune_checkpoint'):
+ fine_tune_ckpt = self.train_config.fine_tune_checkpoint
+ if fine_tune_ckpt.endswith('/') or gfile.IsDirectory(fine_tune_ckpt +
+ '/'):
+ fine_tune_ckpt = estimator_utils.latest_checkpoint(fine_tune_ckpt)
+ print(
+ 'fine_tune_checkpoint[%s] is directory, will use the latest checkpoint: %s'
+ % (self.train_config.fine_tune_checkpoint, fine_tune_ckpt))
+ self.train_config.fine_tune_checkpoint = fine_tune_ckpt
+ input_fn.input_creator.restore(fine_tune_ckpt)
+ return super(EasyRecEstimator, self).train(input_fn, hooks, steps,
+ max_steps, saving_listeners)
+
@property
def feature_configs(self):
if len(self._pipeline_config.feature_configs) > 0:
@@ -62,11 +128,32 @@ def eval_config(self):
def train_config(self):
return self._pipeline_config.train_config
+ @property
+ def incr_save_config(self):
+ return self.train_config.incr_save_config if self.train_config.HasField(
+ 'incr_save_config') else None
+
@property
def export_config(self):
return self._pipeline_config.export_config
+ @property
+ def embedding_parallel(self):
+ return self.train_config.train_distribute in (
+ DistributionStrategy.SokStrategy,
+ DistributionStrategy.EmbeddingParallelStrategy)
+
+ @property
+ def saver_cls(self):
+ # when embedding parallel is used, will use the extended
+ # saver class (EmbeddingParallelSaver) to save sharded embedding
+ tmp_saver_cls = saver.Saver
+ if self.embedding_parallel:
+ tmp_saver_cls = EmbeddingParallelSaver
+ return tmp_saver_cls
+
def _train_model_fn(self, features, labels, run_config):
+ tf.keras.backend.set_learning_phase(1)
model = self._model_cls(
self.model_config,
self.feature_configs,
@@ -98,9 +185,29 @@ def _train_model_fn(self, features, labels, run_config):
for key in loss_dict:
tf.summary.scalar(key, loss_dict[key], family='loss')
+ if Input.DATA_OFFSET in features:
+ task_index, task_num = estimator_utils.get_task_index_and_num()
+ data_offset_var = tf.get_variable(
+ name=Input.DATA_OFFSET,
+ dtype=tf.string,
+ shape=[task_num],
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, Input.DATA_OFFSET],
+ trainable=False)
+ update_offset = tf.assign(data_offset_var[task_index],
+ features[Input.DATA_OFFSET])
+ ops.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_offset)
+ else:
+ data_offset_var = None
+
# update op, usually used for batch-norm
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
+ # register for increment update, such as batchnorm moving_mean and moving_variance
+ global_vars = {x.name: x for x in tf.global_variables()}
+ for x in update_ops:
+ if isinstance(x, ops.Operation) and x.inputs[0].name in global_vars:
+ ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
+ global_vars[x.inputs[0].name])
update_op = tf.group(*update_ops, name='update_barrier')
with tf.control_dependencies([update_op]):
loss = tf.identity(loss, name='total_loss')
@@ -118,21 +225,37 @@ def _train_model_fn(self, features, labels, run_config):
opt, learning_rate = optimizer_builder.build(tmp_config)
tf.summary.scalar('learning_rate', learning_rate[0])
all_opts.append(opt)
- grouped_vars = model.get_grouped_vars()
+ grouped_vars = model.get_grouped_vars(len(all_opts))
assert len(grouped_vars) == len(optimizer_config), \
'the number of var group(%d) != the number of optimizers(%d)' \
% (len(grouped_vars), len(optimizer_config))
optimizer = MultiOptimizer(all_opts, grouped_vars)
+ if self.train_config.train_distribute == DistributionStrategy.SokStrategy:
+ optimizer = sok_optimizer.OptimizerWrapper(optimizer)
+
hooks = []
+ if estimator_utils.has_hvd():
+ assert not self.train_config.sync_replicas, \
+ 'sync_replicas should not be set when using horovod'
+ bcast_hook = hvd_utils.BroadcastGlobalVariablesHook(0)
+ hooks.append(bcast_hook)
+
# for distributed and synced training
if self.train_config.sync_replicas and run_config.num_worker_replicas > 1:
logging.info('sync_replicas: num_worker_replias = %d' %
run_config.num_worker_replicas)
- optimizer = tf.train.SyncReplicasOptimizer(
- optimizer,
- replicas_to_aggregate=run_config.num_worker_replicas,
- total_num_replicas=run_config.num_worker_replicas)
+ if pai_util.is_on_pai():
+ optimizer = tf.train.SyncReplicasOptimizer(
+ optimizer,
+ replicas_to_aggregate=run_config.num_worker_replicas,
+ total_num_replicas=run_config.num_worker_replicas,
+ sparse_accumulator_type=self.train_config.sparse_accumulator_type)
+ else:
+ optimizer = sync_replicas_optimizer.SyncReplicasOptimizer(
+ optimizer,
+ replicas_to_aggregate=run_config.num_worker_replicas,
+ total_num_replicas=run_config.num_worker_replicas)
hooks.append(
optimizer.make_session_run_hook(run_config.is_chief, num_tokens=0))
@@ -168,6 +291,12 @@ def _train_model_fn(self, features, labels, run_config):
self.export_config.max_check_steps,
eval_dir=eval_dir))
+ if self.train_config.enable_oss_stop_signal:
+ hooks.append(oss_stop_hook(self))
+
+ if self.train_config.HasField('dead_line'):
+ hooks.append(deadline_stop_hook(self, self.train_config.dead_line))
+
summaries = ['global_gradient_norm']
if self.train_config.summary_model_vars:
summaries.extend(['gradient_norm', 'gradients'])
@@ -204,6 +333,9 @@ def _train_model_fn(self, features, labels, run_config):
else:
all_train_vars = tf.trainable_variables()
+ if self.embedding_parallel:
+ logging.info('embedding_parallel is enabled')
+
train_op = optimizers.optimize_loss(
loss=loss,
global_step=tf.train.get_global_step(),
@@ -216,7 +348,9 @@ def _train_model_fn(self, features, labels, run_config):
colocate_gradients_with_ops=True,
not_apply_grad_after_first_step=run_config.is_chief and
self._pipeline_config.data_config.chief_redundant,
- name='') # Preventing scope prefix on all variables.
+ name='', # Preventing scope prefix on all variables.
+ incr_save=(self.incr_save_config is not None),
+ embedding_parallel=self.embedding_parallel)
# online evaluation
metric_update_op_dict = None
@@ -225,7 +359,8 @@ def _train_model_fn(self, features, labels, run_config):
metric_dict = model.build_metric_graph(self.eval_config)
for k, v in metric_dict.items():
metric_update_op_dict['%s/batch' % k] = v[1]
- tf.summary.scalar('%s/batch' % k, v[1])
+ if isinstance(v[1], tf.Tensor):
+ tf.summary.scalar('%s/batch' % k, v[1])
train_op = tf.group([train_op] + list(metric_update_op_dict.values()))
if estimator_utils.is_chief():
hooks.append(
@@ -247,67 +382,63 @@ def _train_model_fn(self, features, labels, run_config):
# logging
logging_dict = OrderedDict()
- logging_dict['lr'] = learning_rate[0]
logging_dict['step'] = tf.train.get_global_step()
+ logging_dict['lr'] = learning_rate[0]
logging_dict.update(loss_dict)
if metric_update_op_dict is not None:
logging_dict.update(metric_update_op_dict)
- tensor_order = logging_dict.keys()
-
- def format_fn(tensor_dict):
- stats = []
- for k in tensor_order:
- tensor_value = tensor_dict[k]
- stats.append('%s = %s' % (k, tensor_value))
- return ','.join(stats)
log_step_count_steps = self.train_config.log_step_count_steps
-
- logging_hook = tf.train.LoggingTensorHook(
- logging_dict, every_n_iter=log_step_count_steps, formatter=format_fn)
+ logging_hook = basic_session_run_hooks.LoggingTensorHook(
+ logging_dict,
+ every_n_iter=log_step_count_steps,
+ formatter=estimator_utils.tensor_log_format_func)
hooks.append(logging_hook)
if self.train_config.train_distribute in [
DistributionStrategy.CollectiveAllReduceStrategy,
+ DistributionStrategy.MirroredStrategy,
DistributionStrategy.MultiWorkerMirroredStrategy
]:
# for multi worker strategy, we could not replace the
# inner CheckpointSaverHook, so just use it.
scaffold = tf.train.Scaffold()
- chief_hooks = []
else:
var_list = (
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
- initialize_var_list = [
- x for x in var_list if 'WorkQueue' not in str(type(x))
- ]
+
+ # exclude data_offset_var
+ var_list = [x for x in var_list if x != data_offset_var]
# early_stop flag will not be saved in checkpoint
# and could not be restored from checkpoint
early_stop_var = find_early_stop_var(var_list)
+ var_list = [x for x in var_list if x != early_stop_var]
+
+ initialize_var_list = [
+ x for x in var_list if 'WorkQueue' not in str(type(x))
+ ]
+
# incompatiable shape restore will not be saved in checkpoint
# but must be able to restore from checkpoint
incompatiable_shape_restore = tf.get_collection('T_E_M_P_RESTROE')
- if early_stop_var is not None:
- var_list = [x for x in var_list if x != early_stop_var]
- local_init_op = tf.group([
- tf.initializers.local_variables(),
- tf.initializers.variables([early_stop_var] +
- incompatiable_shape_restore)
- ])
- elif len(incompatiable_shape_restore) > 0:
- local_init_op = tf.group([
- tf.initializers.local_variables(),
- tf.initializers.variables(incompatiable_shape_restore)
- ])
- else:
- local_init_op = None
+
+ local_init_ops = [tf.train.Scaffold.default_local_init_op()]
+ if data_offset_var is not None and estimator_utils.is_chief():
+ local_init_ops.append(tf.initializers.variables([data_offset_var]))
+ if early_stop_var is not None and estimator_utils.is_chief():
+ local_init_ops.append(tf.initializers.variables([early_stop_var]))
+ if len(incompatiable_shape_restore) > 0:
+ local_init_ops.append(
+ tf.initializers.variables(incompatiable_shape_restore))
+
scaffold = tf.train.Scaffold(
- saver=tf.train.Saver(
+ saver=self.saver_cls(
var_list=var_list,
sharded=True,
- max_to_keep=self.train_config.keep_checkpoint_max),
- local_init_op=local_init_op,
+ max_to_keep=self.train_config.keep_checkpoint_max,
+ save_relative_paths=True),
+ local_init_op=tf.group(local_init_ops),
ready_for_local_init_op=tf.report_uninitialized_variables(
var_list=initialize_var_list))
# saver hook
@@ -316,10 +447,15 @@ def format_fn(tensor_dict):
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold,
- write_graph=self.train_config.write_graph)
- chief_hooks = []
- if estimator_utils.is_chief():
+ write_graph=self.train_config.write_graph,
+ data_offset_var=data_offset_var,
+ increment_save_config=self.incr_save_config)
+ if estimator_utils.is_chief() or self.embedding_parallel:
hooks.append(saver_hook)
+ if estimator_utils.is_chief():
+ hooks.append(
+ basic_session_run_hooks.StepCounterHook(
+ every_n_steps=log_step_count_steps, output_dir=self.model_dir))
# profiling hook
if self.train_config.is_profiling and estimator_utils.is_chief():
@@ -333,10 +469,10 @@ def format_fn(tensor_dict):
predictions=predict_dict,
train_op=train_op,
scaffold=scaffold,
- training_chief_hooks=chief_hooks,
training_hooks=hooks)
def _eval_model_fn(self, features, labels, run_config):
+ tf.keras.backend.set_learning_phase(0)
start = time.time()
model = self._model_cls(
self.model_config,
@@ -348,6 +484,7 @@ def _eval_model_fn(self, features, labels, run_config):
loss_dict = model.build_loss_graph()
loss = tf.add_n(list(loss_dict.values()))
loss_dict['total_loss'] = loss
+
metric_dict = model.build_metric_graph(self.eval_config)
for loss_key in loss_dict.keys():
loss_tensor = loss_dict[loss_key]
@@ -355,15 +492,27 @@ def _eval_model_fn(self, features, labels, run_config):
metric_dict['loss/loss/' + loss_key] = tf.metrics.mean(loss_tensor)
tf.logging.info('metric_dict keys: %s' % metric_dict.keys())
+ var_list = (
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
+ ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
+
+ metric_variables = ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)
+ model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
+ scaffold = tf.train.Scaffold(
+ saver=self.saver_cls(
+ var_list=var_list, sharded=True, save_relative_paths=True),
+ ready_for_local_init_op=model_ready_for_local_init_op)
end = time.time()
tf.logging.info('eval graph construct finished. Time %.3fs' % (end - start))
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
+ scaffold=scaffold,
predictions=predict_dict,
eval_metric_ops=metric_dict)
def _distribute_eval_model_fn(self, features, labels, run_config):
+ tf.keras.backend.set_learning_phase(0)
start = time.time()
model = self._model_cls(
self.model_config,
@@ -375,7 +524,7 @@ def _distribute_eval_model_fn(self, features, labels, run_config):
loss_dict = model.build_loss_graph()
loss = tf.add_n(list(loss_dict.values()))
loss_dict['total_loss'] = loss
- metric_dict = model.build_distribute_metric_graph(self.eval_config)
+ metric_dict = model.build_metric_graph(self.eval_config)
for loss_key in loss_dict.keys():
loss_tensor = loss_dict[loss_key]
# add key-prefix to make loss metric key in the same family of train loss
@@ -405,7 +554,7 @@ def _distribute_eval_model_fn(self, features, labels, run_config):
model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
remain_variables = list(
set(global_variables).difference(set(metric_variables)))
- cur_saver = tf.train.Saver(var_list=remain_variables)
+ cur_saver = tf.train.Saver(var_list=remain_variables, sharded=True)
scaffold = tf.train.Scaffold(
saver=cur_saver, ready_for_local_init_op=model_ready_for_local_init_op)
return tf.estimator.EstimatorSpec(
@@ -416,21 +565,27 @@ def _distribute_eval_model_fn(self, features, labels, run_config):
scaffold=scaffold)
def _export_model_fn(self, features, labels, run_config, params):
+ tf.keras.backend.set_learning_phase(0)
model = self._model_cls(
self.model_config,
self.feature_configs,
features,
labels=None,
is_training=False)
- predict_dict = model.build_predict_graph()
+ model.build_predict_graph()
- # add output info to estimator spec
+ export_config = self._pipeline_config.export_config
outputs = {}
- output_list = model.get_outputs()
- for out in output_list:
- assert out in predict_dict, \
- 'output node %s not in prediction_dict, can not be exported' % out
- outputs[out] = predict_dict[out]
+ logging.info('building default outputs')
+ outputs.update(model.build_output_dict())
+ if export_config.export_features:
+ logging.info('building output features')
+ outputs.update(model.build_feature_output_dict())
+ if export_config.export_rtp_outputs:
+ logging.info('building RTP outputs')
+ outputs.update(model.build_rtp_output_dict())
+
+ for out in outputs:
tf.logging.info(
'output %s shape: %s type: %s' %
(out, outputs[out].get_shape().as_list(), outputs[out].dtype))
@@ -441,33 +596,144 @@ def _export_model_fn(self, features, labels, run_config, params):
# save train pipeline.config for debug purpose
pipeline_path = os.path.join(self._model_dir, 'pipeline.config')
- if tf.gfile.Exists(pipeline_path):
- tf.add_to_collection(
+ if gfile.Exists(pipeline_path):
+ ops.add_to_collection(
tf.GraphKeys.ASSET_FILEPATHS,
tf.constant(pipeline_path, dtype=tf.string, name='pipeline.config'))
else:
print('train pipeline_path(%s) does not exist' % pipeline_path)
+ # restore DENSE_UPDATE_VARIABLES collection
+ dense_train_var_path = os.path.join(self.model_dir,
+ constant.DENSE_UPDATE_VARIABLES)
+ if gfile.Exists(dense_train_var_path):
+ with gfile.GFile(dense_train_var_path, 'r') as fin:
+ var_name_to_id_map = json.load(fin)
+ var_name_id_lst = [
+ (x, var_name_to_id_map[x]) for x in var_name_to_id_map
+ ]
+ var_name_id_lst.sort(key=lambda x: x[1])
+ all_vars = {x.op.name: x for x in tf.global_variables()}
+ for var_name, var_id in var_name_id_lst:
+ assert var_name in all_vars, 'dense_train_var[%s] is not found' % var_name
+ ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
+ all_vars[var_name])
+
# add more asset files
- if 'asset_files' in params:
+ if len(export_config.asset_files) > 0:
+ for asset_file in export_config.asset_files:
+ if asset_file.startswith('!'):
+ asset_file = asset_file[1:]
+ _, asset_name = os.path.split(asset_file)
+ ops.add_to_collection(
+ ops.GraphKeys.ASSET_FILEPATHS,
+ tf.constant(asset_file, dtype=tf.string, name=asset_name))
+ elif 'asset_files' in params:
for asset_name in params['asset_files']:
asset_file = params['asset_files'][asset_name]
- tf.add_to_collection(
+ ops.add_to_collection(
tf.GraphKeys.ASSET_FILEPATHS,
tf.constant(asset_file, dtype=tf.string, name=asset_name))
+ if self._pipeline_config.HasField('fg_json_path'):
+ fg_path = self._pipeline_config.fg_json_path
+ if fg_path[0] == '!':
+ fg_path = fg_path[1:]
+ ops.add_to_collection(
+ tf.GraphKeys.ASSET_FILEPATHS,
+ tf.constant(fg_path, dtype=tf.string, name='fg.json'))
+
+ var_list = (
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
+ ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
+
+ scaffold = tf.train.Scaffold(
+ saver=self.saver_cls(
+ var_list=var_list, sharded=True, save_relative_paths=True))
+
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
loss=None,
+ scaffold=scaffold,
predictions=outputs,
export_outputs=export_outputs)
def _model_fn(self, features, labels, mode, config, params):
os.environ['tf.estimator.mode'] = mode
os.environ['tf.estimator.ModeKeys.TRAIN'] = tf.estimator.ModeKeys.TRAIN
+ if self._pipeline_config.feature_config.embedding_on_cpu:
+ os.environ['place_embedding_on_cpu'] = 'True'
+ if self._pipeline_config.fg_json_path:
+ EasyRecEstimator._write_rtp_fg_config_to_col(
+ fg_config_path=self._pipeline_config.fg_json_path)
+ EasyRecEstimator._write_rtp_inputs_to_col(features)
+
+ if self.embedding_parallel:
+ embedding_utils.set_embedding_parallel()
+
if mode == tf.estimator.ModeKeys.TRAIN:
return self._train_model_fn(features, labels, config)
elif mode == tf.estimator.ModeKeys.EVAL:
return self._eval_model_fn(features, labels, config)
elif mode == tf.estimator.ModeKeys.PREDICT:
return self._export_model_fn(features, labels, config, params)
+
+ @staticmethod
+ def _write_rtp_fg_config_to_col(fg_config=None, fg_config_path=None):
+ """Write RTP config to RTP-specified graph collections.
+
+ Args:
+ fg_config: JSON-dict RTP config. If set, fg_config_path will be ignored.
+ fg_config_path: path to the RTP config file.
+ """
+ if fg_config is None:
+ if fg_config_path.startswith('!'):
+ fg_config_path = fg_config_path[1:]
+ with gfile.GFile(fg_config_path, 'r') as f:
+ fg_config = json.load(f)
+ col = ops.get_collection_ref(GraphKeys.RANK_SERVICE_FG_CONF)
+ if len(col) == 0:
+ col.append(json.dumps(fg_config))
+ else:
+ col[0] = json.dumps(fg_config)
+
+ @staticmethod
+ def _write_rtp_inputs_to_col(features):
+ """Write input nodes information to RTP-specified graph collections.
+
+ Args:
+ features: the feature dictionary used as model input.
+ """
+ feature_info_map = dict()
+ for feature_name, feature_value in features.items():
+ feature_info = _tensor_to_tensorinfo(feature_value)
+ feature_info_map[feature_name] = feature_info
+ col = ops.get_collection_ref(GraphKeys.RANK_SERVICE_FEATURE_NODE)
+ if len(col) == 0:
+ col.append(json.dumps(feature_info_map))
+ else:
+ col[0] = json.dumps(feature_info_map)
+
+ def export_checkpoint(self,
+ export_path=None,
+ serving_input_receiver_fn=None,
+ checkpoint_path=None,
+ mode=tf.estimator.ModeKeys.PREDICT):
+ with context.graph_mode():
+ if not checkpoint_path:
+ # Locate the latest checkpoint
+ checkpoint_path = estimator_utils.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise ValueError("Couldn't find trained model at %s." % self._model_dir)
+ with ops.Graph().as_default():
+ input_receiver = serving_input_receiver_fn()
+ estimator_spec = self._call_model_fn(
+ features=input_receiver.features,
+ labels=getattr(input_receiver, 'labels', None),
+ mode=mode,
+ config=self.config)
+ with tf_session.Session(config=self._session_config) as session:
+ graph_saver = estimator_spec.scaffold.saver or saver.Saver(
+ sharded=True)
+ graph_saver.restore(session, checkpoint_path)
+ graph_saver.save(session, export_path)
diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py
index 7ea15a564..f2408ba47 100644
--- a/easy_rec/python/model/easy_rec_model.py
+++ b/easy_rec/python/model/easy_rec_model.py
@@ -2,21 +2,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
+import os
import re
from abc import abstractmethod
import six
import tensorflow as tf
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops.variables import PartitionedVariable
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from easy_rec.python.compat import regularizers
from easy_rec.python.layers import input_layer
+from easy_rec.python.layers.backbone import Backbone
from easy_rec.python.utils import constant
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import restore_filter
from easy_rec.python.utils.load_class import get_register_class_meta
+try:
+ import horovod.tensorflow as hvd
+ from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
+ from sparse_operation_kit import experiment as sok
+except Exception:
+ dynamic_variable_ops = None
+ sok = None
+
+try:
+ from tensorflow.python.framework.load_library import load_op_library
+ import easy_rec
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
+ load_embed_lib = load_op_library(load_embed_lib_path)
+except Exception as ex:
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
+ load_embed_lib = None
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -36,10 +57,31 @@ def __init__(self,
self._base_model_config = model_config
self._model_config = model_config
self._is_training = is_training
+ self._is_predicting = labels is None
self._feature_dict = features
- self._emb_reg = regularizers.l2_regularizer(self.embedding_regularization)
- self._l2_reg = regularizers.l2_regularizer(self.l2_regularization)
+ # embedding variable parameters
+ self._global_ev_params = None
+ if model_config.HasField('ev_params'):
+ self._global_ev_params = model_config.ev_params
+
+ if self.embedding_regularization > 0:
+ self._emb_reg = regularizers.l2_regularizer(self.embedding_regularization)
+ else:
+ self._emb_reg = None
+
+ if self.l2_regularization > 0:
+ self._l2_reg = regularizers.l2_regularizer(self.l2_regularization)
+ else:
+ self._l2_reg = None
+
+ # only used by model with wide feature groups, e.g. WideAndDeep
+ self._wide_output_dim = -1
+ if self.has_backbone:
+ wide_dim = Backbone.wide_embed_dim(model_config.backbone)
+ if wide_dim:
+ self._wide_output_dim = wide_dim
+ logging.info('set `wide_output_dim` to %d' % wide_dim)
self._feature_configs = feature_configs
self.build_input_layer(model_config, feature_configs)
@@ -47,12 +89,44 @@ def __init__(self,
self._labels = labels
self._prediction_dict = {}
self._loss_dict = {}
+ self._metric_dict = {}
# add sample weight from inputs
self._sample_weight = 1.0
if constant.SAMPLE_WEIGHT in features:
self._sample_weight = features[constant.SAMPLE_WEIGHT]
+ self._backbone_output = None
+ self._backbone_net = self.build_backbone_network()
+
+ def build_backbone_network(self):
+ if self.has_backbone:
+ return Backbone(
+ self._base_model_config.backbone,
+ self._feature_dict,
+ input_layer=self._input_layer,
+ l2_reg=self._l2_reg)
+ return None
+
+ @property
+ def has_backbone(self):
+ return self._base_model_config.HasField('backbone')
+
+ @property
+ def backbone(self):
+ if self._backbone_output:
+ return self._backbone_output
+ if self._backbone_net:
+ kwargs = {
+ 'loss_dict': self._loss_dict,
+ 'metric_dict': self._metric_dict,
+ 'prediction_dict': self._prediction_dict,
+ 'labels': self._labels,
+ constant.SAMPLE_WEIGHT: self._sample_weight
+ }
+ return self._backbone_net(self._is_training, **kwargs)
+ return None
+
@property
def embedding_regularization(self):
return self._base_model_config.embedding_regularization
@@ -61,6 +135,10 @@ def embedding_regularization(self):
def kd(self):
return self._base_model_config.kd
+ @property
+ def feature_groups(self):
+ return self._base_model_config.feature_groups
+
@property
def l2_regularization(self):
model_config = getattr(self._base_model_config,
@@ -69,7 +147,7 @@ def l2_regularization(self):
if hasattr(model_config, 'dense_regularization') and \
model_config.HasField('dense_regularization'):
# backward compatibility
- tf.logging.warn(
+ logging.warn(
'dense_regularization is deprecated, please use l2_regularization')
l2_regularization = model_config.dense_regularization
elif hasattr(model_config, 'l2_regularization'):
@@ -80,12 +158,14 @@ def build_input_layer(self, model_config, feature_configs):
self._input_layer = input_layer.InputLayer(
feature_configs,
model_config.feature_groups,
- use_embedding_variable=model_config.use_embedding_variable,
+ wide_output_dim=self._wide_output_dim,
+ ev_params=self._global_ev_params,
embedding_regularizer=self._emb_reg,
kernel_regularizer=self._l2_reg,
variational_dropout_config=model_config.variational_dropout
if model_config.HasField('variational_dropout') else None,
- is_training=False)
+ is_training=self._is_training,
+ is_predicting=self._is_predicting)
@abstractmethod
def build_predict_graph(self):
@@ -95,14 +175,47 @@ def build_predict_graph(self):
def build_loss_graph(self):
pass
- @abstractmethod
def build_metric_graph(self, eval_config):
- pass
+ return self._metric_dict
@abstractmethod
def get_outputs(self):
pass
+ def build_output_dict(self):
+ """For exporting: get standard output nodes."""
+ outputs = {}
+ for name in self.get_outputs():
+ if name not in self._prediction_dict:
+ raise KeyError(
+ 'output node {} not in prediction_dict, can not be exported'.format(
+ name))
+ outputs[name] = self._prediction_dict[name]
+ return outputs
+
+ def build_feature_output_dict(self):
+ """For exporting: get output feature nodes."""
+ outputs = {}
+ for feature_name in self._feature_dict:
+ out_name = 'feature_' + feature_name
+ feature_value = self._feature_dict[feature_name]
+ if isinstance(feature_value, tf.SparseTensor):
+ sparse_values = feature_value.values
+ if sparse_values.dtype != tf.string:
+ sparse_values = tf.as_string(sparse_values)
+ feature_value = tf.sparse_to_dense(feature_value.indices,
+ feature_value.dense_shape,
+ sparse_values, '')
+ elif feature_value.dtype != tf.string:
+ feature_value = tf.as_string(feature_value)
+ feature_value = tf.reduce_join(feature_value, axis=-1, separator=',')
+ outputs[out_name] = feature_value
+ return outputs
+
+ def build_rtp_output_dict(self):
+ """For exporting: get output nodes for RTP infering."""
+ return {}
+
def restore(self,
ckpt_path,
include_global_step=False,
@@ -130,11 +243,6 @@ def restore(self,
name2var_map = self._get_restore_vars(ckpt_var_map_path)
logging.info('start to restore from %s' % ckpt_path)
- if ckpt_path.endswith('/') or tf.gfile.IsDirectory(ckpt_path + '/'):
- ckpt_path = estimator_utils.latest_checkpoint(ckpt_path)
- print('ckpt_path is model_dir, will use the latest checkpoint: %s' %
- ckpt_path)
-
ckpt_reader = tf.train.NewCheckpointReader(ckpt_path)
ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map()
if not include_global_step:
@@ -153,7 +261,7 @@ def restore(self,
for x in shape_arr[1:]:
var_shape[0] += x[0]
var_shape = tensor_shape.TensorShape(var_shape)
- variable = PartitionedVariable(
+ variable = variables.PartitionedVariable(
variable_name,
var_shape,
variable[0].dtype,
@@ -163,7 +271,7 @@ def restore(self,
var_shape = variable.shape.as_list()
if ckpt_var_shape == var_shape:
vars_in_ckpt[variable_name] = list(variable) if isinstance(
- variable, PartitionedVariable) else variable
+ variable, variables.PartitionedVariable) else variable
elif len(ckpt_var_shape) == len(var_shape):
if force_restore_shape_compatible:
# create a variable compatible with checkpoint to restore
@@ -190,6 +298,43 @@ def restore(self,
logging.warning(
'Variable [%s] is available in checkpoint, but '
'incompatible shape dims with model variable.', variable_name)
+ elif 'EmbeddingVariable' in str(type(variable)):
+ if '%s-keys' % variable_name not in ckpt_var2shape_map:
+ continue
+ print('restore embedding_variable %s' % variable_name)
+ from tensorflow.python.training import saver
+ names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
+ saveable_objects = []
+ for name, op in names_to_saveables.items():
+ for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
+ saveable_objects.append(s)
+ init_op = saveable_objects[0].restore([ckpt_path], None)
+ variable._initializer_op = init_op
+ elif type(variable) == list and 'EmbeddingVariable' in str(
+ type(variable[0])):
+ if '%s/part_0-keys' % variable_name not in ckpt_var2shape_map:
+ continue
+ print('restore partitioned embedding_variable %s' % variable_name)
+ from tensorflow.python.training import saver
+ for part_var in variable:
+ names_to_saveables = saver.BaseSaverBuilder.OpListToDict([part_var])
+ saveable_objects = []
+ for name, op in names_to_saveables.items():
+ for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
+ saveable_objects.append(s)
+ init_op = saveable_objects[0].restore([ckpt_path], None)
+ part_var._initializer_op = init_op
+ elif sok is not None and isinstance(variable, sok.DynamicVariable):
+ print('restore dynamic_variable %s' % variable_name)
+ keys, vals = load_embed_lib.load_kv_embed(
+ task_index=hvd.rank(),
+ task_num=hvd.size(),
+ embed_dim=variable._dimension,
+ var_name='embed-' + variable.name.replace('/', '__'),
+ ckpt_path=ckpt_path)
+ with ops.control_dependencies([variable._initializer_op]):
+ variable._initializer_op = dynamic_variable_ops.dummy_var_assign(
+ variable.handle, keys, vals)
else:
fail_restore_vars.append(variable_name)
for variable_name in fail_restore_vars:
@@ -225,8 +370,7 @@ def _get_restore_vars(self, ckpt_var_map_path):
for one_var in all_vars:
var_name = re.sub(VAR_SUFIX_PATTERN, '', one_var.name)
if re.search(PARTITION_PATTERN,
- var_name) and (not var_name.endswith('/AdamAsync_2') and
- not var_name.endswith('/AdamAsync_3')):
+ var_name) and one_var._save_slice_info is not None:
var_name = re.sub(PARTITION_PATTERN, '', var_name)
is_part = True
else:
@@ -238,13 +382,13 @@ def _get_restore_vars(self, ckpt_var_map_path):
name2var[var_name] = [one_var] if is_part else one_var
if ckpt_var_map_path != '':
- if not tf.gfile.Exists(ckpt_var_map_path):
+ if not gfile.Exists(ckpt_var_map_path):
logging.warning('%s not exist' % ckpt_var_map_path)
return name2var
# load var map
name_map = {}
- with open(ckpt_var_map_path, 'r') as fin:
+ with gfile.GFile(ckpt_var_map_path, 'r') as fin:
for one_line in fin:
one_line = one_line.strip()
line_tok = [x for x in one_line.split() if x != '']
@@ -252,14 +396,16 @@ def _get_restore_vars(self, ckpt_var_map_path):
logging.warning('Failed to process: %s' % one_line)
continue
name_map[line_tok[0]] = line_tok[1]
- var_map = {}
+ update_map = {}
+ old_keys = []
for var_name in name2var:
if var_name in name_map:
in_ckpt_name = name_map[var_name]
- var_map[in_ckpt_name] = name2var[var_name]
- else:
- logging.warning('Failed to find in var_map_file(%s): %s' %
- (ckpt_var_map_path, var_name))
+ update_map[in_ckpt_name] = name2var[var_name]
+ old_keys.append(var_name)
+ for tmp_key in old_keys:
+ del name2var[tmp_key]
+ name2var.update(update_map)
return name2var
else:
var_filter, scope_update = self.get_restore_filter()
@@ -297,10 +443,25 @@ def get_restore_filter(self):
return restore_filter.CombineFilter(all_filters,
restore_filter.Logical.AND), None
- def get_grouped_vars(self):
- """Get grouped variables, each group will be optimized by a separate optimizer.
+ def get_grouped_vars(self, opt_num):
+ """Group the vars into different optimization groups.
+
+ Each group will be optimized by a separate optimizer.
+
+ Args:
+ opt_num: number of optimizers from easyrec config.
Return:
- grouped_vars: list of list of variables
+ list of list of variables.
"""
- raise NotImplementedError()
+ assert opt_num == 2, 'could only support 2 optimizers, one for embedding, one for the other layers'
+
+ embedding_vars = []
+ deep_vars = []
+ for tmp_var in variables.trainable_variables():
+ if tmp_var.name.startswith(
+ 'input_layer') or '/embedding_weights' in tmp_var.name:
+ embedding_vars.append(tmp_var)
+ else:
+ deep_vars.append(tmp_var)
+ return [embedding_vars, deep_vars]
diff --git a/easy_rec/python/model/esmm.py b/easy_rec/python/model/esmm.py
index 1c6901309..50567ae63 100644
--- a/easy_rec/python/model/esmm.py
+++ b/easy_rec/python/model/esmm.py
@@ -12,7 +12,6 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
losses = tf.losses
-metrics = tf.metrics
class ESMM(MultiTaskModel):
@@ -32,7 +31,9 @@ def __init__(self,
self._group_num = len(self._model_config.groups)
self._group_features = []
- if self._group_num > 0:
+ if self.has_backbone:
+ logging.info('use bottom backbone network')
+ elif self._group_num > 0:
logging.info('group_num: {0}'.format(self._group_num))
for group_id in range(self._group_num):
group = self._model_config.groups[group_id]
@@ -174,7 +175,9 @@ def build_predict_graph(self):
Returns:
self._prediction_dict: Prediction result of two tasks.
"""
- if self._group_num > 0:
+ if self.has_backbone:
+ all_fea = self.backbone
+ elif self._group_num > 0:
group_fea_arr = []
# Both towers share the underlying network.
for group_id in range(self._group_num):
diff --git a/easy_rec/python/model/fm.py b/easy_rec/python/model/fm.py
index be51c261b..357628d3e 100644
--- a/easy_rec/python/model/fm.py
+++ b/easy_rec/python/model/fm.py
@@ -5,7 +5,6 @@
import tensorflow as tf
from easy_rec.python.layers import fm
-from easy_rec.python.layers import input_layer
from easy_rec.python.model.rank_model import RankModel
from easy_rec.python.protos.fm_pb2 import FM as FMConfig
@@ -34,13 +33,8 @@ def __init__(self,
def build_input_layer(self, model_config, feature_configs):
# overwrite create input_layer to support wide_output_dim
- self._input_layer = input_layer.InputLayer(
- feature_configs,
- model_config.feature_groups,
- wide_output_dim=model_config.num_class,
- use_embedding_variable=model_config.use_embedding_variable,
- embedding_regularizer=self._emb_reg,
- kernel_regularizer=self._l2_reg)
+ self._wide_output_dim = model_config.num_class
+ super(FM, self).build_input_layer(model_config, feature_configs)
def build_predict_graph(self):
wide_fea = tf.reduce_sum(
diff --git a/easy_rec/python/model/match_model.py b/easy_rec/python/model/match_model.py
new file mode 100644
index 000000000..e9c4d2d44
--- /dev/null
+++ b/easy_rec/python/model/match_model.py
@@ -0,0 +1,357 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+import os
+
+import tensorflow as tf
+
+from easy_rec.python.builders import loss_builder
+from easy_rec.python.model.easy_rec_model import EasyRecModel
+from easy_rec.python.protos.loss_pb2 import LossType
+from easy_rec.python.protos.simi_pb2 import Similarity
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+losses = tf.losses
+
+
+class MatchModel(EasyRecModel):
+
+ def __init__(self,
+ model_config,
+ feature_configs,
+ features,
+ labels=None,
+ is_training=False):
+ super(MatchModel, self).__init__(model_config, feature_configs, features,
+ labels, is_training)
+ self._loss_type = self._model_config.loss_type
+ self._num_class = self._model_config.num_class
+
+ if self._loss_type == LossType.CLASSIFICATION:
+ assert self._num_class == 1
+
+ if self._loss_type in [LossType.CLASSIFICATION, LossType.L2_LOSS]:
+ self._is_point_wise = True
+ logging.info('Use point wise dssm.')
+ else:
+ self._is_point_wise = False
+ logging.info('Use list wise dssm.')
+
+ cls_mem = self._model_config.WhichOneof('model')
+ sub_model_config = getattr(self._model_config, cls_mem)
+
+ self._item_ids = None
+ assert sub_model_config is not None, 'sub_model_config undefined: model_cls = %s' % cls_mem
+ if getattr(sub_model_config, 'item_id', '') != '':
+ logging.info('item_id feature is: %s' % sub_model_config.item_id)
+ self._item_ids = features[sub_model_config.item_id]
+
+ def _mask_in_batch(self, logits):
+ batch_size = tf.shape(logits)[0]
+ if getattr(self._model_config, 'ignore_in_batch_neg_sam', False):
+ in_batch = logits[:, :batch_size] - (
+ 1 - tf.diag(tf.ones([batch_size], dtype=tf.float32))) * 1e32
+ return tf.concat([in_batch, logits[:, batch_size:]], axis=1)
+ else:
+ if self._item_ids is not None:
+ mask_in_batch_neg = tf.to_float(
+ tf.equal(self._item_ids[None, :batch_size],
+ self._item_ids[:batch_size, None])) - tf.diag(
+ tf.ones([batch_size], dtype=tf.float32))
+ tf.summary.scalar('in_batch_neg_conflict',
+ tf.reduce_sum(mask_in_batch_neg))
+ return tf.concat([
+ logits[:, :batch_size] - mask_in_batch_neg * 1e32,
+ logits[:, batch_size:]],
+ axis=1) # yapf: disable
+ else:
+ return logits
+
+ def _list_wise_sim(self, user_emb, item_emb):
+ batch_size = tf.shape(user_emb)[0]
+ hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)
+
+ if hard_neg_indices is not None:
+ logging.info('With hard negative examples')
+ noclk_size = tf.shape(hard_neg_indices)[0]
+ # pos_item_emb, neg_item_emb, hard_neg_item_emb = tf.split(
+ # item_emb, [batch_size, -1, noclk_size], axis=0)
+ simple_item_emb, hard_neg_item_emb = tf.split(
+ item_emb, [-1, noclk_size], axis=0)
+ else:
+ # pos_item_emb = item_emb[:batch_size]
+ # neg_item_emb = item_emb[batch_size:]
+ simple_item_emb = item_emb
+
+ # pos_user_item_sim = tf.reduce_sum(
+ # tf.multiply(user_emb, pos_item_emb), axis=1, keep_dims=True)
+ # neg_user_item_sim = tf.matmul(user_emb, tf.transpose(neg_item_emb))
+ # simple_user_item_sim = tf.matmul(user_emb, tf.transpose(simple_item_emb))
+
+ _mode = os.environ['tf.estimator.mode']
+ if _mode == tf.estimator.ModeKeys.PREDICT:
+ simple_user_item_sim = tf.reduce_sum(
+ tf.multiply(user_emb, simple_item_emb), axis=1, keep_dims=True)
+ else:
+ simple_user_item_sim = tf.matmul(user_emb, tf.transpose(simple_item_emb))
+
+ if hard_neg_indices is None:
+ return simple_user_item_sim
+ else:
+ user_emb_expand = tf.gather(user_emb, hard_neg_indices[:, 0])
+ hard_neg_user_item_sim = tf.reduce_sum(
+ tf.multiply(user_emb_expand, hard_neg_item_emb), axis=1)
+ max_num_neg = tf.reduce_max(hard_neg_indices[:, 1]) + 1
+ hard_neg_shape = tf.stack([tf.to_int64(batch_size), max_num_neg])
+ hard_neg_sim = tf.scatter_nd(hard_neg_indices, hard_neg_user_item_sim,
+ hard_neg_shape)
+ hard_neg_mask = tf.scatter_nd(
+ hard_neg_indices,
+ tf.ones_like(hard_neg_user_item_sim, dtype=tf.float32),
+ shape=hard_neg_shape)
+ # set tail positions to -1e32, so that after exp(x), will be zero
+ hard_neg_user_item_sim = hard_neg_sim - (1 - hard_neg_mask) * 1e32
+
+ # user_item_sim = [pos_user_item_sim, neg_user_item_sim]
+ # if hard_neg_indices is not None:
+ # user_item_sim.append(hard_neg_user_item_sim)
+ # return tf.concat(user_item_sim, axis=1)
+
+ return tf.concat([simple_user_item_sim, hard_neg_user_item_sim], axis=1)
+
+ def _point_wise_sim(self, user_emb, item_emb):
+ user_item_sim = tf.reduce_sum(
+ tf.multiply(user_emb, item_emb), axis=1, keep_dims=True)
+ return user_item_sim
+
+ def sim(self, user_emb, item_emb):
+ # Name the outputs of the user tower and the item tower, i.e. the inputs of the
+ # simularity operation.
+ # Explicit names of these nodes are necessary for some online recall systems like
+ # BasicEngine to split up the predicting graph into different clusters.
+ user_emb = tf.identity(user_emb, 'user_tower_emb')
+ item_emb = tf.identity(item_emb, 'item_tower_emb')
+
+ if self._is_point_wise:
+ return self._point_wise_sim(user_emb, item_emb)
+ else:
+ return self._list_wise_sim(user_emb, item_emb)
+
+ def norm(self, fea):
+ fea_norm = tf.nn.l2_normalize(fea, axis=-1)
+ return fea_norm
+
+ def build_predict_graph(self):
+ if not self.has_backbone:
+ raise NotImplementedError(
+ 'method `build_predict_graph` must be implemented when you donot use backbone network'
+ )
+ model = self._model_config.WhichOneof('model')
+ assert model == 'model_params', '`model_params` must be configured'
+ model_params = self._model_config.model_params
+ for out in model_params.outputs:
+ self._outputs.append(out)
+
+ output = self.backbone
+
+ user_tower_emb = output[model_params.user_tower_idx_in_output]
+ item_tower_emb = output[model_params.item_tower_idx_in_output]
+
+ if model_params.simi_func == Similarity.COSINE:
+ user_tower_emb = self.norm(user_tower_emb)
+ item_tower_emb = self.norm(item_tower_emb)
+ temperature = model_params.temperature
+ else:
+ temperature = 1.0
+
+ user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
+
+ if model_params.scale_simi:
+ sim_w = tf.get_variable(
+ 'sim_w',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.ones_initializer())
+ sim_b = tf.get_variable(
+ 'sim_b',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.zeros_initializer())
+ y_pred = user_item_sim * tf.abs(sim_w) + sim_b
+ else:
+ y_pred = user_item_sim
+
+ if self._is_point_wise:
+ y_pred = tf.reshape(y_pred, [-1])
+
+ if self._loss_type == LossType.CLASSIFICATION:
+ self._prediction_dict['logits'] = y_pred
+ self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ y_pred = self._mask_in_batch(y_pred)
+ self._prediction_dict['logits'] = y_pred
+ self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
+ else:
+ self._prediction_dict['y'] = y_pred
+
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
+ self._prediction_dict['user_emb'] = tf.reduce_join(
+ tf.as_string(user_tower_emb), axis=-1, separator=',')
+ self._prediction_dict['item_emb'] = tf.reduce_join(
+ tf.as_string(item_tower_emb), axis=-1, separator=',')
+
+ return self._prediction_dict
+
+ def build_loss_graph(self):
+ if self._is_point_wise:
+ return self._build_point_wise_loss_graph()
+ else:
+ return self._build_list_wise_loss_graph()
+
+ def _build_list_wise_loss_graph(self):
+ if self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ batch_size = tf.shape(self._prediction_dict['probs'])[0]
+ indices = tf.range(batch_size)
+ indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
+ hit_prob = tf.gather_nd(
+ self._prediction_dict['probs'][:batch_size, :batch_size], indices)
+
+ sample_weights = tf.cast(tf.squeeze(self._sample_weight), tf.float32)
+ self._loss_dict['cross_entropy_loss'] = -tf.reduce_mean(
+ tf.log(hit_prob + 1e-12) *
+ sample_weights) / tf.reduce_mean(sample_weights)
+
+ logging.info('softmax cross entropy loss is used')
+
+ user_features = self._prediction_dict['user_tower_emb']
+ pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
+ pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
+ # if pos_simi < 0, produce loss
+ reg_pos_loss = tf.nn.relu(-pos_simi)
+ self._loss_dict['reg_pos_loss'] = tf.reduce_mean(
+ reg_pos_loss * sample_weights) / tf.reduce_mean(sample_weights)
+
+ # the AMM loss for DAT model
+ if all([
+ k in self._prediction_dict.keys() for k in
+ ['augmented_p_u', 'augmented_p_i', 'augmented_a_u', 'augmented_a_i']
+ ]):
+ self._loss_dict[
+ 'amm_loss_u'] = self._model_config.amm_u_weight * tf.reduce_mean(
+ tf.square(self._prediction_dict['augmented_a_u'] -
+ self._prediction_dict['augmented_p_i'][:batch_size]) *
+ sample_weights) / tf.reduce_mean(sample_weights)
+ self._loss_dict[
+ 'amm_loss_i'] = self._model_config.amm_i_weight * tf.reduce_mean(
+ tf.square(self._prediction_dict['augmented_a_i'][:batch_size] -
+ self._prediction_dict['augmented_p_u']) *
+ sample_weights) / tf.reduce_mean(sample_weights)
+
+ else:
+ raise ValueError('invalid loss type: %s' % str(self._loss_type))
+ return self._loss_dict
+
+ def _build_point_wise_loss_graph(self):
+ label = list(self._labels.values())[0]
+ if self._loss_type == LossType.CLASSIFICATION:
+ pred = self._prediction_dict['logits']
+ loss_name = 'cross_entropy_loss'
+ elif self._loss_type == LossType.L2_LOSS:
+ pred = self._prediction_dict['y']
+ loss_name = 'l2_loss'
+ else:
+ raise ValueError('invalid loss type: %s' % str(self._loss_type))
+
+ kwargs = {'loss_name': loss_name}
+ self._loss_dict[loss_name] = loss_builder.build(
+ self._loss_type,
+ label=label,
+ pred=pred,
+ loss_weight=self._sample_weight,
+ **kwargs)
+
+ # build kd loss
+ kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
+ self._labels, self._feature_dict)
+ self._loss_dict.update(kd_loss_dict)
+ return self._loss_dict
+
+ def build_metric_graph(self, eval_config):
+ if self._is_point_wise:
+ return self._build_point_wise_metric_graph(eval_config)
+ else:
+ return self._build_list_wise_metric_graph(eval_config)
+
+ def _build_list_wise_metric_graph(self, eval_config):
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
+ logits = self._prediction_dict['logits']
+ # label = tf.zeros_like(logits[:, :1], dtype=tf.int64)
+ batch_size = tf.shape(logits)[0]
+ label = tf.cast(tf.range(batch_size), tf.int64)
+
+ indices = tf.range(batch_size)
+ indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
+ pos_item_sim = tf.gather_nd(logits[:batch_size, :batch_size], indices)
+ metric_dict = {}
+ for metric in eval_config.metrics_set:
+ if metric.WhichOneof('metric') == 'recall_at_topk':
+ metric_dict['recall@%d' %
+ metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
+ label, logits, metric.recall_at_topk.topk)
+
+ logits_v2 = tf.concat([pos_item_sim[:, None], logits[:, batch_size:]],
+ axis=1)
+ labels_v2 = tf.zeros_like(logits_v2[:, :1], dtype=tf.int64)
+ metric_dict['recall_neg_sam@%d' %
+ metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
+ labels_v2, logits_v2, metric.recall_at_topk.topk)
+
+ metric_dict['recall_in_batch@%d' %
+ metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
+ label, logits[:, :batch_size],
+ metric.recall_at_topk.topk)
+ else:
+ raise ValueError('invalid metric type: %s' % str(metric))
+ return metric_dict
+
+ def _build_point_wise_metric_graph(self, eval_config):
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
+ metric_dict = {}
+ label = list(self._labels.values())[0]
+ for metric in eval_config.metrics_set:
+ if metric.WhichOneof('metric') == 'auc':
+ assert self._loss_type == LossType.CLASSIFICATION
+ metric_dict['auc'] = metrics_tf.auc(label,
+ self._prediction_dict['probs'])
+ elif metric.WhichOneof('metric') == 'mean_absolute_error':
+ assert self._loss_type == LossType.L2_LOSS
+ metric_dict['mean_absolute_error'] = metrics_tf.mean_absolute_error(
+ tf.to_float(label), self._prediction_dict['y'])
+ else:
+ raise ValueError('invalid metric type: %s' % str(metric))
+ return metric_dict
+
+ def get_outputs(self):
+ if not self.has_backbone:
+ raise NotImplementedError(
+ 'could not call get_outputs on abstract class MatchModel')
+ if self._loss_type == LossType.CLASSIFICATION:
+ return [
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
+ 'item_tower_emb'
+ ]
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ self._prediction_dict['logits'] = tf.squeeze(
+ self._prediction_dict['logits'], axis=-1)
+ self._prediction_dict['probs'] = tf.nn.sigmoid(
+ self._prediction_dict['logits'])
+ return [
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
+ 'item_tower_emb'
+ ]
+ elif self._loss_type == LossType.L2_LOSS:
+ return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb']
+ else:
+ raise ValueError('invalid loss type: %s' % str(self._loss_type))
diff --git a/easy_rec/python/model/mind.py b/easy_rec/python/model/mind.py
index a749e6dfe..0e47f79d8 100644
--- a/easy_rec/python/model/mind.py
+++ b/easy_rec/python/model/mind.py
@@ -7,7 +7,7 @@
from easy_rec.python.compat import regularizers
from easy_rec.python.layers import dnn
from easy_rec.python.layers.capsule_layer import CapsuleLayer
-from easy_rec.python.model.easy_rec_model import EasyRecModel
+from easy_rec.python.model.match_model import MatchModel
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.protos.mind_pb2 import MIND as MINDConfig
from easy_rec.python.protos.simi_pb2 import Similarity
@@ -16,10 +16,9 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
losses = tf.losses
-metrics = tf.metrics
-class MIND(EasyRecModel):
+class MIND(MatchModel):
def __init__(self,
model_config,
@@ -29,13 +28,11 @@ def __init__(self,
is_training=False):
super(MIND, self).__init__(model_config, feature_configs, features, labels,
is_training)
- self._loss_type = self._model_config.loss_type
- self._num_class = self._model_config.num_class
assert self._model_config.WhichOneof('model') == 'mind', \
'invalid model config: %s' % self._model_config.WhichOneof('model')
self._model_config = self._model_config.mind
- self._hist_seq_features = self._input_layer(
+ self._hist_seq_features, _, _ = self._input_layer(
self._feature_dict, 'hist', is_combine=False)
self._user_features, _ = self._input_layer(self._feature_dict, 'user')
self._item_features, _ = self._input_layer(self._feature_dict, 'item')
@@ -44,41 +41,37 @@ def __init__(self,
self.user_dnn = copy_obj(self._model_config.user_dnn)
# copy_obj so that any modification will not affect original config
self.item_dnn = copy_obj(self._model_config.item_dnn)
+ # copy obj so that any modification will not affect original config
+ self.concat_dnn = copy_obj(self._model_config.concat_dnn)
self._l2_reg = regularizers.l2_regularizer(
self._model_config.l2_regularization)
- if self._labels is not None:
- self._labels = list(self._labels.values())
- if self._loss_type == LossType.CLASSIFICATION:
- self._labels[0] = tf.cast(self._labels[0], tf.int64)
- elif self._loss_type == LossType.L2_LOSS:
- self._labels[0] = tf.cast(self._labels[0], tf.float32)
-
- if self._loss_type == LossType.CLASSIFICATION:
- assert self._num_class == 1
-
- def sim(self, user_emb, item_emb):
- user_item_sim = tf.reduce_sum(
- tf.multiply(user_emb, item_emb), axis=1, keep_dims=True)
- return user_item_sim
-
- def norm(self, fea):
- fea_norm = tf.norm(fea, axis=-1, keepdims=True)
- return tf.div(fea, tf.maximum(fea_norm, 1e-12))
-
def build_predict_graph(self):
capsule_layer = CapsuleLayer(self._model_config.capsule_config,
self._is_training)
- time_id_fea = [
- x[0] for x in self._hist_seq_features if 'time_id/' in x[0].name
- ]
+ if self._model_config.time_id_fea:
+ time_id_fea = [
+ x[0]
+ for x in self._hist_seq_features
+ if self._model_config.time_id_fea in x[0].name
+ ]
+ logging.info('time_id_fea is set(%s), find num: %d' %
+ (self._model_config.time_id_fea, len(time_id_fea)))
+ else:
+ time_id_fea = []
time_id_fea = time_id_fea[0] if len(time_id_fea) > 0 else None
- hist_seq_feas = [
- x[0] for x in self._hist_seq_features if 'time_id/' not in x[0].name
- ]
+ if time_id_fea is not None:
+ hist_seq_feas = [
+ x[0]
+ for x in self._hist_seq_features
+ if self._model_config.time_id_fea not in x[0].name
+ ]
+ else:
+ hist_seq_feas = [x[0] for x in self._hist_seq_features]
+
# it is assumed that all hist have the same length
hist_seq_len = self._hist_seq_features[0][1]
@@ -107,32 +100,59 @@ def build_predict_graph(self):
time_id_fea = tf.minimum(time_id_fea, time_id_mask[:, :, None])
hist_seq_feas = hist_seq_feas * tf.nn.softmax(time_id_fea, axis=1)
+ tf.summary.histogram('hist_seq_len', hist_seq_len)
+
# batch_size x max_k x high_capsule_dim
high_capsules, num_high_capsules = capsule_layer(hist_seq_feas,
hist_seq_len)
- # concatenate with user features
- user_features = tf.tile(
- tf.expand_dims(self._user_features, axis=1),
- [1, tf.shape(high_capsules)[1], 1])
- user_features = tf.concat([high_capsules, user_features], axis=2)
- num_user_dnn_layer = len(self.user_dnn.hidden_units)
- last_user_hidden = self.user_dnn.hidden_units.pop()
+
+ tf.summary.histogram('num_high_capsules', num_high_capsules)
+
+ # high_capsules = tf.layers.batch_normalization(
+ # high_capsules, training=self._is_training,
+ # trainable=True, name='capsule_bn')
+ # high_capsules = high_capsules * 0.1
+
+ tf.summary.scalar('high_capsules_norm',
+ tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
+ tf.summary.scalar('num_high_capsules',
+ tf.reduce_mean(tf.to_float(num_high_capsules)))
+
+ user_features = tf.layers.batch_normalization(
+ self._user_features,
+ training=self._is_training,
+ trainable=True,
+ name='user_fea_bn')
user_dnn = dnn.DNN(self.user_dnn, self._l2_reg, 'user_dnn',
self._is_training)
user_features = user_dnn(user_features)
- user_features = tf.layers.dense(
- inputs=user_features,
- units=last_user_hidden,
+
+ tf.summary.scalar('user_features_norm',
+ tf.reduce_mean(tf.norm(self._user_features, axis=-1)))
+
+ # concatenate with user features
+ user_features_tile = tf.tile(user_features[:, None, :],
+ [1, tf.shape(high_capsules)[1], 1])
+ user_interests = tf.concat([high_capsules, user_features_tile], axis=2)
+
+ num_concat_dnn_layer = len(self.concat_dnn.hidden_units)
+ last_hidden = self.concat_dnn.hidden_units.pop()
+ concat_dnn = dnn.DNN(self.concat_dnn, self._l2_reg, 'concat_dnn',
+ self._is_training)
+ user_interests = concat_dnn(user_interests)
+ user_interests = tf.layers.dense(
+ inputs=user_interests,
+ units=last_hidden,
kernel_regularizer=self._l2_reg,
- name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
+ name='concat_dnn/dnn_%d' % (num_concat_dnn_layer - 1))
num_item_dnn_layer = len(self.item_dnn.hidden_units)
last_item_hidden = self.item_dnn.hidden_units.pop()
item_dnn = dnn.DNN(self.item_dnn, self._l2_reg, 'item_dnn',
self._is_training)
- item_feature = item_dnn(self._item_features)
- item_feature = tf.layers.dense(
- inputs=item_feature,
+ item_tower_emb = item_dnn(self._item_features)
+ item_tower_emb = tf.layers.dense(
+ inputs=item_tower_emb,
units=last_item_hidden,
kernel_regularizer=self._l2_reg,
name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
@@ -142,102 +162,284 @@ def build_predict_graph(self):
]
if self._model_config.simi_func == Similarity.COSINE:
- item_feature = self.norm(item_feature)
- user_features = self.norm(user_features)
+ item_tower_emb = self.norm(item_tower_emb)
+ user_interests = self.norm(user_interests)
# label guided attention
# attention item features on high capsules vector
- simi = tf.einsum('bhe,be->bh', user_features, item_feature)
- simi = tf.pow(simi, self._model_config.simi_pow)
+ batch_size = tf.shape(user_interests)[0]
+ pos_item_fea = item_tower_emb[:batch_size]
+ simi = tf.einsum('bhe,be->bh', user_interests, pos_item_fea)
+ tf.summary.histogram('interest_item_simi/pre_scale',
+ tf.reduce_max(simi, axis=1))
+ # simi = tf.Print(simi, [tf.reduce_max(simi, axis=1), tf.reduce_min(simi, axis=1)], message='simi_max_0')
+ # simi = tf.pow(simi, self._model_config.simi_pow)
+ simi = simi * self._model_config.simi_pow
+ tf.summary.histogram('interest_item_simi/scaled',
+ tf.reduce_max(simi, axis=1))
+ # simi = tf.Print(simi, [tf.reduce_max(simi, axis=1), tf.reduce_min(simi, axis=1)], message='simi_max')
simi_mask = tf.sequence_mask(num_high_capsules,
self._model_config.capsule_config.max_k)
- user_features = user_features * tf.to_float(simi_mask[:, :, None])
- self._prediction_dict['user_features'] = user_features
+ user_interests = user_interests * tf.to_float(simi_mask[:, :, None])
+ self._prediction_dict['user_interests'] = user_interests
max_thresh = (tf.cast(simi_mask, tf.float32) * 2 - 1) * 1e32
simi = tf.minimum(simi, max_thresh)
simi = tf.nn.softmax(simi, axis=1)
- simi = tf.stop_gradient(simi)
- user_tower_emb = tf.einsum('bhe,bh->be', user_features, simi)
+ tf.summary.histogram('interest_item_simi/softmax',
+ tf.reduce_max(simi, axis=1))
+
+ if self._model_config.simi_pow >= 100:
+ logging.info(
+ 'simi_pow=%d, will change to argmax, only use the most similar interests for calculate loss.'
+ % self._model_config.simi_pow)
+ simi_max_id = tf.argmax(simi, axis=1)
+ simi = tf.one_hot(simi_max_id, tf.shape(simi)[1], dtype=tf.float32)
+
+ user_tower_emb = tf.einsum('bhe,bh->be', user_interests, simi)
# calculate similarity between user_tower_emb and item_tower_emb
- item_tower_emb = item_feature
user_item_sim = self.sim(user_tower_emb, item_tower_emb)
- sim_w = tf.get_variable(
- 'sim_w',
- dtype=tf.float32,
- shape=(1, 1),
- initializer=tf.ones_initializer())
- sim_b = tf.get_variable(
- 'sim_b',
- dtype=tf.float32,
- shape=(1),
- initializer=tf.zeros_initializer())
- y_pred = tf.matmul(user_item_sim, tf.abs(sim_w)) + sim_b
- y_pred = tf.reshape(y_pred, [-1])
+ if self._model_config.scale_simi:
+ sim_w = tf.get_variable(
+ 'sim_w',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.ones_initializer())
+ sim_b = tf.get_variable(
+ 'sim_b',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.zeros_initializer())
+ y_pred = user_item_sim * tf.abs(sim_w) + sim_b
+ else:
+ y_pred = user_item_sim
+
+ if self._is_point_wise:
+ y_pred = tf.reshape(y_pred, [-1])
if self._loss_type == LossType.CLASSIFICATION:
- self._prediction_dict['logits'] = tf.nn.sigmoid(y_pred)
+ self._prediction_dict['logits'] = y_pred
+ self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ y_pred = self._mask_in_batch(y_pred)
+ self._prediction_dict['logits'] = y_pred
+ self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
else:
self._prediction_dict['y'] = y_pred
+ self._prediction_dict['high_capsules'] = high_capsules
+ self._prediction_dict['user_interests'] = user_interests
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
self._prediction_dict['user_emb'] = tf.reduce_join(
- tf.reduce_join(tf.as_string(user_features), axis=-1, separator=','),
+ tf.reduce_join(tf.as_string(user_interests), axis=-1, separator=','),
axis=-1,
separator='|')
self._prediction_dict['user_emb_num'] = num_high_capsules
self._prediction_dict['item_emb'] = tf.reduce_join(
tf.as_string(item_tower_emb), axis=-1, separator=',')
+
+ if self._labels is not None:
+ # for summary purpose
+ batch_simi, batch_capsule_simi = self._build_interest_simi()
+ # self._prediction_dict['probs'] = tf.Print(self._prediction_dict['probs'],
+ # [batch_simi, batch_capsule_simi], message='batch_simi')
+ self._prediction_dict['interests_simi'] = batch_simi
return self._prediction_dict
def build_loss_graph(self):
- if self._loss_type == LossType.CLASSIFICATION:
- logging.info('log loss is used')
- loss = losses.log_loss(self._labels[0], self._prediction_dict['logits'])
- self._loss_dict['cross_entropy_loss'] = loss
- elif self._loss_type == LossType.L2_LOSS:
- logging.info('l2 loss is used')
- loss = tf.reduce_mean(
- tf.square(self._labels[0] - self._prediction_dict['y']))
- self._loss_dict['l2_loss'] = loss
- else:
- raise ValueError('invalid loss type: %s' % str(self._loss_type))
- return self._loss_dict
-
- def _build_interest_metric(self):
- user_features = self._prediction_dict['user_features']
- user_features = self.norm(user_features)
- user_feature_num = self._prediction_dict['user_emb_num']
-
- user_feature_sum_sqr = tf.square(tf.reduce_sum(user_features, axis=1))
- user_feature_sqr_sum = tf.reduce_sum(tf.square(user_features), axis=1)
- simi = user_feature_sum_sqr - user_feature_sqr_sum
-
- simi = tf.reduce_sum(
- simi, axis=1) / tf.maximum(
- tf.to_float(user_feature_num * (user_feature_num - 1)), 1.0)
- user_feature_num = tf.reduce_sum(tf.to_float(user_feature_num > 1))
- return metrics.mean(tf.reduce_sum(simi) / tf.maximum(user_feature_num, 1.0))
+ loss_dict = super(MIND, self).build_loss_graph()
+ if self._model_config.max_interests_simi < 1.0:
+ loss_dict['reg_interest_simi'] = tf.nn.relu(
+ self._prediction_dict['interests_simi'] -
+ self._model_config.max_interests_simi)
+ return loss_dict
+
+ def _build_interest_simi(self):
+ user_emb_num = self._prediction_dict['user_emb_num']
+ high_capsule_mask = tf.sequence_mask(
+ user_emb_num, self._model_config.capsule_config.max_k)
+
+ user_interests = self._prediction_dict['user_interests']
+ high_capsule_mask = tf.to_float(high_capsule_mask[:, :, None])
+ user_interests = self.norm(user_interests) * high_capsule_mask
+
+ user_feature_sum_sqr = tf.square(tf.reduce_sum(user_interests, axis=1))
+ user_feature_sqr_sum = tf.reduce_sum(tf.square(user_interests), axis=1)
+ interest_simi = user_feature_sum_sqr - user_feature_sqr_sum
+
+ high_capsules = self._prediction_dict['high_capsules']
+ high_capsules = self.norm(high_capsules) * high_capsule_mask
+ high_capsule_sum_sqr = tf.square(tf.reduce_sum(high_capsules, axis=1))
+ high_capsule_sqr_sum = tf.reduce_sum(tf.square(high_capsules), axis=1)
+ high_capsule_simi = high_capsule_sum_sqr - high_capsule_sqr_sum
+
+ # normalize by interest number
+ interest_div = tf.maximum(
+ tf.to_float(user_emb_num * (user_emb_num - 1)), 1.0)
+ interest_simi = tf.reduce_sum(interest_simi, axis=1) / interest_div
+
+ high_capsule_simi = tf.reduce_sum(high_capsule_simi, axis=1) / interest_div
+
+ # normalize by batch_size
+ multi_interest = tf.to_float(user_emb_num > 1)
+ sum_interest_simi = tf.reduce_sum(
+ (interest_simi + 1) * multi_interest) / 2.0
+ sum_div = tf.maximum(tf.reduce_sum(multi_interest), 1.0)
+ avg_interest_simi = sum_interest_simi / sum_div
+
+ sum_capsule_simi = tf.reduce_sum(
+ (high_capsule_simi + 1) * multi_interest) / 2.0
+ avg_capsule_simi = sum_capsule_simi / sum_div
+
+ tf.summary.scalar('interest_similarity', avg_interest_simi)
+ tf.summary.scalar('capsule_similarity', avg_capsule_simi)
+ return avg_interest_simi, avg_capsule_simi
def build_metric_graph(self, eval_config):
- metric_dict = {}
+ from easy_rec.python.core.easyrec_metrics import metrics_tf as metrics
+ # build interest metric
+ interest_simi, capsule_simi = self._build_interest_simi()
+ metric_dict = {
+ 'interest_similarity': metrics.mean(interest_simi),
+ 'capsule_similarity': metrics.mean(capsule_simi)
+ }
+ if self._is_point_wise:
+ metric_dict.update(self._build_point_wise_metric_graph(eval_config))
+ return metric_dict
+
+ recall_at_topks = []
for metric in eval_config.metrics_set:
- if metric.WhichOneof('metric') == 'auc':
- assert self._loss_type == LossType.CLASSIFICATION
- metric_dict['auc'] = metrics.auc(self._labels[0],
- self._prediction_dict['logits'])
- elif metric.WhichOneof('metric') == 'mean_absolute_error':
- assert self._loss_type == LossType.L2_LOSS
- metric_dict['mean_absolute_error'] = metrics.mean_absolute_error(
- self._labels[0], self._prediction_dict['y'])
- metric_dict['interest_similarity'] = self._build_interest_metric()
+ if metric.WhichOneof('metric') == 'recall_at_topk':
+ assert self._loss_type in [
+ LossType.CLASSIFICATION, LossType.SOFTMAX_CROSS_ENTROPY
+ ]
+ if metric.recall_at_topk.topk not in recall_at_topks:
+ recall_at_topks.append(metric.recall_at_topk.topk)
+
+ # compute interest recall
+ # [batch_size, num_interests, embed_dim]
+ user_interests = self._prediction_dict['user_interests']
+ # [?, embed_dim]
+ item_tower_emb = self._prediction_dict['item_tower_emb']
+ batch_size = tf.shape(user_interests)[0]
+ # [?, 2] first dimension is the sample_id in batch
+ # second dimension is the neg_id with respect to the sample
+ hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)
+
+ if hard_neg_indices is not None:
+ logging.info('With hard negative examples')
+ noclk_size = tf.shape(hard_neg_indices)[0]
+ simple_item_emb, hard_neg_item_emb = tf.split(
+ item_tower_emb, [-1, noclk_size], axis=0)
+ else:
+ simple_item_emb = item_tower_emb
+ hard_neg_item_emb = None
+
+ # batch_size num_interest sample_neg_num
+ simple_item_sim = tf.einsum('bhe,ne->bhn', user_interests, simple_item_emb)
+ # batch_size sample_neg_num
+ simple_item_sim = tf.reduce_max(simple_item_sim, axis=1)
+ simple_lbls = tf.cast(tf.range(tf.shape(user_interests)[0]), tf.int64)
+
+ # labels = tf.zeros_like(logits[:, :1], dtype=tf.int64)
+ pos_indices = tf.range(batch_size)
+ pos_indices = tf.concat([pos_indices[:, None], pos_indices[:, None]],
+ axis=1)
+ pos_item_sim = tf.gather_nd(simple_item_sim[:batch_size, :batch_size],
+ pos_indices)
+
+ simple_item_sim_v2 = tf.concat(
+ [pos_item_sim[:, None], simple_item_sim[:, batch_size:]], axis=1)
+ simple_lbls_v2 = tf.zeros_like(simple_item_sim_v2[:, :1], dtype=tf.int64)
+
+ for topk in recall_at_topks:
+ metric_dict['interests_recall@%d' % topk] = metrics.recall_at_k(
+ labels=simple_lbls,
+ predictions=simple_item_sim,
+ k=topk,
+ name='interests_recall_at_%d' % topk)
+ metric_dict['interests_neg_sam_recall@%d' % topk] = metrics.recall_at_k(
+ labels=simple_lbls_v2,
+ predictions=simple_item_sim_v2,
+ k=topk,
+ name='interests_recall_neg_sam_at_%d' % topk)
+
+ logits = self._prediction_dict['logits']
+ pos_item_logits = tf.gather_nd(logits[:batch_size, :batch_size],
+ pos_indices)
+ logits_v2 = tf.concat([pos_item_logits[:, None], logits[:, batch_size:]],
+ axis=1)
+ labels_v2 = tf.zeros_like(logits_v2[:, :1], dtype=tf.int64)
+
+ for topk in recall_at_topks:
+ metric_dict['recall@%d' % topk] = metrics.recall_at_k(
+ labels=simple_lbls,
+ predictions=logits,
+ k=topk,
+ name='recall_at_%d' % topk)
+ metric_dict['recall_neg_sam@%d' % topk] = metrics.recall_at_k(
+ labels=labels_v2,
+ predictions=logits_v2,
+ k=topk,
+ name='recall_neg_sam_at_%d' % topk)
+ eval_logits = logits[:, :batch_size]
+ eval_logits = tf.cond(
+ batch_size < topk, lambda: tf.pad(
+ eval_logits, [[0, 0], [0, topk - batch_size]],
+ mode='CONSTANT',
+ constant_values=-1e32,
+ name='pad_eval_logits'), lambda: eval_logits)
+ metric_dict['recall_in_batch@%d' % topk] = metrics.recall_at_k(
+ labels=simple_lbls,
+ predictions=eval_logits,
+ k=topk,
+ name='recall_in_batch_at_%d' % topk)
+
+ # batch_size num_interest
+ if hard_neg_indices is not None:
+ hard_neg_user_emb = tf.gather(user_interests, hard_neg_indices[:, 0])
+ hard_neg_sim = tf.einsum('nhe,ne->nh', hard_neg_user_emb,
+ hard_neg_item_emb)
+ hard_neg_sim = tf.reduce_max(hard_neg_sim, axis=1)
+ max_num_neg = tf.reduce_max(hard_neg_indices[:, 1]) + 1
+ hard_neg_shape = tf.stack([tf.to_int64(batch_size), max_num_neg])
+ hard_neg_mask = tf.scatter_nd(
+ hard_neg_indices,
+ tf.ones_like(hard_neg_sim, dtype=tf.float32),
+ shape=hard_neg_shape)
+ hard_neg_sim = tf.scatter_nd(hard_neg_indices, hard_neg_sim,
+ hard_neg_shape)
+ hard_neg_sim = hard_neg_sim - (1 - hard_neg_mask) * 1e32
+
+ hard_logits = tf.concat([pos_item_logits[:, None], hard_neg_sim], axis=1)
+ hard_lbls = tf.zeros_like(hard_logits[:, :1], dtype=tf.int64)
+ metric_dict['hard_neg_acc'] = metrics.accuracy(
+ hard_lbls, tf.argmax(hard_logits, axis=1))
+
return metric_dict
def get_outputs(self):
if self._loss_type == LossType.CLASSIFICATION:
- return ['logits', 'user_emb', 'item_emb']
+ return [
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num',
+ 'user_interests', 'item_tower_emb'
+ ]
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
+ self._prediction_dict['logits'] = tf.squeeze(
+ self._prediction_dict['logits'], axis=-1)
+ self._prediction_dict['probs'] = tf.nn.sigmoid(
+ self._prediction_dict['logits'])
+ return [
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num',
+ 'user_interests', 'item_tower_emb'
+ ]
elif self._loss_type == LossType.L2_LOSS:
- return ['y', 'user_emb', 'item_emb']
+ return [
+ 'y', 'user_emb', 'item_emb', 'user_emb_num', 'user_interests',
+ 'item_tower_emb'
+ ]
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
diff --git a/easy_rec/python/model/mmoe.py b/easy_rec/python/model/mmoe.py
index acf1d6d59..3cc644f6d 100644
--- a/easy_rec/python/model/mmoe.py
+++ b/easy_rec/python/model/mmoe.py
@@ -26,7 +26,10 @@ def __init__(self,
self._model_config = self._model_config.mmoe
assert isinstance(self._model_config, MMoEConfig)
- self._features, _ = self._input_layer(self._feature_dict, 'all')
+ if self.has_backbone:
+ self._features = self.backbone
+ else:
+ self._features, _ = self._input_layer(self._feature_dict, 'all')
self._init_towers(self._model_config.task_towers)
def build_predict_graph(self):
diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py
index 0cbf3340c..aa102104c 100644
--- a/easy_rec/python/model/multi_task_model.py
+++ b/easy_rec/python/model/multi_task_model.py
@@ -1,12 +1,16 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
+from collections import OrderedDict
import tensorflow as tf
from easy_rec.python.builders import loss_builder
+from easy_rec.python.layers.dnn import DNN
from easy_rec.python.model.rank_model import RankModel
from easy_rec.python.protos import tower_pb2
+from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel
+from easy_rec.python.protos.loss_pb2 import LossType
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -26,6 +30,75 @@ def __init__(self,
self._task_num = None
self._label_name_dict = {}
+ def build_predict_graph(self):
+ if not self.has_backbone:
+ raise NotImplementedError(
+ 'method `build_predict_graph` must be implemented when backbone network do not exists'
+ )
+ model = self._model_config.WhichOneof('model')
+ assert model == 'model_params', '`model_params` must be configured'
+ config = self._model_config.model_params
+ for out in config.outputs:
+ self._outputs.append(out)
+
+ self._init_towers(config.task_towers)
+
+ backbone = self.backbone
+ if type(backbone) in (list, tuple):
+ if len(backbone) != len(config.task_towers):
+ raise ValueError(
+ 'The number of backbone outputs and task towers must be equal')
+ task_input_list = backbone
+ else:
+ task_input_list = [backbone] * len(config.task_towers)
+
+ tower_features = {}
+ for i, task_tower_cfg in enumerate(config.task_towers):
+ tower_name = task_tower_cfg.tower_name
+ with tf.name_scope(tower_name):
+ if task_tower_cfg.HasField('dnn'):
+ tower_dnn = DNN(
+ task_tower_cfg.dnn,
+ self._l2_reg,
+ name=tower_name,
+ is_training=self._is_training)
+ tower_output = tower_dnn(task_input_list[i])
+ else:
+ tower_output = task_input_list[i]
+ tower_features[tower_name] = tower_output
+
+ tower_outputs = {}
+ relation_features = {}
+ # bayes network
+ for task_tower_cfg in config.task_towers:
+ tower_name = task_tower_cfg.tower_name
+ with tf.name_scope(tower_name):
+ if task_tower_cfg.HasField('relation_dnn'):
+ relation_dnn = DNN(
+ task_tower_cfg.relation_dnn,
+ self._l2_reg,
+ name=tower_name + '/relation_dnn',
+ is_training=self._is_training)
+ tower_inputs = [tower_features[tower_name]]
+ for relation_tower_name in task_tower_cfg.relation_tower_names:
+ tower_inputs.append(relation_features[relation_tower_name])
+ relation_input = tf.concat(
+ tower_inputs, axis=-1, name=tower_name + '/relation_input')
+ relation_fea = relation_dnn(relation_input)
+ relation_features[tower_name] = relation_fea
+ else:
+ relation_fea = tower_features[tower_name]
+
+ output_logits = tf.layers.dense(
+ relation_fea,
+ task_tower_cfg.num_class,
+ kernel_regularizer=self._l2_reg,
+ name=tower_name + '/output')
+ tower_outputs[tower_name] = output_logits
+
+ self._add_to_prediction_dict(tower_outputs)
+ return self._prediction_dict
+
def _init_towers(self, task_tower_configs):
"""Init task towers."""
self._task_towers = task_tower_configs
@@ -51,33 +124,88 @@ def _init_towers(self, task_tower_configs):
def _add_to_prediction_dict(self, output):
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
- self._prediction_dict.update(
- self._output_to_prediction_impl(
- output[tower_name],
- loss_type=task_tower_cfg.loss_type,
- num_class=task_tower_cfg.num_class,
- suffix='_%s' % tower_name))
+ if len(task_tower_cfg.losses) == 0:
+ self._prediction_dict.update(
+ self._output_to_prediction_impl(
+ output[tower_name],
+ loss_type=task_tower_cfg.loss_type,
+ num_class=task_tower_cfg.num_class,
+ suffix='_%s' % tower_name))
+ else:
+ for loss in task_tower_cfg.losses:
+ self._prediction_dict.update(
+ self._output_to_prediction_impl(
+ output[tower_name],
+ loss_type=loss.loss_type,
+ num_class=task_tower_cfg.num_class,
+ suffix='_%s' % tower_name))
def build_metric_graph(self, eval_config):
"""Build metric graph for multi task model."""
- metric_dict = {}
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
for metric in task_tower_cfg.metrics_set:
- metric_dict.update(
+ loss_types = {task_tower_cfg.loss_type}
+ if len(task_tower_cfg.losses) > 0:
+ loss_types = {loss.loss_type for loss in task_tower_cfg.losses}
+ self._metric_dict.update(
self._build_metric_impl(
metric,
- loss_type=task_tower_cfg.loss_type,
+ loss_type=loss_types,
label_name=self._label_name_dict[tower_name],
num_class=task_tower_cfg.num_class,
suffix='_%s' % tower_name))
- return metric_dict
+ return self._metric_dict
+
+ def build_loss_weight(self):
+ loss_weights = OrderedDict()
+ num_loss = 0
+ for task_tower_cfg in self._task_towers:
+ tower_name = task_tower_cfg.tower_name
+ losses = task_tower_cfg.losses
+ n = len(losses)
+ if n > 0:
+ loss_weights[tower_name] = [
+ loss.weight * task_tower_cfg.weight for loss in losses
+ ]
+ num_loss += n
+ else:
+ loss_weights[tower_name] = [task_tower_cfg.weight]
+ num_loss += 1
+
+ strategy = self._base_model_config.loss_weight_strategy
+ if strategy == self._base_model_config.Random:
+ weights = tf.random_normal([num_loss])
+ weights = tf.nn.softmax(weights)
+ i = 0
+ for k, v in loss_weights.items():
+ n = len(v)
+ loss_weights[k] = weights[i:i + n]
+ i += n
+ return loss_weights
+
+ def get_learnt_loss(self, loss_type, name, value):
+ strategy = self._base_model_config.loss_weight_strategy
+ if strategy == self._base_model_config.Uncertainty:
+ uncertainty = tf.Variable(
+ 0, name='%s_loss_weight' % name, dtype=tf.float32)
+ tf.summary.scalar('loss/%s_uncertainty' % name, uncertainty)
+ if loss_type in {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
+ return 0.5 * tf.exp(-uncertainty) * value + 0.5 * uncertainty
+ else:
+ return tf.exp(-uncertainty) * value + 0.5 * uncertainty
+ else:
+ strategy_name = EasyRecModel.LossWeightStrategy.Name(strategy)
+ raise ValueError('Unsupported loss weight strategy: ' + strategy_name)
def build_loss_graph(self):
"""Build loss graph for multi task model."""
+ task_loss_weights = self.build_loss_weight()
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
- loss_weight = task_tower_cfg.weight * self._sample_weight
+ loss_weight = 1.0
+ if task_tower_cfg.use_sample_weight:
+ loss_weight *= self._sample_weight
if hasattr(task_tower_cfg, 'task_space_indicator_label') and \
task_tower_cfg.HasField('task_space_indicator_label'):
@@ -87,27 +215,89 @@ def build_loss_graph(self):
task_tower_cfg.in_task_space_weight * in_task_space +
task_tower_cfg.out_task_space_weight * (1 - in_task_space))
- self._loss_dict.update(
- self._build_loss_impl(
- task_tower_cfg.loss_type,
+ if task_tower_cfg.HasField('task_space_indicator_name') and \
+ task_tower_cfg.HasField('task_space_indicator_value'):
+ in_task_space = tf.to_float(
+ tf.equal(
+ self._feature_dict[task_tower_cfg.task_space_indicator_name],
+ task_tower_cfg.task_space_indicator_value))
+ loss_weight = loss_weight * (
+ task_tower_cfg.in_task_space_weight * in_task_space +
+ task_tower_cfg.out_task_space_weight * (1 - in_task_space))
+
+ task_loss_weight = task_loss_weights[tower_name]
+ loss_dict = {}
+ losses = task_tower_cfg.losses
+ if len(losses) == 0:
+ loss_dict = self._build_loss_impl(
+ task_tower_cfg.loss_type,
+ label_name=self._label_name_dict[tower_name],
+ loss_weight=loss_weight,
+ num_class=task_tower_cfg.num_class,
+ suffix='_%s' % tower_name)
+ for loss_name in loss_dict.keys():
+ loss_dict[loss_name] = loss_dict[loss_name] * task_loss_weight[0]
+ else:
+ calibrate_loss = []
+ for loss in losses:
+ if loss.loss_type == LossType.ORDER_CALIBRATE_LOSS:
+ y_t = self._prediction_dict['probs_%s' % tower_name]
+ for relation_tower_name in task_tower_cfg.relation_tower_names:
+ y_rt = self._prediction_dict['probs_%s' % relation_tower_name]
+ cali_loss = tf.reduce_mean(tf.nn.relu(y_t - y_rt))
+ calibrate_loss.append(cali_loss * loss.weight)
+ logging.info('calibrate loss: %s -> %s' %
+ (relation_tower_name, tower_name))
+ continue
+ loss_param = loss.WhichOneof('loss_param')
+ if loss_param is not None:
+ loss_param = getattr(loss, loss_param)
+ loss_ops = self._build_loss_impl(
+ loss.loss_type,
label_name=self._label_name_dict[tower_name],
loss_weight=loss_weight,
num_class=task_tower_cfg.num_class,
- suffix='_%s' % tower_name))
+ suffix='_%s' % tower_name,
+ loss_name=loss.loss_name,
+ loss_param=loss_param)
+ for i, loss_name in enumerate(loss_ops):
+ loss_value = loss_ops[loss_name]
+ if loss.learn_loss_weight:
+ loss_dict[loss_name] = self.get_learnt_loss(
+ loss.loss_type, loss_name, loss_value)
+ else:
+ loss_dict[loss_name] = loss_value * task_loss_weight[i]
+ if calibrate_loss:
+ cali_loss = tf.add_n(calibrate_loss)
+ loss_dict['order_calibrate_loss'] = cali_loss
+ tf.summary.scalar('loss/order_calibrate_loss', cali_loss)
+ self._loss_dict.update(loss_dict)
kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
- self._labels)
+ self._labels, self._feature_dict)
self._loss_dict.update(kd_loss_dict)
return self._loss_dict
def get_outputs(self):
outputs = []
+ if self._outputs:
+ outputs.extend(self._outputs)
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
- outputs.extend(
- self._get_outputs_impl(
- task_tower_cfg.loss_type,
- task_tower_cfg.num_class,
- suffix='_%s' % tower_name))
- return outputs
+ if len(task_tower_cfg.losses) == 0:
+ outputs.extend(
+ self._get_outputs_impl(
+ task_tower_cfg.loss_type,
+ task_tower_cfg.num_class,
+ suffix='_%s' % tower_name))
+ else:
+ for loss in task_tower_cfg.losses:
+ if loss.loss_type == LossType.ORDER_CALIBRATE_LOSS:
+ continue
+ outputs.extend(
+ self._get_outputs_impl(
+ loss.loss_type,
+ task_tower_cfg.num_class,
+ suffix='_%s' % tower_name))
+ return list(set(outputs))
diff --git a/easy_rec/python/model/multi_tower.py b/easy_rec/python/model/multi_tower.py
index 9904c9eee..5cdd89ba5 100644
--- a/easy_rec/python/model/multi_tower.py
+++ b/easy_rec/python/model/multi_tower.py
@@ -22,8 +22,8 @@ def __init__(self,
is_training=False):
super(MultiTower, self).__init__(model_config, feature_configs, features,
labels, is_training)
- assert self._model_config.WhichOneof('model') == 'multi_tower', \
- 'invalid model config: %s' % self._model_config.WhichOneof('model')
+ assert self._model_config.WhichOneof('model') == 'multi_tower', (
+ 'invalid model config: %s' % self._model_config.WhichOneof('model'))
self._model_config = self._model_config.multi_tower
assert isinstance(self._model_config, MultiTowerConfig)
diff --git a/easy_rec/python/model/multi_tower_bst.py b/easy_rec/python/model/multi_tower_bst.py
index 6d93ebeda..4cbc9fd29 100644
--- a/easy_rec/python/model/multi_tower_bst.py
+++ b/easy_rec/python/model/multi_tower_bst.py
@@ -28,7 +28,10 @@ def __init__(self,
super(MultiTowerBST, self).__init__(model_config, feature_configs, features,
labels, is_training)
self._seq_input_layer = seq_input_layer.SeqInputLayer(
- feature_configs, model_config.seq_att_groups)
+ feature_configs,
+ model_config.seq_att_groups,
+ embedding_regularizer=self._emb_reg,
+ ev_params=self._global_ev_params)
assert self._model_config.WhichOneof('model') == 'multi_tower', \
'invalid model config: %s' % self._model_config.WhichOneof('model')
self._model_config = self._model_config.multi_tower
@@ -58,10 +61,18 @@ def __init__(self,
self._bst_tower_features.append(tower_feature)
def dnn_net(self, net, dnn_units, name):
+ dnn_units_len = len(dnn_units)
with tf.variable_scope(name_or_scope=name, reuse=tf.AUTO_REUSE):
for idx, units in enumerate(dnn_units):
- net = tf.layers.dense(
- net, units=units, activation=tf.nn.relu, name='%s_%d' % (name, idx))
+ if idx + 1 < dnn_units_len:
+ net = tf.layers.dense(
+ net,
+ units=units,
+ activation=tf.nn.relu,
+ name='%s_%d' % (name, idx))
+ else:
+ net = tf.layers.dense(
+ net, units=units, activation=None, name='%s_%d' % (name, idx))
return net
def attention_net(self, net, dim, cur_seq_len, seq_size, name):
@@ -73,7 +84,8 @@ def attention_net(self, net, dim, cur_seq_len, seq_size, name):
hist_mask = tf.sequence_mask(
cur_seq_len, maxlen=seq_size - 1) # [B, seq_size-1]
- cur_id_mask = tf.ones([tf.shape(hist_mask)[0], 1], dtype=tf.bool) # [B, 1]
+ cur_id_mask = tf.ones(
+ tf.stack([tf.shape(hist_mask)[0], 1]), dtype=tf.bool) # [B, 1]
mask = tf.concat([hist_mask, cur_id_mask], axis=1) # [B, seq_size]
masks = tf.reshape(tf.tile(mask, [1, seq_size]),
(-1, seq_size, seq_size)) # [B, seq_size, seq_size]
diff --git a/easy_rec/python/model/multi_tower_din.py b/easy_rec/python/model/multi_tower_din.py
index afd473f3d..e586da1cf 100644
--- a/easy_rec/python/model/multi_tower_din.py
+++ b/easy_rec/python/model/multi_tower_din.py
@@ -26,7 +26,10 @@ def __init__(self,
super(MultiTowerDIN, self).__init__(model_config, feature_configs, features,
labels, is_training)
self._seq_input_layer = seq_input_layer.SeqInputLayer(
- feature_configs, model_config.seq_att_groups)
+ feature_configs,
+ model_config.seq_att_groups,
+ embedding_regularizer=self._emb_reg,
+ ev_params=self._global_ev_params)
assert self._model_config.WhichOneof('model') == 'multi_tower', \
'invalid model config: %s' % self._model_config.WhichOneof('model')
self._model_config = self._model_config.multi_tower
@@ -49,8 +52,9 @@ def __init__(self,
for tower_id in range(self._din_tower_num):
tower = self._model_config.din_towers[tower_id]
tower_feature = self._seq_input_layer(self._feature_dict, tower.input)
- regularizers.apply_regularization(
- self._emb_reg, weights_list=[tower_feature['key']])
+
+ # apply regularization for sequence feature key in seq_input_layer.
+
regularizers.apply_regularization(
self._emb_reg, weights_list=[tower_feature['hist_seq_emb']])
self._din_tower_features.append(tower_feature)
@@ -70,7 +74,13 @@ def din(self, dnn_config, deep_fea, name):
[cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
axis=-1) # (B, seq_max_len, emb_dim*4)
- din_layer = dnn.DNN(dnn_config, self._l2_reg, name, self._is_training)
+ din_layer = dnn.DNN(
+ dnn_config,
+ self._l2_reg,
+ name,
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
din_net = din_layer(din_net)
scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)
diff --git a/easy_rec/python/model/multi_tower_recall.py b/easy_rec/python/model/multi_tower_recall.py
new file mode 100644
index 000000000..8f576944e
--- /dev/null
+++ b/easy_rec/python/model/multi_tower_recall.py
@@ -0,0 +1,68 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import tensorflow as tf
+
+from easy_rec.python.layers import dnn
+from easy_rec.python.model.rank_model import RankModel
+
+from easy_rec.python.protos.multi_tower_recall_pb2 import MultiTowerRecall as MultiTowerRecallConfig # NOQA
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class MultiTowerRecall(RankModel):
+
+ def __init__(self,
+ model_config,
+ feature_configs,
+ features,
+ labels=None,
+ is_training=False):
+ super(MultiTowerRecall, self).__init__(model_config, feature_configs,
+ features, labels, is_training)
+ assert self._model_config.WhichOneof('model') == 'multi_tower_recall', (
+ 'invalid model config: %s' % self._model_config.WhichOneof('model'))
+ self._model_config = self._model_config.multi_tower_recall
+ assert isinstance(self._model_config, MultiTowerRecallConfig)
+
+ self.user_tower_feature, _ = self._input_layer(self._feature_dict, 'user')
+ self.item_tower_feature, _ = self._input_layer(self._feature_dict, 'item')
+
+ def build_predict_graph(self):
+
+ user_tower_feature = self.user_tower_feature
+ batch_size = tf.shape(user_tower_feature)[0]
+ pos_item_feature = self.item_tower_feature[:batch_size]
+ neg_item_feature = self.item_tower_feature[batch_size:]
+ item_tower_feature = tf.concat([
+ pos_item_feature[:, tf.newaxis, :],
+ tf.tile(
+ neg_item_feature[tf.newaxis, :, :], multiples=[batch_size, 1, 1])
+ ],
+ axis=1) # noqa: E126
+
+ user_dnn = dnn.DNN(self._model_config.user_tower.dnn, self._l2_reg,
+ 'user_dnn', self._is_training)
+ user_tower_emb = user_dnn(user_tower_feature)
+
+ item_dnn = dnn.DNN(self._model_config.item_tower.dnn, self._l2_reg,
+ 'item_dnn', self._is_training)
+ item_tower_emb = item_dnn(item_tower_feature)
+ item_tower_emb = tf.reshape(item_tower_emb, tf.shape(user_tower_emb))
+
+ tower_fea_arr = []
+ tower_fea_arr.append(user_tower_emb)
+ tower_fea_arr.append(item_tower_emb)
+
+ all_fea = tf.concat(tower_fea_arr, axis=-1)
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
+ 'final_dnn', self._is_training)
+ all_fea = final_dnn_layer(all_fea)
+ output = tf.layers.dense(all_fea, 1, name='output')
+ output = output[:, 0]
+
+ self._add_to_prediction_dict(output)
+
+ return self._prediction_dict
diff --git a/easy_rec/python/model/pdn.py b/easy_rec/python/model/pdn.py
new file mode 100644
index 000000000..7325beb1c
--- /dev/null
+++ b/easy_rec/python/model/pdn.py
@@ -0,0 +1,203 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import tensorflow as tf
+
+from easy_rec.python.layers import dnn
+from easy_rec.python.model.match_model import MatchModel
+from easy_rec.python.protos.simi_pb2 import Similarity
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+losses = tf.losses
+metrics = tf.metrics
+
+
+class PDN(MatchModel):
+
+ def __init__(self,
+ model_config,
+ feature_configs,
+ features,
+ labels=None,
+ is_training=False):
+ super(PDN, self).__init__(model_config, feature_configs, features, labels,
+ is_training)
+ assert self._model_config.WhichOneof('model') == 'pdn', \
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
+ self._model_config = self._model_config.pdn
+
+ self._user_features, _ = self._input_layer(self._feature_dict, 'user')
+ self._item_features, _ = self._input_layer(self._feature_dict, 'item')
+
+ if self._input_layer.has_group('bias'):
+ self._bias_features, _ = self._input_layer(self._feature_dict, 'bias')
+ else:
+ self._bias_features = None
+
+ self._u2i_seq, self._seq_len = self._get_seq_features('u2i_seq')
+ self._i_seq, _ = self._get_seq_features('i_seq')
+ self._i2i_seq, _ = self._get_seq_features('i2i_seq')
+
+ def build_predict_graph(self):
+ trigger_out = self._build_trigger_net()
+ sim_out = self._build_similarity_net()
+ logits = tf.multiply(sim_out, trigger_out)
+
+ seq_mask = tf.to_float(
+ tf.sequence_mask(self._seq_len,
+ tf.shape(sim_out)[1]))
+ logits = tf.reduce_sum(logits * seq_mask[:, :, None], axis=1)
+
+ direct_logits = self._build_direct_net()
+ if direct_logits is not None:
+ logits += direct_logits
+
+ bias_logits = self._build_bias_net()
+ if bias_logits is not None:
+ logits += bias_logits
+
+ logits = tf.squeeze(logits, axis=1)
+ probs = 1 - tf.exp(-logits) # map [0, inf) to [0, 1)
+
+ self._prediction_dict['probs'] = probs
+ self._prediction_dict['logits'] = tf.log(
+ tf.clip_by_value(probs, 1e-8, 1 - 1e-8))
+ return self._prediction_dict
+
+ def _get_seq_features(self, name):
+ seqs, _, _ = self._input_layer(self._feature_dict, name, is_combine=False)
+ seq_len = seqs[0][1]
+ seq = tf.concat([x[0] for x in seqs], axis=2)
+ return seq, seq_len
+
+ def _build_trigger_net(self):
+ user_dnn_layer = dnn.DNN(self._model_config.user_dnn, self._l2_reg,
+ 'user_dnn', self._is_training)
+ user_fea = user_dnn_layer(self._user_features)
+
+ trigger_seq = tf.concat([self._u2i_seq, self._i_seq], axis=2)
+ u2i_dnn_layer = dnn.DNN(self._model_config.u2i_dnn, self._l2_reg, 'u2i_dnn',
+ self._is_training)
+ trigger_seq_fea = u2i_dnn_layer(trigger_seq)
+
+ trigger_merge_fea = trigger_seq_fea + user_fea[:, None, :]
+ trigger_dnn_layer = dnn.DNN(
+ self._model_config.trigger_dnn,
+ self._l2_reg,
+ 'trigger_dnn',
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
+
+ # output: N x seq_len x d, d is usually set to 1
+ trigger_out = trigger_dnn_layer(trigger_merge_fea)
+ # exp(x): map (-inf, inf) to (0, inf)
+ trigger_out = tf.exp(trigger_out)
+
+ self._prediction_dict['trigger_out'] = tf.reduce_join(
+ tf.reduce_join(
+ tf.as_string(trigger_out, precision=4, shortest=True),
+ axis=2,
+ separator=','),
+ axis=1,
+ separator=';')
+ return trigger_out
+
+ def _build_similarity_net(self):
+ item_dnn_layer = dnn.DNN(self._model_config.item_dnn, self._l2_reg,
+ 'item_dnn', self._is_training)
+ item_fea = item_dnn_layer(self._item_features)
+
+ sim_side_dnn_layer = dnn.DNN(self._model_config.i2i_dnn, self._l2_reg,
+ 'i2i_dnn', self._is_training)
+ sim_seq_fea = sim_side_dnn_layer(self._i_seq)
+
+ sim_seq_cross = sim_seq_fea * item_fea[:, None, :]
+
+ item_fea_tile = tf.tile(item_fea[:, None, :],
+ [1, tf.shape(sim_seq_fea)[1], 1])
+
+ sim_seq_concat = tf.concat(
+ [sim_seq_cross, sim_seq_cross, self._i2i_seq, item_fea_tile], axis=2)
+ sim_dnn_layer = dnn.DNN(
+ self._model_config.sim_dnn,
+ self._l2_reg,
+ 'sim_dnn',
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
+ # output: N x seq_len x 1
+ sim_out = sim_dnn_layer(sim_seq_concat)
+ # exp(x): map (-inf, inf) to (0, inf)
+ sim_out = tf.exp(sim_out)
+
+ self._prediction_dict['sim_out'] = tf.reduce_join(
+ tf.reduce_join(
+ tf.as_string(sim_out, precision=4, shortest=True),
+ axis=2,
+ separator=','),
+ axis=1,
+ separator=';')
+ return sim_out
+
+ def _build_direct_net(self):
+ if self._model_config.HasField('direct_user_dnn') and \
+ self._model_config.HasField('direct_item_dnn'):
+ direct_user_layer = dnn.DNN(
+ self._model_config.direct_user_dnn,
+ 'direct_user_dnn',
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
+ direct_user_out = direct_user_layer(self._user_features)
+ direct_item_layer = dnn.DNN(
+ self._model_config.direct_item_dnn,
+ 'direct_item_dnn',
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
+ direct_item_out = direct_item_layer(self._item_features)
+
+ if self._model_config.simi_func == Similarity.COSINE:
+ direct_user_out = self.norm(direct_user_out)
+ direct_item_out = self.norm(direct_item_out)
+
+ self._prediction_dict['direct_user_embedding'] = direct_user_out
+ self._prediction_dict['direct_item_embedding'] = direct_item_out
+ direct_logits = tf.reduce_sum(direct_user_out * direct_item_out, axis=1)
+
+ if self._model_config.scale_simi:
+ sim_w = tf.get_variable(
+ 'direct_net/sim_w',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.ones_initializer())
+ sim_b = tf.get_variable(
+ 'direct_net/sim_b',
+ dtype=tf.float32,
+ shape=(1),
+ initializer=tf.zeros_initializer())
+ direct_logits = direct_logits * tf.abs(sim_w) + sim_b
+
+ return tf.nn.softplus(direct_logits)
+ else:
+ return None
+
+ def _build_bias_net(self):
+ if self._model_config.HasField('bias_dnn'):
+ assert self._bias_features is not None, 'bias group must be defined'
+ bias_dnn_layer = dnn.DNN(
+ self._model_config.bias_dnn,
+ self._l2_reg,
+ 'bias_dnn',
+ self._is_training,
+ last_layer_no_activation=True,
+ last_layer_no_batch_norm=True)
+ bias_logits = bias_dnn_layer(self._bias_features)
+ return tf.nn.softplus(bias_logits)
+ else:
+ return None
+
+ def get_outputs(self):
+ return ['logits', 'probs', 'trigger_out', 'sim_out']
diff --git a/easy_rec/python/model/ple.py b/easy_rec/python/model/ple.py
index f3ad71215..e04781bcd 100644
--- a/easy_rec/python/model/ple.py
+++ b/easy_rec/python/model/ple.py
@@ -27,7 +27,10 @@ def __init__(self,
self._layer_nums = len(self._model_config.extraction_networks)
self._task_nums = len(self._model_config.task_towers)
- self._features, _ = self._input_layer(self._feature_dict, 'all')
+ if self.has_backbone:
+ self._features = self.backbone
+ else:
+ self._features, _ = self._input_layer(self._feature_dict, 'all')
self._init_towers(self._model_config.task_towers)
def gate(self, selector_fea, vec_feas, name):
diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py
index b05bac879..dc3771daf 100644
--- a/easy_rec/python/model/rank_model.py
+++ b/easy_rec/python/model/rank_model.py
@@ -1,12 +1,15 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+
import tensorflow as tf
+from tensorflow.python.ops import math_ops
from easy_rec.python.builders import loss_builder
-from easy_rec.python.core import metrics as metrics_lib
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType
-from easy_rec.python.utils import pai_util
+
+from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_pred # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -24,9 +27,31 @@ def __init__(self,
labels, is_training)
self._loss_type = self._model_config.loss_type
self._num_class = self._model_config.num_class
-
+ self._losses = self._model_config.losses
if self._labels is not None:
- self._label_name = list(self._labels.keys())[0]
+ if model_config.HasField('label_name'):
+ self._label_name = model_config.label_name
+ else:
+ self._label_name = list(self._labels.keys())[0]
+ self._outputs = []
+
+ def build_predict_graph(self):
+ if not self.has_backbone:
+ raise NotImplementedError(
+ 'method `build_predict_graph` must be implemented when backbone network do not exits'
+ )
+ model = self._model_config.WhichOneof('model')
+ assert model == 'model_params', '`model_params` must be configured'
+ config = self._model_config.model_params
+ for out in config.outputs:
+ self._outputs.append(out)
+
+ output = self.backbone
+ if int(output.shape[-1]) != self._num_class:
+ logging.info('add head logits layer for rank model')
+ output = tf.layers.dense(output, self._num_class, name='output')
+ self._add_to_prediction_dict(output)
+ return self._prediction_dict
def _output_to_prediction_impl(self,
output,
@@ -34,16 +59,52 @@ def _output_to_prediction_impl(self,
num_class=1,
suffix=''):
prediction_dict = {}
- if loss_type == LossType.CLASSIFICATION:
+ binary_loss_type = {
+ LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
+ LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
+ LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
+ LossType.PAIRWISE_LOGISTIC_LOSS, LossType.BINARY_CROSS_ENTROPY_LOSS,
+ LossType.LISTWISE_DISTILL_LOSS
+ }
+ if loss_type in binary_loss_type:
+ assert num_class == 1, 'num_class must be 1 when loss type is %s' % loss_type.name
+ output = tf.squeeze(output, axis=1)
+ probs = tf.sigmoid(output)
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs))
+ prediction_dict['logits' + suffix] = output
+ prediction_dict['probs' + suffix] = probs
+ elif loss_type == LossType.JRC_LOSS:
+ assert num_class == 2, 'num_class must be 2 when loss type is JRC_LOSS'
+ probs = tf.nn.softmax(output, axis=1)
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs[:, 1]))
+ prediction_dict['logits' + suffix] = output
+ prediction_dict['pos_logits' + suffix] = output[:, 1]
+ prediction_dict['probs' + suffix] = probs[:, 1]
+ elif loss_type == LossType.ZILN_LOSS:
+ assert num_class == 3, 'num_class must be 3 when loss type is ZILN_LOSS'
+ probs, preds = zero_inflated_lognormal_pred(output)
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs))
+ tf.summary.scalar('prediction/y', tf.reduce_mean(preds))
+ prediction_dict['logits' + suffix] = output
+ prediction_dict['probs' + suffix] = probs
+ prediction_dict['y' + suffix] = preds
+ elif loss_type == LossType.CLASSIFICATION:
if num_class == 1:
output = tf.squeeze(output, axis=1)
probs = tf.sigmoid(output)
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs))
prediction_dict['logits' + suffix] = output
prediction_dict['probs' + suffix] = probs
else:
probs = tf.nn.softmax(output, axis=1)
prediction_dict['logits' + suffix] = output
+ prediction_dict['logits' + suffix + '_1'] = output[:, 1]
prediction_dict['probs' + suffix] = probs
+ prediction_dict['probs' + suffix + '_1'] = probs[:, 1]
+ prediction_dict['logits' + suffix + '_y'] = math_ops.reduce_max(
+ output, axis=1)
+ prediction_dict['probs' + suffix + '_y'] = math_ops.reduce_max(
+ probs, axis=1)
prediction_dict['y' + suffix] = tf.argmax(output, axis=1)
elif loss_type == LossType.L2_LOSS:
output = tf.squeeze(output, axis=1)
@@ -54,44 +115,174 @@ def _output_to_prediction_impl(self,
return prediction_dict
def _add_to_prediction_dict(self, output):
- self._prediction_dict.update(
- self._output_to_prediction_impl(
- output, loss_type=self._loss_type, num_class=self._num_class))
+ if len(self._losses) == 0:
+ prediction_dict = self._output_to_prediction_impl(
+ output, loss_type=self._loss_type, num_class=self._num_class)
+ self._prediction_dict.update(prediction_dict)
+ else:
+ for loss in self._losses:
+ prediction_dict = self._output_to_prediction_impl(
+ output, loss_type=loss.loss_type, num_class=self._num_class)
+ self._prediction_dict.update(prediction_dict)
+
+ def build_rtp_output_dict(self):
+ """Forward tensor as `rank_predict`, which is a special node for RTP."""
+ outputs = {}
+ outputs.update(super(RankModel, self).build_rtp_output_dict())
+ rank_predict = None
+ try:
+ op = tf.get_default_graph().get_operation_by_name('rank_predict')
+ if len(op.outputs) != 1:
+ raise ValueError(
+ ('failed to build RTP rank_predict output: op {}[{}] has output ' +
+ 'size {}, however 1 is expected.').format(op.name, op.type,
+ len(op.outputs)))
+ rank_predict = op.outputs[0]
+ except KeyError:
+ forwarded = None
+ loss_types = {self._loss_type}
+ if len(self._losses) > 0:
+ loss_types = {loss.loss_type for loss in self._losses}
+ binary_loss_set = {
+ LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
+ LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
+ LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
+ LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
+ LossType.LISTWISE_RANK_LOSS
+ }
+ if loss_types & binary_loss_set:
+ if 'probs' in self._prediction_dict:
+ forwarded = self._prediction_dict['probs']
+ else:
+ raise ValueError(
+ 'failed to build RTP rank_predict output: classification model ' +
+ "expect 'probs' prediction, which is not found. Please check if" +
+ ' build_predict_graph() is called.')
+ elif loss_types & {
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
+ }:
+ if 'y' in self._prediction_dict:
+ forwarded = self._prediction_dict['y']
+ else:
+ raise ValueError(
+ 'failed to build RTP rank_predict output: regression model expect'
+ +
+ "'y' prediction, which is not found. Please check if build_predic"
+ + 't_graph() is called.')
+ else:
+ logging.warning(
+ 'failed to build RTP rank_predict: unsupported loss type {}'.format(
+ loss_types))
+ if forwarded is not None:
+ rank_predict = tf.identity(forwarded, name='rank_predict')
+ if rank_predict is not None:
+ outputs['rank_predict'] = rank_predict
+ return outputs
def _build_loss_impl(self,
loss_type,
label_name,
loss_weight=1.0,
num_class=1,
- suffix=''):
+ suffix='',
+ loss_name='',
+ loss_param=None):
loss_dict = {}
- if loss_type == LossType.CLASSIFICATION:
- loss_name = 'cross_entropy_loss' + suffix
+ binary_loss_type = {
+ LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
+ LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
+ LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
+ LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS,
+ LossType.LISTWISE_DISTILL_LOSS, LossType.ZILN_LOSS
+ }
+ if loss_type in {
+ LossType.CLASSIFICATION, LossType.BINARY_CROSS_ENTROPY_LOSS
+ }:
+ loss_name = loss_name if loss_name else 'cross_entropy_loss' + suffix
+ pred = self._prediction_dict['logits' + suffix]
+ elif loss_type in binary_loss_type:
+ if not loss_name:
+ loss_name = LossType.Name(loss_type).lower() + suffix
+ else:
+ loss_name = loss_name + suffix
pred = self._prediction_dict['logits' + suffix]
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
- loss_name = 'l2_loss' + suffix
+ loss_name = loss_name if loss_name else 'l2_loss' + suffix
pred = self._prediction_dict['y' + suffix]
else:
raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))
- loss_dict[loss_name] = loss_builder.build(loss_type,
- self._labels[label_name], pred,
- loss_weight, num_class)
+ tf.summary.scalar('labels/%s' % label_name,
+ tf.reduce_mean(tf.to_float(self._labels[label_name])))
+ kwargs = {'loss_name': loss_name}
+ if loss_param is not None:
+ if hasattr(loss_param, 'session_name'):
+ kwargs['session_ids'] = self._feature_dict[loss_param.session_name]
+ loss_dict[loss_name] = loss_builder.build(
+ loss_type,
+ self._labels[label_name],
+ pred,
+ loss_weight,
+ num_class,
+ loss_param=loss_param,
+ **kwargs)
return loss_dict
def build_loss_graph(self):
- self._loss_dict.update(
- self._build_loss_impl(
+ loss_dict = {}
+ with tf.name_scope('loss'):
+ if len(self._losses) == 0:
+ loss_dict = self._build_loss_impl(
self._loss_type,
label_name=self._label_name,
loss_weight=self._sample_weight,
- num_class=self._num_class))
-
- # build kd loss
- kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
- self._labels)
- self._loss_dict.update(kd_loss_dict)
-
+ num_class=self._num_class)
+ else:
+ strategy = self._base_model_config.loss_weight_strategy
+ loss_weight = [1.0]
+ if strategy == self._base_model_config.Random and len(self._losses) > 1:
+ weights = tf.random_normal([len(self._losses)])
+ loss_weight = tf.nn.softmax(weights)
+ for i, loss in enumerate(self._losses):
+ loss_param = loss.WhichOneof('loss_param')
+ if loss_param is not None:
+ loss_param = getattr(loss, loss_param)
+ loss_ops = self._build_loss_impl(
+ loss.loss_type,
+ label_name=self._label_name,
+ loss_weight=self._sample_weight,
+ num_class=self._num_class,
+ loss_name=loss.loss_name,
+ loss_param=loss_param)
+ for loss_name, loss_value in loss_ops.items():
+ if strategy == self._base_model_config.Fixed:
+ loss_dict[loss_name] = loss_value * loss.weight
+ elif strategy == self._base_model_config.Uncertainty:
+ if loss.learn_loss_weight:
+ uncertainty = tf.Variable(
+ 0, name='%s_loss_weight' % loss_name, dtype=tf.float32)
+ tf.summary.scalar('%s_uncertainty' % loss_name, uncertainty)
+ if loss.loss_type in {
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS
+ }:
+ loss_dict[loss_name] = 0.5 * tf.exp(
+ -uncertainty) * loss_value + 0.5 * uncertainty
+ else:
+ loss_dict[loss_name] = tf.exp(
+ -uncertainty) * loss_value + 0.5 * uncertainty
+ else:
+ loss_dict[loss_name] = loss_value * loss.weight
+ elif strategy == self._base_model_config.Random:
+ loss_dict[loss_name] = loss_value * loss_weight[i]
+ else:
+ raise ValueError('Unsupported loss weight strategy: ' +
+ strategy.Name)
+ self._loss_dict.update(loss_dict)
+ # build kd loss
+ kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
+ self._labels,
+ self._feature_dict)
+ self._loss_dict.update(kd_loss_dict)
return self._loss_dict
def _build_metric_impl(self,
@@ -100,32 +291,47 @@ def _build_metric_impl(self,
label_name,
num_class=1,
suffix=''):
+ if not isinstance(loss_type, set):
+ loss_type = {loss_type}
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
+ from easy_rec.python.core import metrics as metrics_lib
+ binary_loss_set = {
+ LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
+ LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
+ LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
+ LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
+ LossType.LISTWISE_RANK_LOSS, LossType.ZILN_LOSS
+ }
metric_dict = {}
if metric.WhichOneof('metric') == 'auc':
- assert loss_type == LossType.CLASSIFICATION
-
- if num_class == 1:
+ assert loss_type & binary_loss_set
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
- metric_dict['auc' + suffix] = tf.metrics.auc(
+ metric_dict['auc' + suffix] = metrics_tf.auc(
label,
self._prediction_dict['probs' + suffix],
num_thresholds=metric.auc.num_thresholds)
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
- metric_dict['auc' + suffix] = tf.metrics.auc(
+ metric_dict['auc' + suffix] = metrics_tf.auc(
label,
self._prediction_dict['probs' + suffix][:, 1],
num_thresholds=metric.auc.num_thresholds)
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'gauc':
- assert loss_type == LossType.CLASSIFICATION
- if num_class == 1:
+ assert loss_type & binary_loss_set
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
+ uids = self._feature_dict[metric.gauc.uid_field]
+ if isinstance(uids, tf.sparse.SparseTensor):
+ uids = tf.sparse_to_dense(
+ uids.indices, uids.dense_shape, uids.values, default_value='')
+ uids = tf.reshape(uids, [-1])
metric_dict['gauc' + suffix] = metrics_lib.gauc(
label,
self._prediction_dict['probs' + suffix],
- uids=self._feature_dict[metric.gauc.uid_field],
+ uids=uids,
reduction=metric.gauc.reduction)
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
@@ -137,17 +343,17 @@ def _build_metric_impl(self,
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'session_auc':
- assert loss_type == LossType.CLASSIFICATION
- if num_class == 1:
+ assert loss_type & binary_loss_set
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
- metric_dict['gauc' + suffix] = metrics_lib.session_auc(
+ metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
label,
self._prediction_dict['probs' + suffix],
session_ids=self._feature_dict[metric.session_auc.session_id_field],
reduction=metric.session_auc.reduction)
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
- metric_dict['gauc' + suffix] = metrics_lib.session_auc(
+ metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
label,
self._prediction_dict['probs' + suffix][:, 1],
session_ids=self._feature_dict[metric.session_auc.session_id_field],
@@ -155,223 +361,125 @@ def _build_metric_impl(self,
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'max_f1':
- assert loss_type == LossType.CLASSIFICATION
- if num_class == 1:
+ assert loss_type & binary_loss_set
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
label = tf.to_int64(self._labels[label_name])
- metric_dict['f1' + suffix] = metrics_lib.max_f1(
+ metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
label, self._prediction_dict['logits' + suffix])
elif num_class == 2:
label = tf.to_int64(self._labels[label_name])
- metric_dict['f1' + suffix] = metrics_lib.max_f1(
+ metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
label, self._prediction_dict['logits' + suffix][:, 1])
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'recall_at_topk':
- assert loss_type == LossType.CLASSIFICATION
+ assert loss_type & binary_loss_set
assert num_class > 1
label = tf.to_int64(self._labels[label_name])
- metric_dict['recall_at_topk' + suffix] = tf.metrics.recall_at_k(
+ metric_dict['recall_at_topk' + suffix] = metrics_tf.recall_at_k(
label, self._prediction_dict['logits' + suffix],
metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'mean_absolute_error':
label = tf.to_float(self._labels[label_name])
- if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
- metric_dict['mean_absolute_error' +
- suffix] = tf.metrics.mean_absolute_error(
- label, self._prediction_dict['y' + suffix])
- elif loss_type == LossType.CLASSIFICATION and num_class == 1:
- metric_dict['mean_absolute_error' +
- suffix] = tf.metrics.mean_absolute_error(
- label, self._prediction_dict['probs' + suffix])
- else:
- assert False, 'mean_absolute_error is not supported for this model'
- elif metric.WhichOneof('metric') == 'mean_squared_error':
- label = tf.to_float(self._labels[label_name])
- if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
- metric_dict['mean_squared_error' +
- suffix] = tf.metrics.mean_squared_error(
- label, self._prediction_dict['y' + suffix])
- elif loss_type == LossType.CLASSIFICATION and num_class == 1:
- metric_dict['mean_squared_error' +
- suffix] = tf.metrics.mean_squared_error(
- label, self._prediction_dict['probs' + suffix])
- else:
- assert False, 'mean_squared_error is not supported for this model'
- elif metric.WhichOneof('metric') == 'root_mean_squared_error':
- label = tf.to_float(self._labels[label_name])
- if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
- metric_dict['root_mean_squared_error' +
- suffix] = tf.metrics.root_mean_squared_error(
- label, self._prediction_dict['y' + suffix])
- elif loss_type == LossType.CLASSIFICATION and num_class == 1:
- metric_dict['root_mean_squared_error' +
- suffix] = tf.metrics.root_mean_squared_error(
- label, self._prediction_dict['probs' + suffix])
- else:
- assert False, 'root_mean_squared_error is not supported for this model'
- elif metric.WhichOneof('metric') == 'accuracy':
- assert loss_type == LossType.CLASSIFICATION
- assert num_class > 1
- label = tf.to_int64(self._labels[label_name])
- metric_dict['accuracy' + suffix] = tf.metrics.accuracy(
- label, self._prediction_dict['y' + suffix])
- return metric_dict
-
- def _build_distribute_metric_impl(self,
- metric,
- loss_type,
- label_name,
- num_class=1,
- suffix=''):
- if pai_util.is_on_pai():
- from easy_rec.python.core import metrics_impl_pai as distribute_metrics_tf
- else:
- from easy_rec.python.core import metrics_impl_tf as distribute_metrics_tf
- metric_dict = {}
- if metric.WhichOneof('metric') == 'auc':
- assert loss_type == LossType.CLASSIFICATION
- if num_class == 1:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['auc' + suffix] = distribute_metrics_tf.auc(
- label, self._prediction_dict['probs' + suffix])
- elif num_class == 2:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['auc' + suffix] = distribute_metrics_tf.auc(
- label, self._prediction_dict['probs' + suffix][:, 1])
- else:
- raise ValueError('Wrong class number')
- elif metric.WhichOneof('metric') == 'gauc':
- assert loss_type == LossType.CLASSIFICATION
- if num_class == 1:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['gauc' + suffix] = distribute_metrics_tf.gauc(
- label,
- self._prediction_dict['probs' + suffix],
- uids=self._feature_dict[metric.gauc.uid_field],
- reduction=metric.gauc.reduction)
- elif num_class == 2:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['gauc' + suffix] = distribute_metrics_tf.gauc(
- label,
- self._prediction_dict['probs' + suffix][:, 1],
- uids=self._feature_dict[metric.gauc.uid_field],
- reduction=metric.gauc.reduction)
- else:
- raise ValueError('Wrong class number')
- elif metric.WhichOneof('metric') == 'session_auc':
- assert loss_type == LossType.CLASSIFICATION
- if num_class == 1:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['gauc' + suffix] = distribute_metrics_tf.session_auc(
- label,
- self._prediction_dict['probs' + suffix],
- session_ids=self._feature_dict[metric.session_auc.session_id_field],
- reduction=metric.session_auc.reduction)
- elif num_class == 2:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['gauc' + suffix] = distribute_metrics_tf.session_auc(
- label,
- self._prediction_dict['probs' + suffix][:, 1],
- session_ids=self._feature_dict[metric.session_auc.session_id_field],
- reduction=metric.session_auc.reduction)
- else:
- raise ValueError('Wrong class number')
- elif metric.WhichOneof('metric') == 'max_f1':
- assert loss_type == LossType.CLASSIFICATION
- if num_class == 1:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['f1' + suffix] = distribute_metrics_tf.max_f1(
- label, self._prediction_dict['logits' + suffix])
- elif num_class == 2:
- label = tf.to_int64(self._labels[label_name])
- metric_dict['f1' + suffix] = distribute_metrics_tf.max_f1(
- label, self._prediction_dict['logits' + suffix][:, 1])
- else:
- raise ValueError('Wrong class number')
- elif metric.WhichOneof('metric') == 'recall_at_topk':
- assert loss_type == LossType.CLASSIFICATION
- assert num_class > 1
- label = tf.to_int64(self._labels[label_name])
- metric_dict['recall_at_topk' +
- suffix] = distribute_metrics_tf.recall_at_k(
- label, self._prediction_dict['logits' + suffix],
- metric.recall_at_topk.topk)
- elif metric.WhichOneof('metric') == 'mean_absolute_error':
- label = tf.to_float(self._labels[label_name])
- if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
+ if loss_type & {
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
+ }:
metric_dict['mean_absolute_error' +
- suffix] = distribute_metrics_tf.mean_absolute_error(
+ suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['y' + suffix])
- elif loss_type == LossType.CLASSIFICATION and num_class == 1:
+ elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
metric_dict['mean_absolute_error' +
- suffix] = distribute_metrics_tf.mean_absolute_error(
+ suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'mean_absolute_error is not supported for this model'
elif metric.WhichOneof('metric') == 'mean_squared_error':
label = tf.to_float(self._labels[label_name])
- if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
+ if loss_type & {
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
+ }:
metric_dict['mean_squared_error' +
- suffix] = distribute_metrics_tf.mean_squared_error(
+ suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['y' + suffix])
- elif loss_type == LossType.CLASSIFICATION and num_class == 1:
+ elif num_class == 1 and loss_type & binary_loss_set:
metric_dict['mean_squared_error' +
- suffix] = distribute_metrics_tf.mean_squared_error(
+ suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'root_mean_squared_error':
label = tf.to_float(self._labels[label_name])
- if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
+ if loss_type & {
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
+ }:
metric_dict['root_mean_squared_error' +
- suffix] = distribute_metrics_tf.root_mean_squared_error(
+ suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['y' + suffix])
- elif loss_type == LossType.CLASSIFICATION and num_class == 1:
+ elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
metric_dict['root_mean_squared_error' +
- suffix] = distribute_metrics_tf.root_mean_squared_error(
+ suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'root_mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'accuracy':
- assert loss_type == LossType.CLASSIFICATION
+ assert loss_type & {LossType.CLASSIFICATION}
assert num_class > 1
label = tf.to_int64(self._labels[label_name])
- metric_dict['accuracy' + suffix] = distribute_metrics_tf.accuracy(
+ metric_dict['accuracy' + suffix] = metrics_tf.accuracy(
label, self._prediction_dict['y' + suffix])
return metric_dict
def build_metric_graph(self, eval_config):
- metric_dict = {}
+ loss_types = {self._loss_type}
+ if len(self._losses) > 0:
+ loss_types = {loss.loss_type for loss in self._losses}
for metric in eval_config.metrics_set:
- metric_dict.update(
+ self._metric_dict.update(
self._build_metric_impl(
metric,
- loss_type=self._loss_type,
- label_name=self._label_name,
- num_class=self._num_class))
- return metric_dict
-
- def build_distribute_metric_graph(self, eval_config):
- metric_dict = {}
- for metric in eval_config.metrics_set:
- metric_dict.update(
- self._build_distribute_metric_impl(
- metric,
- loss_type=self._loss_type,
+ loss_type=loss_types,
label_name=self._label_name,
num_class=self._num_class))
- return metric_dict
+ return self._metric_dict
def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
+ binary_loss_set = {
+ LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
+ LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
+ LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
+ LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS
+ }
+ if loss_type in binary_loss_set:
+ return ['probs' + suffix, 'logits' + suffix]
+ if loss_type == LossType.JRC_LOSS:
+ return ['probs' + suffix, 'pos_logits' + suffix]
+ if loss_type == LossType.ZILN_LOSS:
+ return ['probs' + suffix, 'y' + suffix, 'logits' + suffix]
if loss_type == LossType.CLASSIFICATION:
if num_class == 1:
return ['probs' + suffix, 'logits' + suffix]
else:
- return ['y' + suffix, 'probs' + suffix, 'logits' + suffix]
+ return [
+ 'y' + suffix, 'probs' + suffix, 'logits' + suffix,
+ 'probs' + suffix + '_y', 'logits' + suffix + '_y',
+ 'probs' + suffix + '_1', 'logits' + suffix + '_1'
+ ]
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
return ['y' + suffix]
else:
raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))
def get_outputs(self):
- return self._get_outputs_impl(self._loss_type, self._num_class)
+ if len(self._losses) == 0:
+ outputs = self._get_outputs_impl(self._loss_type, self._num_class)
+ if self._outputs:
+ outputs.extend(self._outputs)
+ return list(set(outputs))
+
+ all_outputs = []
+ if self._outputs:
+ all_outputs.extend(self._outputs)
+ for loss in self._losses:
+ outputs = self._get_outputs_impl(loss.loss_type, self._num_class)
+ all_outputs.extend(outputs)
+ return list(set(all_outputs))
diff --git a/easy_rec/python/model/rocket_launching.py b/easy_rec/python/model/rocket_launching.py
index 1baeb2b98..aea29bf52 100755
--- a/easy_rec/python/model/rocket_launching.py
+++ b/easy_rec/python/model/rocket_launching.py
@@ -12,7 +12,6 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
-metrics = tf.metrics
class RocketLaunching(RankModel):
diff --git a/easy_rec/python/model/simple_multi_task.py b/easy_rec/python/model/simple_multi_task.py
index b4c0613bc..05dd7a773 100644
--- a/easy_rec/python/model/simple_multi_task.py
+++ b/easy_rec/python/model/simple_multi_task.py
@@ -27,7 +27,10 @@ def __init__(self,
self._model_config = self._model_config.simple_multi_task
assert isinstance(self._model_config, SimpleMultiTaskConfig)
- self._features, _ = self._input_layer(self._feature_dict, 'all')
+ if self.has_backbone:
+ self._features = self.backbone
+ else:
+ self._features, _ = self._input_layer(self._feature_dict, 'all')
self._init_towers(self._model_config.task_towers)
def build_predict_graph(self):
diff --git a/easy_rec/python/model/uniter.py b/easy_rec/python/model/uniter.py
new file mode 100644
index 000000000..40dfc8cb1
--- /dev/null
+++ b/easy_rec/python/model/uniter.py
@@ -0,0 +1,46 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import tensorflow as tf
+
+from easy_rec.python.layers import dnn
+from easy_rec.python.layers import uniter
+from easy_rec.python.model.rank_model import RankModel
+
+from easy_rec.python.protos.uniter_pb2 import Uniter as UNITERConfig # NOQA
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class Uniter(RankModel):
+ """UNITER: UNiversal Image-TExt Representation Learning.
+
+ See the original paper:
+ https://arxiv.org/abs/1909.11740
+ """
+
+ def __init__(self,
+ model_config,
+ feature_configs,
+ features,
+ labels=None,
+ is_training=False):
+ super(Uniter, self).__init__(model_config, feature_configs, features,
+ labels, is_training)
+ assert self._model_config.WhichOneof('model') == 'uniter', (
+ 'invalid model config: %s' % self._model_config.WhichOneof('model'))
+
+ self._uniter_layer = uniter.Uniter(model_config, feature_configs, features,
+ self._model_config.uniter.config,
+ self._input_layer)
+ self._model_config = self._model_config.uniter
+
+ def build_predict_graph(self):
+ hidden = self._uniter_layer(self._is_training, l2_reg=self._l2_reg)
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
+ 'final_dnn', self._is_training)
+ all_fea = final_dnn_layer(hidden)
+
+ final = tf.layers.dense(all_fea, self._num_class, name='output')
+ self._add_to_prediction_dict(final)
+ return self._prediction_dict
diff --git a/easy_rec/python/model/wide_and_deep.py b/easy_rec/python/model/wide_and_deep.py
index 119af575c..48b620bd7 100755
--- a/easy_rec/python/model/wide_and_deep.py
+++ b/easy_rec/python/model/wide_and_deep.py
@@ -5,7 +5,6 @@
import tensorflow as tf
from easy_rec.python.layers import dnn
-from easy_rec.python.layers import input_layer
from easy_rec.python.model.rank_model import RankModel
from easy_rec.python.protos.wide_and_deep_pb2 import WideAndDeep as WideAndDeepConfig # NOQA
@@ -36,17 +35,11 @@ def __init__(self,
def build_input_layer(self, model_config, feature_configs):
# overwrite create input_layer to support wide_output_dim
has_final = len(model_config.wide_and_deep.final_dnn.hidden_units) > 0
- wide_output_dim = model_config.wide_and_deep.wide_output_dim
+ self._wide_output_dim = model_config.wide_and_deep.wide_output_dim
if not has_final:
model_config.wide_and_deep.wide_output_dim = model_config.num_class
- wide_output_dim = model_config.num_class
- self._input_layer = input_layer.InputLayer(
- feature_configs,
- model_config.feature_groups,
- wide_output_dim=wide_output_dim,
- use_embedding_variable=model_config.use_embedding_variable,
- embedding_regularizer=self._emb_reg,
- kernel_regularizer=self._l2_reg)
+ self._wide_output_dim = model_config.num_class
+ super(WideAndDeep, self).build_input_layer(model_config, feature_configs)
def build_predict_graph(self):
wide_fea = tf.add_n(self._wide_features)
@@ -86,23 +79,43 @@ def build_predict_graph(self):
return self._prediction_dict
- def get_grouped_vars(self):
+ def get_grouped_vars(self, opt_num):
"""Group the vars into different optimization groups.
Each group will be optimized by a separate optimizer.
+ Args:
+ opt_num: number of optimizers from easyrec config.
+
Return:
list of list of variables.
"""
- assert len(self._model_config.final_dnn.hidden_units) == 0, \
- 'if use different optimizers for wide group and deep group, '\
- + ' final_dnn should not be set.'
- wide_vars = []
- deep_vars = []
- for tmp_var in tf.trainable_variables():
- if tmp_var.name.startswith('input_layer') and \
- (not tmp_var.name.startswith('input_layer_1')):
- wide_vars.append(tmp_var)
- else:
- deep_vars.append(tmp_var)
- return [wide_vars, deep_vars]
+ assert opt_num <= 3, 'could only support 2 or 3 optimizers, ' + \
+ 'if opt_num = 2, one for the wide , and one for the others, ' + \
+ 'if opt_num = 3, one for the wide, second for the deep embeddings, ' + \
+ 'and third for the other layers.'
+
+ if opt_num == 2:
+ wide_vars = []
+ deep_vars = []
+ for tmp_var in tf.trainable_variables():
+ if tmp_var.name.startswith('input_layer') and \
+ (not tmp_var.name.startswith('input_layer_1')):
+ wide_vars.append(tmp_var)
+ else:
+ deep_vars.append(tmp_var)
+ return [wide_vars, deep_vars]
+ elif opt_num == 3:
+ wide_vars = []
+ embedding_vars = []
+ deep_vars = []
+ for tmp_var in tf.trainable_variables():
+ if tmp_var.name.startswith('input_layer') and \
+ (not tmp_var.name.startswith('input_layer_1')):
+ wide_vars.append(tmp_var)
+ elif tmp_var.name.startswith(
+ 'input_layer') or '/embedding_weights' in tmp_var.name:
+ embedding_vars.append(tmp_var)
+ else:
+ deep_vars.append(tmp_var)
+ return [wide_vars, embedding_vars, deep_vars]
diff --git a/easy_rec/python/ops/1.12/incr_record.so b/easy_rec/python/ops/1.12/incr_record.so
new file mode 100755
index 000000000..821391e7b
Binary files /dev/null and b/easy_rec/python/ops/1.12/incr_record.so differ
diff --git a/easy_rec/python/ops/1.12/kafka.so b/easy_rec/python/ops/1.12/kafka.so
index d5b33cc46..fef4351b0 100755
Binary files a/easy_rec/python/ops/1.12/kafka.so and b/easy_rec/python/ops/1.12/kafka.so differ
diff --git a/easy_rec/python/ops/1.12/libcustom_ops.so b/easy_rec/python/ops/1.12/libcustom_ops.so
new file mode 100755
index 000000000..6d094f598
Binary files /dev/null and b/easy_rec/python/ops/1.12/libcustom_ops.so differ
diff --git a/easy_rec/python/ops/1.12/libembed_op.so b/easy_rec/python/ops/1.12/libembed_op.so
index 5f46ee7f8..8ed91452e 100644
Binary files a/easy_rec/python/ops/1.12/libembed_op.so and b/easy_rec/python/ops/1.12/libembed_op.so differ
diff --git a/easy_rec/python/ops/1.12/librdkafka++.so.1 b/easy_rec/python/ops/1.12/librdkafka++.so.1
new file mode 100755
index 000000000..8a448378c
Binary files /dev/null and b/easy_rec/python/ops/1.12/librdkafka++.so.1 differ
diff --git a/easy_rec/python/ops/1.12/librdkafka.so.1 b/easy_rec/python/ops/1.12/librdkafka.so.1
new file mode 100755
index 000000000..c7ab65e96
Binary files /dev/null and b/easy_rec/python/ops/1.12/librdkafka.so.1 differ
diff --git a/easy_rec/python/ops/1.12/libstr_avx_op.so b/easy_rec/python/ops/1.12/libstr_avx_op.so
new file mode 100755
index 000000000..8544d120c
Binary files /dev/null and b/easy_rec/python/ops/1.12/libstr_avx_op.so differ
diff --git a/easy_rec/python/ops/1.12_pai/__init__.py b/easy_rec/python/ops/1.12_pai/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/easy_rec/python/ops/1.12_pai/incr_record.so b/easy_rec/python/ops/1.12_pai/incr_record.so
new file mode 100755
index 000000000..ab607d79a
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/incr_record.so differ
diff --git a/easy_rec/python/ops/1.12_pai/kafka.so b/easy_rec/python/ops/1.12_pai/kafka.so
new file mode 100755
index 000000000..6603a0938
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/kafka.so differ
diff --git a/easy_rec/python/ops/1.12_pai/libcustom_ops.so b/easy_rec/python/ops/1.12_pai/libcustom_ops.so
new file mode 100755
index 000000000..2676f0a24
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libcustom_ops.so differ
diff --git a/easy_rec/python/ops/1.12_pai/libembed_op.so b/easy_rec/python/ops/1.12_pai/libembed_op.so
new file mode 100755
index 000000000..6def7a5ca
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libembed_op.so differ
diff --git a/easy_rec/python/ops/1.12_pai/libhiredis.so.1.0.0 b/easy_rec/python/ops/1.12_pai/libhiredis.so.1.0.0
new file mode 100644
index 000000000..63ae04d40
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libhiredis.so.1.0.0 differ
diff --git a/easy_rec/python/ops/1.12_pai/libkafka.so b/easy_rec/python/ops/1.12_pai/libkafka.so
new file mode 100755
index 000000000..566ce198b
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libkafka.so differ
diff --git a/easy_rec/python/ops/1.12_pai/librdkafka++.so.1 b/easy_rec/python/ops/1.12_pai/librdkafka++.so.1
new file mode 100755
index 000000000..8a448378c
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/librdkafka++.so.1 differ
diff --git a/easy_rec/python/ops/1.12_pai/librdkafka.so.1 b/easy_rec/python/ops/1.12_pai/librdkafka.so.1
new file mode 100755
index 000000000..c7ab65e96
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/librdkafka.so.1 differ
diff --git a/easy_rec/python/ops/1.12_pai/libredis++.so b/easy_rec/python/ops/1.12_pai/libredis++.so
new file mode 100644
index 000000000..cadfccc27
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libredis++.so differ
diff --git a/easy_rec/python/ops/1.12_pai/libredis++.so.1 b/easy_rec/python/ops/1.12_pai/libredis++.so.1
new file mode 100644
index 000000000..cadfccc27
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libredis++.so.1 differ
diff --git a/easy_rec/python/ops/1.12_pai/libredis++.so.1.2.3 b/easy_rec/python/ops/1.12_pai/libredis++.so.1.2.3
new file mode 100644
index 000000000..cadfccc27
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libredis++.so.1.2.3 differ
diff --git a/easy_rec/python/ops/1.12_pai/libstr_avx_op.so b/easy_rec/python/ops/1.12_pai/libstr_avx_op.so
new file mode 100755
index 000000000..8544d120c
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libstr_avx_op.so differ
diff --git a/easy_rec/python/ops/1.12_pai/libwrite_sparse_kv.so b/easy_rec/python/ops/1.12_pai/libwrite_sparse_kv.so
new file mode 100755
index 000000000..d50ee8edc
Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libwrite_sparse_kv.so differ
diff --git a/easy_rec/python/ops/1.15/incr_record.so b/easy_rec/python/ops/1.15/incr_record.so
new file mode 100755
index 000000000..a548b9c9c
Binary files /dev/null and b/easy_rec/python/ops/1.15/incr_record.so differ
diff --git a/easy_rec/python/ops/1.15/kafka.so b/easy_rec/python/ops/1.15/kafka.so
index 3ba64834a..6886446d8 100755
Binary files a/easy_rec/python/ops/1.15/kafka.so and b/easy_rec/python/ops/1.15/kafka.so differ
diff --git a/easy_rec/python/ops/1.15/libcustom_ops.so b/easy_rec/python/ops/1.15/libcustom_ops.so
new file mode 100755
index 000000000..5023cfe47
Binary files /dev/null and b/easy_rec/python/ops/1.15/libcustom_ops.so differ
diff --git a/easy_rec/python/ops/1.15/libembed_op.so b/easy_rec/python/ops/1.15/libembed_op.so
index 69d396100..970ae21ae 100755
Binary files a/easy_rec/python/ops/1.15/libembed_op.so and b/easy_rec/python/ops/1.15/libembed_op.so differ
diff --git a/easy_rec/python/ops/1.15/librdkafka++.so b/easy_rec/python/ops/1.15/librdkafka++.so
new file mode 100755
index 000000000..969f8ab1d
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka++.so differ
diff --git a/easy_rec/python/ops/1.15/librdkafka++.so.1 b/easy_rec/python/ops/1.15/librdkafka++.so.1
new file mode 100755
index 000000000..969f8ab1d
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka++.so.1 differ
diff --git a/easy_rec/python/ops/1.15/librdkafka.so b/easy_rec/python/ops/1.15/librdkafka.so
new file mode 100755
index 000000000..c83248971
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka.so differ
diff --git a/easy_rec/python/ops/1.15/librdkafka.so.1 b/easy_rec/python/ops/1.15/librdkafka.so.1
new file mode 100755
index 000000000..c83248971
Binary files /dev/null and b/easy_rec/python/ops/1.15/librdkafka.so.1 differ
diff --git a/easy_rec/python/ops/1.15/libstr_avx_op.so b/easy_rec/python/ops/1.15/libstr_avx_op.so
new file mode 100755
index 000000000..4237e9820
Binary files /dev/null and b/easy_rec/python/ops/1.15/libstr_avx_op.so differ
diff --git a/easy_rec/python/ops/2.12/libcustom_ops.so b/easy_rec/python/ops/2.12/libcustom_ops.so
new file mode 100755
index 000000000..4739657eb
Binary files /dev/null and b/easy_rec/python/ops/2.12/libcustom_ops.so differ
diff --git a/easy_rec/python/ops/2.12/libload_embed.so b/easy_rec/python/ops/2.12/libload_embed.so
new file mode 100755
index 000000000..3d71d6b5e
Binary files /dev/null and b/easy_rec/python/ops/2.12/libload_embed.so differ
diff --git a/easy_rec/python/ops/2.12/libstr_avx_op.so b/easy_rec/python/ops/2.12/libstr_avx_op.so
new file mode 100755
index 000000000..d438cc4a1
Binary files /dev/null and b/easy_rec/python/ops/2.12/libstr_avx_op.so differ
diff --git a/easy_rec/python/ops/DeepRec/incr_record.so b/easy_rec/python/ops/DeepRec/incr_record.so
new file mode 100755
index 000000000..fd8f73a48
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/incr_record.so differ
diff --git a/easy_rec/python/ops/DeepRec/kafka.so b/easy_rec/python/ops/DeepRec/kafka.so
new file mode 100755
index 000000000..ec0d5b9f0
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/kafka.so differ
diff --git a/easy_rec/python/ops/DeepRec/libcustom_ops.so b/easy_rec/python/ops/DeepRec/libcustom_ops.so
new file mode 100755
index 000000000..6fac5578f
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/libcustom_ops.so differ
diff --git a/easy_rec/python/ops/DeepRec/libembed_op.so b/easy_rec/python/ops/DeepRec/libembed_op.so
new file mode 100755
index 000000000..58975bd6f
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/libembed_op.so differ
diff --git a/easy_rec/python/ops/DeepRec/librdkafka++.so b/easy_rec/python/ops/DeepRec/librdkafka++.so
new file mode 100755
index 000000000..d9a8463e0
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka++.so differ
diff --git a/easy_rec/python/ops/DeepRec/librdkafka++.so.1 b/easy_rec/python/ops/DeepRec/librdkafka++.so.1
new file mode 100755
index 000000000..d9a8463e0
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka++.so.1 differ
diff --git a/easy_rec/python/ops/DeepRec/librdkafka.so b/easy_rec/python/ops/DeepRec/librdkafka.so
new file mode 100755
index 000000000..431eeb3cf
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka.so differ
diff --git a/easy_rec/python/ops/DeepRec/librdkafka.so.1 b/easy_rec/python/ops/DeepRec/librdkafka.so.1
new file mode 100755
index 000000000..431eeb3cf
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/librdkafka.so.1 differ
diff --git a/easy_rec/python/ops/DeepRec/libstr_avx_op.so b/easy_rec/python/ops/DeepRec/libstr_avx_op.so
new file mode 100755
index 000000000..bb8d36306
Binary files /dev/null and b/easy_rec/python/ops/DeepRec/libstr_avx_op.so differ
diff --git a/easy_rec/python/ops/build_ops.sh b/easy_rec/python/ops/build_ops.sh
new file mode 100755
index 000000000..985d74451
--- /dev/null
+++ b/easy_rec/python/ops/build_ops.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/bash
+TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
+TF_LFLAGS=$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
+TF_ABI=$(python -c 'import tensorflow as tf; print(str(tf.sysconfig.CXX11_ABI_FLAG if "CXX11_ABI_FLAG" in dir(tf.sysconfig) else 0))')
+echo "tensorflow include path: $TF_INC"
+echo "tensorflow link flags: $TF_LFLAGS"
+echo "CXX11_ABI_FLAG=$TF_ABI"
+
+script_path=`readlink -f $0`
+ops_dir=`dirname $script_path`
+ops_src_dir=${ops_dir}/src
+
+ops_bin_dir=`python -c "import easy_rec; print(easy_rec.get_ops_dir())" |tail -1`
+
+if [ -z "$ops_bin_dir" ]
+then
+ echo "could not determine ops_bin_dir"
+ exit 1
+fi
+
+if [ ! -e $ops_bin_dir ]
+then
+ mkdir -p $ops_bin_dir
+fi
+
+ops_bin=${ops_bin_dir}/libload_embed.so
+
+g++ -D_GLIBCXX_USE_CXX11_ABI=$TF_ABI -shared -O3 -DNDEBUG -Wl,-rpath,'$ORIGIN' -fpermissive -mfma -fopenmp ${ops_src_dir}/load_kv_embed.cc ${ops_src_dir}/load_dense_embed.cc -o ${ops_bin} -fPIC -I $TF_INC $TF_LFLAGS -L/lib64
+
+python -c "import tensorflow as tf; tf.load_op_library('$ops_bin')"
+err_code=$?
+if [ $err_code -ne 0 ]
+then
+ echo "build failed"
+ exit $err_code
+fi
diff --git a/easy_rec/python/ops/gen_kafka_ops.py b/easy_rec/python/ops/gen_kafka_ops.py
new file mode 100644
index 000000000..16bba500d
--- /dev/null
+++ b/easy_rec/python/ops/gen_kafka_ops.py
@@ -0,0 +1,193 @@
+"""Python wrappers around TensorFlow ops.
+
+This file is MACHINE GENERATED! Do not edit.
+Original C++ source file: kafka_ops_deprecated.cc
+"""
+
+import logging
+import os
+import traceback
+
+import six as _six
+import tensorflow as tf
+from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
+from tensorflow.python.eager import context as _context
+from tensorflow.python.eager import core as _core
+from tensorflow.python.eager import execute as _execute
+# Needed to trigger the call to _set_call_cpp_shape_fn.
+from tensorflow.python.framework import dtypes as _dtypes
+from tensorflow.python.framework import ops as _ops
+from tensorflow.python.util.tf_export import tf_export
+
+import easy_rec
+
+kafka_module = None
+if easy_rec.ops_dir is not None:
+ kafka_ops_path = os.path.join(easy_rec.ops_dir, 'kafka.so')
+ if os.path.exists(kafka_ops_path):
+ try:
+ kafka_module = tf.load_op_library(kafka_ops_path)
+ except Exception:
+ logging.warning('load %s failed: %s' %
+ (kafka_ops_path, traceback.format_exc()))
+
+
+@tf_export('io_kafka_dataset_v2')
+def io_kafka_dataset_v2(topics,
+ servers,
+ group,
+ eof,
+ timeout,
+ config_global,
+ config_topic,
+ message_key,
+ message_offset,
+ name=None):
+ """Creates a dataset that emits the messages of one or more Kafka topics.
+
+ Args:
+ topics: A `Tensor` of type `string`.
+ A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset].
+ servers: A `Tensor` of type `string`. A list of bootstrap servers.
+ group: A `Tensor` of type `string`. The consumer group id.
+ eof: A `Tensor` of type `bool`.
+ If True, the kafka reader will stop on EOF.
+ timeout: A `Tensor` of type `int64`.
+ The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+ config_global: A `Tensor` of type `string`.
+ A `tf.string` tensor containing global configuration
+ properties in [Key=Value] format,
+ eg. ["enable.auto.commit=false", "heartbeat.interval.ms=2000"],
+ please refer to 'Global configuration properties' in librdkafka doc.
+ config_topic: A `Tensor` of type `string`.
+ A `tf.string` tensor containing topic configuration
+ properties in [Key=Value] format, eg. ["auto.offset.reset=earliest"],
+ please refer to 'Topic configuration properties' in librdkafka doc.
+ message_key: A `Tensor` of type `bool`.
+ message_offset: A `Tensor` of type `bool`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `variant`.
+ """
+ return kafka_module.io_kafka_dataset_v2(
+ topics=topics,
+ servers=servers,
+ group=group,
+ eof=eof,
+ timeout=timeout,
+ config_global=config_global,
+ config_topic=config_topic,
+ message_key=message_key,
+ message_offset=message_offset,
+ name=name)
+
+
+def io_kafka_dataset_eager_fallback(topics,
+ servers,
+ group,
+ eof,
+ timeout,
+ config_global,
+ config_topic,
+ message_key,
+ message_offset,
+ name=None,
+ ctx=None):
+ """This is the slowpath function for Eager mode.
+
+ This is for function io_kafka_dataset
+ """
+ _ctx = ctx if ctx else _context.context()
+ topics = _ops.convert_to_tensor(topics, _dtypes.string)
+ servers = _ops.convert_to_tensor(servers, _dtypes.string)
+ group = _ops.convert_to_tensor(group, _dtypes.string)
+ eof = _ops.convert_to_tensor(eof, _dtypes.bool)
+ timeout = _ops.convert_to_tensor(timeout, _dtypes.int64)
+ config_global = _ops.convert_to_tensor(config_global, _dtypes.string)
+ config_topic = _ops.convert_to_tensor(config_topic, _dtypes.string)
+ message_key = _ops.convert_to_tensor(message_key, _dtypes.bool)
+ message_offset = _ops.convert_to_tensor(message_offset, _dtypes.bool)
+ _inputs_flat = [
+ topics, servers, group, eof, timeout, config_global, config_topic,
+ message_key, message_offset
+ ]
+ _attrs = None
+ _result = _execute.execute(
+ b'IOKafkaDataset',
+ 1,
+ inputs=_inputs_flat,
+ attrs=_attrs,
+ ctx=_ctx,
+ name=name)
+ _execute.record_gradient('IOKafkaDataset', _inputs_flat, _attrs, _result,
+ name)
+ _result, = _result
+ return _result
+
+
+@tf_export('io_write_kafka_v2')
+def io_write_kafka_v2(message, topic, servers, name=None):
+ r"""TODO: add doc.
+
+ Args:
+ message: A `Tensor` of type `string`.
+ topic: A `Tensor` of type `string`.
+ servers: A `Tensor` of type `string`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `string`.
+ """
+ _ctx = _context._context
+ if _ctx is None or not _ctx._eager_context.is_eager:
+ _op = kafka_module.io_write_kafka_v2(
+ message=message, topic=topic, servers=servers, name=name)
+ _result = _op.outputs[:]
+ _inputs_flat = _op.inputs
+ _attrs = None
+ _execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result,
+ name)
+ _result, = _result
+ return _result
+
+ else:
+ try:
+ _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
+ _ctx._context_handle, _ctx._eager_context.device_name, 'IOWriteKafka',
+ name, _ctx._post_execution_callbacks, message, topic, servers)
+ return _result
+ except _core._FallbackException:
+ return io_write_kafka_eager_fallback(
+ message, topic, servers, name=name, ctx=_ctx)
+ except _core._NotOkStatusException as e:
+ if name is not None:
+ message = e.message + ' name: ' + name
+ else:
+ message = e.message
+ _six.raise_from(_core._status_to_exception(e.code, message), None)
+
+
+def io_write_kafka_eager_fallback(message, topic, servers, name=None, ctx=None):
+ """This is the slowpath function for Eager mode.
+
+ This is for function io_write_kafka
+ """
+ _ctx = ctx if ctx else _context.context()
+ message = _ops.convert_to_tensor(message, _dtypes.string)
+ topic = _ops.convert_to_tensor(topic, _dtypes.string)
+ servers = _ops.convert_to_tensor(servers, _dtypes.string)
+ _inputs_flat = [message, topic, servers]
+ _attrs = None
+ _result = _execute.execute(
+ b'IOWriteKafka',
+ 1,
+ inputs=_inputs_flat,
+ attrs=_attrs,
+ ctx=_ctx,
+ name=name)
+ _execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result, name)
+ _result, = _result
+ return _result
diff --git a/easy_rec/python/ops/gen_str_avx_op.py b/easy_rec/python/ops/gen_str_avx_op.py
new file mode 100644
index 000000000..d022d52cb
--- /dev/null
+++ b/easy_rec/python/ops/gen_str_avx_op.py
@@ -0,0 +1,28 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+import os
+
+import tensorflow as tf
+from tensorflow.python.ops import string_ops
+
+import easy_rec
+from easy_rec.python.utils import constant
+
+try:
+ str_avx_op_path = os.path.join(easy_rec.ops_dir, 'libstr_avx_op.so')
+ str_avx_op = tf.load_op_library(str_avx_op_path)
+ logging.info('load avx string_split op from %s succeed' % str_avx_op_path)
+except Exception as ex:
+ logging.warning('load avx string_split op failed: %s' % str(ex))
+ str_avx_op = None
+
+
+def str_split_by_chr(input_str, sep, skip_empty):
+ if constant.has_avx_str_split() and str_avx_op is not None:
+ assert len(sep) == 1, \
+ 'invalid data_config.separator(%s) len(%d) != 1' % (
+ sep, len(sep))
+ return str_avx_op.avx512_string_split(input_str, sep, skip_empty=skip_empty)
+ else:
+ return string_ops.string_split(input_str, sep, skip_empty=skip_empty)
diff --git a/easy_rec/python/ops/incr_record.py b/easy_rec/python/ops/incr_record.py
new file mode 100644
index 000000000..b4bad11e4
--- /dev/null
+++ b/easy_rec/python/ops/incr_record.py
@@ -0,0 +1,30 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+import os
+
+import tensorflow as tf
+
+import easy_rec
+
+try:
+ op_path = os.path.join(easy_rec.ops_dir, 'incr_record.so')
+ op = tf.load_op_library(op_path)
+ get_sparse_indices = op.get_sparse_indices
+ set_sparse_indices = op.set_sparse_indices
+ if 'kv_resource_incr_gather' in dir(op):
+ kv_resource_incr_gather = getattr(op, 'kv_resource_incr_gather')
+ else:
+ kv_resource_incr_gather = None
+except ImportError as ex:
+ get_sparse_indices = None
+ set_sparse_indices = None
+ kv_resource_incr_gather = None
+ logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' %
+ str(ex))
+except Exception as ex:
+ get_sparse_indices = None
+ set_sparse_indices = None
+ kv_resource_incr_gather = None
+ logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' %
+ str(ex))
diff --git a/easy_rec/python/ops/src/load_dense_embed.cc b/easy_rec/python/ops/src/load_dense_embed.cc
new file mode 100644
index 000000000..f35bce4fd
--- /dev/null
+++ b/easy_rec/python/ops/src/load_dense_embed.cc
@@ -0,0 +1,158 @@
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/framework/resource_handle.h"
+
+namespace tensorflow {
+
+class LoadEmbedOp: public OpKernel {
+ public:
+ explicit LoadEmbedOp(OpKernelConstruction* context)
+ : OpKernel(context)
+ {
+ OP_REQUIRES_OK(context, context->GetAttr("task_index", &task_index_));
+ OP_REQUIRES_OK(context, context->GetAttr("task_num", &task_num_));
+ OP_REQUIRES_OK(context, context->GetAttr("embed_dim", &embed_dim_));
+ OP_REQUIRES_OK(context, context->GetAttr("embed_part_size", &embed_part_size_));
+ OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
+ }
+
+ int get_embed_part_id(const std::string & embed_file_path) const {
+ // embed-input_layer__all_fea__embedding_weights:0_part-0.bin
+ size_t tmp_pos = embed_file_path.rfind('-', embed_file_path.size() - 5);
+ if (tmp_pos == std::string::npos) {
+ LOG(ERROR) << "'-' is not found in embed_file_path=" << embed_file_path;
+ return -1;
+ }
+ std::string token = embed_file_path.substr(tmp_pos + 1,
+ embed_file_path.size() - 4);
+ return std::atoi(token.c_str());
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* file_name_t = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("ckpt_path", &file_name_t));
+
+ tstring file_name = file_name_t->flat()(0);
+ tstring folder = file_name + "-embedding/";
+ tstring prefix = var_name_ + "-part-";
+
+ LOG(INFO) << "task[" << task_index_ << "] file_name=" << file_name
+ << " folder=" << folder << " prefix=" << prefix;
+
+ DIR* pdir = opendir(folder.c_str());
+ struct dirent* ent = nullptr;
+
+ std::vector embed_files;
+ while((ent = readdir(pdir))) {
+ if (ent->d_type & DT_REG) {
+ std::string name = ent->d_name;
+ if (name.find(prefix) == std::string::npos) {
+ continue;
+ }
+ if (name.find(".bin") != std::string::npos) {
+ std::string embed_path = folder + name;
+ embed_files.push_back(embed_path);
+ }
+ }
+ }
+ ::closedir(pdir);
+
+ std::sort(embed_files.begin(), embed_files.end());
+
+ // output shape
+ TensorShape val_output_shape({embed_part_size_, embed_dim_});
+ Tensor * out_vals_t = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("vals", val_output_shape, &out_vals_t));
+
+ float * out_val_ptr = (float *)out_vals_t->tensor_data().data();
+ const int part_embed_flt_cnt = embed_part_size_ * embed_dim_;
+ // memset(out_val_ptr, 0, sizeof(float) * part_embed_flt_cnt);
+
+ const int total_embed_cnt = embed_part_size_ * task_num_;
+ const int embed_part_cnt_o = embed_files.size();
+ int part_update_cnt = 0;
+ for(const auto & embed_file : embed_files) {
+ LOG(INFO) << "task[" << task_index_ << "] will load embed_file: " << embed_file;
+ std::ifstream fin(embed_file.c_str());
+ fin.seekg(0, fin.end);
+ const size_t file_len = fin.tellg();
+ fin.seekg(0, fin.beg);
+
+ const size_t embed_flt_cnt_o = file_len / sizeof(float);
+ std::vector part_embed_o(embed_flt_cnt_o);
+ fin.read((char *)(part_embed_o.data()), file_len);
+ fin.close();
+
+ const int part_id_o = get_embed_part_id(embed_file);
+ const size_t embed_id_cnt_o = embed_flt_cnt_o / embed_dim_;
+ for(int embed_id_o = 0; embed_id_o < embed_id_cnt_o; ++embed_id_o) {
+ const int part_id_n = embed_id_o * embed_part_cnt_o + part_id_o;
+ if ((part_id_n % task_num_) == task_index_ &&
+ part_id_n < total_embed_cnt) {
+ const int embed_id_n = part_id_n / task_num_;
+ memcpy(out_val_ptr + embed_id_n * embed_dim_,
+ &part_embed_o[embed_id_o * embed_dim_],
+ sizeof(float) * embed_dim_);
+ part_update_cnt++;
+ }
+ }
+ }
+
+ LOG(INFO) << "task[" << task_index_ << "] embed_part_size="
+ << embed_part_size_ << " part_update_cnt="
+ << part_update_cnt;
+ OP_REQUIRES(ctx, (part_update_cnt == embed_part_size_ ||
+ part_update_cnt + 1 == embed_part_size_),
+ errors::InvalidArgument(
+ "part_update_cnt or part_update_cnt + 1 should be equal to "
+ "embed_part_size_, but are: ", part_update_cnt,
+ " and ", embed_part_size_));
+
+ if (part_update_cnt < embed_part_size_) {
+ memset(out_val_ptr + (part_embed_flt_cnt - embed_dim_),
+ 0, sizeof(float) * embed_dim_);
+ }
+ }
+
+ private:
+ int task_index_;
+ int task_num_;
+ int embed_dim_;
+ int embed_part_size_;
+ string var_name_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("LoadEmbed").Device(DEVICE_CPU), LoadEmbedOp);
+
+REGISTER_OP("LoadEmbed")
+ .Attr("task_index: int")
+ .Attr("task_num: int")
+ .Attr("embed_dim: int")
+ .Attr("embed_part_size: int")
+ .Attr("var_name: string")
+ .Input("ckpt_path: string")
+ .Output("vals: float32")
+ .SetIsStateful();
+
+} // end namespace tensorflow
diff --git a/easy_rec/python/ops/src/load_kv_embed.cc b/easy_rec/python/ops/src/load_kv_embed.cc
new file mode 100644
index 000000000..1dc97e1fc
--- /dev/null
+++ b/easy_rec/python/ops/src/load_kv_embed.cc
@@ -0,0 +1,190 @@
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/framework/resource_handle.h"
+
+namespace tensorflow {
+
+
+class LoadKVEmbedOp: public OpKernel {
+ public:
+ explicit LoadKVEmbedOp(OpKernelConstruction* context)
+ : OpKernel(context)
+ {
+ OP_REQUIRES_OK(context, context->GetAttr("task_index", &task_index_));
+ OP_REQUIRES_OK(context, context->GetAttr("task_num", &task_num_));
+ OP_REQUIRES_OK(context, context->GetAttr("embed_dim", &embed_dim_));
+ OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* file_name_t = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("ckpt_path", &file_name_t));
+
+ tstring file_name = file_name_t->flat()(0);
+
+ tstring folder = file_name + "-embedding/";
+
+ tstring prefix = var_name_ + "-part-";
+
+ LOG(INFO) << "file_name=" << file_name << " folder=" << folder << " prefix=" << prefix;
+
+ DIR* pdir = opendir(folder.c_str());
+ struct dirent* ent = nullptr;
+
+ std::vector key_ptr_vec;
+ std::vector val_ptr_vec;
+ std::vector key_num_vec;
+ int all_worker_total_keys = 0;
+ while((ent = readdir(pdir))) {
+ if (ent->d_type & DT_REG) {
+ std::string name = ent->d_name;
+ if (name.find(prefix) == std::string::npos) {
+ continue;
+ }
+ if (name.find(".key") != std::string::npos) {
+ std::string key_path = folder + name;
+ LOG(INFO) << "load keys from " << key_path;
+ std::ifstream fin(key_path.c_str(), std::ifstream::binary);
+ fin.seekg(0, fin.end);
+ size_t file_len = fin.tellg();
+ fin.seekg(0, fin.beg);
+ const size_t key_num = file_len / sizeof(int64_t);
+ key_num_vec.push_back(key_num);
+ int64_t * key_buf = new int64_t[key_num];
+ fin.read((char *)key_buf, file_len);
+ fin.close();
+ key_ptr_vec.push_back(key_buf);
+
+ LOG(INFO) << "load keys from " << key_path << " key_num=" << key_num;
+
+ std::string val_path = key_path.substr(0, key_path.size()-4) + ".val";
+ LOG(INFO) << "load vals from " << val_path;
+ fin.open(val_path.c_str(), std::ifstream::binary);
+ if (! fin) {
+ char err_msg_buf[1024];
+ snprintf(err_msg_buf, 1024, "error: file does not exists: %s",
+ val_path.c_str());
+ LOG(ERROR) << err_msg_buf;
+ throw std::runtime_error(err_msg_buf);
+ }
+ fin.seekg(0, fin.end);
+ file_len = fin.tellg();
+ if (file_len != key_num * embed_dim_ * sizeof(float)) {
+ fin.close();
+ char err_msg_buf[1024];
+ snprintf(err_msg_buf, 1024,
+ "error: key_num[%ld] does not match with val_num[%ld], embed_dim=[%d]",
+ key_num, file_len / sizeof(float), embed_dim_);
+ LOG(ERROR) << err_msg_buf;
+ throw std::runtime_error(err_msg_buf);
+ }
+ fin.seekg(0, fin.beg);
+ float * val_buf = new float[key_num * embed_dim_];
+ fin.read((char *)val_buf, file_len);
+ fin.close();
+ val_ptr_vec.push_back(val_buf);
+
+ all_worker_total_keys += key_num;
+ LOG(INFO) << "all_worker_total_keys=" << all_worker_total_keys;
+ }
+ }
+ }
+ closedir(pdir);
+
+ // filter key by index
+ const int vec_num = key_num_vec.size();
+ std::vector > sel_ids;
+ sel_ids.reserve(all_worker_total_keys / task_num_);
+ int total_keys = 0;
+ for(int i = 0; i < key_ptr_vec.size(); ++i) {
+ const int64_t * key_ptr = key_ptr_vec[i];
+ const int key_num = key_num_vec[i];
+ for(int j = 0; j < key_num; ++j) {
+ int assign_id = key_ptr[j] % task_num_;
+ if (assign_id < 0) {
+ assign_id += task_num_;
+ }
+ if (assign_id == task_index_) {
+ total_keys++;
+ sel_ids.push_back(std::pair(i,j));
+ }
+ }
+ }
+
+ LOG(INFO) << "task[" << task_index_ << "/" << task_num_
+ << "] all_worker_total_keys=" << all_worker_total_keys
+ << " load_part_num=" << vec_num
+ << " total_keys=" << total_keys << " embed_dim=" << embed_dim_;
+
+ // output shape
+ TensorShape key_output_shape({total_keys});
+ Tensor * out_keys_t = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("keys", key_output_shape, &out_keys_t));
+ TensorShape val_output_shape({total_keys, embed_dim_});
+ Tensor * out_vals_t = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("vals", val_output_shape, &out_vals_t));
+
+ {
+ std::random_device rd;
+ std::mt19937 g(rd());
+ std::shuffle(sel_ids.begin(), sel_ids.end(), g);
+ }
+
+ int64_t * key_ptr = (int64_t*)out_keys_t->tensor_data().data();
+ float * val_ptr = (float*)out_vals_t->tensor_data().data();
+ for(auto iter = sel_ids.begin(); iter != sel_ids.end(); ++iter) {
+ const int64_t * src_key_ptr = key_ptr_vec[iter->first] + iter->second;
+ const float * src_val_ptr = val_ptr_vec[iter->first] + iter->second * embed_dim_;
+ key_ptr[0] = src_key_ptr[0];
+ memcpy(val_ptr, src_val_ptr, sizeof(float) * embed_dim_);
+ key_ptr += 1;
+ val_ptr += embed_dim_;
+ }
+
+ for(int i = 0; i < vec_num; ++i) {
+ delete [] key_ptr_vec[i];
+ delete [] val_ptr_vec[i];
+ }
+ }
+
+ private:
+ int task_index_;
+ int task_num_;
+ int embed_dim_;
+ string var_name_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("LoadKVEmbed").Device(DEVICE_CPU), LoadKVEmbedOp);
+
+REGISTER_OP("LoadKVEmbed")
+ .Attr("task_index: int")
+ .Attr("task_num: int")
+ .Attr("embed_dim: int")
+ .Attr("var_name: string")
+ .Input("ckpt_path: string")
+ .Output("keys: int64")
+ .Output("vals: float32")
+ .SetIsStateful();
+
+} // end namespace tensorflow
diff --git a/easy_rec/python/predict.py b/easy_rec/python/predict.py
index 46633708e..ced1a7573 100644
--- a/easy_rec/python/predict.py
+++ b/easy_rec/python/predict.py
@@ -8,8 +8,17 @@
import tensorflow as tf
from tensorflow.python.lib.io import file_io
-from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.inference.csv_predictor import CSVPredictor
+from easy_rec.python.inference.hive_predictor import HivePredictor
+from easy_rec.python.inference.parquet_predictor import ParquetPredictor
+from easy_rec.python.inference.parquet_predictor_v2 import ParquetPredictorV2
from easy_rec.python.main import predict
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import numpy_utils
+from easy_rec.python.utils.hive_utils import HiveUtils
+
+from easy_rec.python.inference.hive_parquet_predictor import HiveParquetPredictor # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -18,12 +27,11 @@
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
-tf.app.flags.DEFINE_string(
- 'input_path', None, 'predict data path, if specified will '
- 'override pipeline_config.eval_input_path')
+tf.app.flags.DEFINE_string('input_path', None, 'predict data path')
tf.app.flags.DEFINE_string('output_path', None, 'path to save predict result')
tf.app.flags.DEFINE_integer('batch_size', 1024, help='batch size')
-
+tf.app.flags.DEFINE_bool('with_header', False,
+ 'whether the input csv file has header')
# predict by checkpoint
tf.app.flags.DEFINE_string('pipeline_config_path', None,
'Path to pipeline config '
@@ -42,18 +50,84 @@
tf.app.flags.DEFINE_string(
'output_cols', 'ALL_COLUMNS',
'output columns, such as: score float. multiple columns are separated by ,')
-tf.app.flags.DEFINE_string('input_sep', ',', 'separator of predict result file')
tf.app.flags.DEFINE_string('output_sep', chr(1),
'separator of predict result file')
-
+tf.app.flags.DEFINE_string('selected_cols', None, '')
+tf.app.flags.DEFINE_string('fg_json_path', '', '')
+tf.app.flags.DEFINE_string('ds_vector_recall', '', '')
+tf.app.flags.DEFINE_string('input_type', '', 'data_config.input_type')
FLAGS = tf.app.flags.FLAGS
+input_class_map = {y: x for x, y in DatasetConfig.InputType.items()}
+input_class_map_r = {x: y for x, y in DatasetConfig.InputType.items()}
+
+
+def get_input_type(input_type, data_config):
+ if input_type:
+ return input_class_map[input_type]
+ return data_config.input_type
+
def main(argv):
if FLAGS.saved_model_dir:
logging.info('Predict by saved_model.')
- predictor = Predictor(FLAGS.saved_model_dir)
+ if FLAGS.pipeline_config_path:
+ pipeline_config_path = FLAGS.pipeline_config_path
+ else:
+ pipeline_config_path = config_util.search_pipeline_config(
+ FLAGS.saved_model_dir)
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+ data_config = pipeline_config.data_config
+ input_type = get_input_type(FLAGS.input_type, data_config)
+ if input_type in [data_config.HiveParquetInput, data_config.HiveInput]:
+ all_cols, all_col_types = HiveUtils(
+ data_config=pipeline_config.data_config,
+ hive_config=pipeline_config.hive_train_input).get_all_cols(
+ FLAGS.input_path)
+ if input_type == DatasetConfig.HiveParquetInput:
+ predictor = HiveParquetPredictor(
+ FLAGS.saved_model_dir,
+ pipeline_config.data_config,
+ fg_json_path=FLAGS.fg_json_path,
+ hive_config=pipeline_config.hive_train_input,
+ output_sep=FLAGS.output_sep,
+ all_cols=all_cols,
+ all_col_types=all_col_types)
+ else:
+ predictor = HivePredictor(
+ FLAGS.saved_model_dir,
+ pipeline_config.data_config,
+ fg_json_path=FLAGS.fg_json_path,
+ hive_config=pipeline_config.hive_train_input,
+ output_sep=FLAGS.output_sep,
+ all_cols=all_cols,
+ all_col_types=all_col_types)
+ elif input_type in [data_config.ParquetInput, data_config.ParquetInputV2]:
+ predictor_cls = ParquetPredictor
+ if input_type == data_config.ParquetInputV2:
+ predictor_cls = ParquetPredictorV2
+ predictor = predictor_cls(
+ FLAGS.saved_model_dir,
+ pipeline_config.data_config,
+ ds_vector_recall=FLAGS.ds_vector_recall,
+ fg_json_path=FLAGS.fg_json_path,
+ selected_cols=FLAGS.selected_cols,
+ output_sep=FLAGS.output_sep,
+ pipeline_config=pipeline_config)
+ elif input_type == data_config.CSVInput:
+ predictor = CSVPredictor(
+ FLAGS.saved_model_dir,
+ pipeline_config.data_config,
+ FLAGS.with_header,
+ ds_vector_recall=FLAGS.ds_vector_recall,
+ fg_json_path=FLAGS.fg_json_path,
+ selected_cols=FLAGS.selected_cols,
+ output_sep=FLAGS.output_sep)
+ else:
+ assert False, 'invalid input type: %s' % input_class_map_r[input_type]
+
logging.info('input_path = %s, output_path = %s' %
(FLAGS.input_path, FLAGS.output_path))
if 'TF_CONFIG' in os.environ:
@@ -68,10 +142,9 @@ def main(argv):
FLAGS.output_path,
reserved_cols=FLAGS.reserved_cols,
output_cols=FLAGS.output_cols,
+ batch_size=FLAGS.batch_size,
slice_id=task_index,
- slice_num=worker_num,
- input_sep=FLAGS.input_sep,
- output_sep=FLAGS.output_sep)
+ slice_num=worker_num)
else:
logging.info('Predict by checkpoint_path.')
assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
@@ -90,7 +163,7 @@ def main(argv):
logging.info('will save predict result to %s' % FLAGS.output_path)
with tf.gfile.GFile(FLAGS.output_path, 'wb') as fout:
for k in pred_result:
- fout.write(str(k).replace("u'", '"').replace("'", '"') + '\n')
+ fout.write(json.dumps(k, cls=numpy_utils.NumpyEncoder) + '\n')
if __name__ == '__main__':
diff --git a/easy_rec/python/protos/backbone.proto b/easy_rec/python/protos/backbone.proto
new file mode 100644
index 000000000..b86da78ec
--- /dev/null
+++ b/easy_rec/python/protos/backbone.proto
@@ -0,0 +1,119 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+import "easy_rec/python/protos/keras_layer.proto";
+
+message InputLayer {
+ optional bool do_batch_norm = 1;
+ optional bool do_layer_norm = 2;
+ optional float dropout_rate = 3;
+ optional float feature_dropout_rate = 4;
+ optional bool only_output_feature_list = 5;
+ optional bool only_output_3d_tensor = 6;
+ optional bool output_2d_tensor_and_feature_list = 7;
+ optional bool output_seq_and_normal_feature = 8;
+ optional uint32 wide_output_dim = 9;
+ optional bool concat_seq_feature = 10 [default = true];
+}
+
+message RawInputLayer {
+}
+
+message EmbeddingLayer {
+ required uint32 embedding_dim = 1;
+ optional uint32 vocab_size = 2;
+ optional string combiner = 3 [default = 'weight'];
+ optional bool concat = 4 [default = true];
+}
+
+message Lambda {
+ required string expression = 1;
+}
+
+message Input {
+ oneof name {
+ string feature_group_name = 1;
+ string block_name = 2;
+ string package_name = 3;
+ bool use_package_input = 4;
+ }
+ optional string input_fn = 11;
+ optional string input_slice = 12;
+ optional bool ignore_input = 13 [default = false];
+ optional InputLayer reset_input = 14;
+ optional string package_input = 15;
+ optional string package_input_fn = 16;
+}
+
+message RecurrentLayer {
+ required uint32 num_steps = 1 [default = 1];
+ optional uint32 fixed_input_index = 2;
+ required KerasLayer keras_layer = 3;
+}
+
+message RepeatLayer {
+ required uint32 num_repeat = 1 [default = 1];
+ // default output the list of multiple outputs
+ optional int32 output_concat_axis = 2;
+ required KerasLayer keras_layer = 3;
+ optional string input_slice = 4;
+ optional string input_fn = 5;
+}
+
+message Layer {
+ oneof layer {
+ Lambda lambda = 1;
+ KerasLayer keras_layer = 2;
+ RecurrentLayer recurrent = 3;
+ RepeatLayer repeat = 4;
+ }
+}
+
+message Block {
+ required string name = 1;
+ // the input names of feature groups or other blocks
+ repeated Input inputs = 2;
+ optional int32 input_concat_axis = 3 [default = -1];
+ optional bool merge_inputs_into_list = 4;
+ optional string extra_input_fn = 5;
+
+ // sequential layers
+ repeated Layer layers = 100;
+
+ // only take effect when there are no layers
+ oneof layer {
+ InputLayer input_layer = 101;
+ Lambda lambda = 102;
+ KerasLayer keras_layer = 103;
+ RecurrentLayer recurrent = 104;
+ RepeatLayer repeat = 105;
+ RawInputLayer raw_input = 106;
+ EmbeddingLayer embedding_layer = 107;
+ }
+}
+
+// a package of blocks for reuse; e.g. call in a contrastive learning manner
+message BlockPackage {
+ // package name
+ required string name = 1;
+ // a few blocks generating a DAG
+ repeated Block blocks = 2;
+ // the names of output blocks, will be merge into a tensor
+ repeated string concat_blocks = 3;
+ // the names of output blocks, return as a list or single tensor
+ repeated string output_blocks = 4;
+}
+
+message BackboneTower {
+ // a few sub DAGs
+ repeated BlockPackage packages = 1;
+ // a few blocks generating a DAG
+ repeated Block blocks = 2;
+ // the names of output blocks, will be merge into a tensor
+ repeated string concat_blocks = 3;
+ // the names of output blocks, return as a list or single tensor
+ repeated string output_blocks = 4;
+ // optional top mlp layer
+ optional MLP top_mlp = 5;
+}
diff --git a/easy_rec/python/protos/cmbf.proto b/easy_rec/python/protos/cmbf.proto
new file mode 100644
index 000000000..34e082115
--- /dev/null
+++ b/easy_rec/python/protos/cmbf.proto
@@ -0,0 +1,52 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+
+message CMBFTower {
+ // The number of heads of cross modal fusion layer
+ required uint32 multi_head_num = 1 [default = 1];
+ // The number of heads of image feature learning layer
+ required uint32 image_multi_head_num = 101 [default = 1];
+ // The number of heads of text feature learning layer
+ required uint32 text_multi_head_num = 102 [default = 1];
+ // The dimension of text heads
+ required uint32 text_head_size = 2;
+ // The dimension of image heads
+ required uint32 image_head_size = 3 [default = 64];
+ // The number of patches of image feature, take effect when there is only one image feature
+ required uint32 image_feature_patch_num = 4 [default = 1];
+ // Do dimension reduce to this size for image feature before single modal learning module
+ required uint32 image_feature_dim = 5 [default = 0];
+ // The number of self attention layers for image features
+ required uint32 image_self_attention_layer_num = 6 [default = 0];
+ // The number of self attention layers for text features
+ required uint32 text_self_attention_layer_num = 7 [default = 1];
+ // The number of cross modal layers
+ required uint32 cross_modal_layer_num = 8 [default = 1];
+ // The dimension of image cross modal heads
+ required uint32 image_cross_head_size = 9;
+ // The dimension of text cross modal heads
+ required uint32 text_cross_head_size = 10;
+ // Dropout probability for hidden layers
+ required float hidden_dropout_prob = 11 [default = 0.0];
+ // Dropout probability of the attention probabilities
+ required float attention_probs_dropout_prob = 12 [default = 0.0];
+
+ // Whether to add embeddings for different text sequence features
+ required bool use_token_type = 13 [default = false];
+ // Whether to add position embeddings for the position of each token in the text sequence
+ required bool use_position_embeddings = 14 [default = true];
+ // Maximum sequence length that might ever be used with this model
+ required uint32 max_position_embeddings = 15 [default = 0];
+ // Dropout probability for text sequence embeddings
+ required float text_seq_emb_dropout_prob = 16 [default = 0.1];
+ // dnn layers for other features
+ optional DNN other_feature_dnn = 17;
+}
+
+message CMBF {
+ required CMBFTower config = 1;
+
+ required DNN final_dnn = 2;
+}
diff --git a/easy_rec/python/protos/dat.proto b/easy_rec/python/protos/dat.proto
new file mode 100644
index 000000000..2325fbb87
--- /dev/null
+++ b/easy_rec/python/protos/dat.proto
@@ -0,0 +1,25 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+import "easy_rec/python/protos/simi.proto";
+
+
+message DATTower {
+ required string id = 1;
+ required DNN dnn = 2;
+};
+
+
+message DAT {
+ required DATTower user_tower = 1;
+ required DATTower item_tower = 2;
+ required float l2_regularization = 3 [default = 1e-4];
+ optional Similarity simi_func = 4 [default=COSINE];
+ required bool ignore_in_batch_neg_sam = 5 [default = false];
+ optional float temperature = 6 [default = 1.0];
+ // loss weight for amm_i
+ required float amm_i_weight = 7 [default = 0.5];
+ // loss weight for amm_u
+ required float amm_u_weight = 8 [default = 0.5];
+}
diff --git a/easy_rec/python/protos/data_source.proto b/easy_rec/python/protos/data_source.proto
index a05134d12..76394d0d4 100644
--- a/easy_rec/python/protos/data_source.proto
+++ b/easy_rec/python/protos/data_source.proto
@@ -5,16 +5,39 @@ message KafkaServer {
required string server = 1;
required string topic = 2;
required string group = 3;
- required uint32 partitions = 4;
- repeated uint32 offset = 5;
+ oneof offset {
+ // in json format: {'0':10, '1':20}
+ string offset_info = 40;
+ // offset_time could be two formats:
+ // 1: %Y%m%d %H:%M:%S '20220508 12:00:00'
+ // 2: %s '1651982400'
+ string offset_time = 42;
+ }
+ // kafka global config, such as: fetch.max.bytes=1024
+ repeated string config_global = 5;
+ // kafka topic config, such as: max.partition.fetch.bytes=1024
+ repeated string config_topic = 6;
}
message DatahubServer{
required string akId = 1;
required string akSecret = 2;
- required string region = 3;
+ required string endpoint = 3;
required string project = 4;
required string topic = 5;
- required uint32 shard_num = 6;
- required uint32 life_cycle = 7;
+ oneof offset {
+ // in json format: {"0":{"cursor": ""}, "1":{"cursor":""}}
+ string offset_info = 60;
+ // offset_time could be two formats:
+ // 1: %Y%m%d %H:%M:%S "20220508 12:00:00"
+ // 2: %s "1651982400"
+ string offset_time = 62;
+ }
+}
+
+message BinaryDataInput {
+ // support gfile.Glob
+ repeated string category_path = 1;
+ repeated string dense_path = 2;
+ repeated string label_path = 3;
}
diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto
index 2710d2d91..ca19dcd04 100644
--- a/easy_rec/python/protos/dataset.proto
+++ b/easy_rec/python/protos/dataset.proto
@@ -16,6 +16,28 @@ message NegativeSampler {
optional string attr_delimiter = 5 [default=":"];
optional uint32 num_eval_sample = 6 [default=0];
+
+ // only works on DataScience/Local
+ optional string field_delimiter = 7 [default="\001"];
+}
+
+message NegativeSamplerInMemory {
+ // sample data path
+ // itemid weight attrs
+ required string input_path = 1;
+ // number of negative sample
+ required uint32 num_sample = 2;
+ // field names of attrs in train data or eval data
+ repeated string attr_fields = 3;
+ // field name of item_id in train data or eval data
+ required string item_id_field = 4;
+
+ optional string attr_delimiter = 5 [default=":"];
+
+ optional uint32 num_eval_sample = 6 [default=0];
+
+ // only works on DataScience/Local
+ optional string field_delimiter = 7 [default="\001"];
}
// Weighted Random Sampling ItemID not with Edge
@@ -41,6 +63,9 @@ message NegativeSamplerV2 {
optional string attr_delimiter = 8 [default=":"];
optional uint32 num_eval_sample = 9 [default=0];
+
+ // only works on DataScience/Local
+ optional string field_delimiter = 10 [default="\001"];
}
// Weighted Random Sampling ItemID not in Batch and Sampling Hard Edge
@@ -68,6 +93,9 @@ message HardNegativeSampler {
optional string attr_delimiter = 9 [default=":"];
optional uint32 num_eval_sample = 10 [default=0];
+
+ // only works on DataScience/Local
+ optional string field_delimiter = 11 [default="\001"];
}
// Weighted Random Sampling ItemID not with Edge and Sampling Hard Edge
@@ -98,6 +126,9 @@ message HardNegativeSamplerV2 {
optional string attr_delimiter = 10 [default=":"];
optional uint32 num_eval_sample = 11 [default=0];
+
+ // only works on DataScience/Local
+ optional string field_delimiter = 12 [default="\001"];
}
message DatasetConfig {
@@ -119,6 +150,14 @@ message DatasetConfig {
optional string default_val = 3;
optional uint32 input_dim = 4 [default=1];
optional uint32 input_shape = 5 [default = 1];
+ // user-defined function for label. eg: tf.math.log1p, remap_lbl
+ optional string user_define_fn = 6;
+ // user-defined function path. eg: /samples/demo_script/process_lbl.py
+ optional string user_define_fn_path = 7;
+ // output field type of user-defined function.
+ optional FieldType user_define_fn_res_type = 8;
+ // ignore value
+ optional string ignore_val = 9;
}
// set auto_expand_input_fields to true to
@@ -137,6 +176,14 @@ message DatasetConfig {
// are labels have dimension > 1
repeated uint32 label_dim = 42;
+ message LabelFunction {
+ required string label_name = 1;
+ required string label_func = 2;
+ }
+
+ // extra transformation functions that generate new labels
+ repeated LabelFunction extra_label_func = 43;
+
// whether to shuffle data
optional bool shuffle = 5 [default = true];
@@ -153,10 +200,15 @@ message DatasetConfig {
optional uint32 prefetch_size = 7 [default = 32];
// shard dataset to 1/num_workers in distribute mode
- optional bool shard = 8 [default = false];
+ // this param is not used anymore
+ optional bool shard = 801 [default = false];
+
+ // shard by file, not by sample, valid only for CSVInput
+ optional bool file_shard = 802 [default = false];
enum InputType {
// csv format input, could be used in local or hdfs
+ // support .gz compression(but not .tar.gz files)
CSVInput = 10;
// @Depreciated
CSVInputV2 = 11;
@@ -170,13 +222,34 @@ message DatasetConfig {
OdpsInputV3 = 9;
RTPInput = 4;
RTPInputV2 = 5;
- OdpsRTPInput = 6;
+ OdpsRTPInput = 601;
+ OdpsRTPInputV2 = 602;
TFRecordInput = 7;
BatchTFRecordInput = 14;
// for the purpose to debug performance bottleneck of
// input pipelines
DummyInput = 8;
KafkaInput = 13;
+ HiveInput = 16;
+ HiveRTPInput = 17;
+ HiveParquetInput = 18;
+
+ // All features are packed into one field for fast copying to gpu,
+ // and there are no feature preprocessing step, it is assumed that
+ // features are preprocessed before training.
+ // Requirements: python3 and tf2.x due to multiprocssing spawn and
+ // RaggedTensor apis.
+ ParquetInput = 19;
+
+ // Features are not packed, and are preprocessing separately.
+ // Requirements: python3 and tf2.x due to multiprocssing spawn and
+ // RaggedTensor apis.
+ ParquetInputV2 = 20;
+
+ // c++ version of parquet dataset which currently are only available
+ // with deeprec.
+ ParquetInputV3 = 21;
+ CriteoInput = 1001;
}
required InputType input_type = 10;
@@ -239,10 +312,16 @@ message DatasetConfig {
// may not be the same as that in csv files.
optional bool with_header = 25 [default = false];
+ repeated string feature_fields = 26;
+
oneof sampler {
NegativeSampler negative_sampler = 101;
NegativeSamplerV2 negative_sampler_v2 = 102;
HardNegativeSampler hard_negative_sampler = 103;
HardNegativeSamplerV2 hard_negative_sampler_v2 = 104;
+ NegativeSamplerInMemory negative_sampler_in_memory = 105;
}
+ optional uint32 eval_batch_size = 1001 [default = 4096];
+
+ optional bool drop_remainder = 1002 [default = false];
}
diff --git a/easy_rec/python/protos/dbmtl.proto b/easy_rec/python/protos/dbmtl.proto
index 57a7733b2..9adff1f62 100644
--- a/easy_rec/python/protos/dbmtl.proto
+++ b/easy_rec/python/protos/dbmtl.proto
@@ -3,8 +3,14 @@ package protos;
import "easy_rec/python/protos/dnn.proto";
import "easy_rec/python/protos/tower.proto";
+import "easy_rec/python/protos/cmbf.proto";
+import "easy_rec/python/protos/uniter.proto";
message DBMTL {
+ // shared bottom cmbf layer
+ optional CMBFTower bottom_cmbf = 101;
+ // shared bottom uniter layer
+ optional UniterTower bottom_uniter = 102;
// shared bottom dnn layer
optional DNN bottom_dnn = 1;
// mmoe expert dnn layer definition
diff --git a/easy_rec/python/protos/dnn.proto b/easy_rec/python/protos/dnn.proto
index 021d34dbb..10fb6631f 100644
--- a/easy_rec/python/protos/dnn.proto
+++ b/easy_rec/python/protos/dnn.proto
@@ -12,3 +12,22 @@ message DNN {
// use batch normalization
optional bool use_bn = 4 [default = true];
}
+
+message MLP {
+ // hidden units for each layer
+ repeated uint32 hidden_units = 1;
+ // ratio of dropout
+ repeated float dropout_ratio = 2;
+ // activation function
+ optional string activation = 3 [default = 'relu'];
+ // use batch normalization
+ optional bool use_bn = 4 [default = true];
+ optional bool use_final_bn = 5 [default = true];
+ optional string final_activation = 6 [default = 'relu'];
+ optional bool use_bias = 7 [default = false];
+ // kernel_initializer
+ optional string initializer = 8 [default = 'he_uniform'];
+ optional bool use_bn_after_activation = 9;
+ optional bool use_final_bias = 10 [default = false];
+ optional bool add_to_outputs = 11 [default = false];
+}
diff --git a/easy_rec/python/protos/dssm.proto b/easy_rec/python/protos/dssm.proto
index ab83e66b1..c0015a28e 100644
--- a/easy_rec/python/protos/dssm.proto
+++ b/easy_rec/python/protos/dssm.proto
@@ -18,4 +18,8 @@ message DSSM {
optional Similarity simi_func = 4 [default=COSINE];
// add a layer for scaling the similarity
optional bool scale_simi = 5 [default = true];
+ optional string item_id = 9;
+ required bool ignore_in_batch_neg_sam = 10 [default = false];
+ // normalize user_tower_embedding and item_tower_embedding
+ optional float temperature = 11 [default = 1.0];
}
diff --git a/easy_rec/python/protos/dssm_senet.proto b/easy_rec/python/protos/dssm_senet.proto
new file mode 100644
index 000000000..ee941104f
--- /dev/null
+++ b/easy_rec/python/protos/dssm_senet.proto
@@ -0,0 +1,27 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+import "easy_rec/python/protos/simi.proto";
+import "easy_rec/python/protos/layer.proto";
+
+message DSSM_SENet_Tower {
+ required string id = 1;
+ required SENet senet = 2;
+ required DNN dnn = 3;
+
+};
+
+
+message DSSM_SENet {
+ required DSSM_SENet_Tower user_tower = 1;
+ required DSSM_SENet_Tower item_tower = 2;
+ required float l2_regularization = 3 [default = 1e-4];
+ optional Similarity simi_func = 4 [default=COSINE];
+ // add a layer for scaling the similarity
+ optional bool scale_simi = 5 [default = true];
+ optional string item_id = 9;
+ required bool ignore_in_batch_neg_sam = 10 [default = false];
+ // normalize user_tower_embedding and item_tower_embedding
+ optional float temperature = 11 [default = 1.0];
+}
diff --git a/easy_rec/python/protos/eas_serving.proto b/easy_rec/python/protos/eas_serving.proto
deleted file mode 100644
index 96b3be297..000000000
--- a/easy_rec/python/protos/eas_serving.proto
+++ /dev/null
@@ -1,62 +0,0 @@
-syntax = "proto3";
-package protos;
-
-message EmbeddingPartData {
- // Shape of the embedding
- repeated int64 shape = 1;
- // Data
- repeated float data = 2 [packed = true];
-}
-
-message Config {
- // 例如输入特征为"1005,109;0;93eaba74",此时分号分割的为column,
- // 逗号分割的为每个column的多个feature, 下划线分割为feature名字和对应的value。
- string column_delim = 1;
- string feature_delim = 2;
-
- // 指定字符串hash分桶的算法,支持HarmHash(对应于tf.strings.to_hash_bucket_fast())
- // 和SipHash(对应于tf.strings.to_hash_bucket_strong())两种字符串hash分桶算法
- string hash = 3;
-
- // embedding_name to embedding
- map embeddings = 4;
- // 指定embedding lookup的结果的最大L2-norm
- map embedding_max_norm = 5;
- // 指定embedding的combiner策略,支持sum, mean和sqrtn
- map embedding_combiner = 6;
-
- Model model = 7;
-}
-
-message Embedding {
- // 指定该embedding切分的总数
- int32 partition_num = 1;
- repeated EmbeddingPart parts = 2;
-}
-
-message EmbeddingPart {
- // 指定EmbeddingPartData(*.pb)所在的路径
- string embedding_part_path = 1;
- // 指定该embedding part所属第几个part
- int32 partition_id = 2;
- // 指定该embedding part的shape(可以从EmbeddingPartData中读取)
- repeated int64 shape = 3;
- // embedding part的部署策略, 支持本地部署(local)和远程部署(remote)
- string deploy_strategy = 4;
-}
-
-message ModelInput {
- string feature_name = 1;
- string embedding_name = 2;
- string placeholder_name = 3;
- string weight_name = 4;
-}
-
-message Model {
- // 指定模型所在路径,便于加载模型
- string model_path = 1;
- // 指定模型的sinature的名字
- string model_signature_name = 2;
- // model input description
- repeated ModelInput model_inputs = 3;
-}
diff --git a/easy_rec/python/protos/easy_rec_model.proto b/easy_rec/python/protos/easy_rec_model.proto
index 6f8ca590d..87cd6754d 100644
--- a/easy_rec/python/protos/easy_rec_model.proto
+++ b/easy_rec/python/protos/easy_rec_model.proto
@@ -1,6 +1,7 @@
syntax = "proto2";
package protos;
+import "easy_rec/python/protos/backbone.proto";
import "easy_rec/python/protos/fm.proto";
import "easy_rec/python/protos/deepfm.proto";
import "easy_rec/python/protos/wide_and_deep.proto";
@@ -16,13 +17,33 @@ import "easy_rec/python/protos/dbmtl.proto";
import "easy_rec/python/protos/ple.proto";
import "easy_rec/python/protos/simple_multi_task.proto";
import "easy_rec/python/protos/dcn.proto";
+import "easy_rec/python/protos/cmbf.proto";
+import "easy_rec/python/protos/uniter.proto";
import "easy_rec/python/protos/autoint.proto";
import "easy_rec/python/protos/mind.proto";
import "easy_rec/python/protos/loss.proto";
import "easy_rec/python/protos/rocket_launching.proto";
import "easy_rec/python/protos/variational_dropout.proto";
+import "easy_rec/python/protos/multi_tower_recall.proto";
+import "easy_rec/python/protos/tower.proto";
+import "easy_rec/python/protos/pdn.proto";
+import "easy_rec/python/protos/dssm_senet.proto";
+import "easy_rec/python/protos/simi.proto";
+import "easy_rec/python/protos/dat.proto";
// for input performance test
message DummyModel {
+}
+
+// configure backbone network common parameters
+message ModelParams {
+ optional float l2_regularization = 1;
+ repeated string outputs = 2;
+ repeated BayesTaskTower task_towers = 3;
+ optional int32 user_tower_idx_in_output = 4 [default = 0];
+ optional int32 item_tower_idx_in_output = 5 [default = 1];
+ optional Similarity simi_func = 6 [default = COSINE];
+ optional float temperature = 7 [default = 1.0];
+ optional bool scale_simi = 8 [default = false];
}
@@ -36,22 +57,46 @@ message KD {
required string soft_label_name = 21;
// default to be logits
optional bool label_is_logits = 22 [default=true];
- // currently only support CROSS_ENTROPY_LOSS and L2_LOSS
required LossType loss_type = 3;
optional float loss_weight = 4 [default=1.0];
- // only for loss_type == CROSS_ENTROPY_LOSS
+ // only for loss_type == CROSS_ENTROPY_LOSS or BINARY_CROSS_ENTROPY_LOSS or KL_DIVERGENCE_LOSS
optional float temperature = 5 [default=1.0];
-
+ // field name for indicating the sample space for this task
+ optional string task_space_indicator_name = 6;
+ // field value for indicating the sample space for this task
+ optional string task_space_indicator_value = 7;
+ // the loss weight for sample in the task space
+ optional float in_task_space_weight = 8 [default = 1.0];
+ // the loss weight for sample out the task space
+ optional float out_task_space_weight = 9 [default = 1.0];
+
+ oneof loss_param {
+ F1ReweighedLoss f1_reweighted_loss = 101;
+ SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102;
+ CircleLoss circle_loss = 103;
+ MultiSimilarityLoss multi_simi_loss = 104;
+ BinaryFocalLoss binary_focal_loss = 105;
+ PairwiseLoss pairwise_loss = 106;
+ PairwiseFocalLoss pairwise_focal_loss = 107;
+ PairwiseLogisticLoss pairwise_logistic_loss = 108;
+ JRCLoss jrc_loss = 109;
+ PairwiseHingeLoss pairwise_hinge_loss = 110;
+ ListwiseRankLoss listwise_rank_loss = 111;
+ ListwiseDistillLoss listwise_distill_loss = 112;
+ }
}
message EasyRecModel {
required string model_class = 1;
+ // just a name for backbone config
+ optional string model_name = 99;
// actually input layers, each layer produce a group of feature
repeated FeatureGroupConfig feature_groups = 2;
// model parameters
oneof model {
+ ModelParams model_params = 100;
DummyModel dummy = 101;
WideAndDeep wide_and_deep = 102;
DeepFM deepfm = 103;
@@ -60,11 +105,17 @@ message EasyRecModel {
DCN dcn = 106;
AutoInt autoint = 107;
DLRM dlrm = 108;
+ CMBF cmbf = 109;
+ Uniter uniter = 110;
+ MultiTowerRecall multi_tower_recall = 200;
DSSM dssm = 201;
MIND mind = 202;
DropoutNet dropoutnet = 203;
CoMetricLearningI2I metric_learning = 204;
+ PDN pdn = 205;
+ DSSM_SENet dssm_senet = 206;
+ DAT dat = 207;
MMoE mmoe = 301;
ESMM esmm = 302;
@@ -84,7 +135,7 @@ message EasyRecModel {
optional uint32 num_class = 10 [default = 1];
- optional bool use_embedding_variable = 11 [default=false];
+ optional EVParams ev_params = 11;
repeated KD kd = 12;
@@ -95,4 +146,16 @@ message EasyRecModel {
optional VariationalDropoutLayer variational_dropout = 14;
repeated Loss losses = 15;
+
+ enum LossWeightStrategy {
+ Fixed = 0;
+ Uncertainty = 1;
+ Random = 2;
+ }
+ required LossWeightStrategy loss_weight_strategy = 16 [default = Fixed];
+
+ optional BackboneTower backbone = 17;
+
+ // label name for rank_model to select one label between multiple labels
+ optional string label_name = 18;
}
diff --git a/easy_rec/python/protos/export.proto b/easy_rec/python/protos/export.proto
index b5b419118..568ad216b 100644
--- a/easy_rec/python/protos/export.proto
+++ b/easy_rec/python/protos/export.proto
@@ -15,7 +15,7 @@ message ExportConfig {
// type of exporter [final | latest | best | none] when train_and_evaluation
// final: performs a single export in the end of training
// latest: regularly exports the serving graph and checkpoints
- // latest: export the best model according to best_exporter_metric
+ // best: export the best model according to best_exporter_metric
// none: do not perform export
optional string exporter_type = 2 [default = 'final'];
@@ -42,9 +42,22 @@ message ExportConfig {
// multi value field list
optional MultiValueFields multi_value_fields = 10;
+
+ // auto analyze multi value fields
+ optional bool auto_multi_value = 16 [default = false];
+
// is placeholder named by input
optional bool placeholder_named_by_input = 11 [default = false];
// filter out inputs, only keep effective ones
optional bool filter_inputs = 12 [default = true];
+
+ // export the original feature values as string
+ optional bool export_features = 13 [default = false];
+
+ // export the outputs required by RTP
+ optional bool export_rtp_outputs = 14 [default = false];
+
+ // export asset files
+ repeated string asset_files = 15;
}
diff --git a/easy_rec/python/protos/feature_config.proto b/easy_rec/python/protos/feature_config.proto
index 18ef12ea1..8c0c0c214 100644
--- a/easy_rec/python/protos/feature_config.proto
+++ b/easy_rec/python/protos/feature_config.proto
@@ -3,6 +3,7 @@ package protos;
import "easy_rec/python/protos/hyperparams.proto";
import "easy_rec/python/protos/dnn.proto";
+import "easy_rec/python/protos/layer.proto";
enum WideOrDeep {
DEEP = 0;
WIDE = 1;
@@ -15,19 +16,24 @@ message AttentionCombiner {
message MultiHeadAttentionCombiner {
}
-message TextCnnCombiner {
- repeated uint32 filter_sizes = 1;
- repeated uint32 num_filters = 2;
-}
-
message SequenceCombiner {
oneof combiner {
AttentionCombiner attention = 1;
MultiHeadAttentionCombiner multi_head_attention = 2;
- TextCnnCombiner text_cnn = 3;
+ TextCNN text_cnn = 3;
}
}
+message EVParams {
+ optional uint64 filter_freq = 1 [default=0];
+ optional uint64 steps_to_live = 2 [default=0];
+ // use embedding cache, only for sok hybrid embedding
+ optional bool use_cache = 3 [default=false];
+ // for sok hybrid key value embedding
+ optional uint64 init_capacity = 4 [default=8388608];
+ optional uint64 max_capacity = 5 [default=16777216];
+}
+
message FeatureConfig {
enum FeatureType {
IdFeature = 0;
@@ -36,6 +42,17 @@ message FeatureConfig {
ComboFeature = 3;
LookupFeature = 4;
SequenceFeature = 5;
+ ExprFeature = 6;
+ PassThroughFeature = 7;
+ }
+
+ enum FieldType {
+ INT32 = 0;
+ INT64 = 1;
+ STRING = 2;
+ FLOAT = 4;
+ DOUBLE = 5;
+ BOOL = 6;
}
optional string feature_name = 1;
@@ -61,6 +78,8 @@ message FeatureConfig {
// delimeter to separate sequence multi-values
optional string seq_multi_sep = 101;
+ // truncate sequence data to max_seq_len
+ optional uint32 max_seq_len = 102;
optional string vocab_file = 11;
repeated string vocab_list = 12;
@@ -75,7 +94,7 @@ message FeatureConfig {
optional int32 max_partitions = 18 [default = 1];
// combiner
- optional string combiner = 19 [default = 'mean'];
+ optional string combiner = 19 [default = 'sum'];
// embedding initializer
optional Initializer initializer = 20;
@@ -86,18 +105,46 @@ message FeatureConfig {
optional int32 precision = 21 [default = -1];
// normalize raw feature to [0-1]
- optional double min_val = 22 [default=0.0];
- optional double max_val = 23 [default=0.0];
+ optional double min_val = 212 [default=0.0];
+ optional double max_val = 213 [default=0.0];
+
+ // normalization function for raw features:
+ // such as: tf.math.log1p
+ optional string normalizer_fn = 214;
// raw feature of multiple dimensions
optional uint32 raw_input_dim = 24 [default=1];
// sequence feature combiner
optional SequenceCombiner sequence_combiner = 25;
+
+ // sub feature type for sequence feature
+ optional FeatureType sub_feature_type = 26 [default = IdFeature];
+
+ // sequence length
+ optional uint32 sequence_length = 27 [default = 1];
+
+ // for expr feature
+ optional string expression = 30;
+
+ // embedding variable params
+ optional EVParams ev_params = 31;
+
+ // for combo feature:
+ // if not set, use cross_column
+ // otherwise, the input features are first joined
+ // and then passed to categorical_column
+ optional string combo_join_sep = 401 [default = ''];
+ // separator for each inputs
+ // if not set, combo inputs will not be split
+ repeated string combo_input_seps = 402;
}
message FeatureConfigV2 {
repeated FeatureConfig features = 1 ;
+ // force place embedding lookup ops on cpu to improve
+ // training and inference efficiency.
+ optional bool embedding_on_cpu = 2 [default=false];
}
message FeatureGroupConfig {
@@ -105,12 +152,14 @@ message FeatureGroupConfig {
repeated string feature_names = 2;
optional WideOrDeep wide_deep = 3 [default = DEEP];
- optional SeqAttGroupConfig sequence_features = 4;
+ repeated SeqAttGroupConfig sequence_features = 4;
+ optional bool negative_sampler = 5 [default = false];
}
message SeqAttMap {
repeated string key = 1;
repeated string hist_seq = 2;
+ repeated string aux_hist_seq = 3;
}
message SeqAttGroupConfig {
@@ -119,4 +168,7 @@ message SeqAttGroupConfig {
optional bool tf_summary = 3 [default = false];
optional DNN seq_dnn = 4;
optional bool allow_key_search = 5 [default = false];
+ optional bool need_key_feature = 6 [default = true];
+ optional bool allow_key_transform = 7 [default = false];
+ optional bool transform_dnn = 8 [default = false];
}
diff --git a/easy_rec/python/protos/fm.proto b/easy_rec/python/protos/fm.proto
index c90af8cab..31d8f27d7 100644
--- a/easy_rec/python/protos/fm.proto
+++ b/easy_rec/python/protos/fm.proto
@@ -2,5 +2,6 @@ syntax = "proto2";
package protos;
message FM {
+ optional bool use_variant = 1;
optional float l2_regularization = 5 [default = 1e-4];
}
diff --git a/easy_rec/python/protos/hive_config.proto b/easy_rec/python/protos/hive_config.proto
new file mode 100644
index 000000000..be2d16dbd
--- /dev/null
+++ b/easy_rec/python/protos/hive_config.proto
@@ -0,0 +1,18 @@
+syntax = "proto2";
+package protos;
+
+message HiveConfig {
+ // hive master's ip
+ required string host = 1;
+
+ // hive port
+ required uint32 port = 2 [default = 10000];
+
+ // hive username
+ required string username = 3 [default = 'admin'];
+
+ // hive database
+ required string database = 4 [default = 'default'];
+
+ required string table_name = 5;
+}
diff --git a/easy_rec/python/protos/keras_layer.proto b/easy_rec/python/protos/keras_layer.proto
new file mode 100644
index 000000000..2b8047064
--- /dev/null
+++ b/easy_rec/python/protos/keras_layer.proto
@@ -0,0 +1,41 @@
+syntax = "proto2";
+package protos;
+
+import "google/protobuf/struct.proto";
+import "easy_rec/python/protos/layer.proto";
+import "easy_rec/python/protos/dnn.proto";
+import "easy_rec/python/protos/fm.proto";
+import "easy_rec/python/protos/seq_encoder.proto";
+
+message KerasLayer {
+ required string class_name = 1;
+ oneof params {
+ google.protobuf.Struct st_params = 2;
+ PeriodicEmbedding periodic_embedding = 3;
+ AutoDisEmbedding auto_dis_embedding = 4;
+ NaryDisEmbedding nary_dis_embedding = 21;
+ FM fm = 5;
+ MaskBlock mask_block = 6;
+ MaskNet masknet = 7;
+ SENet senet = 8;
+ Bilinear bilinear = 9;
+ FiBiNet fibinet = 10;
+ MLP mlp = 11;
+ DINEncoder din = 12;
+ BSTEncoder bst = 13;
+ MMoELayer mmoe = 14;
+ SequenceAugment seq_aug = 15;
+ PPNet ppnet = 16;
+ TextCNN text_cnn = 17;
+ HighWayTower highway = 18;
+ OverlapFeature overlap = 19;
+ MappedDotProduct dot_product = 20;
+ Attention attention = 22;
+ MultiHeadAttention multi_head_attention = 23;
+ Transformer transformer = 24;
+ TextEncoder text_encoder = 25;
+ WeightedGate gate = 26;
+ AITMTower aitm = 27;
+ CIN cin=28;
+ }
+}
diff --git a/easy_rec/python/protos/layer.proto b/easy_rec/python/protos/layer.proto
index 78cee0aac..4f45a3d08 100644
--- a/easy_rec/python/protos/layer.proto
+++ b/easy_rec/python/protos/layer.proto
@@ -1,7 +1,150 @@
syntax = "proto2";
package protos;
+import "easy_rec/python/protos/dnn.proto";
+
message HighWayTower {
- required string input = 1;
+ optional string input = 1;
required uint32 emb_size = 2;
+ required string activation = 3 [default = 'relu'];
+ optional float dropout_rate = 4;
+ optional float init_gate_bias = 5 [default = -3.0];
+ optional uint32 num_layers = 6 [default = 1];
+}
+
+message PeriodicEmbedding {
+ required uint32 embedding_dim = 1;
+ required float sigma = 2;
+ optional bool add_linear_layer = 3 [default = true];
+ optional string linear_activation = 4 [default = 'relu'];
+ optional bool output_3d_tensor = 5;
+ optional bool output_tensor_list = 6;
+}
+
+message AutoDisEmbedding {
+ required uint32 embedding_dim = 1;
+ required uint32 num_bins = 2;
+ required float keep_prob = 3 [default = 0.8];
+ required float temperature = 4;
+ optional bool output_3d_tensor = 5;
+ optional bool output_tensor_list = 6;
+}
+
+message NaryDisEmbedding {
+ required uint32 embedding_dim = 1;
+ repeated uint32 carries = 2;
+ optional float multiplier = 3 [default = 1.0];
+ optional string intra_ary_pooling = 4 [default = 'sum'];
+ // for now, inter_ary_pooling not support yet
+ optional string inter_ary_pooling = 5 [default = 'concat'];
+ optional bool output_3d_tensor = 6 [default = false];
+ optional bool output_tensor_list = 7;
+ optional uint32 num_replicas = 8 [default = 1];
+}
+
+message SENet {
+ required uint32 reduction_ratio = 1 [default = 4];
+ optional uint32 num_squeeze_group = 2 [default = 2];
+ optional bool use_skip_connection = 3 [default = true];
+ optional bool use_output_layer_norm = 4 [default = true];
+}
+
+message Bilinear {
+ required string type = 1 [default = 'interaction'];
+ required bool use_plus = 2 [default = true];
+ required uint32 num_output_units = 3;
+}
+
+message FiBiNet {
+ optional Bilinear bilinear = 1;
+ required SENet senet = 2;
+ optional MLP mlp = 8;
+}
+
+message MaskBlock {
+ optional float reduction_factor = 1;
+ optional uint32 output_size = 2;
+ optional uint32 aggregation_size = 3;
+ optional bool input_layer_norm = 4 [default = false];
+ optional uint32 projection_dim = 5;
+}
+
+message MaskNet {
+ repeated MaskBlock mask_blocks = 1;
+ required bool use_parallel = 2 [default = true];
+ optional MLP mlp = 3;
+ optional bool input_layer_norm = 4 [default = true];
+}
+
+message MMoELayer {
+ // number of tasks
+ required uint32 num_task = 1;
+ // mmoe expert mlp layer definition
+ optional MLP expert_mlp = 2;
+ // number of mmoe experts
+ optional uint32 num_expert = 3;
+}
+
+// used in CDN model
+message WeightedGate {
+ optional uint32 weight_index = 1 [default = 0];
+ optional MLP mlp = 2;
+}
+
+// used in PPNet
+message GateNN {
+ optional uint32 output_dim = 1;
+ optional uint32 hidden_dim = 2;
+ // activation function
+ optional string activation = 3 [default = 'relu'];
+ // use batch normalization
+ optional bool use_bn = 4 [default = false];
+ optional float dropout_rate = 5;
+}
+
+message PPNet {
+ required MLP mlp = 1;
+ required GateNN gate_params = 2;
+ // run mode: eager, lazy
+ required string mode = 3 [default = 'eager'];
+ optional bool full_gate_input = 4 [default = true];
+}
+
+message TextCNN {
+ repeated uint32 filter_sizes = 1;
+ repeated uint32 num_filters = 2;
+ required uint32 pad_sequence_length = 3;
+ optional string activation = 4 [default = 'relu'];
+ optional MLP mlp = 5;
+}
+
+message OverlapFeature {
+ optional string separator = 1;
+ optional string default_value = 2;
+ repeated string methods = 3;
+ optional string normalize_fn = 4;
+ repeated float boundaries = 5;
+ optional int32 embedding_dim = 6;
+ optional int32 print_first_n = 7 [default = 0];
+ optional int32 summarize = 8;
+}
+
+message MappedDotProduct {
+ optional string separator = 1;
+ optional float default_value = 2;
+ optional string normalize_fn = 3;
+ repeated float boundaries = 4;
+ optional int32 embedding_dim = 5;
+ optional int32 print_first_n = 6 [default = 0];
+ optional int32 summarize = 7;
+}
+
+message AITMTower {
+ optional uint32 project_dim = 1;
+ optional MLP transfer_mlp = 2;
+ optional bool stop_gradient = 3 [default = true];
+}
+
+message CIN {
+ repeated int32 hidden_feature_sizes = 1;
}
diff --git a/easy_rec/python/protos/loss.proto b/easy_rec/python/protos/loss.proto
index 36c82e7d3..4416111a8 100644
--- a/easy_rec/python/protos/loss.proto
+++ b/easy_rec/python/protos/loss.proto
@@ -12,11 +12,39 @@ enum LossType {
MULTI_SIMILARITY_LOSS = 6;
SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING = 7;
PAIR_WISE_LOSS = 8;
+ F1_REWEIGHTED_LOSS = 9;
+ BINARY_FOCAL_LOSS = 10;
+ PAIRWISE_FOCAL_LOSS = 11;
+ PAIRWISE_LOGISTIC_LOSS = 12;
+ PAIRWISE_HINGE_LOSS = 17;
+ JRC_LOSS = 13;
+ ORDER_CALIBRATE_LOSS = 14;
+ BINARY_CROSS_ENTROPY_LOSS = 15;
+ KL_DIVERGENCE_LOSS = 16;
+ LISTWISE_RANK_LOSS = 18;
+ LISTWISE_DISTILL_LOSS = 19;
+ ZILN_LOSS = 20;
}
message Loss {
required LossType loss_type = 1;
- required float weight = 2 [default = 1.0];
+ optional float weight = 2 [default = 1.0];
+ optional string loss_name = 3;
+ optional bool learn_loss_weight = 4 [default = false];
+ oneof loss_param {
+ F1ReweighedLoss f1_reweighted_loss = 101;
+ SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102;
+ CircleLoss circle_loss = 103;
+ MultiSimilarityLoss multi_simi_loss = 104;
+ BinaryFocalLoss binary_focal_loss = 105;
+ PairwiseLoss pairwise_loss = 106;
+ PairwiseFocalLoss pairwise_focal_loss = 107;
+ PairwiseLogisticLoss pairwise_logistic_loss = 108;
+ JRCLoss jrc_loss = 109;
+ PairwiseHingeLoss pairwise_hinge_loss = 110;
+ ListwiseRankLoss listwise_rank_loss = 111;
+ ListwiseDistillLoss listwise_distill_loss = 112;
+ }
};
message SoftmaxCrossEntropyWithNegativeMining {
@@ -37,3 +65,71 @@ message MultiSimilarityLoss {
required float lamb = 3 [default = 1];
required float eps = 4 [default = 0.1];
}
+
+message F1ReweighedLoss {
+ required float f1_beta_square = 1 [default = 1.0];
+ required float label_smoothing = 2 [default = 0];
+}
+
+message BinaryFocalLoss {
+ required float gamma = 1 [default = 2.0];
+ optional float alpha = 2;
+ optional float ohem_ratio = 3 [default = 1.0];
+ optional float label_smoothing = 4 [default = 0];
+}
+
+message PairwiseLoss {
+ required float margin = 1 [default = 0];
+ optional string session_name = 2;
+ optional float temperature = 3 [default = 1.0];
+}
+
+message PairwiseFocalLoss {
+ required float gamma = 1 [default = 2.0];
+ optional float alpha = 2;
+ optional float hinge_margin = 3 [default = 1.0];
+ optional string session_name = 4;
+ optional float ohem_ratio = 5 [default = 1.0];
+ optional float temperature = 6 [default = 1.0];
+}
+
+message PairwiseLogisticLoss {
+ required float temperature = 1 [default = 1.0];
+ optional string session_name = 2;
+ optional float hinge_margin = 3;
+ optional float ohem_ratio = 4 [default = 1.0];
+ optional bool use_label_margin = 5 [default = false];
+}
+
+message PairwiseHingeLoss {
+ required float temperature = 1 [default = 1.0];
+ optional string session_name = 2;
+ optional float margin = 3 [default = 1.0];
+ optional float ohem_ratio = 4 [default = 1.0];
+ optional bool label_is_logits = 5 [default = true];
+ optional bool use_label_margin = 6 [default = true];
+ optional bool use_exponent = 7 [default = false];
+}
+
+message JRCLoss {
+ required string session_name = 1;
+ optional float alpha = 2 [default = 0.5];
+ optional bool same_label_loss = 3 [default = true];
+ required string loss_weight_strategy = 4 [default = 'fixed'];
+}
+
+message ListwiseRankLoss {
+ required float temperature = 1 [default = 1.0];
+ optional string session_name = 2;
+ optional string transform_fn = 3;
+ optional bool label_is_logits = 4 [default = false];
+ optional bool scale_logits = 5 [default = false];
+}
+
+message ListwiseDistillLoss {
+ required float temperature = 1 [default = 1.0];
+ optional string session_name = 2;
+ optional string transform_fn = 3;
+ optional float label_clip_max_value = 4 [default = 512.0];
+ optional bool scale_logits = 5 [default = false];
+}
diff --git a/easy_rec/python/protos/mind.proto b/easy_rec/python/protos/mind.proto
index 7868666f5..aab988a94 100644
--- a/easy_rec/python/protos/mind.proto
+++ b/easy_rec/python/protos/mind.proto
@@ -17,6 +17,13 @@ message Capsule {
optional float routing_logits_scale = 5 [default=20];
// routing logits initial stddev
optional float routing_logits_stddev = 6 [default=1.0];
+ // squash power
+ optional float squash_pow = 7 [default=1.0];
+ // output ratio
+ optional float scale_ratio = 8 [default=1.0];
+ // constant interest number
+ // in default, use log(seq_len)
+ optional bool const_caps_num = 9 [default=false];
}
message MIND {
@@ -27,13 +34,15 @@ message MIND {
// preprocessing dnn before entering capsule layer
optional DNN pre_capsule_dnn = 101;
- // dnn layers applied on concated results of
- // capsule output and user_context(none sequence features)
+ // dnn layers applied on user_context(none sequence features)
required DNN user_dnn = 102;
+ // concat user and capsule dnn
+ required DNN concat_dnn = 103;
+
// method to combine several user sequences
// such as item_ids, category_ids
- optional UserSeqCombineMethod user_seq_combine = 103 [default=SUM];
+ optional UserSeqCombineMethod user_seq_combine = 104 [default=SUM];
// dnn layers applied on item features
required DNN item_dnn = 2;
@@ -44,7 +53,21 @@ message MIND {
// the better
optional float simi_pow = 4 [default=10];
- optional Similarity simi_func = 6 [default=COSINE];
+ optional Similarity simi_func = 5 [default=COSINE];
+
+ // add a layer for scaling the similarity
+ optional bool scale_simi = 6 [default=true];
required float l2_regularization = 7 [default = 1e-4];
+
+ optional string time_id_fea = 8;
+
+ optional string item_id = 9;
+
+ optional bool ignore_in_batch_neg_sam = 10 [default = false];
+
+ // if small than 1.0, then a loss will be added to
+ // limit the maximal interest similarities, but
+ // in experiments, setup such a loss leads to low hitrate.
+ optional float max_interests_simi = 11 [default = 1.0];
}
diff --git a/easy_rec/python/protos/multi_tower_recall.proto b/easy_rec/python/protos/multi_tower_recall.proto
new file mode 100644
index 000000000..58b6431db
--- /dev/null
+++ b/easy_rec/python/protos/multi_tower_recall.proto
@@ -0,0 +1,19 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+import "easy_rec/python/protos/simi.proto";
+
+
+message RecallTower {
+ required DNN dnn = 1;
+};
+
+
+message MultiTowerRecall {
+ required RecallTower user_tower = 1;
+ required RecallTower item_tower = 2;
+ required float l2_regularization = 3 [default = 1e-4];
+ required DNN final_dnn = 4;
+ required bool ignore_in_batch_neg_sam = 10 [default = false];
+}
diff --git a/easy_rec/python/protos/optimizer.proto b/easy_rec/python/protos/optimizer.proto
index 4be0c6cc6..b825149f9 100644
--- a/easy_rec/python/protos/optimizer.proto
+++ b/easy_rec/python/protos/optimizer.proto
@@ -15,6 +15,7 @@ message Optimizer {
AdagradOptimizer adagrad_optimizer = 107;
FtrlOptimizer ftrl_optimizer = 108;
AdamAsyncWOptimizer adam_asyncw_optimizer = 109;
+ LazyAdamOptimizer lazy_adam_optimizer = 110;
}
optional bool use_moving_average = 5 [default = false];
optional float moving_average_decay = 6 [default = 0.9999];
@@ -65,10 +66,17 @@ message AdamAsyncWOptimizer {
optional float beta2 = 4 [default = 0.999];
}
+message LazyAdamOptimizer {
+ optional LearningRate learning_rate = 1;
+ optional float beta1 = 3 [default = 0.9];
+ optional float beta2 = 4 [default = 0.999];
+}
+
// Configuration message for the AdagradOptimizer
// See: https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
message AdagradOptimizer {
- optional LearningRate learning_rate = 1;
+ optional LearningRate learning_rate = 1;
+ optional float initial_accumulator_value = 2 [default = 0.1];
}
// Only available on pai-tf, which has better performance than AdamOptimizer
diff --git a/easy_rec/python/protos/pdn.proto b/easy_rec/python/protos/pdn.proto
new file mode 100644
index 000000000..dba0852b9
--- /dev/null
+++ b/easy_rec/python/protos/pdn.proto
@@ -0,0 +1,48 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+import "easy_rec/python/protos/simi.proto";
+
+// requires 3 sequence groups:
+// u2i: user behavior info on intereacted item sequence
+// i_seq: trigger item side info sequence
+// i2i: trigger item and target item co-occurance info
+
+message PDN {
+ // encode user info
+ required DNN user_dnn = 1;
+ // encode target item info
+ required DNN item_dnn = 2;
+
+ // encode u2i seq info
+ required DNN u2i_dnn = 3;
+
+ // produce trigger score
+ required DNN trigger_dnn = 4;
+
+ // encode trigger item seqs to target item co-occurance info
+ required DNN i2i_dnn = 5;
+
+ // produce sim score
+ required DNN sim_dnn = 6;
+
+ // direct net user_dnn
+ optional DNN direct_user_dnn = 7;
+
+ // direct net item_dnn
+ optional DNN direct_item_dnn = 8;
+
+ // for direct net, similar to DSSM
+ optional Similarity simi_func = 9 [default=COSINE];
+
+ // for direct net
+ optional bool scale_simi = 10 [default = true];
+
+ // bias net dnn
+ optional DNN bias_dnn = 11;
+
+ optional string item_id = 12;
+
+ optional float l2_regularization = 13 [default=1e-6];
+}
diff --git a/easy_rec/python/protos/pipeline.proto b/easy_rec/python/protos/pipeline.proto
index 09f44c200..1030bb31d 100644
--- a/easy_rec/python/protos/pipeline.proto
+++ b/easy_rec/python/protos/pipeline.proto
@@ -8,6 +8,7 @@ import "easy_rec/python/protos/dataset.proto";
import "easy_rec/python/protos/feature_config.proto";
import "easy_rec/python/protos/easy_rec_model.proto";
import "easy_rec/python/protos/data_source.proto";
+import "easy_rec/python/protos/hive_config.proto";
// EasyRecConfig: the pipeline_config, including all sub configs
@@ -16,11 +17,17 @@ message EasyRecConfig {
string train_input_path = 1;
KafkaServer kafka_train_input = 2;
DatahubServer datahub_train_input = 12;
+ HiveConfig hive_train_input = 101;
+ BinaryDataInput binary_train_input = 102;
+ string parquet_train_input = 103;
}
oneof eval_path {
string eval_input_path = 3;
KafkaServer kafka_eval_input = 4;
DatahubServer datahub_eval_input = 13;
+ HiveConfig hive_eval_input= 201;
+ BinaryDataInput binary_eval_input = 202;
+ string parquet_eval_input = 203;
}
required string model_dir = 5;
@@ -40,5 +47,15 @@ message EasyRecConfig {
optional ExportConfig export_config = 15;
+ // Json file[RTP FG] to define input data and features:
+ // * In easy_rec.python.utils.fg_util.load_fg_json_to_config:
+ // data_config and feature_config will be generated
+ // based on fg_json.
+ // * After generation, a prefix '!' is added:
+ // fg_json_path = '!' + fg_json_path
+ // indicates config update is already done, and should not
+ // be updated anymore. In this way, we make load_fg_json_to_config
+ // function reentrant.
+ // This step is done before edit_config_json to take effect.
optional string fg_json_path = 16;
}
diff --git a/easy_rec/python/protos/predict.proto b/easy_rec/python/protos/predict.proto
new file mode 100644
index 000000000..888b4ef63
--- /dev/null
+++ b/easy_rec/python/protos/predict.proto
@@ -0,0 +1,75 @@
+syntax = "proto3";
+
+package com.alibaba.pairec.processor;
+
+import "easy_rec/python/protos/tf_predict.proto";
+
+// context features
+message ContextFeatures {
+ repeated PBFeature features = 1;
+}
+
+message PBFeature {
+ oneof value {
+ int32 int_feature = 1;
+ int64 long_feature = 2;
+ string string_feature = 3;
+ float float_feature = 4;
+ }
+}
+
+// PBRequest specifies the request for aggregator
+message PBRequest {
+ // debug mode
+ int32 debug_level = 1;
+
+ // user features
+ map user_features = 2;
+
+ // item ids
+ repeated string item_ids = 3;
+
+ // context features for each item
+ map context_features = 4;
+
+ int32 faiss_neigh_num = 5;
+}
+
+// return results
+message Results {
+ repeated double scores = 1 [packed = true];
+}
+
+enum StatusCode {
+ OK = 0;
+ INPUT_EMPTY = 1;
+ EXCEPTION = 2;
+}
+
+// PBResponse specifies the response for aggregator
+message PBResponse {
+ // results
+ map results = 1;
+
+ // item features
+ map item_features = 2;
+
+ // generate features
+ map generate_features = 3;
+
+ // context features
+ map context_features = 4;
+
+ string error_msg = 5;
+
+ StatusCode status_code = 6;
+
+ repeated string item_ids = 7;
+ repeated string outputs = 8;
+
+ // all fg input features
+ map raw_features = 9;
+
+ // tf output tensors
+ map tf_outputs = 10;
+}
diff --git a/easy_rec/python/protos/seq_encoder.proto b/easy_rec/python/protos/seq_encoder.proto
new file mode 100644
index 000000000..2edf71fa3
--- /dev/null
+++ b/easy_rec/python/protos/seq_encoder.proto
@@ -0,0 +1,111 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+
+
+message Attention {
+ optional bool use_scale = 1 [default = false];
+ optional bool scale_by_dim = 2 [default = false];
+ optional string score_mode = 3 [default = 'dot'];
+ optional float dropout = 4 [default = 0.0];
+ optional int32 seed = 5;
+ optional bool return_attention_scores = 6 [default = false];
+ optional bool use_causal_mask = 7 [default = false];
+}
+
+message MultiHeadAttention {
+ required uint32 num_heads = 1;
+ required uint32 key_dim = 2;
+ optional uint32 value_dim = 3;
+ optional float dropout = 4 [default = 0.0];
+ optional bool use_bias = 5 [default = true];
+ optional bool return_attention_scores = 6 [default = false];
+ optional bool use_causal_mask = 7 [default = false];
+ // The expected shape of an output tensor, besides the batch
+ // and sequence dims. If not specified, projects back to the query
+ // feature dim (the query input's last dimension).
+ optional uint32 output_shape = 8;
+ // axes over which the attention is applied.
+ repeated int32 attention_axes = 9;
+ optional string kernel_initializer = 10 [default = 'glorot_uniform'];
+ optional string bias_initializer = 11 [default = 'zeros'];
+}
+
+message Transformer {
+ // Size of the encoder layers and the pooler layer
+ required uint32 hidden_size = 1;
+ // Number of hidden layers in the Transformer encoder
+ required uint32 num_hidden_layers = 2;
+ // Number of attention heads for each attention layer in the Transformer encoder
+ required uint32 num_attention_heads = 3;
+ // The size of the "intermediate" (i.e. feed-forward) layer in the Transformer encoder
+ required uint32 intermediate_size = 4;
+ // The non-linear activation function (function or string) in the encoder and pooler.
+ required string hidden_act = 5 [default = 'relu'];
+ // The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
+ required float hidden_dropout_prob = 6 [default = 0.1];
+ required uint32 vocab_size = 7;
+ // The maximum sequence length that this model might ever be used with
+ required uint32 max_position_embeddings = 8 [default = 512];
+ // Whether to add position embeddings for the position of each token in the text sequence
+ required bool use_position_embeddings = 9 [default = false];
+ // Whether to output all token embedding, if set to false, then only output the first token embedding
+ required bool output_all_token_embeddings = 10 [default = true];
+ // The dropout ratio for the attention probabilities
+ optional float attention_probs_dropout_prob = 11 [default = 0.0];
+}
+
+message TextEncoder {
+ required Transformer transformer = 1;
+ required string separator = 2 [default = ' '];
+ optional string vocab_file = 3;
+ optional int32 default_token_id = 4 [default = 0];
+}
+
+message BSTEncoder {
+ // Size of the encoder layers and the pooler layer
+ required uint32 hidden_size = 1;
+ // Number of hidden layers in the Transformer encoder
+ required uint32 num_hidden_layers = 2;
+ // Number of attention heads for each attention layer in the Transformer encoder
+ required uint32 num_attention_heads = 3;
+ // The size of the "intermediate" (i.e. feed-forward) layer in the Transformer encoder
+ required uint32 intermediate_size = 4;
+ // The non-linear activation function (function or string) in the encoder and pooler.
+ required string hidden_act = 5 [default = 'gelu']; // "gelu", "relu", "tanh" and "swish" are supported.
+ // The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
+ required float hidden_dropout_prob = 6 [default = 0.1];
+ // The dropout ratio for the attention probabilities
+ required float attention_probs_dropout_prob = 7 [default = 0.1];
+ // The maximum sequence length that this model might ever be used with
+ required uint32 max_position_embeddings = 8 [default = 512];
+ // Whether to add position embeddings for the position of each token in the text sequence
+ required bool use_position_embeddings = 9 [default = true];
+ // The stddev of the truncated_normal_initializer for initializing all weight matrices
+ required float initializer_range = 10 [default = 0.02];
+ // Whether to output all token embedding, if set to false, then only output the first token embedding
+ required bool output_all_token_embeddings = 11 [default = true];
+ // The position of target item (i.e. head, tail, ignore)
+ required string target_item_position = 12 [default = 'head'];
+ // Whether to preserve a position for target
+ required bool reserve_target_position = 13 [default = true];
+}
+
+message DINEncoder {
+ // din attention layer
+ required MLP attention_dnn = 1;
+ // whether to keep target item feature
+ required bool need_target_feature = 2 [default = true];
+ // option: softmax, sigmoid
+ required string attention_normalizer = 3 [default = 'softmax'];
+}
+
+message SequenceAugment {
+ // Percentage length of mask original sequence
+ required float mask_rate = 1 [default = 0.6];
+ // Percentage left of crop original sequence
+ required float crop_rate = 2 [default = 0.2];
+ // Percentage length of reorder original sequence
+ required float reorder_rate = 3 [default = 0.6];
+}
diff --git a/easy_rec/python/protos/tf_predict.proto b/easy_rec/python/protos/tf_predict.proto
new file mode 100644
index 000000000..c95526ab6
--- /dev/null
+++ b/easy_rec/python/protos/tf_predict.proto
@@ -0,0 +1,100 @@
+syntax = "proto3";
+
+package tensorflow.eas;
+option cc_enable_arenas = true;
+
+enum ArrayDataType {
+ // Not a legal value for DataType. Used to indicate a DataType field
+ // has not been set.
+ DT_INVALID = 0;
+
+ // Data types that all computation devices are expected to be
+ // capable to support.
+ DT_FLOAT = 1;
+ DT_DOUBLE = 2;
+ DT_INT32 = 3;
+ DT_UINT8 = 4;
+ DT_INT16 = 5;
+ DT_INT8 = 6;
+ DT_STRING = 7;
+ DT_COMPLEX64 = 8; // Single-precision complex
+ DT_INT64 = 9;
+ DT_BOOL = 10;
+ DT_QINT8 = 11; // Quantized int8
+ DT_QUINT8 = 12; // Quantized uint8
+ DT_QINT32 = 13; // Quantized int32
+ DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
+ DT_QINT16 = 15; // Quantized int16
+ DT_QUINT16 = 16; // Quantized uint16
+ DT_UINT16 = 17;
+ DT_COMPLEX128 = 18; // Double-precision complex
+ DT_HALF = 19;
+ DT_RESOURCE = 20;
+ DT_VARIANT = 21; // Arbitrary C++ data types
+}
+
+// Dimensions of an array
+message ArrayShape {
+ repeated int64 dim = 1 [packed = true];
+}
+
+// Protocol buffer representing an array
+message ArrayProto {
+ // Data Type.
+ ArrayDataType dtype = 1;
+
+ // Shape of the array.
+ ArrayShape array_shape = 2;
+
+ // DT_FLOAT.
+ repeated float float_val = 3 [packed = true];
+
+ // DT_DOUBLE.
+ repeated double double_val = 4 [packed = true];
+
+ // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
+ repeated int32 int_val = 5 [packed = true];
+
+ // DT_STRING.
+ repeated bytes string_val = 6;
+
+ // DT_INT64.
+ repeated int64 int64_val = 7 [packed = true];
+
+ // DT_BOOL.
+ repeated bool bool_val = 8 [packed = true];
+}
+
+// PredictRequest specifies which TensorFlow model to run, as well as
+// how inputs are mapped to tensors and how outputs are filtered before
+// returning to user.
+message PredictRequest {
+ // A named signature to evaluate. If unspecified, the default signature
+ // will be used
+ string signature_name = 1;
+
+ // Input tensors.
+ // Names of input tensor are alias names. The mapping from aliases to real
+ // input tensor names is expected to be stored as named generic signature
+ // under the key "inputs" in the model export.
+ // Each alias listed in a generic signature named "inputs" should be provided
+ // exactly once in order to run the prediction.
+ map inputs = 2;
+
+ // Output filter.
+ // Names specified are alias names. The mapping from aliases to real output
+ // tensor names is expected to be stored as named generic signature under
+ // the key "outputs" in the model export.
+ // Only tensors specified here will be run/fetched and returned, with the
+ // exception that when none is specified, all tensors specified in the
+ // named signature will be run/fetched and returned.
+ repeated string output_filter = 3;
+
+ int32 debug_level = 100;
+}
+
+// Response for PredictRequest on successful run.
+message PredictResponse {
+ // Output tensors.
+ map outputs = 1;
+}
diff --git a/easy_rec/python/protos/tower.proto b/easy_rec/python/protos/tower.proto
index c58013d33..aa5622e20 100644
--- a/easy_rec/python/protos/tower.proto
+++ b/easy_rec/python/protos/tower.proto
@@ -26,12 +26,20 @@ message TaskTower {
optional DNN dnn = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
- // label name for indcating the sample space for the task tower
+ // label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 10;
// the loss weight for sample in the task space
optional float in_task_space_weight = 11 [default = 1.0];
// the loss weight for sample out the task space
optional float out_task_space_weight = 12 [default = 1.0];
+ // multiple losses
+ repeated Loss losses = 13;
+ // whether to use sample weight in this tower
+ required bool use_sample_weight = 14 [default = true];
+ // field name for indicating the sample space for this task
+ optional string task_space_indicator_name = 15;
+ // field value for indicating the sample space for this task
+ optional string task_space_indicator_value = 16;
};
@@ -54,7 +62,7 @@ message BayesTaskTower {
optional DNN relation_dnn = 8;
// training loss weights
optional float weight = 9 [default = 1.0];
- // label name for indcating the sample space for the task tower
+ // label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 10;
// the loss weight for sample in the task space
optional float in_task_space_weight = 11 [default = 1.0];
@@ -64,4 +72,12 @@ message BayesTaskTower {
// required uint32 prediction_level = 13;
// prediction weights
// optional float prediction_weight = 14 [default = 1.0];
+ // multiple losses
+ repeated Loss losses = 15;
+ // whether to use sample weight in this tower
+ required bool use_sample_weight = 16 [default = true];
+ // field name for indicating the sample space for this task
+ optional string task_space_indicator_name = 17;
+ // field value for indicating the sample space for this task
+ optional string task_space_indicator_value = 18;
};
diff --git a/easy_rec/python/protos/train.proto b/easy_rec/python/protos/train.proto
index 7cd6181b4..ab3ca4ddc 100644
--- a/easy_rec/python/protos/train.proto
+++ b/easy_rec/python/protos/train.proto
@@ -19,6 +19,62 @@ enum DistributionStrategy {
// multi worker multi gpu mode
// see tf.distribute.experimental.MultiWorkerMirroredStrategy
MultiWorkerMirroredStrategy = 5;
+ // use horovod strategy
+ HorovodStrategy = 6;
+ // support kv embedding, support kv embedding shard
+ SokStrategy = 7;
+ // support embedding shard, requires horovod
+ EmbeddingParallelStrategy = 8;
+}
+
+message IncrementSaveConfig {
+ message Kafka {
+ message Consumer {
+ optional string config_topic = 1;
+ optional string config_global = 2;
+ optional int64 offset = 3 [default=0];
+ optional int32 timeout = 4 [default=600];
+ }
+ required string server = 1;
+ required string topic = 2;
+ required Consumer consumer = 3;
+ }
+
+ message Datahub {
+ message Consumer {
+ optional int64 offset = 1 [default=0];
+ optional int32 timeout = 2 [default=600];
+ }
+ required string akId = 1;
+ required string akSecret = 2;
+ required string region = 3;
+ required string project = 4;
+ required string topic = 5;
+ required Consumer consumer = 6;
+ }
+
+ message File {
+ optional string incr_save_dir = 1 [default="incr_save"];
+ // relative to model_dir
+ optional bool relative = 2 [default=true];
+ // for online inference, please set the storage.mount_path to mount_path
+ // online service will fail
+ optional string mount_path = 3 [default="/home/admin/docker_ml/workspace/incr_save/"];
+ }
+
+ optional int32 sparse_save_secs = 1 [default=0];
+ optional int32 dense_save_secs = 2 [default=0];
+ optional int32 sparse_save_steps = 3 [default=0];
+ optional int32 dense_save_steps = 4 [default=0];
+
+ // if open, will save increment updates to model_dir/incr_save/
+ optional bool debug_save_update = 5 [default=false];
+
+ oneof incr_update {
+ Kafka kafka = 501;
+ Datahub datahub = 502;
+ File fs = 503;
+ }
}
// Message for configuring EasyRecModel training jobs (train.py).
@@ -46,6 +102,12 @@ message TrainConfig {
// In case so, build a SyncReplicateOptimizer
optional bool sync_replicas = 9 [default = true];
+ // only take effect on pai-tf when sync_replicas is set,
+ // options are:
+ // raw, hash, multi_map, list, parallel
+ // in general, multi_map runs faster than other options.
+ optional string sparse_accumulator_type = 901 [default='multi_map'];
+
// Number of training steps between replica startup.
// This flag must be set to 0 if sync_replicas is set to true.
optional float startup_delay_steps = 10 [default = 15];
@@ -101,4 +163,15 @@ message TrainConfig {
// match variable patterns to freeze
repeated string freeze_gradient = 30;
+
+ // increment save config
+ optional IncrementSaveConfig incr_save_config = 31;
+
+ // enable oss stop signal
+ // stop by create OSS_STOP_SIGNAL under model_dir
+ optional bool enable_oss_stop_signal = 32 [default = false];
+
+ // stop training after dead_line time, format:
+ // 20220508 23:59:59
+ optional string dead_line = 33;
}
diff --git a/easy_rec/python/protos/uniter.proto b/easy_rec/python/protos/uniter.proto
new file mode 100644
index 000000000..9efc1dc9e
--- /dev/null
+++ b/easy_rec/python/protos/uniter.proto
@@ -0,0 +1,35 @@
+syntax = "proto2";
+package protos;
+
+import "easy_rec/python/protos/dnn.proto";
+
+message UniterTower {
+ // Size of the encoder layers and the pooler layer
+ required uint32 hidden_size = 1;
+ // Number of hidden layers in the Transformer encoder
+ required uint32 num_hidden_layers = 2;
+ // Number of attention heads for each attention layer in the Transformer encoder
+ required uint32 num_attention_heads = 3;
+ // The size of the "intermediate" (i.e. feed-forward) layer in the Transformer encoder
+ required uint32 intermediate_size = 4;
+ // The non-linear activation function (function or string) in the encoder and pooler.
+ required string hidden_act = 5 [default = 'gelu']; // "gelu", "relu", "tanh" and "swish" are supported.
+ // The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
+ required float hidden_dropout_prob = 6 [default = 0.1];
+ // The dropout ratio for the attention probabilities
+ required float attention_probs_dropout_prob = 7 [default = 0.1];
+ // The maximum sequence length that this model might ever be used with
+ required uint32 max_position_embeddings = 8 [default = 512];
+ // Whether to add position embeddings for the position of each token in the text sequence
+ required bool use_position_embeddings = 9 [default = true];
+ // The stddev of the truncated_normal_initializer for initializing all weight matrices
+ required float initializer_range = 10 [default = 0.02];
+ // dnn layers for other features
+ optional DNN other_feature_dnn = 11;
+}
+
+message Uniter {
+ required UniterTower config = 1;
+
+ required DNN final_dnn = 2;
+}
diff --git a/easy_rec/python/test/csv_input_test.py b/easy_rec/python/test/csv_input_test.py
index 576b42297..ae0793fa5 100644
--- a/easy_rec/python/test/csv_input_test.py
+++ b/easy_rec/python/test/csv_input_test.py
@@ -2,6 +2,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Define cv_input, the base class for cv tasks."""
+import os
+import unittest
+
import tensorflow as tf
from google.protobuf import text_format
@@ -10,6 +13,7 @@
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
from easy_rec.python.utils import config_util
+from easy_rec.python.utils import constant
from easy_rec.python.utils.test_utils import RunAsSubprocess
if tf.__version__ >= '2.0':
@@ -264,6 +268,14 @@ def test_csv_input_ex(self):
sess.run(init_op)
feature_dict, label_dict = sess.run([features, labels])
+ @unittest.skipIf('AVX_TEST' not in os.environ,
+ 'Only execute when avx512 instructions are supported')
+ @RunAsSubprocess
+ def test_csv_input_ex_avx(self):
+ constant.enable_avx_str_split()
+ self.test_csv_input_ex()
+ constant.disable_avx_str_split()
+
@RunAsSubprocess
def test_csv_data_ignore_error(self):
data_config_str = """
diff --git a/easy_rec/python/test/dh_local_run.py b/easy_rec/python/test/dh_local_run.py
index a4282f891..be514d382 100644
--- a/easy_rec/python/test/dh_local_run.py
+++ b/easy_rec/python/test/dh_local_run.py
@@ -37,7 +37,8 @@ def test_datahub_train_eval(self):
odps_cmd = OdpsCommand(odps_oss_config)
self._success = test_utils.test_datahub_train_eval(
- '%s/configs/deepfm.config' % odps_oss_config.temp_dir, self._test_dir)
+ '%s/configs/deepfm.config' % odps_oss_config.temp_dir, odps_oss_config,
+ self._test_dir)
odps_cmd.run_list(end)
self.assertTrue(self._success)
@@ -48,8 +49,6 @@ def test_datahub_train_eval(self):
'--odps_config', type=str, default=None, help='odps config path')
parser.add_argument(
'--oss_config', type=str, default=None, help='ossutilconfig path')
- parser.add_argument(
- '--datahub_config', type=str, default=None, help='datahub_config')
parser.add_argument(
'--bucket_name', type=str, default=None, help='test oss bucket name')
parser.add_argument('--arn', type=str, default=None, help='oss rolearn')
@@ -73,8 +72,6 @@ def test_datahub_train_eval(self):
if args.odps_config:
odps_oss_config.load_odps_config(args.odps_config)
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
- if args.datahub_config:
- odps_oss_config.load_dh_config(args.datahub_config)
if args.oss_config:
odps_oss_config.load_oss_config(args.oss_config)
if args.odpscmd:
@@ -89,7 +86,6 @@ def test_datahub_train_eval(self):
odps_oss_config.arn = args.arn
if args.bucket_name:
odps_oss_config.bucket_name = args.bucket_name
- print(args)
prepare(odps_oss_config)
start = [
'deep_fm/create_external_deepfm_table.sql',
diff --git a/easy_rec/python/test/eval_metric_test.py b/easy_rec/python/test/eval_metric_test.py
index 57d111be1..22e019b6a 100644
--- a/easy_rec/python/test/eval_metric_test.py
+++ b/easy_rec/python/test/eval_metric_test.py
@@ -7,9 +7,6 @@
import tensorflow as tf
from absl.testing import parameterized
-from easy_rec.python.core.metrics import gauc
-from easy_rec.python.core.metrics import max_f1
-from easy_rec.python.core.metrics import session_auc
from easy_rec.python.utils.test_utils import RunAsSubprocess
if tf.__version__ >= '2.0':
@@ -23,6 +20,7 @@ def setUp(self):
@RunAsSubprocess
def test_max_f1(self):
+ from easy_rec.python.core.metrics import max_f1
labels = tf.constant([1, 0, 0, 1], dtype=tf.int32)
probs = tf.constant([0.9, 0.8, 0.7, 0.6], dtype=tf.float32)
f1, f1_update_op = max_f1(labels, probs)
@@ -35,6 +33,7 @@ def test_max_f1(self):
@RunAsSubprocess
def test_gauc_all_negative_label(self):
+ from easy_rec.python.core.metrics import gauc
labels = tf.constant([0, 0, 0, 0], dtype=tf.int32)
probs = tf.constant([0.9, 0.8, 0.7, 0.6], dtype=tf.float32)
uids = tf.constant([1, 1, 1, 1], dtype=tf.int32)
@@ -50,6 +49,7 @@ def test_gauc_all_negative_label(self):
['_reduction_mean_by_positive_num', 'mean_by_positive_num', 0.6]])
@RunAsSubprocess
def test_gauc(self, reduction, expected):
+ from easy_rec.python.core.metrics import gauc
labels = tf.placeholder(dtype=tf.int32, shape=(None,))
probs = tf.placeholder(dtype=tf.float32, shape=(None,))
uids = tf.placeholder(dtype=tf.int32, shape=(None,))
@@ -78,6 +78,7 @@ def test_gauc(self, reduction, expected):
['_reduction_mean_by_positive_num', 'mean_by_positive_num', 0.6]])
@RunAsSubprocess
def test_session_auc(self, reduction, expected):
+ from easy_rec.python.core.metrics import session_auc
labels = tf.placeholder(dtype=tf.int32, shape=(None,))
probs = tf.placeholder(dtype=tf.float32, shape=(None,))
session_ids = tf.placeholder(dtype=tf.int32, shape=(None,))
diff --git a/easy_rec/python/test/export_test.py b/easy_rec/python/test/export_test.py
index 69d339120..23bff2890 100644
--- a/easy_rec/python/test/export_test.py
+++ b/easy_rec/python/test/export_test.py
@@ -10,6 +10,7 @@
import numpy as np
import tensorflow as tf
+from tensorflow.python.platform import gfile
import easy_rec
from easy_rec.python.inference.predictor import Predictor
@@ -17,11 +18,6 @@
from easy_rec.python.utils import test_utils
from easy_rec.python.utils.test_utils import RunAsSubprocess
-if tf.__version__ >= '2.0':
- gfile = tf.compat.v1.gfile
-else:
- gfile = tf.gfile
-
class ExportTest(tf.test.TestCase):
@@ -54,7 +50,7 @@ def _predict_and_check(self,
for key in keys:
val0 = output_res[i][key]
val1 = cmp_result[i][key]
- diff = np.abs(val0 - val1)
+ diff = np.max(np.abs(val0 - val1))
assert diff < tol, \
'too much difference: %.6f for %s, tol=%.6f' \
% (diff, key, tol)
@@ -81,6 +77,10 @@ def test_multi_tower(self):
self._export_test('samples/model_config/multi_tower_export.config',
self._extract_data)
+ def test_filter_input(self):
+ self._export_test('samples/model_config/export_filter_input.config',
+ self._extract_data)
+
def test_mmoe(self):
self._export_test(
'samples/model_config/mmoe_on_taobao.config',
@@ -115,6 +115,7 @@ def test_export_with_asset(self):
--pipeline_config_path %s
--export_dir %s
--asset_files fg.json:samples/model_config/taobao_fg.json
+ --export_done_file ExportDone
""" % (
config_path,
export_dir,
@@ -127,6 +128,7 @@ def test_export_with_asset(self):
export_dir = files[0]
assert gfile.Exists(export_dir + '/assets/taobao_fg.json')
assert gfile.Exists(export_dir + '/assets/pipeline.config')
+ assert gfile.Exists(export_dir + '/ExportDone')
def test_export_with_out_in_ckpt_config(self):
test_dir = test_utils.get_tmp_dir()
@@ -155,6 +157,12 @@ def _post_check_func(pipeline_config):
test_dir=test_dir,
post_check_func=_post_check_func))
+ def test_multi_class_predict(self):
+ self._export_test(
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
+ extract_data_func=self._extract_data,
+ keys=['probs', 'logits', 'probs_y', 'logits_y', 'y'])
+
def _export_test(self,
pipeline_config_path,
extract_data_func=None,
@@ -382,7 +390,7 @@ def test_big_model_embedding_variable_v2_oss_export(self):
pipeline_config_path,
test_data_path,
self._extract_rtp_data,
- total_steps=1000)
+ total_steps=100)
def _test_big_model_export_to_oss(self,
pipeline_config_path,
@@ -436,7 +444,7 @@ def _test_big_model_export_to_oss(self,
--input_path %s
--output_path %s
""" % (config_path, test_data_path, result_path)
- proc = test_utils.run_cmd(predict_cmd % (),
+ proc = test_utils.run_cmd(predict_cmd,
'%s/log_%s.txt' % (test_dir, 'predict'))
proc.wait()
self.assertTrue(proc.returncode == 0)
diff --git a/easy_rec/python/test/fg_test.py b/easy_rec/python/test/fg_test.py
index 6fe163f96..efbce46ea 100644
--- a/easy_rec/python/test/fg_test.py
+++ b/easy_rec/python/test/fg_test.py
@@ -1,6 +1,7 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
+import unittest
import tensorflow as tf
from google.protobuf import text_format
@@ -52,6 +53,18 @@ def test_fg_dtype(self):
'samples/model_config/taobao_fg_test_dtype.config', self._test_dir)
self.assertTrue(self._success)
+ def test_fg_train(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/fg_train.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf('-PAI' not in tf.__version__,
+ 'Only test when pai-tf is used.')
+ def test_fg_train_ev(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/fg_train_ev.config', self._test_dir)
+ self.assertTrue(self._success)
+
if __name__ == '__main__':
tf.test.main()
diff --git a/easy_rec/python/test/hive_input_test.py b/easy_rec/python/test/hive_input_test.py
new file mode 100644
index 000000000..71aeafd4b
--- /dev/null
+++ b/easy_rec/python/test/hive_input_test.py
@@ -0,0 +1,311 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Define cv_input, the base class for cv tasks."""
+import logging
+import os
+import unittest
+
+import tensorflow as tf
+from google.protobuf import text_format
+
+from easy_rec.python.input.hive_input import HiveInput
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
+from easy_rec.python.protos.hive_config_pb2 import HiveConfig
+from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import test_utils
+
+if tf.__version__ >= '2.0':
+ import tensorflow.compat.v1 as tf
+
+gfile = tf.gfile
+
+if tf.__version__ >= '2.0':
+ from tensorflow.python.framework.ops import disable_eager_execution
+
+ disable_eager_execution()
+ tf = tf.compat.v1
+
+
+class HiveInputTest(tf.test.TestCase):
+
+ def _init_config(self):
+ hive_host = os.environ['hive_host']
+ hive_username = os.environ['hive_username']
+ hive_table_name = os.environ['hive_table_name']
+ hive_hash_fields = os.environ['hive_hash_fields']
+
+ hive_train_input = """
+ host: "{}"
+ username: "{}"
+ table_name: "{}"
+ limit_num: 500
+ hash_fields: "{}"
+ """.format(hive_host, hive_username, hive_table_name, hive_hash_fields)
+ hive_eval_input = """
+ host: "{}"
+ username: "{}"
+ table_name: "{}"
+ limit_num: 500
+ hash_fields: "{}"
+ """.format(hive_host, hive_username, hive_table_name, hive_hash_fields)
+ self.hive_train_input_config = HiveConfig()
+ text_format.Merge(hive_train_input, self.hive_train_input_config)
+
+ self.hive_eval_input_config = HiveConfig()
+ text_format.Merge(hive_eval_input, self.hive_eval_input_config)
+
+ def __init__(self, methodName='HiveInputTest'):
+ super(HiveInputTest, self).__init__(methodName=methodName)
+
+ @unittest.skipIf('hive_host' not in os.environ or
+ 'hive_username' not in os.environ or
+ 'hive_table_name' not in os.environ or
+ 'hive_hash_fields' not in os.environ,
+ """Only execute hive_config var are specified,hive_host、
+ hive_username、hive_table_name、hive_hash_fields is available.""")
+ def test_hive_input(self):
+ self._init_config()
+ data_config_str = """
+ batch_size: 1024
+ label_fields: "label_1"
+ label_fields: "label_2"
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: HiveInput
+ input_fields {
+ input_name:'label_1'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label_2'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "class_of_worker"
+ }
+ input_fields {
+ input_name: "industry_code"
+ }
+ input_fields {
+ input_name: "occupation_code"
+ }
+ input_fields {
+ input_name: "education"
+ }
+ input_fields {
+ input_name: "wage_per_hour"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "enrolled_in_edu_inst_last_wk"
+ }
+ input_fields {
+ input_name: "major_industry"
+ }
+ input_fields {
+ input_name: "major_occupation"
+ }
+ input_fields {
+ input_name: "mace"
+ }
+ input_fields {
+ input_name: "hispanic_origin"
+ }
+ input_fields {
+ input_name: "sex"
+ }
+ input_fields {
+ input_name: "member_of_a_labor_union"
+ }
+ input_fields {
+ input_name: "reason_for_unemployment"
+ }
+ input_fields {
+ input_name: "full_or_part_time_employment_stat"
+ }
+ input_fields {
+ input_name: "capital_gains"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "capital_losses"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "divdends_from_stocks"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "tax_filer_status"
+ }
+ input_fields {
+ input_name: "region_of_previous_residence"
+ }
+ input_fields {
+ input_name: "state_of_previous_residence"
+ }
+ input_fields {
+ input_name: "detailed_household_and_family_stat"
+ }
+ input_fields {
+ input_name: "detailed_household_summary_in_household"
+ }
+ input_fields {
+ input_name: "instance_weight"
+ }
+ input_fields {
+ input_name: "migration_code_change_in_msa"
+ }
+ input_fields {
+ input_name: "migration_code_change_in_reg"
+ }
+ input_fields {
+ input_name: "migration_code_move_within_reg"
+ }
+ input_fields {
+ input_name: "live_in_this_house_1_year_ago"
+ }
+ input_fields {
+ input_name: "migration_prev_res_in_sunbelt"
+ }
+ input_fields {
+ input_name: "num_persons_worked_for_employer"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "family_members_under_18"
+ }
+ input_fields {
+ input_name: "country_of_birth_father"
+ }
+ input_fields {
+ input_name: "country_of_birth_mother"
+ }
+ input_fields {
+ input_name: "country_of_birth_self"
+ }
+ input_fields {
+ input_name: "citizenship"
+ }
+ input_fields {
+ input_name: "own_business_or_self_employed"
+ }
+ input_fields {
+ input_name: "fill_inc_questionnaire_for_veteran_s_admin"
+ }
+ input_fields {
+ input_name: "veterans_benefits"
+ }
+ input_fields {
+ input_name: "weeks_worked_in_year"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "year"
+ }
+ """
+
+ feature_config_str = """
+ input_names: "own_business_or_self_employed"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+ """
+
+ dataset_config = DatasetConfig()
+ text_format.Merge(data_config_str, dataset_config)
+ feature_config = FeatureConfig()
+ text_format.Merge(feature_config_str, feature_config)
+ feature_configs = [feature_config]
+
+ empty_config = FeatureConfig()
+ empty_config.CopyFrom(feature_config)
+ while len(empty_config.input_names) > 0:
+ empty_config.input_names.pop()
+ while len(empty_config.shared_names) > 0:
+ empty_config.shared_names.pop()
+ train_input_fn = HiveInput(dataset_config, feature_configs,
+ self.hive_train_input_config).create_input()
+ dataset = train_input_fn(mode=tf.estimator.ModeKeys.TRAIN)
+ iterator = dataset.make_initializable_iterator()
+ tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
+ features, labels = iterator.get_next()
+ init_op = tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS)
+ gpu_options = tf.GPUOptions(allow_growth=True)
+ session_config = tf.ConfigProto(
+ gpu_options=gpu_options,
+ allow_soft_placement=True,
+ log_device_placement=False)
+ with self.test_session(config=session_config) as sess:
+ sess.run(init_op)
+ feature_dict, label_dict = sess.run([features, labels])
+ for key in feature_dict:
+ print(key, feature_dict[key][:5])
+
+ for key in label_dict:
+ print(key, label_dict[key][:5])
+ return 0
+
+ @unittest.skipIf('hive_host' not in os.environ or
+ 'hive_username' not in os.environ or
+ 'hive_table_name' not in os.environ or
+ 'hive_hash_fields' not in os.environ,
+ """Only execute hive_config var are specified,hive_host、
+ hive_username、hive_table_name、hive_hash_fields is available.""")
+ def test_mmoe(self):
+ pipeline_config_path = 'samples/emr_script/mmoe/mmoe_census_income.config'
+ gpus = test_utils.get_available_gpus()
+ if len(gpus) > 0:
+ test_utils.set_gpu_id(gpus[0])
+ else:
+ test_utils.set_gpu_id(None)
+
+ if not isinstance(pipeline_config_path, EasyRecConfig):
+ logging.info('testing pipeline config %s' % pipeline_config_path)
+ if 'TF_CONFIG' in os.environ:
+ del os.environ['TF_CONFIG']
+
+ if isinstance(pipeline_config_path, EasyRecConfig):
+ pipeline_config = pipeline_config_path
+ else:
+ pipeline_config = test_utils._load_config_for_test(
+ pipeline_config_path, self._test_dir)
+
+ pipeline_config.train_config.train_distribute = 0
+ pipeline_config.train_config.num_gpus_per_worker = 1
+ pipeline_config.train_config.sync_replicas = False
+
+ config_util.save_pipeline_config(pipeline_config, self._test_dir)
+ test_pipeline_config_path = os.path.join(self._test_dir, 'pipeline.config')
+ hyperparam_str = ''
+ train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s %s' % (
+ test_pipeline_config_path, hyperparam_str)
+ proc = test_utils.run_cmd(train_cmd,
+ '%s/log_%s.txt' % (self._test_dir, 'master'))
+ proc.wait()
+ if proc.returncode != 0:
+ logging.error('train %s failed' % test_pipeline_config_path)
+ return 1
+ return 0
+
+ def setUp(self):
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
+ self._test_dir = test_utils.get_tmp_dir()
+ self._success = True
+ logging.info('test dir: %s' % self._test_dir)
+
+ def tearDown(self):
+ test_utils.set_gpu_id(None)
+ if self._success:
+ test_utils.clean_up(self._test_dir)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/easy_rec/python/test/hpo_test.py b/easy_rec/python/test/hpo_test.py
index d3fc60a71..a570d87d0 100644
--- a/easy_rec/python/test/hpo_test.py
+++ b/easy_rec/python/test/hpo_test.py
@@ -9,18 +9,19 @@
import numpy as np
import tensorflow as tf
+from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
from easy_rec.python.utils import config_util
from easy_rec.python.utils import hpo_util
from easy_rec.python.utils import test_utils
if tf.__version__ >= '2.0':
- gfile = tf.compat.v1.gfile
+ import tensorflow.io.gfile as gfile
from tensorflow.core.protobuf import config_pb2
ConfigProto = config_pb2.ConfigProto
GPUOptions = config_pb2.GPUOptions
else:
- gfile = tf.gfile
+ from tensorflow.python.platform import gfile
GPUOptions = tf.GPUOptions
ConfigProto = tf.ConfigProto
@@ -198,6 +199,22 @@ def test_edit_config_v12(self):
assert len(tmp_fea.boundaries) == 25
assert np.abs(tmp_fea.boundaries[1] - 21.0) < 1e-5
+ def test_edit_config_v13(self):
+ tmp_file = 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config'
+ tmp_config = config_util.get_configs_from_pipeline_file(tmp_file)
+ tmp_file = 'samples/hpo/hpo_param_v13.json'
+ tmp_config = config_util.edit_config(tmp_config, self.load_config(tmp_file))
+ assert not tmp_config.export_config.multi_placeholder
+
+ def test_edit_config_v14(self):
+ tmp_file = 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config'
+ tmp_config = config_util.get_configs_from_pipeline_file(tmp_file)
+ tmp_file = 'samples/hpo/hpo_param_v14.json'
+ tmp_config = config_util.edit_config(tmp_config, self.load_config(tmp_file))
+ for i, tmp_fea in enumerate(tmp_config.feature_configs):
+ if tmp_fea.input_names[0] == 'hour':
+ assert len(tmp_fea.feature_type) == FeatureConfig.RawFeature
+
def test_save_eval_metrics_with_env(self):
os.environ['TF_CONFIG'] = """
{ "cluster": {
diff --git a/easy_rec/python/test/kafka_test.py b/easy_rec/python/test/kafka_test.py
new file mode 100644
index 000000000..f0da2d5d5
--- /dev/null
+++ b/easy_rec/python/test/kafka_test.py
@@ -0,0 +1,373 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import json
+import logging
+import os
+import threading
+import time
+import traceback
+import unittest
+
+import numpy as np
+import six
+import tensorflow as tf
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.input.kafka_dataset import KafkaDataset
+from easy_rec.python.utils import numpy_utils
+from easy_rec.python.utils import test_utils
+
+try:
+ import kafka
+ from kafka import KafkaProducer, KafkaAdminClient
+ from kafka.admin import NewTopic
+except ImportError:
+ logging.warning('kafka-python is not installed: %s' % traceback.format_exc())
+
+
+class KafkaTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._success = True
+ self._test_dir = test_utils.get_tmp_dir()
+ if self._testMethodName == 'test_session':
+ self._kafka_server_proc = None
+ self._zookeeper_proc = None
+ return
+
+ logging.info('Testing %s.%s, test_dir=%s' %
+ (type(self).__name__, self._testMethodName, self._test_dir))
+ self._log_dir = os.path.join(self._test_dir, 'logs')
+ if not gfile.IsDirectory(self._log_dir):
+ gfile.MakeDirs(self._log_dir)
+
+ self._kafka_servers = ['127.0.0.1:9092']
+ self._test_topic = 'kafka_op_test_topic'
+
+ if 'kafka_install_dir' in os.environ:
+ kafka_install_dir = os.environ.get('kafka_install_dir', None)
+
+ zookeeper_config_raw = '%s/config/zookeeper.properties' % kafka_install_dir
+ zookeeper_config = os.path.join(self._test_dir, 'zookeeper.properties')
+ with open(zookeeper_config, 'w') as fout:
+ with open(zookeeper_config_raw, 'r') as fin:
+ for line_str in fin:
+ if line_str.startswith('dataDir='):
+ fout.write('dataDir=%s/zookeeper\n' % self._test_dir)
+ else:
+ fout.write(line_str)
+ cmd = 'bash %s/bin/zookeeper-server-start.sh %s' % (kafka_install_dir,
+ zookeeper_config)
+ log_file = os.path.join(self._log_dir, 'zookeeper.log')
+ self._zookeeper_proc = test_utils.run_cmd(cmd, log_file)
+
+ kafka_config_raw = '%s/config/server.properties' % kafka_install_dir
+ kafka_config = os.path.join(self._test_dir, 'server.properties')
+ with open(kafka_config, 'w') as fout:
+ with open(kafka_config_raw, 'r') as fin:
+ for line_str in fin:
+ if line_str.startswith('log.dirs='):
+ fout.write('log.dirs=%s/kafka\n' % self._test_dir)
+ else:
+ fout.write(line_str)
+ cmd = 'bash %s/bin/kafka-server-start.sh %s' % (kafka_install_dir,
+ kafka_config)
+ log_file = os.path.join(self._log_dir, 'kafka_server.log')
+ self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
+
+ started = False
+ while not started:
+ if self._kafka_server_proc.poll(
+ ) and self._kafka_server_proc.returncode:
+ logging.warning('start kafka server failed, will retry.')
+ os.system('cat %s' % log_file)
+ self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
+ time.sleep(5)
+ else:
+ try:
+ admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
+ logging.info('old topics: %s' % (','.join(admin_clt.list_topics())))
+ admin_clt.close()
+ started = True
+ except kafka.errors.NoBrokersAvailable:
+ time.sleep(2)
+ self._create_topic()
+ else:
+ self._zookeeper_proc = None
+ self._kafka_server_proc = None
+ self._should_stop = False
+ self._producer = None
+
+ def _create_topic(self, num_partitions=2):
+ admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
+
+ logging.info('create topic: %s' % self._test_topic)
+ topic_list = [
+ NewTopic(
+ name=self._test_topic,
+ num_partitions=num_partitions,
+ replication_factor=1)
+ ]
+
+ admin_clt.create_topics(new_topics=topic_list, validate_only=False)
+ logging.info('all topics: %s' % (','.join(admin_clt.list_topics())))
+ admin_clt.close()
+
+ def _create_producer(self, generate_func):
+ # start produce thread
+
+ prod = threading.Thread(target=generate_func)
+ prod.start()
+ return prod
+
+ def _stop_producer(self):
+ if self._producer is not None:
+ self._should_stop = True
+ self._producer.join()
+
+ def tearDown(self):
+ try:
+ self._stop_producer()
+ if self._kafka_server_proc is not None:
+ self._kafka_server_proc.terminate()
+ except Exception as ex:
+ logging.warning('exception terminate kafka proc: %s' % str(ex))
+
+ try:
+ if self._zookeeper_proc is not None:
+ self._zookeeper_proc.terminate()
+ except Exception as ex:
+ logging.warning('exception terminate zookeeper proc: %s' % str(ex))
+
+ test_utils.set_gpu_id(None)
+ if self._success:
+ test_utils.clean_up(self._test_dir)
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
+ 'Only execute when kafka is available')
+ def test_kafka_ops(self):
+ try:
+ test_utils.set_gpu_id(None)
+
+ def _generate():
+ producer = KafkaProducer(
+ bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
+ i = 0
+ while not self._should_stop:
+ msg = 'user_id_%d' % i
+ producer.send(self._test_topic, msg)
+ producer.close()
+
+ self._producer = self._create_producer(_generate)
+
+ group = 'dataset_consumer'
+ k = KafkaDataset(
+ servers=self._kafka_servers[0],
+ topics=[self._test_topic + ':0', self._test_topic + ':1'],
+ group=group,
+ eof=True,
+ # control the maximal read of each partition
+ config_global=['max.partition.fetch.bytes=1048576'],
+ message_key=True,
+ message_offset=True)
+
+ batch_dataset = k.batch(5)
+
+ iterator = iterator_ops.Iterator.from_structure(
+ batch_dataset.output_types)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ sess = tf.Session()
+ sess.run(init_batch_op)
+
+ p = sess.run(get_next)
+
+ self.assertEquals(len(p), 3)
+ offset = p[2]
+ self.assertEquals(offset[0], '0:0')
+ self.assertEquals(offset[1], '0:1')
+
+ p = sess.run(get_next)
+ offset = p[2]
+ self.assertEquals(offset[0], '0:5')
+ self.assertEquals(offset[1], '0:6')
+
+ max_iter = 300
+ while max_iter > 0:
+ sess.run(get_next)
+ max_iter -= 1
+ except tf.errors.OutOfRangeError:
+ pass
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
+ 'Only execute when kafka is available')
+ def test_kafka_train(self):
+ try:
+ # start produce thread
+ self._producer = self._create_producer(self._generate)
+
+ test_utils.set_gpu_id(None)
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_avazu_kafka.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+ def _generate(self):
+ producer = KafkaProducer(
+ bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
+ while not self._should_stop:
+ with open('data/test/dwd_avazu_ctr_deepmodel_10w.csv', 'r') as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ if self._should_stop:
+ break
+ if six.PY3:
+ line_str = line_str.encode('utf-8')
+ producer.send(self._test_topic, line_str)
+ producer.close()
+ logging.info('data generation thread done.')
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
+ 'Only execute when kafka is available')
+ def test_kafka_train_chief_redundant(self):
+ try:
+ # start produce thread
+ self._producer = self._create_producer(self._generate)
+
+ test_utils.set_gpu_id(None)
+
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/deepfm_combo_avazu_kafka_chief_redundant.config',
+ self._test_dir,
+ num_evaluator=1)
+ self.assertTrue(self._success)
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
+ 'Only execute when kafka is available')
+ def test_kafka_train_v2(self):
+ try:
+ # start produce thread
+ self._producer = self._create_producer(self._generate)
+
+ test_utils.set_gpu_id(None)
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_avazu_kafka_time_offset.config',
+ self._test_dir)
+
+ self.assertTrue(self._success)
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+ @unittest.skipIf(
+ 'kafka_install_dir' not in os.environ or 'oss_path' not in os.environ or
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
+ def test_kafka_processor(self):
+ self._test_kafka_processor(
+ 'samples/model_config/taobao_fg_incr_save.config')
+
+ @unittest.skipIf(
+ 'kafka_install_dir' not in os.environ or 'oss_path' not in os.environ or
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
+ def test_kafka_processor_ev(self):
+ self._test_kafka_processor(
+ 'samples/model_config/taobao_fg_incr_save_ev.config')
+
+ def _test_kafka_processor(self, config_path):
+ self._success = False
+ success = test_utils.test_distributed_train_eval(
+ config_path, self._test_dir, total_steps=500)
+ self.assertTrue(success)
+ export_cmd = """
+ python -m easy_rec.python.export --pipeline_config_path %s/pipeline.config
+ --export_dir %s/export/sep/ --oss_path=%s --oss_ak=%s --oss_sk=%s --oss_endpoint=%s
+ --asset_files ./samples/rtp_fg/fg.json
+ --checkpoint_path %s/train/model.ckpt-0
+ """ % (self._test_dir, self._test_dir, os.environ['oss_path'],
+ os.environ['oss_ak'], os.environ['oss_sk'],
+ os.environ['oss_endpoint'], self._test_dir)
+ proc = test_utils.run_cmd(export_cmd,
+ '%s/log_export_sep.txt' % self._test_dir)
+ proc.wait()
+ self.assertTrue(proc.returncode == 0)
+ files = gfile.Glob(os.path.join(self._test_dir, 'export/sep/[1-9][0-9]*'))
+ export_sep_dir = files[0]
+
+ predict_cmd = """
+ python -m easy_rec.python.inference.processor.test --saved_model_dir %s
+ --input_path data/test/rtp/taobao_test_feature.txt
+ --output_path %s/processor.out --test_dir %s
+ """ % (export_sep_dir, self._test_dir, self._test_dir)
+ envs = dict(os.environ)
+ envs['PROCESSOR_TEST'] = '1'
+ proc = test_utils.run_cmd(
+ predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs)
+ proc.wait()
+ self.assertTrue(proc.returncode == 0)
+
+ with open('%s/processor.out' % self._test_dir, 'r') as fin:
+ processor_out = []
+ for line_str in fin:
+ line_str = line_str.strip()
+ processor_out.append(json.loads(line_str))
+
+ predictor = Predictor(os.path.join(self._test_dir, 'train/export/final/'))
+ with open('data/test/rtp/taobao_test_feature.txt', 'r') as fin:
+ inputs = []
+ for line_str in fin:
+ line_str = line_str.strip()
+ line_tok = line_str.split(';')[-1]
+ line_tok = line_tok.split(chr(2))
+ inputs.append(line_tok)
+ output_res = predictor.predict(inputs, batch_size=1024)
+
+ with open('%s/predictor.out' % self._test_dir, 'w') as fout:
+ for i in range(len(output_res)):
+ fout.write(
+ json.dumps(output_res[i], cls=numpy_utils.NumpyEncoder) + '\n')
+
+ for i in range(len(output_res)):
+ val0 = output_res[i]['probs']
+ val1 = processor_out[i]['probs']
+ diff = np.abs(val0 - val1)
+ assert diff < 1e-4, 'too much difference[%.6f] >= 1e-4' % diff
+ self._success = True
+
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
+ 'Only execute when kafka is available')
+ def test_kafka_train_v3(self):
+ try:
+ # start produce thread
+ self._producer = self._create_producer(self._generate)
+
+ test_utils.set_gpu_id(None)
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_avazu_kafka_time_offset2.config',
+ self._test_dir)
+
+ self.assertTrue(self._success)
+ except Exception as ex:
+ self._success = False
+ raise ex
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/easy_rec/python/test/local_incr_test.py b/easy_rec/python/test/local_incr_test.py
new file mode 100644
index 000000000..ad2d657f3
--- /dev/null
+++ b/easy_rec/python/test/local_incr_test.py
@@ -0,0 +1,122 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import json
+import logging
+import os
+import unittest
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.utils import numpy_utils
+from easy_rec.python.utils import test_utils
+
+
+class LocalIncrTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._success = True
+ self._test_dir = test_utils.get_tmp_dir()
+
+ logging.info('Testing %s.%s, test_dir=%s' %
+ (type(self).__name__, self._testMethodName, self._test_dir))
+ self._log_dir = os.path.join(self._test_dir, 'logs')
+ if not gfile.IsDirectory(self._log_dir):
+ gfile.MakeDirs(self._log_dir)
+
+ @unittest.skipIf(
+ 'oss_path' not in os.environ or
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
+ def test_incr_save(self):
+ self._test_incr_save(
+ 'samples/model_config/taobao_fg_incr_save_local.config')
+
+ @unittest.skipIf(
+ 'oss_path' not in os.environ or
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
+ def test_incr_save_ev(self):
+ self._test_incr_save(
+ 'samples/model_config/taobao_fg_incr_save_ev_local.config')
+
+ @unittest.skipIf(
+ 'oss_path' not in os.environ or
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
+ def test_incr_save_share_ev(self):
+ self._test_incr_save(
+ 'samples/model_config/taobao_fg_incr_save_share_ev_local.config')
+
+ def _test_incr_save(self, config_path):
+ self._success = False
+ success = test_utils.test_distributed_train_eval(
+ config_path,
+ self._test_dir,
+ total_steps=100,
+ edit_config_json={
+ 'train_config.incr_save_config.fs.mount_path':
+ os.path.join(self._test_dir, 'train/incr_save/')
+ })
+ self.assertTrue(success)
+ export_cmd = """
+ python -m easy_rec.python.export --pipeline_config_path %s/pipeline.config
+ --export_dir %s/export/sep/ --oss_path=%s --oss_ak=%s --oss_sk=%s --oss_endpoint=%s
+ --asset_files ./samples/rtp_fg/fg.json
+ --checkpoint_path %s/train/model.ckpt-0
+ """ % (self._test_dir, self._test_dir, os.environ['oss_path'],
+ os.environ['oss_ak'], os.environ['oss_sk'],
+ os.environ['oss_endpoint'], self._test_dir)
+ proc = test_utils.run_cmd(export_cmd,
+ '%s/log_export_sep.txt' % self._test_dir)
+ proc.wait()
+ self.assertTrue(proc.returncode == 0)
+ files = gfile.Glob(os.path.join(self._test_dir, 'export/sep/[1-9][0-9]*'))
+ export_sep_dir = files[0]
+
+ predict_cmd = """
+ python -m easy_rec.python.inference.processor.test --saved_model_dir %s
+ --input_path data/test/rtp/taobao_test_feature.txt
+ --output_path %s/processor.out --test_dir %s
+ """ % (export_sep_dir, self._test_dir, self._test_dir)
+ envs = dict(os.environ)
+ envs['PROCESSOR_TEST'] = '1'
+ proc = test_utils.run_cmd(
+ predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs)
+ proc.wait()
+ self.assertTrue(proc.returncode == 0)
+
+ with open('%s/processor.out' % self._test_dir, 'r') as fin:
+ processor_out = []
+ for line_str in fin:
+ line_str = line_str.strip()
+ processor_out.append(json.loads(line_str))
+
+ predictor = Predictor(os.path.join(self._test_dir, 'train/export/final/'))
+ with open('data/test/rtp/taobao_test_feature.txt', 'r') as fin:
+ inputs = []
+ for line_str in fin:
+ line_str = line_str.strip()
+ line_tok = line_str.split(';')[-1]
+ line_tok = line_tok.split(chr(2))
+ inputs.append(line_tok)
+ output_res = predictor.predict(inputs, batch_size=1024)
+
+ with open('%s/predictor.out' % self._test_dir, 'w') as fout:
+ for i in range(len(output_res)):
+ fout.write(
+ json.dumps(output_res[i], cls=numpy_utils.NumpyEncoder) + '\n')
+
+ for i in range(len(output_res)):
+ val0 = output_res[i]['probs']
+ val1 = processor_out[i]['probs']
+ diff = np.abs(val0 - val1)
+ assert diff < 1e-4, 'too much difference[%.6f] >= 1e-4' % diff
+ self._success = True
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/easy_rec/python/test/loss_test.py b/easy_rec/python/test/loss_test.py
index da2afd11c..f78b74ce6 100644
--- a/easy_rec/python/test/loss_test.py
+++ b/easy_rec/python/test/loss_test.py
@@ -4,6 +4,8 @@
from easy_rec.python.loss.circle_loss import circle_loss
from easy_rec.python.loss.circle_loss import get_anchor_positive_triplet_mask
+from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
+
from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
if tf.__version__ >= '2.0':
@@ -12,6 +14,16 @@
class LossTest(tf.test.TestCase):
+ def test_f1_reweighted_loss(self):
+ print('test_f1_reweighted_loss')
+ logits = tf.constant([0.1, 0.5, 0.3, 0.8, -0.1, 0.3])
+ labels = tf.constant([1, 1, 0, 0, 1, 1])
+ loss = f1_reweight_sigmoid_cross_entropy(
+ labels=labels, logits=logits, beta_square=4)
+ with self.test_session() as sess:
+ loss_val = sess.run(loss)
+ self.assertAlmostEqual(loss_val, 0.47844395, delta=1e-5)
+
def test_softmax_loss_with_negative_mining(self):
print('test_softmax_loss_with_negative_mining')
user_emb = tf.constant([[0.1, 0.5, 0.3], [0.8, -0.1, 0.3], [0.28, 0.3, 0.9],
@@ -23,10 +35,10 @@ def test_softmax_loss_with_negative_mining(self):
label = tf.constant([1, 1, 0, 0, 1, 1])
loss = softmax_loss_with_negative_mining(
- user_emb, item_emb, label, num_negative_samples=1)
+ user_emb, item_emb, label, num_negative_samples=2, seed=1)
with self.test_session() as sess:
loss_val = sess.run(loss)
- self.assertAlmostEqual(loss_val, 0.5240243, delta=1e-5)
+ self.assertAlmostEqual(loss_val, 0.48577175, delta=1e-5)
def test_circle_loss(self):
print('test_circle_loss')
diff --git a/easy_rec/python/test/odps_run.py b/easy_rec/python/test/odps_run.py
index 4b6179fbc..84bd44f9b 100644
--- a/easy_rec/python/test/odps_run.py
+++ b/easy_rec/python/test/odps_run.py
@@ -10,7 +10,7 @@
import oss2
import tensorflow as tf
-from easy_rec.python.test.odps_test import OdpsTest
+from easy_rec.python.test.odps_test_cls import OdpsTest
from easy_rec.python.test.odps_test_prepare import prepare
from easy_rec.python.test.odps_test_util import OdpsOSSConfig
from easy_rec.python.test.odps_test_util import delete_oss_path
@@ -27,13 +27,11 @@ class TestPipelineOnOdps(tf.test.TestCase):
"""train eval export test on odps."""
def test_deepfm(self):
- start_files = [
- 'deep_fm/create_external_deepfm_table.sql',
- 'deep_fm/create_inner_deepfm_table.sql'
- ]
+ start_files = ['deep_fm/create_inner_deepfm_table.sql']
test_files = [
'deep_fm/train_deepfm_model.sql', 'deep_fm/eval_deepfm.sql',
- 'deep_fm/export_deepfm.sql', 'deep_fm/predict_deepfm.sql'
+ 'deep_fm/export_deepfm.sql', 'deep_fm/predict_deepfm.sql',
+ 'deep_fm/export_rtp_ckpt.sql'
]
end_file = ['deep_fm/drop_table.sql']
@@ -42,10 +40,7 @@ def test_deepfm(self):
tot.drop_table()
def test_mmoe(self):
- start_files = [
- 'mmoe/create_external_mmoe_table.sql',
- 'mmoe/create_inner_mmoe_table.sql'
- ]
+ start_files = ['mmoe/create_inner_mmoe_table.sql']
test_files = [
'mmoe/train_mmoe_model.sql',
'mmoe/eval_mmoe.sql',
@@ -59,7 +54,6 @@ def test_mmoe(self):
def test_dssm(self):
start_files = [
- 'dssm/create_external_dssm_table.sql',
'dssm/create_inner_dssm_table.sql',
]
test_files = [
@@ -76,15 +70,13 @@ def test_dssm(self):
tot.drop_table()
def test_multi_tower(self):
- start_files = [
- 'multi_tower/create_external_multi_tower_table.sql',
- 'multi_tower/create_inner_multil_tower_table.sql',
- ]
+ start_files = ['multi_tower/create_inner_multi_tower_table.sql']
test_files = [
'multi_tower/train_multil_tower_din_model.sql',
'multi_tower/train_multil_tower_bst_model.sql',
'multi_tower/eval_multil_tower.sql',
'multi_tower/export_multil_tower.sql',
+ 'multi_tower/export_again_multi_tower.sql',
'multi_tower/predict_multil_tower.sql',
]
end_file = ['multi_tower/drop_multil_tower_table.sql']
@@ -93,16 +85,13 @@ def test_multi_tower(self):
tot.drop_table()
def test_other(self):
- start_files = [
- 'deep_fm/create_external_deepfm_table.sql',
- 'deep_fm/create_inner_deepfm_table.sql'
- ]
+ start_files = ['deep_fm/create_inner_deepfm_table.sql']
test_files = [
# 'other_test/test_train_gpuRequired_mirrored', # 线上报错,
# 'other_test/test_train_distribute_strategy_collective', # 线上报错,
'other_test/test_train_hpo_with_evaluator.sql',
- 'other_test/test_train_version.sql',
- 'other_test/test_train_distribute_strategy_ess.sql',
+ # 'other_test/test_train_version.sql',
+ # 'other_test/test_train_distribute_strategy_ess.sql',
'other_test/test_train_before_export.sql',
'other_test/test_eval_checkpoint_path.sql',
'other_test/test_export_checkpoint_path.sql',
@@ -115,10 +104,7 @@ def test_other(self):
tot.drop_table()
def test_best_exporter(self):
- start_files = [
- 'deep_fm/create_external_deepfm_table.sql',
- 'deep_fm/create_inner_deepfm_table.sql'
- ]
+ start_files = ['deep_fm/create_inner_deepfm_table.sql']
test_files = [
'other_test/test_train_best_export.sql',
]
@@ -172,10 +158,7 @@ def test_embedding_variable(self):
tot.drop_table()
def test_multi_value_export(self):
- start_files = [
- 'multi_value/create_external_multi_value_table.sql',
- 'multi_value/create_inner_multi_value_table.sql',
- ]
+ start_files = ['multi_value/create_inner_multi_value_table.sql']
test_files = ['multi_value/train_multi_tower_model.sql']
end_file = ['multi_value/drop_table.sql']
tot = OdpsTest(start_files, test_files, end_file, odps_oss_config)
@@ -184,11 +167,12 @@ def test_multi_value_export(self):
def test_boundary_test(self):
start_files = [
- 'boundary/create_external_boundary_table.sql',
'boundary/create_inner_boundary_table.sql',
]
test_files = [
- 'boundary/train_multi_tower_model.sql', 'boundary/train_compat.sql'
+ 'boundary/train_multi_tower_model.sql',
+ 'boundary/finetune_multi_tower_model.sql',
+ 'boundary/finetune_multi_tower_conti.sql', 'boundary/train_compat.sql'
]
end_file = ['boundary/drop_table.sql']
tot = OdpsTest(start_files, test_files, end_file, odps_oss_config)
@@ -196,12 +180,8 @@ def test_boundary_test(self):
tot.drop_table()
def test_vector_retrieve(self):
- start_files = [
- 'vector_retrieve/create_inner_vector_table.sql'
- ]
- test_files = [
- 'vector_retrieve/run_vector_retrieve.sql'
- ]
+ start_files = ['vector_retrieve/create_inner_vector_table.sql']
+ test_files = ['vector_retrieve/run_vector_retrieve.sql']
end_file = ['vector_retrieve/drop_table.sql']
tot = OdpsTest(start_files, test_files, end_file, odps_oss_config)
tot.start_test()
@@ -219,6 +199,11 @@ def test_vector_retrieve(self):
parser.add_argument('--arn', type=str, default=None, help='oss rolearn')
parser.add_argument(
'--odpscmd', type=str, default='odpscmd', help='odpscmd path')
+ parser.add_argument(
+ '--algo_name',
+ type=str,
+ default='easy_rec_ext',
+ help='whether use pai-tf 1.15')
parser.add_argument(
'--algo_project', type=str, default=None, help='algo project name')
parser.add_argument(
@@ -228,6 +213,12 @@ def test_vector_retrieve(self):
help='algo resource project name')
parser.add_argument(
'--algo_version', type=str, default=None, help='algo version')
+ parser.add_argument(
+ '--is_outer',
+ type=int,
+ default=1,
+ help='is outer pai or inner pai, the arguments are differed slightly due to history reasons'
+ )
args, unknown_args = parser.parse_known_args()
sys.argv = [sys.argv[0]]
for unk_arg in unknown_args:
@@ -245,10 +236,15 @@ def test_vector_retrieve(self):
odps_oss_config.algo_res_project = args.algo_res_project
if args.algo_version:
odps_oss_config.algo_version = args.algo_version
+ algo_names = ['easy_rec_ext15', 'easy_rec_ext']
+ assert args.algo_name in algo_names, 'algo_name must be oneof: %s' % (
+ ','.join(algo_names))
+ odps_oss_config.algo_name = args.algo_name
if args.arn:
odps_oss_config.arn = args.arn
if args.bucket_name:
odps_oss_config.bucket_name = args.bucket_name
+ odps_oss_config.is_outer = args.is_outer
prepare(odps_oss_config)
tf.test.main()
diff --git a/easy_rec/python/test/odps_test.py b/easy_rec/python/test/odps_test_cls.py
similarity index 100%
rename from easy_rec/python/test/odps_test.py
rename to easy_rec/python/test/odps_test_cls.py
diff --git a/easy_rec/python/test/odps_test_prepare.py b/easy_rec/python/test/odps_test_prepare.py
index 155c8baaa..e4b0c23d2 100644
--- a/easy_rec/python/test/odps_test_prepare.py
+++ b/easy_rec/python/test/odps_test_prepare.py
@@ -1,6 +1,7 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import glob
import logging
import os
import shutil
@@ -24,7 +25,9 @@ def download_data(ali_bucket, script_path):
if os.path.exists(os.path.join(script_path, 'test')):
shutil.rmtree(os.path.join(script_path, 'test'))
- for obj in oss2.ObjectIterator(ali_bucket, prefix='test/odps/data/'):
+ # download data from oss://${ali_bucket}/data/odps_test/
+ # to script_path/test_data
+ for obj in oss2.ObjectIterator(ali_bucket, prefix='data/odps_test/'):
obj_key = obj.key
tmp_oss_dir = os.path.split(obj_key)[0]
obj_path = os.path.join(script_path, tmp_oss_dir)
@@ -36,14 +39,31 @@ def download_data(ali_bucket, script_path):
if obj_key.endswith('/'):
continue
- dst_name = obj_key.replace('test/odps/data/', 'test_data/')
+ dst_name = obj_key.replace('data/odps_test/', 'test_data/')
dst_path = os.path.join(script_path, dst_name)
dst_dir, _ = os.path.split(dst_path)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
ali_bucket.get_object_to_file(obj_key, dst_path)
- logging.info('down file %s to %s completed' %
- ('oss://easyrec/' + obj_key, dst_path))
+ logging.info('down file oss://%s/%s to %s completed' %
+ (ali_bucket.bucket_name, obj_key, dst_path))
+
+
+def merge_files(merge_dir, merge_out):
+ """Merge files in merge_dir into merge_out.
+
+ Args:
+ merge_dir: files of directory to merge.
+ merge_out: merged output file.
+ """
+ input_files = list(glob.glob(merge_dir + '/*'))
+ logging.info('merge %s into %s' % (','.join(input_files), merge_out))
+ with open(merge_out, 'w') as fout:
+ for input_path in glob.glob(merge_dir + '/*'):
+ with open(input_path, 'r') as fin:
+ for line_str in fin:
+ fout.write(line_str)
+ return merge_out
def change_files(odps_oss_config, file_path):
@@ -58,8 +78,10 @@ def change_files(odps_oss_config, file_path):
return
endpoint = odps_oss_config.endpoint.replace('http://', '')
- endpoint_internal = endpoint.replace('.aliyuncs.com',
- '-internal.aliyuncs.com')
+ # endpoint_internal = endpoint.replace('.aliyuncs.com',
+ # '-internal.aliyuncs.com')
+
+ test_data_dir = os.path.join(odps_oss_config.temp_dir, 'test_data')
with open(file_path, 'r') as fin:
lines = fin.readlines()
@@ -67,6 +89,7 @@ def change_files(odps_oss_config, file_path):
with open(file_path, 'w') as fw:
for line in lines:
if 'pai' in line.lower() and 'easy_rec_ext' in line.lower():
+ line = 'pai -name ' + odps_oss_config.algo_name + '\n'
if odps_oss_config.algo_project:
line += '-project=%s\n' % odps_oss_config.algo_project
if odps_oss_config.algo_res_project:
@@ -74,15 +97,34 @@ def change_files(odps_oss_config, file_path):
if odps_oss_config.algo_version:
line += '-Dversion=%s\n' % odps_oss_config.algo_version
- line = line.replace('{OSS_BUCKET_NAME}', odps_oss_config.bucket_name)
+ if odps_oss_config.is_outer:
+ line = line.replace('{OSS_BUCKET_NAME}', odps_oss_config.bucket_name)
+ line = line.replace('{ROLEARN}', odps_oss_config.arn)
+ line = line.replace('{OSS_ENDPOINT}', endpoint)
+ else:
+ tmp_e = odps_oss_config.endpoint
+ # tmp_e = tmp_e.replace('oss-cn-', 'cn-')
+ # tmp_e = tmp_e.replace('.aliyuncs.com', '.oss-internal.aliyun-inc.com')
+ if '-Dbuckets=' in line:
+ line = '-Dbuckets=oss://%s/?role_arn=%s&host=%s\n' % (
+ odps_oss_config.bucket_name, odps_oss_config.arn, tmp_e)
+ elif '-Darn=' in line or '-DossHost' in line:
+ continue
+ line = line.replace('{OSS_BUCKET_NAME}', odps_oss_config.bucket_name)
+
line = line.replace('{TIME_STAMP}', str(odps_oss_config.time_stamp))
+ if 'tunnel upload' in line:
+ line = line.replace('{TEST_DATA_DIR}', test_data_dir)
+ # merge files
+ toks = [x for x in line.split(' ') if x != '']
+ merge_path = toks[2]
+ merge_dir = '_'.join(merge_path.split('_')[:-1])
+ if not os.path.exists(merge_path):
+ merge_files(merge_dir, merge_path)
+
# for emr odps test only
line = line.replace('{TEMP_DIR}', str(odps_oss_config.temp_dir))
-
- line = line.replace('{ROLEARN}', odps_oss_config.arn)
- line = line.replace('{OSS_ENDPOINT_INTERNAL}', endpoint_internal)
- line = line.replace('{OSS_ENDPOINT}', endpoint)
line = line.replace('{ODPS_PROJ_NAME}', odps_oss_config.project_name)
line = line.replace('{EXP_NAME}', odps_oss_config.exp_dir)
fw.write(line)
@@ -98,7 +140,7 @@ def put_data_to_bucket(odps_oss_config):
odps_oss_config.oss_secret,
odps_oss_config.endpoint,
odps_oss_config.bucket_name)
- for sub_dir in ['configs', 'test_data']:
+ for sub_dir in ['configs']:
for root, dirs, files in os.walk(
os.path.join(odps_oss_config.temp_dir, sub_dir)):
for one_file in files:
diff --git a/easy_rec/python/test/odps_test_util.py b/easy_rec/python/test/odps_test_util.py
index f3b540354..35dc7f743 100644
--- a/easy_rec/python/test/odps_test_util.py
+++ b/easy_rec/python/test/odps_test_util.py
@@ -1,7 +1,6 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
-import configparser
import logging
import os
import time
@@ -38,7 +37,7 @@ class OdpsOSSConfig:
def __init__(self, script_path='./samples/odps_script'):
self.time_stamp = int(time.time())
- temp_dir = os.environ.get('TEST_DIR', '/tmp')
+ temp_dir = os.environ.get('TMPDIR', '/tmp')
self.exp_dir = 'easy_rec_odps_test_%d' % self.time_stamp
self.temp_dir = os.path.join(temp_dir, self.exp_dir)
self.log_dir = os.path.join(self.temp_dir, 'logs/')
@@ -59,16 +58,16 @@ def __init__(self, script_path='./samples/odps_script'):
self.odpscmd_path = os.environ.get('ODPS_CMD_PATH', 'odpscmd')
self.odps_config_path = ''
- # input table project name replace {ODPS_PROJ_NAME} in
- # samples/odps_script:
- # grep ODPS_PROJ_NAME -r samples/odps_script/
+
self.project_name = ''
self.dh_id = ''
self.dh_key = ''
- self.dh_endpoint = ''
- self.dh_topic = ''
- self.dh_project = ''
+
+ self.dh_endpoint = 'https://dh-cn-beijing.aliyuncs.com/'
+ self.dh_topic = 'easy_rec_test'
+ self.dh_project = 'easy_rec_test'
+
self.odps_endpoint = ''
self.dh = None
@@ -78,15 +77,11 @@ def __init__(self, script_path='./samples/odps_script'):
self.algo_project = None
self.algo_res_project = None
self.algo_version = None
+ self.algo_name = 'easy_rec_ext'
- def load_dh_config(self, config_path):
- configer = configparser.ConfigParser()
- configer.read(config_path, encoding='utf-8')
- self.dh_id = configer.get('datahub', 'access_id')
- self.dh_key = configer.get('datahub', 'access_key')
- self.dh_endpoint = configer.get('datahub', 'endpoint')
- self.dh_topic = configer.get('datahub', 'topic_name')
- self.dh_project = configer.get('datahub', 'project')
+ # default to outer environment
+ # the difference are ossHost buckets arn settings
+ self.is_outer = True
def load_oss_config(self, config_path):
with open(config_path, 'r') as fin:
@@ -106,10 +101,18 @@ def load_odps_config(self, config_path):
for line_str in fin:
line_str = line_str.strip()
line_str = line_str.replace(' ', '')
- if line_str.startswith('project_name='):
- self.project_name = line_str[len('project_name='):]
- if line_str.startswith('end_point='):
- self.odps_endpoint = line_str[len('end_point='):]
+ key_str = 'project_name='
+ if line_str.startswith(key_str):
+ self.project_name = line_str[len(key_str):]
+ key_str = 'end_point='
+ if line_str.startswith(key_str):
+ self.odps_endpoint = line_str[len(key_str):]
+ key_str = 'access_id='
+ if line_str.startswith(key_str):
+ self.dh_id = line_str[len(key_str):]
+ key_str = 'access_key='
+ if line_str.startswith(key_str):
+ self.dh_key = line_str[len(key_str):]
def clean_topic(self, dh_project):
if not dh_project:
@@ -154,46 +157,50 @@ def init_dh_and_odps(self):
self.odpsTable = 'deepfm_train_%s' % self.time_stamp
self.clean_project()
read_odps = DataFrame(self.odps.get_table(self.odpsTable))
- col = read_odps.schema.names
+ col_name = read_odps.schema.names
col_type = [self.get_input_type(str(i)) for i in read_odps.schema.types]
try:
- self.dh.create_project(self.dh_project, 'EasyRecTest')
+ self.dh.create_project(self.dh_project, comment='EasyRecTest')
logging.info('create project success!')
except ResourceExistException:
- logging.info('project %s already exist!' % self.dh_project)
- except Exception as ex:
- logging.info(traceback.format_exc(ex))
- record_schema = RecordSchema.from_lists(col, col_type)
+ logging.warning('project %s already exist!' % self.dh_project)
+ except Exception:
+ logging.error(traceback.format_exc())
+ record_schema = RecordSchema.from_lists(col_name, col_type)
try:
# project_name, topic_name, shard_count, life_cycle, record_schema, comment
- self.dh.create_tuple_topic(self.dh_project, self.dh_topic, 7, 3,
- record_schema, 'easyrec_datahub')
- logging.info('create tuple topic success!')
+ self.dh.create_tuple_topic(
+ self.dh_project,
+ self.dh_topic,
+ 7,
+ 3,
+ record_schema,
+ comment='EasyRecTest')
+ logging.info('create tuple topic %s success!' % self.dh_topic)
except ResourceExistException:
logging.info('topic %s already exist!' % self.dh_topic)
except Exception as ex:
- logging.error('exception:', ex)
+ logging.error('exception:%s' % str(ex))
logging.error(traceback.format_exc())
try:
self.dh.wait_shards_ready(self.dh_project, self.dh_topic)
- logging.info('shards all ready')
+ logging.info('datahub[%s,%s] shards all ready' %
+ (self.dh_project, self.dh_topic))
topic_result = self.dh.get_topic(self.dh_project, self.dh_topic)
if topic_result.record_type != RecordType.TUPLE:
- logging.error('topic type illegal! ')
+ logging.error('invalid topic type: %s' % str(topic_result.record_type))
record_schema = topic_result.record_schema
t = self.odps.get_table(self.odpsTable)
with t.open_reader() as reader:
- size = 0
record_list = []
- for data in reader[0:1000]:
+ for data in reader:
record = TupleRecord(values=data.values, schema=record_schema)
record_list.append(record)
- if size % 1000:
- self.dh.put_records(self.dh_project, self.dh_topic, record_list)
- record_list = []
- size += 1
- except Exception as e:
- logging.error(e)
+ for i in range(10):
+ self.dh.put_records(self.dh_project, self.dh_topic, record_list)
+ except Exception as ex:
+ logging.error('exception: %s' % str(ex))
+ logging.error(traceback.format_exc())
def get_oss_bucket(oss_key, oss_secret, endpoint, bucket_name):
diff --git a/easy_rec/python/test/pre_check_test.py b/easy_rec/python/test/pre_check_test.py
new file mode 100644
index 000000000..58b295157
--- /dev/null
+++ b/easy_rec/python/test/pre_check_test.py
@@ -0,0 +1,54 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import logging
+
+import tensorflow as tf
+
+from easy_rec.python.utils import test_utils
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+gfile = tf.gfile
+
+
+class CheckTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._test_dir = test_utils.get_tmp_dir()
+ self._success = True
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
+ logging.info('test dir: %s' % self._test_dir)
+
+ def tearDown(self):
+ test_utils.set_gpu_id(None)
+ if self._success:
+ test_utils.clean_up(self._test_dir)
+
+ def test_csv_input_train_with_check(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_taobao.config',
+ self._test_dir,
+ check_mode=True)
+ self.assertTrue(self._success)
+
+ def test_rtp_input_train_with_check(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/taobao_fg.config',
+ self._test_dir,
+ check_mode=True)
+ self.assertTrue(self._success)
+
+ def test_csv_input_with_pre_check(self):
+ self._success = test_utils.test_single_pre_check(
+ 'samples/model_config/dbmtl_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_rtp_input_with_pre_check(self):
+ self._success = test_utils.test_single_pre_check(
+ 'samples/model_config/dbmtl_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/easy_rec/python/test/predictor_test.py b/easy_rec/python/test/predictor_test.py
index f17ff0263..7ad8ae36e 100644
--- a/easy_rec/python/test/predictor_test.py
+++ b/easy_rec/python/test/predictor_test.py
@@ -9,7 +9,9 @@
import numpy as np
import tensorflow as tf
+from easy_rec.python.inference.csv_predictor import CSVPredictor
from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.utils import config_util
from easy_rec.python.utils import test_utils
from easy_rec.python.utils.test_utils import RunAsSubprocess
@@ -123,37 +125,217 @@ def test_fm_pred_dict(self):
class PredictorTestOnDS(tf.test.TestCase):
def setUp(self):
- self._test_input_path = 'data/test/inference/taobao_infer_data.txt'
- self._test_output_path = 'data/test/inference/taobao_infer_result'
- self.gpus = test_utils.get_available_gpus()
- self.assertTrue(len(self.gpus) > 0, 'no available gpu on this machine')
- logging.info('available gpus %s' % self.gpus)
- test_utils.set_gpu_id(self.gpus[0])
+ self._test_dir = test_utils.get_tmp_dir()
+ self._test_output_path = None
logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
def tearDown(self):
- if (os.path.exists(self._test_output_path)):
+ if self._test_output_path and (os.path.exists(self._test_output_path)):
shutil.rmtree(self._test_output_path)
test_utils.set_gpu_id(None)
@RunAsSubprocess
def test_local_pred(self):
- predictor = Predictor('data/test/inference/tb_multitower_export/')
+ test_input_path = 'data/test/inference/taobao_infer_data.txt'
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
+ pipeline_config_path = os.path.join(saved_model_dir,
+ 'assets/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+ predictor = CSVPredictor(
+ saved_model_dir,
+ pipeline_config.data_config,
+ output_sep=';',
+ selected_cols='')
+
+ predictor.predict_impl(
+ test_input_path,
+ self._test_output_path,
+ reserved_cols='ALL_COLUMNS',
+ output_cols='ALL_COLUMNS',
+ slice_id=0,
+ slice_num=1)
+ header_truth = 'logits;probs;clk;buy;pid;adgroup_id;cate_id;campaign_id;customer;'\
+ 'brand;user_id;cms_segid;cms_group_id;final_gender_code;age_level;pvalue_level;' \
+ 'shopping_level;occupation;new_user_class_level;tag_category_list;tag_brand_list;price'
+
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
+ output_res = f.readlines()
+ self.assertTrue(len(output_res) == 101)
+ self.assertEqual(output_res[0].strip(), header_truth)
+
+ @RunAsSubprocess
+ def test_local_pred_with_header(self):
+ test_input_path = 'data/test/inference/taobao_infer_data_with_header.txt'
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
+ pipeline_config_path = os.path.join(saved_model_dir,
+ 'assets/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+ pipeline_config.data_config.with_header = True
+
+ predictor = CSVPredictor(
+ saved_model_dir,
+ pipeline_config.data_config,
+ with_header=True,
+ output_sep=';',
+ selected_cols='')
+
predictor.predict_impl(
- self._test_input_path,
+ test_input_path,
self._test_output_path,
reserved_cols='ALL_COLUMNS',
output_cols='ALL_COLUMNS',
slice_id=0,
- slice_num=1,
- input_sep=',',
- output_sep=';')
+ slice_num=1)
+ header_truth = 'logits;probs;clk;buy;pid;adgroup_id;cate_id;campaign_id;customer;'\
+ 'brand;user_id;cms_segid;cms_group_id;final_gender_code;age_level;pvalue_level;' \
+ 'shopping_level;occupation;new_user_class_level;tag_category_list;tag_brand_list;price'
+
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
+ output_res = f.readlines()
+ self.assertTrue(len(output_res) == 101)
+ self.assertEqual(output_res[0].strip(), header_truth)
- with open(self._test_output_path + '/slice_0.csv', 'r') as f:
+ @RunAsSubprocess
+ def test_local_pred_without_config(self):
+ test_input_path = 'data/test/inference/taobao_infer_data.txt'
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
+ self._success = test_utils.test_single_predict(self._test_dir,
+ test_input_path,
+ self._test_output_path,
+ saved_model_dir)
+ self.assertTrue(self._success)
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
output_res = f.readlines()
self.assertTrue(len(output_res) == 101)
+ @RunAsSubprocess
+ def test_local_pred_with_part_col(self):
+ test_input_path = 'data/test/inference/taobao_infer_data.txt'
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
+ pipeline_config_path = os.path.join(saved_model_dir,
+ 'assets/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+
+ predictor = CSVPredictor(
+ saved_model_dir,
+ pipeline_config.data_config,
+ output_sep=';',
+ selected_cols='')
+
+ predictor.predict_impl(
+ test_input_path,
+ self._test_output_path,
+ reserved_cols='clk,buy,user_id,adgroup_id',
+ output_cols='probs',
+ slice_id=0,
+ slice_num=1)
+ header_truth = 'probs;clk;buy;user_id;adgroup_id'
+
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
+ output_res = f.readlines()
+ self.assertTrue(len(output_res) == 101)
+ self.assertEqual(output_res[0].strip(), header_truth)
+
+ @RunAsSubprocess
+ def test_local_pred_rtp(self):
+ test_input_path = 'data/test/inference/taobao_infer_rtp_data.txt'
+ self._test_output_path = os.path.join(self._test_dir,
+ 'taobao_test_feature_result')
+ saved_model_dir = 'data/test/inference/tb_multitower_rtp_export/'
+ pipeline_config_path = os.path.join(saved_model_dir,
+ 'assets/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+
+ predictor = CSVPredictor(
+ saved_model_dir,
+ pipeline_config.data_config,
+ output_sep=';',
+ selected_cols='0,3')
+ predictor.predict_impl(
+ test_input_path,
+ self._test_output_path,
+ reserved_cols='ALL_COLUMNS',
+ output_cols='ALL_COLUMNS',
+ slice_id=0,
+ slice_num=1)
+ header_truth = 'logits;probs;clk;no_used_1;no_used_2;features'
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
+ output_res = f.readlines()
+ self.assertTrue(len(output_res) == 101)
+ self.assertEqual(output_res[0].strip(), header_truth)
+
+ @RunAsSubprocess
+ def test_local_pred_rtp_with_part_col(self):
+ test_input_path = 'data/test/inference/taobao_infer_rtp_data.txt'
+ self._test_output_path = os.path.join(self._test_dir,
+ 'taobao_test_feature_result')
+ saved_model_dir = 'data/test/inference/tb_multitower_rtp_export/'
+ pipeline_config_path = os.path.join(saved_model_dir,
+ 'assets/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+
+ predictor = CSVPredictor(
+ saved_model_dir,
+ pipeline_config.data_config,
+ output_sep=';',
+ selected_cols='0,3')
+ predictor.predict_impl(
+ test_input_path,
+ self._test_output_path,
+ reserved_cols='clk,features,no_used_1',
+ output_cols='ALL_COLUMNS',
+ slice_id=0,
+ slice_num=1)
+ header_truth = 'logits;probs;clk;features;no_used_1'
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
+ output_res = f.readlines()
+ self.assertTrue(len(output_res) == 101)
+ self.assertEqual(output_res[0].strip(), header_truth)
+
+ @RunAsSubprocess
+ def test_local_pred_embedding(self):
+ test_input_path = 'data/test/inference/taobao_item_feature_data.csv'
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_item_feature')
+ saved_model_dir = 'data/test/inference/dssm_item_model/'
+ pipeline_config_path = os.path.join(saved_model_dir,
+ 'assets/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+ predictor = CSVPredictor(
+ saved_model_dir,
+ pipeline_config.data_config,
+ ds_vector_recall=True,
+ output_sep=';',
+ selected_cols='pid,adgroup_id,cate_id,campaign_id,customer,brand,price')
+
+ predictor.predict_impl(
+ test_input_path,
+ self._test_output_path,
+ reserved_cols='adgroup_id',
+ output_cols='item_emb',
+ slice_id=0,
+ slice_num=1)
+
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
+ output_res = f.readlines()
+ self.assertTrue(
+ output_res[1] ==
+ '-0.187066,-0.027638,-0.117294,0.115318,-0.273561,0.035698,-0.055832,'
+ '0.226849,-0.105808,-0.152751,0.081528,-0.183329,0.134619,0.185392,'
+ '0.096774,0.104428,0.161868,0.269710,-0.268538,0.138760,-0.170105,'
+ '0.232625,-0.121130,0.198466,-0.078941,0.017774,0.268834,-0.238553,0.084058,'
+ '-0.269466,-0.289651,0.179517;620392\n')
+
class PredictorTestV2(tf.test.TestCase):
diff --git a/easy_rec/python/test/run.py b/easy_rec/python/test/run.py
index cfcc44bfd..0c7ac4c79 100644
--- a/easy_rec/python/test/run.py
+++ b/easy_rec/python/test/run.py
@@ -2,11 +2,15 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
import os
+import sys
import unittest
import tensorflow as tf
+from easy_rec.python.utils import test_utils
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -15,55 +19,119 @@
tf.app.flags.DEFINE_string('pattern', '*_test.py', 'test file pattern')
tf.app.flags.DEFINE_string('test_dir', 'easy_rec/python/test',
'directory to be tested')
+tf.app.flags.DEFINE_integer('num_parallel', 10,
+ 'number of parallel executed cases.')
+tf.app.flags.DEFINE_integer('timeout', 3600,
+ 'maximal execute time in seconds for each case.')
FLAGS = tf.flags.FLAGS
def gather_test_cases(test_dir, pattern):
- test_suite = unittest.TestSuite()
discover = unittest.defaultTestLoader.discover(
test_dir, pattern=pattern, top_level_dir=None)
all_tests = []
for suite_discovered in discover:
-
for test_case in suite_discovered:
- test_suite.addTest(test_case)
+ if 'ModuleImportFailure' in str(test_case):
+ logging.error('Failed to gather case: %s' % str(test_case))
+ sys.exit(1)
+ if '_FailedTest' in str(test_case):
+ logging.error('Failed to gather case: %s' % str(test_case))
+ logging.error('Detail message: %s' % test_case.debug())
+ sys.exit(1)
if hasattr(test_case, '__iter__'):
for subcase in test_case:
- if FLAGS.list_tests or FLAGS.list_test_to_file:
- print(subcase.id())
- tid = subcase.id().split('.')[0]
- if tid not in all_tests:
- all_tests.append(tid)
+ toks = subcase.id().split('.')
+ case_file = toks[0]
+ case_name = '.'.join(toks[1:])
+ if (case_file, case_name) not in all_tests:
+ all_tests.append((case_file, case_name))
else:
- if FLAGS.list_tests or FLAGS.list_test_to_file:
- print(test_case.id())
- tid = subcase.id().split('.')[0]
- if tid not in all_tests:
- all_tests.append(tid)
+ toks = test_case.id().split('.')[0]
+ case_file = toks[0]
+ case_name = '.'.join(toks[1:])
+ if (case_file, case_name) not in all_tests:
+ all_tests.append((case_file, case_name))
if FLAGS.list_test_to_file:
- print('save test lists to %s' % FLAGS.list_test_to_file)
+ logging.info('Total number of cases: %d' % len(all_tests))
+ logging.info('save test lists to %s' % FLAGS.list_test_to_file)
with open(FLAGS.list_test_to_file, 'w') as fout:
- for t_name in all_tests:
- fout.write('%s\n' % t_name)
- return test_suite
+ for t_file, t_name in all_tests:
+ fout.write('%s %s\n' % (t_file, t_name))
+ elif FLAGS.list_tests:
+ logging.info('Total number of cases: %d' % len(all_tests))
+ for t_file, t_name in all_tests:
+ logging.info('\t%s.%s' % (t_file, t_name))
+ return all_tests
def main(argv):
- runner = unittest.TextTestRunner()
- test_suite = gather_test_cases(os.path.abspath(FLAGS.test_dir), FLAGS.pattern)
+ all_tests = gather_test_cases(os.path.abspath(FLAGS.test_dir), FLAGS.pattern)
if FLAGS.list_tests or FLAGS.list_test_to_file:
return
- result = runner.run(test_suite)
- if not result.wasSuccessful():
- print('FailNum: %d ErrorNum: %d' %
- (len(result.failures), len(result.errors)))
+ test_dir = os.environ.get('TEST_DIR', '.')
+ if not os.path.isdir(test_dir):
+ os.makedirs(test_dir)
+ test_log_dir = os.path.join(test_dir, 'logs')
+ if not os.path.exists(test_log_dir):
+ os.makedirs(test_log_dir)
+ logging.info('Total number of cases: %d test_dir: %s' %
+ (len(all_tests), test_dir))
+
+ max_num_port_per_proc = 3
+ total_port_num = (max_num_port_per_proc + 2) * FLAGS.num_parallel * 10
+ all_available_ports = test_utils.get_ports_base(total_port_num).tolist()
+
+ procs = {}
+ failed_cases = []
+ for case_file, case_name in all_tests:
+ while len(procs) >= FLAGS.num_parallel:
+ procs_done = []
+ for proc in procs:
+ if proc.poll() is not None:
+ if proc.returncode != 0:
+ fail_file, fail_name, _ = procs[proc]
+ failed_cases.append((fail_file, fail_name, proc.returncode))
+ procs_done.append(proc)
+ for proc in procs_done:
+ _, _, tmp_ports = procs[proc]
+ all_available_ports.extend([int(x) for x in tmp_ports.split(',')])
+ del procs[proc]
+ cmd = 'python -m easy_rec.python.test.%s %s' % (case_file, case_name)
+ log_file = '%s/%s.%s.log' % (test_log_dir, case_file, case_name)
+ tmp_ports = ','.join(
+ [str(x) for x in all_available_ports[:max_num_port_per_proc]])
+ all_available_ports = all_available_ports[max_num_port_per_proc:]
+
+ logging.info('Run %s.%s Log: %s' % (case_file, case_name, log_file))
+ case_envs = dict(os.environ)
+ case_envs['ports'] = tmp_ports
+ proc = test_utils.run_cmd(cmd, log_file, env=case_envs)
+ procs[proc] = (case_file, case_name, tmp_ports)
+
+ for proc in procs:
+ try:
+ test_utils.proc_wait(
+ proc, timeout=int(os.environ.get('TEST_TIME_OUT', 1200)))
+ except Exception as ex:
+ fail_file, fail_name = procs[proc]
+ logging.info('Case Exception: %s.%s %s' % (fail_file, fail_name, str(ex)))
+ proc.kill()
+
+ if proc.returncode != 0:
+ fail_file, fail_name, _ = procs[proc]
+ failed_cases.append((fail_file, fail_name, proc.returncode))
+
+ if len(failed_cases) > 0:
+ logging.info('Number Cases Failed: %d' % len(failed_cases))
+ for fail_file, fail_name, exit_code in failed_cases:
+ logging.info('\t%s.%s failed, exit_code:%d log: %s.%s.log' %
+ (fail_file, fail_name, exit_code, fail_file, fail_name))
+ return 1
else:
- if 'UnitTestSucceedFlag' in os.environ:
- flag_file = os.environ['UnitTestSucceedFlag']
- with open(flag_file, 'w') as fout:
- fout.write('unit succeed.')
- print('create flag file: %s' % flag_file)
+ logging.info('TestSucceed.')
+ return 0
if __name__ == '__main__':
diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py
index a3bc6fc92..60745f16f 100644
--- a/easy_rec/python/test/train_eval_test.py
+++ b/easy_rec/python/test/train_eval_test.py
@@ -1,22 +1,43 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import glob
import logging
import os
+import threading
+import time
import unittest
-from distutils.version import LooseVersion
import numpy as np
+import six
import tensorflow as tf
+from distutils.version import LooseVersion
+from tensorflow.python.platform import gfile
from easy_rec.python.main import predict
from easy_rec.python.utils import config_util
+from easy_rec.python.utils import constant
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import test_utils
+try:
+ import graphlearn as gl
+except Exception:
+ gl = None
+
+try:
+ import horovod as hvd
+except Exception:
+ hvd = None
+
+try:
+ from sparse_operation_kit import experiment as sok
+except Exception:
+ sok = None
+
+tf_version = tf.__version__
if tf.__version__ >= '2.0':
tf = tf.compat.v1
-gfile = tf.gfile
class TrainEvalTest(tf.test.TestCase):
@@ -42,6 +63,18 @@ def test_deepfm_with_combo_feature(self):
'samples/model_config/deepfm_combo_on_avazu_ctr.config', self._test_dir)
self.assertTrue(self._success)
+ def test_deepfm_with_combo_v2_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_v2_on_avazu_ctr.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_deepfm_with_combo_v3_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_v3_on_avazu_ctr.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_deepfm_freeze_gradient(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_freeze_gradient.config', self._test_dir)
@@ -71,10 +104,20 @@ def test_wide_and_deep(self):
self._test_dir)
self.assertTrue(self._success)
+ def test_wide_and_deep_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/wide_and_deep_backbone_on_avazau.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_dlrm(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dlrm_on_taobao.config', self._test_dir)
+ def test_dlrm_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dlrm_backbone_on_taobao.config', self._test_dir)
+
def test_adamw_optimizer(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_combo_on_avazu_adamw_ctr.config',
@@ -88,29 +131,46 @@ def test_momentumw_optimizer(self):
self.assertTrue(self._success)
def test_deepfm_with_param_edit(self):
+ model_dir = os.path.join(self._test_dir, 'train_new')
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
self._test_dir,
- hyperparam_str='{"model_dir":"experiments/deepfm_multi_cls_on_avazu_ctr", '
- '"model_config.deepfm.wide_output_dim": 32}')
+ hyperparam_str='{"model_dir":"%s", '
+ '"model_config.deepfm.wide_output_dim": 32}' % model_dir)
self.assertTrue(self._success)
+ config_path = os.path.join(model_dir, 'pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(config_path)
+ self.assertTrue(pipeline_config.model_dir == model_dir)
+ self.assertTrue(pipeline_config.model_config.deepfm.wide_output_dim == 32)
def test_multi_tower(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_multi_tower_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/multi_tower_backbone_on_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_multi_tower_gauc(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_on_taobao_gauc.config',
self._test_dir)
self.assertTrue(self._success)
+ def test_multi_tower_session_auc(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/multi_tower_on_taobao_session_auc.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_multi_tower_save_checkpoint_secs(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_save_secs_on_taobao.config',
self._test_dir,
- total_steps=500)
+ total_steps=100)
ckpts_times = []
ckpt_dir = os.path.join(self._test_dir, 'train')
for filepath in os.listdir(ckpt_dir):
@@ -119,9 +179,11 @@ def test_multi_tower_save_checkpoint_secs(self):
# remove last ckpt time
ckpts_times = np.array(sorted(ckpts_times)[:-1])
# ensure interval is 20s
+ diffs = list(ckpts_times[1:] - ckpts_times[:-1])
+ logging.info('nearby ckpts_times diff = %s' % diffs)
self.assertAllClose(
ckpts_times[1:] - ckpts_times[:-1], [20] * (len(ckpts_times) - 1),
- atol=8)
+ atol=20)
self.assertTrue(self._success)
def test_keep_ckpt_max(self):
@@ -129,7 +191,6 @@ def test_keep_ckpt_max(self):
def _post_check_func(pipeline_config):
ckpt_prefix = os.path.join(pipeline_config.model_dir, 'model.ckpt-*.meta')
ckpts = gfile.Glob(ckpt_prefix)
- print(ckpts)
assert len(ckpts) == 3, 'invalid number of checkpoints: %d' % len(ckpts)
self._success = test_utils.test_single_train_eval(
@@ -154,8 +215,9 @@ def _post_check_func(pipeline_config):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_best_export_on_taobao.config',
self._test_dir,
- total_steps=1000,
- post_check_func=_post_check_func)
+ total_steps=800,
+ post_check_func=_post_check_func,
+ timeout=3000)
self.assertTrue(self._success)
def test_latest_ckpt(self):
@@ -179,6 +241,67 @@ def _post_check_func(pipeline_config):
post_check_func=_post_check_func)
self.assertTrue(self._success)
+ def test_oss_stop_signal(self):
+ train_dir = os.path.join(self._test_dir, 'train/')
+
+ def _watch_func():
+ while True:
+ tmp_ckpt = estimator_utils.latest_checkpoint(train_dir)
+ if tmp_ckpt is not None:
+ version = estimator_utils.get_ckpt_version(tmp_ckpt)
+ if version > 30:
+ break
+ time.sleep(1)
+ stop_file = os.path.join(train_dir, 'OSS_STOP_SIGNAL')
+ with open(stop_file, 'w') as fout:
+ fout.write('OSS_STOP_SIGNAL')
+
+ watch_th = threading.Thread(target=_watch_func)
+ watch_th.start()
+
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/taobao_fg_signal_stop.config',
+ self._test_dir,
+ total_steps=1000)
+ self.assertTrue(self._success)
+ watch_th.join()
+ final_ckpt = estimator_utils.latest_checkpoint(train_dir)
+ ckpt_version = estimator_utils.get_ckpt_version(final_ckpt)
+ logging.info('final ckpt version = %d' % ckpt_version)
+ self._success = ckpt_version < 1000
+ assert ckpt_version < 1000
+
+ def test_dead_line_stop_signal(self):
+ train_dir = os.path.join(self._test_dir, 'train/')
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/dead_line_stop.config',
+ self._test_dir,
+ total_steps=1000)
+ self.assertTrue(self._success)
+ final_ckpt = estimator_utils.latest_checkpoint(train_dir)
+ ckpt_version = estimator_utils.get_ckpt_version(final_ckpt)
+ logging.info('final ckpt version = %d' % ckpt_version)
+ self._success = ckpt_version < 1000
+ assert ckpt_version < 1000
+
+ def test_fine_tune_latest_ckpt_path(self):
+
+ def _post_check_func(pipeline_config):
+ logging.info('model_dir: %s' % pipeline_config.model_dir)
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ os.path.join(pipeline_config.model_dir, 'pipeline.config'), False)
+ logging.info('fine_tune_checkpoint: %s' %
+ pipeline_config.train_config.fine_tune_checkpoint)
+ return pipeline_config.train_config.fine_tune_checkpoint == \
+ 'data/test/mt_ckpt/model.ckpt-100'
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/multi_tower_on_taobao.config',
+ self._test_dir,
+ fine_tune_checkpoint='data/test/mt_ckpt',
+ post_check_func=_post_check_func)
+ self.assertTrue(self._success)
+
def test_fine_tune_ckpt(self):
def _post_check_func(pipeline_config):
@@ -215,21 +338,63 @@ def test_fm(self):
'samples/model_config/fm_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_place_embed_on_cpu(self):
+ os.environ['place_embedding_on_cpu'] = 'True'
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/fm_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
def test_din(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/din_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_din_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/din_backbone_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
def test_bst(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/bst_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_bst_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/bst_backbone_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_cl4srec(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/cl4srec_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
def test_dcn(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dcn_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_ziln_loss(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/mlp_on_taobao_with_ziln_loss.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_fibinet(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/fibinet_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_masknet(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/masknet_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_dcn_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dcn_backbone_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
def test_dcn_with_f1(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dcn_f1_on_taobao.config', self._test_dir)
@@ -240,6 +405,75 @@ def test_autoint(self):
'samples/model_config/autoint_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_uniter(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/uniter_on_movielens.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_highway(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/highway_on_movielens.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ # @unittest.skipIf(
+ # LooseVersion(tf.__version__) >= LooseVersion('2.0.0'),
+ # 'has no CustomOp when tf version == 2.4')
+ # def test_custom_op(self):
+ # self._success = test_utils.test_single_train_eval(
+ # 'samples/model_config/cl4srec_on_taobao_with_custom_op.config',
+ # self._test_dir)
+ # self.assertTrue(self._success)
+
+ def test_cdn(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/cdn_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_ppnet(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/ppnet_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_uniter_only_text_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/uniter_on_movielens_only_text_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_uniter_only_image_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/uniter_on_movielens_only_image_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_cmbf(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/cmbf_on_movielens.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_cmbf_with_multi_loss(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/cmbf_with_multi_loss.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_cmbf_has_other_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/cmbf_on_movielens_has_other_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_cmbf_only_text_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/cmbf_on_movielens_only_text_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_cmbf_only_image_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/cmbf_on_movielens_only_image_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_dssm(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_on_taobao.config', self._test_dir)
@@ -255,24 +489,35 @@ def test_metric_learning(self):
'samples/model_config/metric_learning_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_neg_sampler(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_neg_sampler_v2(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_v2_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_hard_neg_sampler(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_hard_neg_sampler_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dssm_hard_neg_regular_sampler(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dssm_hard_neg_sampler_regular_on_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_hard_neg_sampler_v2(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config',
@@ -359,11 +604,6 @@ def test_deepfm_with_sigmoid_l2_loss(self):
self._test_dir)
self.assertTrue(self._success)
- # def test_deepfm_with_sequence_attention(self):
- # self._success = test_utils.test_single_train_eval(
- # 'samples/model_config/deppfm_seq_attn_on_taobao.config', self._test_dir)
- # self.assertTrue(self._success)
-
def test_deepfm_with_embedding_learning_rate(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_combo_on_avazu_emblr_ctr.config',
@@ -376,11 +616,28 @@ def test_deepfm_with_eval_online(self):
self._test_dir)
self.assertTrue(self._success)
+ def test_deepfm_with_eval_online_gauc(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_on_avazu_eval_online_gauc_ctr.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_mmoe_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/mmoe_backbone_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_mmoe_with_multi_loss(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/mmoe_on_taobao_with_multi_loss.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_mmoe_deprecated(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_taobao_deprecated.config', self._test_dir)
@@ -392,7 +649,13 @@ def test_simple_multi_task(self):
self._test_dir)
self.assertTrue(self._success)
- def test_essm(self):
+ def test_simple_multi_task_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/simple_multi_task_backbone_on_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_esmm(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/esmm_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
@@ -402,11 +665,37 @@ def test_tag_kv_input(self):
'samples/model_config/kv_tag.config', self._test_dir)
self.assertTrue(self._success)
+ def test_aitm(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/aitm_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
def test_dbmtl(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dbmtl_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ def test_dbmtl_backbone(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_backbone_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_dbmtl_cmbf(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_cmbf_on_movielens.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_dbmtl_uniter(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_uniter_on_movielens.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_dbmtl_with_multi_loss(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_taobao_with_multi_loss.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_early_stop(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_early_stop_on_taobao.config',
@@ -425,6 +714,12 @@ def test_early_stop_dis(self):
self._test_dir)
self.assertTrue(self._success)
+ def test_latest_export_with_asset(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/din_on_taobao_latest_export.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_incompatible_restore(self):
def _post_check_func(config):
@@ -444,13 +739,51 @@ def _post_check_func(config):
def test_dbmtl_variational_dropout(self):
self._success = test_utils.test_single_train_eval(
- 'samples/model_config/dbmtl_variational_dropout.config', self._test_dir)
+ 'samples/model_config/dbmtl_variational_dropout.config',
+ self._test_dir,
+ post_check_func=test_utils.test_feature_selection)
self.assertTrue(self._success)
def test_dbmtl_variational_dropout_feature_num(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dbmtl_variational_dropout_feature_num.config',
- self._test_dir)
+ self._test_dir,
+ post_check_func=test_utils.test_feature_selection)
+ self.assertTrue(self._success)
+
+ def test_essm_variational_dropout(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/esmm_variational_dropout_on_taobao.config',
+ self._test_dir,
+ post_check_func=test_utils.test_feature_selection)
+ self.assertTrue(self._success)
+
+ def test_fm_variational_dropout(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/fm_variational_dropout_on_taobao.config',
+ self._test_dir,
+ post_check_func=test_utils.test_feature_selection)
+ self.assertTrue(self._success)
+
+ def test_deepfm_with_combo_feature_variational_dropout(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_combo_variational_dropout_on_avazu_ctr.config',
+ self._test_dir,
+ post_check_func=test_utils.test_feature_selection)
+ self.assertTrue(self._success)
+
+ def test_dbmtl_sequence_variational_dropout(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_variational_dropout_on_sequence_feature_taobao.config',
+ self._test_dir,
+ post_check_func=test_utils.test_feature_selection)
+ self.assertTrue(self._success)
+
+ def test_din_variational_dropout(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/din_varitional_dropout_on_taobao.config',
+ self._test_dir,
+ post_check_func=test_utils.test_feature_selection)
self.assertTrue(self._success)
def test_rocket_launching(self):
@@ -464,6 +797,12 @@ def test_rocket_launching_feature_based(self):
self._test_dir)
self.assertTrue(self._success)
+ def test_rocket_launching_with_rtp_input(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/rocket_launching_with_rtp_input.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_dbmtl_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dbmtl_mmoe_on_taobao.config', self._test_dir)
@@ -474,6 +813,37 @@ def test_train_with_ps_worker(self):
'samples/model_config/multi_tower_on_taobao.config', self._test_dir)
self.assertTrue(self._success)
+ @unittest.skip("Timeout on CI machine")
+ def test_fit_on_eval(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/multi_tower_on_taobao.config',
+ self._test_dir,
+ total_steps=10,
+ num_evaluator=1,
+ fit_on_eval=True)
+ self.assertTrue(self._success)
+
+ def test_unbalance_data(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/multi_tower_on_taobao_unblanace.config',
+ self._test_dir,
+ total_steps=0,
+ num_epoch=1,
+ num_evaluator=1)
+ self.assertTrue(self._success)
+
+ def test_train_with_ps_worker_with_evaluator(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/multi_tower_on_taobao.config',
+ self._test_dir,
+ num_evaluator=1)
+ self.assertTrue(self._success)
+ final_export_dir = os.path.join(self._test_dir, 'train/export/final')
+ all_saved_files = glob.glob(final_export_dir + '/*/saved_model.pb')
+ logging.info('final_export_dir=%s all_saved_files=%s' %
+ (final_export_dir, ','.join(all_saved_files)))
+ self.assertTrue(len(all_saved_files) == 1)
+
def test_train_with_ps_worker_chief_redundant(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao_chief_redundant.config',
@@ -501,6 +871,18 @@ def test_batch_tfrecord_input(self):
self._test_dir)
self.assertTrue(self._success)
+ def test_autodis_embedding(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_on_criteo_with_autodis.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_periodic_embedding(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_on_criteo_with_periodic.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_sample_weight(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_with_sample_weight.config', self._test_dir)
@@ -511,20 +893,42 @@ def test_dssm_sample_weight(self):
'samples/model_config/dssm_with_sample_weight.config', self._test_dir)
self.assertTrue(self._success)
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dssm_neg_sampler_with_sample_weight(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dssm_neg_sampler_with_sample_weight.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
@unittest.skipIf(
- LooseVersion(tf.__version__) < LooseVersion('2.3.0'),
- 'MultiWorkerMirroredStrategy need tf version > 2.3')
+ LooseVersion(tf.__version__) != LooseVersion('2.3.0'),
+ 'MultiWorkerMirroredStrategy need tf version == 2.3')
def test_train_with_multi_worker_mirror(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_multi_worker_mirrored_strategy_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)
+ @unittest.skipIf(
+ LooseVersion(tf.__version__) != LooseVersion('2.3.0'),
+ 'MultiWorkerMirroredStrategy need tf version == 2.3')
+ def test_train_mmoe_with_multi_worker_mirror(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/mmoe_mirrored_strategy_on_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
def test_fg_dtype(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/taobao_fg_test_dtype.config', self._test_dir)
self.assertTrue(self._success)
+ @unittest.skipIf(six.PY2, 'Only run in python3')
+ def test_share_not_used(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/share_not_used.config', self._test_dir)
+ self.assertTrue(self._success)
+
def test_sequence_autoint(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/autoint_on_sequence_feature_taobao.config',
@@ -549,15 +953,9 @@ def test_sequence_dssm(self):
self._test_dir)
self.assertTrue(self._success)
- # def test_sequence_essm(self):
- # self._success = test_utils.test_single_train_eval(
- # 'samples/model_config/essm_on_sequence_feature_taobao.config',
- # self._test_dir)
- # self.assertTrue(self._success)
-
- def test_sequence_fm(self):
+ def test_sequence_esmm(self):
self._success = test_utils.test_single_train_eval(
- 'samples/model_config/fm_on_sequence_feature_taobao.config',
+ 'samples/model_config/esmm_on_sequence_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)
@@ -589,12 +987,313 @@ def test_sequence_wide_and_deep(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/wide_and_deep_on_sequence_feature_taobao.config',
self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_numeric_boundary_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_numeric_hash_bucket_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_numeric_hash_bucket_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_numeric_raw_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_numeric_raw_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_numeric_num_buckets_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_numeric_num_buckets_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_multi_numeric_boundary_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_multi_numeric_boundary_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_multi_numeric_hash_bucket_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_multi_numeric_hash_bucket_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_multi_numeric_raw_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_multi_numeric_raw_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_multi_numeric_num_buckets_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_multi_numeric_num_buckets_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_multi_sequence_dbmtl(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_multi_sequence_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
def test_multi_optimizer(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/wide_and_deep_two_opti.config', self._test_dir)
self.assertTrue(self._success)
+ def test_embedding_separate_optimizer(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_expr_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/multi_tower_on_taobao_for_expr.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_gzip_data(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/din_on_gzip_data.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_cmd_config_param(self):
+
+ def _post_check_config(pipeline_config):
+ train_saved_config_path = os.path.join(self._test_dir,
+ 'train/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ train_saved_config_path)
+ assert pipeline_config.model_config.deepfm.wide_output_dim == 8,\
+ 'invalid model_config.deepfm.wide_output_dim=%d' % \
+ pipeline_config.model_config.deepfm.wide_output_dim
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
+ self._test_dir,
+ post_check_func=_post_check_config,
+ extra_cmd_args='--model_config.deepfm.wide_output_dim 8')
+
+ def test_cmd_config_param_v2(self):
+
+ def _post_check_config(pipeline_config):
+ train_saved_config_path = os.path.join(self._test_dir,
+ 'train/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ train_saved_config_path)
+ assert pipeline_config.model_config.deepfm.wide_output_dim == 1,\
+ 'invalid model_config.deepfm.wide_output_dim=%d' % \
+ pipeline_config.model_config.deepfm.wide_output_dim
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
+ self._test_dir,
+ post_check_func=_post_check_config,
+ extra_cmd_args='--model_config.deepfm.wide_output_dim=1')
+
+ def test_cmd_config_param_v3(self):
+
+ def _post_check_config(pipeline_config):
+ train_saved_config_path = os.path.join(self._test_dir,
+ 'train/pipeline.config')
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ train_saved_config_path)
+ assert pipeline_config.model_config.deepfm.wide_output_dim == 3,\
+ 'invalid model_config.deepfm.wide_output_dim=%d' % \
+ pipeline_config.model_config.deepfm.wide_output_dim
+
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
+ self._test_dir,
+ post_check_func=_post_check_config,
+ extra_cmd_args='--model_config.deepfm.wide_output_dim="3"')
+
+ def test_distribute_eval_deepfm_multi_cls(self):
+ cur_eval_path = 'data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls'
+ self._success = test_utils.test_distributed_eval(
+ 'samples/model_config/deepfm_distribute_eval_multi_cls_on_avazu_ctr.config',
+ cur_eval_path, self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_distribute_eval_deepfm_single_cls(self):
+ cur_eval_path = 'data/test/distribute_eval_test/dwd_distribute_eval_avazu_out_test_combo'
+ self._success = test_utils.test_distributed_eval(
+ 'samples/model_config/deepfm_distribute_eval_combo_on_avazu_ctr.config',
+ cur_eval_path, self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_distribute_eval_dssm_pointwise_classification(self):
+ cur_eval_path = 'data/test/distribute_eval_test/dssm_distribute_eval_pointwise_classification_taobao_ckpt'
+ self._success = test_utils.test_distributed_eval(
+ 'samples/model_config/dssm_distribute_eval_pointwise_classification_on_taobao.config',
+ cur_eval_path, self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_distribute_eval_dssm_reg(self):
+ cur_eval_path = 'data/test/distribute_eval_test/dssm_distribute_eval_reg_taobao_ckpt'
+ self._success = test_utils.test_distributed_eval(
+ 'samples/model_config/dssm_distribute_eval_reg_on_taobao.config',
+ cur_eval_path, self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_distribute_eval_dropout(self):
+ cur_eval_path = 'data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt'
+ self._success = test_utils.test_distributed_eval(
+ 'samples/model_config/dropoutnet_distribute_eval_on_taobao.config',
+ cur_eval_path, self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_distribute_eval_esmm(self):
+ cur_eval_path = 'data/test/distribute_eval_test/esmm_distribute_eval_taobao_ckpt'
+ self._success = test_utils.test_distributed_eval(
+ 'samples/model_config/esmm_distribute_eval_on_taobao.config',
+ cur_eval_path, self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_share_no_used(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/share_embedding_not_used.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dssm_neg_sampler_sequence_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dssm_neg_sampler_sequence_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dssm_neg_sampler_need_key_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dssm_neg_sampler_need_key_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_dbmtl_on_multi_numeric_boundary_need_key_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_multi_numeric_boundary_need_key_feature_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_dbmtl_on_multi_numeric_boundary_allow_key_transform(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_dbmtl_on_multi_numeric_boundary_aux_hist_seq(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_aux_hist_seq_taobao.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_multi_tower_recall_neg_sampler_sequence_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_multi_tower_recall_neg_sampler_only_sequence_feature(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/multi_tower_recall_neg_sampler_only_sequence_feature.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(hvd is None, 'horovod is not installed')
+ def test_horovod(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/deepfm_combo_on_avazu_ctr.config',
+ self._test_dir,
+ use_hvd=True)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(hvd is None or sok is None,
+ 'horovod and sok is not installed')
+ def test_sok(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/multi_tower_on_taobao_sok.config',
+ self._test_dir,
+ use_hvd=True)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(
+ six.PY2 or tf_version.split('.')[0] != '2',
+ 'only run on python3 and tf 2.x')
+ def test_train_parquet(self):
+ os.environ[constant.NO_ARITHMETRIC_OPTI] = '1'
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dlrm_on_criteo_parquet.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(hvd is None, 'horovod is not installed')
+ def test_train_parquet_embedding_parallel(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/dlrm_on_criteo_parquet_ep.config',
+ self._test_dir,
+ use_hvd=True)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(hvd is None, 'horovod is not installed')
+ def test_train_parquet_embedding_parallel_v2(self):
+ self._success = test_utils.test_distributed_train_eval(
+ 'samples/model_config/dlrm_on_criteo_parquet_ep_v2.config',
+ self._test_dir,
+ use_hvd=True)
+ self.assertTrue(self._success)
+
+ def test_pdn(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/pdn_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dssm_senet(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dssm_senet_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dssm_backbone_on_taobao(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dssm_on_taobao_backbone.config', self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dssm_senet_backbone_on_taobao(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dssm_senet_on_taobao_backbone.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_parallel_dssm_backbone_on_taobao(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/parallel_dssm_on_taobao_backbone.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ def test_xdeefm_backbone_on_taobao(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/xdeepfm_on_taobao_backbone.config',
+ self._test_dir)
+ self.assertTrue(self._success)
+
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
+ def test_dat_on_taobao(self):
+ self._success = test_utils.test_single_train_eval(
+ 'samples/model_config/dat_on_taobao.config', self._test_dir)
+ self.assertTrue(self._success)
+
if __name__ == '__main__':
tf.test.main()
diff --git a/easy_rec/python/test/util_test.py b/easy_rec/python/test/util_test.py
index c660145f6..c14524488 100644
--- a/easy_rec/python/test/util_test.py
+++ b/easy_rec/python/test/util_test.py
@@ -4,6 +4,8 @@
import tensorflow as tf
from easy_rec.python.utils import estimator_utils
+from easy_rec.python.utils.dag import DAG
+from easy_rec.python.utils.expr_util import get_expression
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -20,6 +22,64 @@ def test_get_ckpt_version(self):
'oss://easyrec/ckpts/model.ckpt-6500')
assert ver == 6500, 'invalid version: %s' % str(ver)
+ def test_get_expression_greater(self):
+ result = get_expression('age_level>item_age_level',
+ ['age_level', 'item_age_level'])
+ assert result == "tf.greater(parsed_dict['age_level'], parsed_dict['item_age_level'])"
+
+ def test_get_expression_greater_equal(self):
+ result = get_expression('age_level>=item_age_level',
+ ['age_level', 'item_age_level'])
+ assert result == "tf.greater_equal(parsed_dict['age_level'], parsed_dict['item_age_level'])"
+
+ def test_get_expression_less(self):
+ result = get_expression('age_level3)&(item_age_level<1)',
+ ['age_level', 'item_age_level'])
+ assert result == "tf.greater(parsed_dict['age_level'], 3) & tf.less(parsed_dict['item_age_level'], 1)"
+
+ result = get_expression(
+ '(age_level>item_age_level) & (age_level3)|(item_age_level<1)',
+ ['age_level', 'item_age_level'])
+ assert result == "tf.greater(parsed_dict['age_level'], 3) | tf.less(parsed_dict['item_age_level'], 1)"
+
+ def test_dag(self):
+ dag = DAG()
+ dag.add_node('a')
+ dag.add_node('b')
+ dag.add_node('c')
+ dag.add_node('d')
+ dag.add_edge('a', 'b')
+ dag.add_edge('a', 'd')
+ dag.add_edge('b', 'c')
+ order = dag.topological_sort()
+ idx_a = order.index('a')
+ idx_b = order.index('b')
+ idx_c = order.index('c')
+ idx_d = order.index('d')
+ assert idx_a < idx_b
+ assert idx_a < idx_d
+ assert idx_b < idx_c
+ c = dag.all_downstreams('b')
+ assert c == ['c']
+ leaf = dag.all_leaves()
+ assert leaf == ['c', 'd']
+
if __name__ == '__main__':
tf.test.main()
diff --git a/easy_rec/python/test/zero_inflated_lognormal_test.py b/easy_rec/python/test/zero_inflated_lognormal_test.py
new file mode 100644
index 000000000..f512e48e8
--- /dev/null
+++ b/easy_rec/python/test/zero_inflated_lognormal_test.py
@@ -0,0 +1,53 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import numpy as np
+import tensorflow as tf
+from scipy import stats
+
+from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+# Absolute error tolerance in asserting array near.
+_ERR_TOL = 1e-6
+
+
+# softplus function that calculates log(1+exp(x))
+def _softplus(x):
+ return np.log(1.0 + np.exp(x))
+
+
+# sigmoid function that calculates 1/(1+exp(-x))
+def _sigmoid(x):
+ return 1 / (1 + np.exp(-x))
+
+
+class ZeroInflatedLognormalLossTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(ZeroInflatedLognormalLossTest, self).setUp()
+ self.logits = np.array([[.1, .2, .3], [.4, .5, .6]])
+ self.labels = np.array([[0.], [1.5]])
+
+ def zero_inflated_lognormal(self, labels, logits):
+ positive_logits = logits[..., :1]
+ loss_zero = _softplus(positive_logits)
+ loc = logits[..., 1:2]
+ scale = np.maximum(
+ _softplus(logits[..., 2:]), np.sqrt(tf.keras.backend.epsilon()))
+ log_prob_non_zero = stats.lognorm.logpdf(
+ x=labels, s=scale, loc=0, scale=np.exp(loc))
+ loss_non_zero = _softplus(-positive_logits) - log_prob_non_zero
+ return np.mean(np.where(labels == 0., loss_zero, loss_non_zero), axis=-1)
+
+ def test_loss_value(self):
+ expected_loss = self.zero_inflated_lognormal(self.labels, self.logits)
+ expected_loss = np.average(expected_loss)
+ loss = zero_inflated_lognormal_loss(self.labels, self.logits)
+ self.assertNear(self.evaluate(loss), expected_loss, _ERR_TOL)
+
+
+if __name__ == '__main__':
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/easy_rec/python/tools/add_boundaries_to_config.py b/easy_rec/python/tools/add_boundaries_to_config.py
index 09d2d9a1d..18d5f6037 100644
--- a/easy_rec/python/tools/add_boundaries_to_config.py
+++ b/easy_rec/python/tools/add_boundaries_to_config.py
@@ -3,11 +3,13 @@
import json
import logging
import os
+import sys
import common_io
import tensorflow as tf
from easy_rec.python.utils import config_util
+from easy_rec.python.utils import io_util
if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -61,4 +63,5 @@ def main(argv):
if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
diff --git a/easy_rec/python/tools/add_feature_info_to_config.py b/easy_rec/python/tools/add_feature_info_to_config.py
new file mode 100644
index 000000000..7594d038b
--- /dev/null
+++ b/easy_rec/python/tools/add_feature_info_to_config.py
@@ -0,0 +1,145 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
+import logging
+import os
+import sys
+
+import tensorflow as tf
+
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import io_util
+from easy_rec.python.utils.hive_utils import HiveUtils
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+logging.basicConfig(
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
+ level=logging.INFO)
+tf.app.flags.DEFINE_string('template_config_path', None,
+ 'Path to template pipeline config '
+ 'file.')
+tf.app.flags.DEFINE_string('output_config_path', None,
+ 'Path to output pipeline config '
+ 'file.')
+tf.app.flags.DEFINE_string('config_table', '', 'config table')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(argv):
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ FLAGS.template_config_path)
+ sels = 'feature,feature_info,message'
+ feature_info_map = {}
+ drop_feature_names = []
+
+ if pipeline_config.WhichOneof('train_path') == 'hive_train_input':
+ hive_util = HiveUtils(
+ data_config=pipeline_config.data_config,
+ hive_config=pipeline_config.hive_train_input,
+ selected_cols=sels,
+ record_defaults=['', '', ''])
+ reader = hive_util.hive_read_line(FLAGS.config_table)
+ for record in reader:
+ feature_name = record[0][0]
+ feature_info_map[feature_name] = json.loads(record[0][1])
+ if 'DROP IT' in record[0][2]:
+ drop_feature_names.append(feature_name)
+
+ else:
+ import common_io
+ reader = common_io.table.TableReader(FLAGS.config_table, selected_cols=sels)
+ while True:
+ try:
+ record = reader.read()
+ feature_name = record[0][0]
+ feature_info_map[feature_name] = json.loads(record[0][1])
+ if 'DROP IT' in record[0][2]:
+ drop_feature_names.append(feature_name)
+ except common_io.exception.OutOfRangeException:
+ reader.close()
+ break
+
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
+ if drop_feature_names:
+ tmp_feature_configs = feature_configs[:]
+ for fea_cfg in tmp_feature_configs:
+ fea_name = fea_cfg.input_names[0]
+ if fea_name in drop_feature_names:
+ feature_configs.remove(fea_cfg)
+ for feature_config in feature_configs:
+ feature_name = feature_config.input_names[0]
+ if feature_name in feature_info_map:
+ logging.info('edited %s' % feature_name)
+ feature_config.embedding_dim = int(
+ feature_info_map[feature_name]['embedding_dim'])
+ logging.info('modify embedding_dim to %s' % feature_config.embedding_dim)
+ if 'boundary' in feature_info_map[feature_name]:
+ feature_config.ClearField('boundaries')
+ feature_config.boundaries.extend(
+ [float(i) for i in feature_info_map[feature_name]['boundary']])
+ logging.info('modify boundaries to %s' % feature_config.boundaries)
+ elif 'hash_bucket_size' in feature_info_map[feature_name]:
+ feature_config.hash_bucket_size = int(
+ feature_info_map[feature_name]['hash_bucket_size'])
+ logging.info('modify hash_bucket_size to %s' %
+ feature_config.hash_bucket_size)
+ # modify num_steps
+ pipeline_config.train_config.num_steps = feature_info_map['__NUM_STEPS__'][
+ 'num_steps']
+ logging.info('modify num_steps to %s' %
+ pipeline_config.train_config.num_steps)
+ # modify decay_steps
+ optimizer_configs = pipeline_config.train_config.optimizer_config
+ for optimizer_config in optimizer_configs:
+ optimizer = optimizer_config.WhichOneof('optimizer')
+ optimizer = getattr(optimizer_config, optimizer)
+ learning_rate = optimizer.learning_rate.WhichOneof('learning_rate')
+ learning_rate = getattr(optimizer.learning_rate, learning_rate)
+ if hasattr(learning_rate, 'decay_steps'):
+ learning_rate.decay_steps = feature_info_map['__DECAY_STEPS__'][
+ 'decay_steps']
+ logging.info('modify decay_steps to %s' % learning_rate.decay_steps)
+
+ for feature_group in pipeline_config.model_config.feature_groups:
+ feature_names = feature_group.feature_names
+ reserved_features = []
+ for feature_name in feature_names:
+ if feature_name not in drop_feature_names:
+ reserved_features.append(feature_name)
+ else:
+ logging.info('drop feature: %s' % feature_name)
+ feature_group.ClearField('feature_names')
+ feature_group.feature_names.extend(reserved_features)
+ for sequence_feature in feature_group.sequence_features:
+ seq_att_maps = sequence_feature.seq_att_map
+ for seq_att in seq_att_maps:
+ keys = seq_att.key
+ reserved_keys = []
+ for key in keys:
+ if key not in drop_feature_names:
+ reserved_keys.append(key)
+ else:
+ logging.info('drop sequence feature key: %s' % key)
+ seq_att.ClearField('key')
+ seq_att.key.extend(reserved_keys)
+
+ hist_seqs = seq_att.hist_seq
+ reserved_hist_seqs = []
+ for hist_seq in hist_seqs:
+ if hist_seq not in drop_feature_names:
+ reserved_hist_seqs.append(hist_seq)
+ else:
+ logging.info('drop sequence feature hist_seq: %s' % hist_seq)
+ seq_att.ClearField('hist_seq')
+ seq_att.hist_seq.extend(reserved_hist_seqs)
+
+ config_dir, config_name = os.path.split(FLAGS.output_config_path)
+ config_util.save_pipeline_config(pipeline_config, config_dir, config_name)
+
+
+if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
+ tf.app.run()
diff --git a/easy_rec/python/tools/criteo/__init__.py b/easy_rec/python/tools/criteo/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/easy_rec/python/tools/criteo/convert_data.py b/easy_rec/python/tools/criteo/convert_data.py
new file mode 100644
index 000000000..6382865e7
--- /dev/null
+++ b/easy_rec/python/tools/criteo/convert_data.py
@@ -0,0 +1,157 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import gzip
+import logging
+import multiprocessing
+import os
+import traceback
+
+import numpy as np
+import pandas as pd
+import six
+from tensorflow.python.platform import gfile
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+
+def save_np_bin(labels, dense_arr, cate_arr, prefix):
+ with gfile.GFile(prefix + '_label.bin', 'wb') as fout:
+ fout.write(np.array(labels, dtype=np.int32).tobytes())
+ with gfile.GFile(prefix + '_dense.bin', 'wb') as fout:
+ fout.write(np.array(dense_arr, dtype=np.float32).tobytes())
+ with gfile.GFile(prefix + '_category.bin', 'wb') as fout:
+ fout.write(np.array(cate_arr, dtype=np.float32).tobytes())
+
+
+def save_parquet(labels, dense_arr, cate_arr, prefix):
+ df = {'is_click': labels}
+ for i in range(1, 14):
+ df['f' + str(i)] = dense_arr[:, i - 1]
+ for i in range(1, 27):
+ df['c' + str(i)] = cate_arr[:, i - 1]
+ df = pd.DataFrame(df)
+ save_path = prefix + '.parquet'
+ logging.info('save to %s' % save_path)
+ df.to_parquet(save_path)
+
+
+def convert(input_path, prefix, part_record_num, save_format):
+ logging.info('start to convert %s, part_record_num=%d, save_format=%s' %
+ (input_path, part_record_num, save_format))
+ save_func = save_np_bin
+ if save_format == 'parquet':
+ save_func = save_parquet
+ batch_size = part_record_num
+ labels = np.zeros([batch_size], dtype=np.int32)
+ dense_arr = np.zeros([batch_size, 13], dtype=np.float32)
+ cate_arr = np.zeros([batch_size, 26], dtype=np.uint32)
+ part_id = 0
+ total_line = 0
+ try:
+ sid = 0
+ with gfile.GFile(input_path, 'rb') as gz_fin:
+ for line_str in gzip.GzipFile(fileobj=gz_fin, mode='rb'):
+ if six.PY3:
+ line_str = str(line_str, 'utf-8')
+ line_str = line_str.strip()
+ line_toks = line_str.split('\t')
+ labels[sid] = int(line_toks[0])
+
+ for j in range(1, 14):
+ x = line_toks[j]
+ dense_arr[sid, j - 1] = float(x) if x != '' else 0.0
+
+ for j in range(14, 40):
+ x = line_toks[j]
+ cate_arr[sid, j - 14] = int(x, 16) if x != '' else 0
+
+ sid += 1
+ if sid == batch_size:
+ save_func(labels, dense_arr, cate_arr, prefix + '_' + str(part_id))
+ logging.info('\t%s write part: %d' % (input_path, part_id))
+ part_id += 1
+ total_line += sid
+ sid = 0
+ if sid > 0:
+ save_func(labels[:sid], dense_arr[:sid], cate_arr[:sid],
+ prefix + '_' + str(part_id))
+ logging.info('\t%s write final part: %d' % (input_path, part_id))
+ part_id += 1
+ total_line += sid
+ except Exception as ex:
+ logging.error('convert %s failed: %s' % (input_path, str(ex)))
+ logging.error(traceback.format_exc())
+ return
+ logging.info('done convert %s, total_line=%d, part_num=%d' %
+ (input_path, total_line, part_id))
+
+
+if __name__ == '__main__':
+ """Convert criteo 1T data to binary format.
+
+ The outputs are stored in multiple parts, each with at most part_record_num samples.
+ Each part consists of 3 files:
+ xxx_yyy_label.bin,
+ xxx_yyy_dense.bin,
+ xxx_yyy_category.bin,
+ xxx is in range [0-23], range of yyy is determined by part_record_num,
+
+ If part_record_num is set to the default value 8M, there will be 535 parts. We convert
+ the data on machine with 64GB memory, if you memory is limited, you can convert the .gz
+ files one by one, or you can set a small part_record_num.
+ """
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input_dir', type=str, default=None, help='criteo 1t data dir')
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default=None,
+ help='criteo binary data output dir ')
+ parser.add_argument(
+ '--save_format',
+ type=str,
+ default='npy',
+ help='save format, choices: npy|parquet')
+ parser.add_argument(
+ '--part_record_num',
+ type=int,
+ default=1024 * 1024 * 8,
+ help='the maximal number of samples in each binary file')
+ parser.add_argument(
+ '--dt',
+ nargs='*',
+ type=int,
+ help='select days to convert, default to select all: 0-23')
+
+ args = parser.parse_args()
+
+ assert args.input_dir, 'input_dir is not set'
+ assert args.save_dir, 'save_dir is not set'
+
+ save_dir = args.save_dir
+ if not save_dir.endswith('/'):
+ save_dir = save_dir + '/'
+ if not gfile.IsDirectory(save_dir):
+ gfile.MakeDirs(save_dir)
+
+ if args.dt is None or len(args.dt) == 0:
+ days = list(range(0, 24))
+ else:
+ days = list(args.dt)
+
+ proc_arr = []
+ for d in days:
+ input_path = os.path.join(args.input_dir, 'day_%d.gz' % d)
+ prefix = os.path.join(args.save_dir, str(d))
+ proc = multiprocessing.Process(
+ target=convert,
+ args=(input_path, prefix, args.part_record_num, args.save_format))
+ convert(input_path, prefix, args.part_record_num, args.save_format)
+ proc.start()
+ proc_arr.append(proc)
+ for proc in proc_arr:
+ proc.join()
diff --git a/easy_rec/python/tools/faiss_index_pai.py b/easy_rec/python/tools/faiss_index_pai.py
new file mode 100644
index 000000000..e9ebe3f89
--- /dev/null
+++ b/easy_rec/python/tools/faiss_index_pai.py
@@ -0,0 +1,116 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from __future__ import print_function
+
+import logging
+import os
+import sys
+
+import faiss
+import numpy as np
+import tensorflow as tf
+
+from easy_rec.python.utils import io_util
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command')
+tf.app.flags.DEFINE_integer('batch_size', 1024, 'batch size')
+tf.app.flags.DEFINE_integer('embedding_dim', 32, 'embedding dimension')
+tf.app.flags.DEFINE_string('index_output_dir', '', 'index output directory')
+tf.app.flags.DEFINE_string('index_type', 'IVFFlat', 'index type')
+tf.app.flags.DEFINE_integer('ivf_nlist', 1000, 'nlist')
+tf.app.flags.DEFINE_integer('hnsw_M', 32, 'hnsw M')
+tf.app.flags.DEFINE_integer('hnsw_efConstruction', 200, 'hnsw efConstruction')
+tf.app.flags.DEFINE_integer('debug', 0, 'debug index')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(argv):
+ reader = tf.python_io.TableReader(
+ FLAGS.tables, slice_id=0, slice_count=1, capacity=FLAGS.batch_size * 2)
+ i = 0
+ id_map_f = tf.gfile.GFile(
+ os.path.join(FLAGS.index_output_dir, 'id_mapping'), 'w')
+ embeddings = []
+ while True:
+ try:
+ records = reader.read(FLAGS.batch_size)
+ for j, record in enumerate(records):
+ if isinstance(record[0], bytes):
+ eid = record[0].decode('utf-8')
+ id_map_f.write('%s\n' % eid)
+
+ embeddings.extend(
+ [list(map(float, record[1].split(b','))) for record in records])
+ i += 1
+ if i % 100 == 0:
+ logging.info('read %d embeddings.' % (i * FLAGS.batch_size))
+ except tf.python_io.OutOfRangeException:
+ break
+ reader.close()
+ id_map_f.close()
+
+ logging.info('Building faiss index..')
+ if FLAGS.index_type == 'IVFFlat':
+ quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
+ index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, FLAGS.ivf_nlist,
+ faiss.METRIC_INNER_PRODUCT)
+ elif FLAGS.index_type == 'HNSWFlat':
+ index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, FLAGS.hnsw_M,
+ faiss.METRIC_INNER_PRODUCT)
+ index.hnsw.efConstruction = FLAGS.hnsw_efConstruction
+ else:
+ raise NotImplementedError
+
+ embeddings = np.array(embeddings)
+ if FLAGS.index_type == 'IVFFlat':
+ logging.info('train embeddings...')
+ index.train(embeddings)
+
+ logging.info('build embeddings...')
+ index.add(embeddings)
+ faiss.write_index(index, 'faiss_index')
+
+ with tf.gfile.GFile(
+ os.path.join(FLAGS.index_output_dir, 'faiss_index'), 'wb') as f_out:
+ with open('faiss_index', 'rb') as f_in:
+ f_out.write(f_in.read())
+
+ if FLAGS.debug != 0:
+ # IVFFlat
+ for ivf_nlist in [100, 500, 1000, 2000]:
+ quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
+ index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, ivf_nlist,
+ faiss.METRIC_INNER_PRODUCT)
+ index.train(embeddings)
+ index.add(embeddings)
+ index_name = 'faiss_index_ivfflat_nlist%d' % ivf_nlist
+ faiss.write_index(index, index_name)
+ with tf.gfile.GFile(
+ os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
+ with open(index_name, 'rb') as f_in:
+ f_out.write(f_in.read())
+
+ # HNSWFlat
+ for hnsw_M in [16, 32, 64, 128]:
+ for hnsw_efConstruction in [64, 128, 256, 512, 1024, 2048, 4096, 8196]:
+ if hnsw_efConstruction < hnsw_M * 2:
+ continue
+ index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, hnsw_M,
+ faiss.METRIC_INNER_PRODUCT)
+ index.hnsw.efConstruction = hnsw_efConstruction
+ index.add(embeddings)
+ index_name = 'faiss_index_hnsw_M%d_ef%d' % (hnsw_M, hnsw_efConstruction)
+ faiss.write_index(index, index_name)
+ with tf.gfile.GFile(
+ os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
+ with open(index_name, 'rb') as f_in:
+ f_out.write(f_in.read())
+
+
+if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
+ tf.app.run()
diff --git a/easy_rec/python/tools/feature_selection.py b/easy_rec/python/tools/feature_selection.py
index f80986642..f50a00fac 100644
--- a/easy_rec/python/tools/feature_selection.py
+++ b/easy_rec/python/tools/feature_selection.py
@@ -3,27 +3,34 @@
import json
import os
+import sys
from collections import OrderedDict
import numpy as np
import pandas as pd
import tensorflow as tf
+from tensorflow.python.framework.meta_graph import read_meta_graph_file
from easy_rec.python.utils import config_util
+from easy_rec.python.utils import io_util
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
import matplotlib # NOQA
matplotlib.use('Agg') # NOQA
import matplotlib.pyplot as plt # NOQA
tf.app.flags.DEFINE_string('model_type', 'variational_dropout',
- 'feature selection model tyoe')
+ 'feature selection model type')
tf.app.flags.DEFINE_string('config_path', '',
'feature selection model config path')
tf.app.flags.DEFINE_string('checkpoint_path', None,
'feature selection model checkpoint path')
tf.app.flags.DEFINE_string('output_dir', '',
'feature selection result directory')
-tf.app.flags.DEFINE_integer('topk', 100, 'select topk importance features')
+tf.app.flags.DEFINE_integer(
+ 'topk', 100, 'select topk importance features for each feature group')
tf.app.flags.DEFINE_string('fg_path', '', 'fg config path')
tf.app.flags.DEFINE_bool('visualize', False,
'visualization feature selection result or not')
@@ -50,63 +57,83 @@ def __init__(self,
def process(self):
tf.logging.info('Loading logit_p of VariationalDropout layer ...')
- feature_dim_dropout_p, embedding_wise_variational_dropout = self._feature_dim_dropout_ratio(
+ feature_dim_dropout_p_map, embedding_wise_variational_dropout = self._feature_dim_dropout_ratio(
)
- tf.logging.info('Calculating feature importance ...')
- feature_importance = self._get_feature_importance(
- feature_dim_dropout_p, embedding_wise_variational_dropout)
+ feature_importance_map = {}
+ for group_name, feature_dim_dropout_p in feature_dim_dropout_p_map.items():
+ tf.logging.info('Calculating %s feature importance ...' % group_name)
+ feature_importance = self._get_feature_importance(
+ feature_dim_dropout_p, embedding_wise_variational_dropout)
+ feature_importance_map[group_name] = feature_importance
- tf.logging.info('Processing model config ...')
- self._process_config(feature_importance)
+ tf.logging.info('Dump %s feature importance to csv ...' % group_name)
+ self._dump_to_csv(feature_importance, group_name)
- tf.logging.info('Dump feature importance to csv ...')
- self._dump_to_csv(feature_importance)
+ if self._visualize:
+ tf.logging.info('Visualizing %s feature importance ...' % group_name)
+ if embedding_wise_variational_dropout:
+ self._visualize_embedding_dim_importance(feature_dim_dropout_p)
+ self._visualize_feature_importance(feature_importance, group_name)
- if self._visualize:
- tf.logging.info('Visualizing feature importance ...')
- if embedding_wise_variational_dropout:
- self._visualize_embedding_dim_importance(feature_dim_dropout_p)
- self._visualize_feature_importance(feature_importance)
+ tf.logging.info('Processing model config ...')
+ self._process_config(feature_importance_map)
def _feature_dim_dropout_ratio(self):
"""Get dropout ratio of embedding-wise or feature-wise."""
config = config_util.get_configs_from_pipeline_file(self._config_path)
assert config.model_config.HasField(
'variational_dropout'), 'variational_dropout must be in model_config'
+
embedding_wise_variational_dropout = config.model_config.variational_dropout.embedding_wise_variational_dropout
- features_dim = {
- cfg.input_names[0]: cfg.embedding_dim
- if cfg.HasField('embedding_dim') else cfg.raw_input_dim
- for cfg in config_util.get_compatible_feature_configs(config)
- }
- features = list(config.model_config.feature_groups[0].feature_names)
+
if self._checkpoint_path is None or len(self._checkpoint_path) == 0:
checkpoint_path = tf.train.latest_checkpoint(config.model_dir)
else:
checkpoint_path = self._checkpoint_path
+ meta_graph_def = read_meta_graph_file(checkpoint_path + '.meta')
+ features_dimension_map = dict()
+ for col_def in meta_graph_def.collection_def[
+ 'variational_dropout'].bytes_list.value:
+ name, features_dimension = json.loads(col_def)
+ name = 'all' if name == '' else name
+ features_dimension_map[name] = OrderedDict(features_dimension)
+
tf.logging.info('Reading checkpoint from %s ...' % checkpoint_path)
reader = tf.train.NewCheckpointReader(checkpoint_path)
- logit_p = reader.get_tensor('logit_p')
- feature_dims_importance = tf.sigmoid(logit_p)
- with tf.Session() as sess:
- feature_dims_importance = feature_dims_importance.eval(session=sess)
- feature_dim_dropout_p = {}
- if embedding_wise_variational_dropout:
- index_end = 0
- for feature_name in features:
- index_start = index_end
- index_end = index_start + features_dim[feature_name]
- feature_dim_dropout_p[feature_name] = feature_dims_importance[
- index_start:index_end]
- else:
- index = 0
- for feature_name in features:
- feature_dim_dropout_p[feature_name] = feature_dims_importance[index]
- index += 1
- return feature_dim_dropout_p, embedding_wise_variational_dropout
+ feature_dim_dropout_p_map = {}
+ for feature_group in config.model_config.feature_groups:
+ group_name = feature_group.group_name
+
+ logit_p_name = 'logit_p' if group_name == 'all' else 'logit_p_%s' % group_name
+ try:
+ logit_p = reader.get_tensor(logit_p_name)
+ except Exception:
+ print('get `logit_p` failed, try to get `backbone/logit_p`')
+ logit_p = reader.get_tensor('backbone/' + logit_p_name)
+ feature_dims_importance = tf.sigmoid(logit_p)
+ with tf.Session() as sess:
+ feature_dims_importance = feature_dims_importance.eval(session=sess)
+
+ feature_dim_dropout_p = {}
+ if embedding_wise_variational_dropout:
+ index_end = 0
+ for feature_name, feature_dim in features_dimension_map[
+ group_name].items():
+ index_start = index_end
+ index_end = index_start + feature_dim
+ feature_dim_dropout_p[feature_name] = feature_dims_importance[
+ index_start:index_end]
+ else:
+ index = 0
+ for feature_name in features_dimension_map[group_name].keys():
+ feature_dim_dropout_p[feature_name] = feature_dims_importance[index]
+ index += 1
+
+ feature_dim_dropout_p_map[group_name] = feature_dim_dropout_p
+ return feature_dim_dropout_p_map, embedding_wise_variational_dropout
def _get_feature_importance(self, feature_dim_dropout_p,
embedding_wise_variational_dropout):
@@ -123,19 +150,40 @@ def _get_feature_importance(self, feature_dim_dropout_p,
sorted(feature_dim_dropout_p.items(), key=lambda e: e[1]))
return feature_importance
- def _process_config(self, feature_importance):
+ def _process_config(self, feature_importance_map):
"""Process model config and fg config with feature selection."""
- selected_features = set()
- for i, (feature_name, _) in enumerate(feature_importance.items()):
- if i < self._topk:
- selected_features.add(feature_name)
+ excluded_features = set()
+ for group_name, feature_importance in feature_importance_map.items():
+ for i, (feature_name, _) in enumerate(feature_importance.items()):
+ if i >= self._topk:
+ excluded_features.add(feature_name)
+
config = config_util.get_configs_from_pipeline_file(self._config_path)
+ # keep sequence features and side-infos
+ sequence_features = set()
+ for feature_group in config.model_config.feature_groups:
+ for sequence_feature in feature_group.sequence_features:
+ for seq_att_map in sequence_feature.seq_att_map:
+ for key in seq_att_map.key:
+ sequence_features.add(key)
+ for hist_seq in seq_att_map.hist_seq:
+ sequence_features.add(hist_seq)
+ # compat with din
+ for sequence_feature in config.model_config.seq_att_groups:
+ for seq_att_map in sequence_feature.seq_att_map:
+ for key in seq_att_map.key:
+ sequence_features.add(key)
+ for hist_seq in seq_att_map.hist_seq:
+ sequence_features.add(hist_seq)
+ excluded_features = excluded_features - sequence_features
+
feature_configs = []
for feature_config in config_util.get_compatible_feature_configs(config):
feature_name = feature_config.feature_name if feature_config.HasField('feature_name') \
else feature_config.input_names[0]
- if feature_name in selected_features:
+ if feature_name not in excluded_features:
feature_configs.append(feature_config)
+
if config.feature_configs:
config.ClearField('feature_configs')
config.feature_configs.extend(feature_configs)
@@ -146,7 +194,7 @@ def _process_config(self, feature_importance):
for feature_group in config.model_config.feature_groups:
feature_names = []
for feature_name in feature_group.feature_names:
- if feature_name in selected_features:
+ if feature_name not in excluded_features:
feature_names.append(feature_name)
feature_group.ClearField('feature_names')
feature_group.feature_names.extend(feature_names)
@@ -159,7 +207,10 @@ def _process_config(self, feature_importance):
fg_json = json.load(f, object_pairs_hook=OrderedDict)
features = []
for feature in fg_json['features']:
- if feature['feature_name'] in selected_features:
+ if 'feature_name' in feature:
+ if feature['feature_name'] not in excluded_features:
+ features.append(feature)
+ else:
features.append(feature)
fg_json['features'] = features
with tf.gfile.Open(
@@ -167,10 +218,11 @@ def _process_config(self, feature_importance):
'w') as f:
json.dump(fg_json, f, indent=4)
- def _dump_to_csv(self, feature_importance):
+ def _dump_to_csv(self, feature_importance, group_name):
"""Dump feature importance data to a csv file."""
with tf.gfile.Open(
- os.path.join(self._output_dir, 'feature_dropout_ratio.csv'), 'w') as f:
+ os.path.join(self._output_dir,
+ 'feature_dropout_ratio_%s.csv' % group_name), 'w') as f:
df = pd.DataFrame(
columns=['feature_name', 'mean_drop_p'],
data=[list(kv) for kv in feature_importance.items()])
@@ -215,7 +267,7 @@ def _visualize_embedding_dim_importance(self, feature_dim_dropout_p):
with tf.gfile.GFile(img_path, 'wb') as f:
plt.savefig(f, format='png')
- def _visualize_feature_importance(self, feature_importance):
+ def _visualize_feature_importance(self, feature_importance, group_name):
"""Draw feature importance histogram."""
df = pd.DataFrame(
columns=['feature_name', 'mean_drop_p'],
@@ -243,11 +295,13 @@ def _visualize_feature_importance(self, feature_importance):
plt.grid(linestyle='--', alpha=0.5)
plt.xlim(0, 1)
with tf.gfile.GFile(
- os.path.join(self._output_dir, 'feature_dropout_pic.png'), 'wb') as f:
+ os.path.join(self._output_dir,
+ 'feature_dropout_pic_%s.png' % group_name), 'wb') as f:
plt.savefig(f, format='png')
if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
if FLAGS.model_type == 'variational_dropout':
fs = VariationalDropoutFS(
FLAGS.config_path,
diff --git a/easy_rec/python/tools/hit_rate_ds.py b/easy_rec/python/tools/hit_rate_ds.py
new file mode 100644
index 000000000..5528e0aa2
--- /dev/null
+++ b/easy_rec/python/tools/hit_rate_ds.py
@@ -0,0 +1,223 @@
+# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+# """Evaluation of Top k hitrate."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import logging
+import os
+import sys
+
+import graphlearn as gl
+import tensorflow as tf
+
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import io_util
+from easy_rec.python.utils.config_util import process_multi_file_input_path
+from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
+from easy_rec.python.utils.hit_rate_utils import load_graph
+from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
+from easy_rec.python.utils.hive_utils import HiveUtils
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
+
+logging.basicConfig(
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
+ level=logging.INFO)
+
+tf.app.flags.DEFINE_string('item_emb_table', '', 'item embedding table name')
+tf.app.flags.DEFINE_string('gt_table', '', 'ground truth table name')
+tf.app.flags.DEFINE_string('hitrate_details_result', '',
+ 'hitrate detail file path')
+tf.app.flags.DEFINE_string('total_hitrate_result', '',
+ 'total hitrate result file path')
+
+tf.app.flags.DEFINE_string('pipeline_config_path', '', 'pipeline config path')
+tf.app.flags.DEFINE_integer('batch_size', 512, 'batch size')
+tf.app.flags.DEFINE_integer('emb_dim', 128, 'embedding dimension')
+tf.app.flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i')
+tf.app.flags.DEFINE_integer('top_k', '5', 'top_k hitrate.')
+tf.app.flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).')
+tf.app.flags.DEFINE_bool('knn_strict', False, 'use exact search.')
+tf.app.flags.DEFINE_integer('timeout', '60', 'timeout')
+tf.app.flags.DEFINE_integer('num_interests', 1, 'max number of interests')
+tf.app.flags.DEFINE_string('gt_table_field_sep', '\t', 'gt_table_field_sep')
+tf.app.flags.DEFINE_string('item_emb_table_field_sep', '\t',
+ 'item_emb_table_field_sep')
+tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def compute_hitrate(g, gt_all, hitrate_writer, gt_table=None):
+ """Compute hitrate of each worker.
+
+ Args:
+ g: a GL Graph instance.
+ gt_reader: reader of input trigger_items_table.
+ hitrate_writer: writer of hitrate table.
+ gt_table: ground truth table.
+
+ Returns:
+ total_hits: total hits of this worker.
+ total_gt_count: total count of ground truth items of this worker.
+ """
+ total_hits = 0.0
+ total_gt_count = 0.0
+
+ for gt_record in gt_all:
+ gt_record = list(gt_record)
+ hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \
+ compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k)
+ total_hits += hits
+ total_gt_count += gt_count
+
+ src_ids = [str(ids) for ids in src_ids]
+ hitrates = [str(hitrate) for hitrate in hitrates]
+ topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids]
+ topk_dists = [
+ ','.join('|'.join(str(x)
+ for x in dist)
+ for dist in dists)
+ for dists in recall_distances
+ ]
+ bad_cases = [','.join(str(x) for x in bad_case) for bad_case in bad_cases]
+ bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists]
+
+ hitrate_writer.write('\n'.join([
+ '\t'.join(line) for line in zip(src_ids, topk_recalls, topk_dists,
+ hitrates, bad_cases, bad_dists)
+ ]))
+ print('total_hits: ', total_hits)
+ print('total_gt_count: ', total_gt_count)
+ return total_hits, total_gt_count
+
+
+def gt_hdfs(gt_table, batch_size, gt_file_sep):
+
+ if '*' in gt_table or ',' in gt_table:
+ file_paths = tf.gfile.Glob(gt_table.split(','))
+ elif tf.gfile.IsDirectory(gt_table):
+ file_paths = tf.gfile.Glob(os.path.join(gt_table, '*'))
+ else:
+ file_paths = tf.gfile.Glob(gt_table)
+
+ batch_list, i = [], 0
+ for file_path in file_paths:
+ with tf.gfile.GFile(file_path, 'r') as fin:
+ for gt in fin:
+ i += 1
+ gt_list = gt.strip().split(gt_file_sep)
+ # make id , emb_num to int
+ gt_list[0], gt_list[3] = int(gt_list[0]), int(gt_list[3])
+ batch_list.append(tuple(i for i in gt_list))
+ if i >= batch_size:
+ yield batch_list
+ batch_list, i = [], 0
+ if i != 0:
+ yield batch_list
+
+
+def main():
+ tf_config = json.loads(os.environ['TF_CONFIG'])
+ worker_count = len(tf_config['cluster']['worker'])
+ task_index = tf_config['task']['index']
+ job_name = tf_config['task']['type']
+
+ hitrate_details_result = FLAGS.hitrate_details_result
+ total_hitrate_result = FLAGS.total_hitrate_result
+ i_emb_table = FLAGS.item_emb_table
+ gt_table = FLAGS.gt_table
+
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ FLAGS.pipeline_config_path)
+ logging.info('i_emb_table %s', i_emb_table)
+
+ input_type = pipeline_config.data_config.input_type
+ input_type_name = DatasetConfig.InputType.Name(input_type)
+ if input_type_name == 'CSVInput':
+ i_emb_table = process_multi_file_input_path(i_emb_table)
+ else:
+ hive_utils = HiveUtils(
+ data_config=pipeline_config.data_config,
+ hive_config=pipeline_config.hive_train_input)
+ i_emb_table = hive_utils.get_table_location(i_emb_table)
+
+ g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
+ FLAGS.knn_strict)
+ gl.set_tracker_mode(0)
+ gl.set_field_delimiter(FLAGS.item_emb_table_field_sep)
+
+ cluster = tf.train.ClusterSpec({
+ 'ps': tf_config['cluster']['ps'],
+ 'worker': tf_config['cluster']['worker']
+ })
+ server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
+
+ if job_name == 'ps':
+ server.join()
+ else:
+ worker_hosts = [
+ str(host.split(':')[0]) + ':888' + str(i)
+ for i, host in enumerate(tf_config['cluster']['worker'])
+ ]
+ worker_hosts = ','.join(worker_hosts)
+ g.init(task_index=task_index, task_count=worker_count, hosts=worker_hosts)
+ # Your model, use g to do some operation, such as sampling
+
+ if input_type_name == 'CSVInput':
+ gt_all = gt_hdfs(gt_table, FLAGS.batch_size, FLAGS.gt_table_field_sep)
+ else:
+ gt_reader = HiveUtils(
+ data_config=pipeline_config.data_config,
+ hive_config=pipeline_config.hive_train_input,
+ selected_cols='*')
+ gt_all = gt_reader.hive_read_lines(gt_table, FLAGS.batch_size)
+ if not tf.gfile.IsDirectory(hitrate_details_result):
+ tf.gfile.MakeDirs(hitrate_details_result)
+ hitrate_details_result = os.path.join(hitrate_details_result,
+ 'part-%s' % task_index)
+ details_writer = tf.gfile.GFile(hitrate_details_result, 'w')
+ print('Start compute hitrate...')
+ total_hits, total_gt_count = compute_hitrate(g, gt_all, details_writer,
+ gt_table)
+ var_total_hitrate, var_worker_count = reduce_hitrate(
+ cluster, total_hits, total_gt_count, task_index)
+
+ with tf.train.MonitoredTrainingSession(
+ master=server.target, is_chief=(task_index == 0)) as sess:
+ outs = sess.run([var_total_hitrate, var_worker_count])
+
+ # write after all workers have completed the calculation of hitrate.
+ print('outs: ', outs)
+ if outs[1] == worker_count:
+ logging.info(outs)
+ with tf.gfile.GFile(total_hitrate_result, 'w') as total_writer:
+ total_writer.write(str(outs[0]))
+
+ details_writer.close()
+ g.close()
+ print('Compute hitrate done.')
+
+
+if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
+ main()
diff --git a/easy_rec/python/tools/hit_rate_pai.py b/easy_rec/python/tools/hit_rate_pai.py
new file mode 100644
index 000000000..977df20be
--- /dev/null
+++ b/easy_rec/python/tools/hit_rate_pai.py
@@ -0,0 +1,138 @@
+# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Evaluation of Top k hitrate."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import tensorflow as tf
+
+from easy_rec.python.utils import io_util
+from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
+from easy_rec.python.utils.hit_rate_utils import load_graph
+from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+flags.DEFINE_integer('task_index', None, 'Task index')
+flags.DEFINE_integer('task_count', None, 'Task count')
+flags.DEFINE_string('job_name', None, 'worker or ps or aligraph')
+flags.DEFINE_string('ps_hosts', '', 'ps hosts')
+flags.DEFINE_string('worker_hosts', '', 'worker hosts')
+flags.DEFINE_string('tables', '', 'input odps tables name')
+flags.DEFINE_string('outputs', '', 'ouput odps tables name')
+flags.DEFINE_integer('batch_size', 512, 'batch size')
+flags.DEFINE_integer('emb_dim', 128, 'embedding dimension')
+flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i')
+flags.DEFINE_integer('top_k', '5', 'top_k hitrate.')
+flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).')
+flags.DEFINE_bool('knn_strict', False, 'use exact search.')
+flags.DEFINE_integer('timeout', '60', 'timeout')
+flags.DEFINE_integer('num_interests', 1, 'max number of interests')
+
+
+def compute_hitrate(g, gt_reader, hitrate_writer):
+ """Compute hitrate of each worker.
+
+ Args:
+ g: a GL Graph instance.
+ gt_reader: odps reader of input trigger_items_table.
+ hitrate_writer: odps writer of hitrate table.
+
+ Returns:
+ total_hits: total hits of this worker.
+ total_gt_count: total count of ground truth items of this worker.
+ """
+ total_hits = 0.0
+ total_gt_count = 0.0
+ while True:
+ try:
+ gt_record = gt_reader.read(FLAGS.batch_size)
+ hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \
+ compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k)
+ total_hits += hits
+ total_gt_count += gt_count
+ topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids]
+ topk_dists = [
+ ','.join(str(x) for x in dists) for dists in recall_distances
+ ]
+ bad_cases = [','.join(str(x) for x in case) for case in bad_cases]
+ bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists]
+
+ hitrate_writer.write(
+ list(
+ zip(src_ids, topk_recalls, topk_dists, hitrates, bad_cases,
+ bad_dists)),
+ indices=[0, 1, 2, 3, 4, 5])
+ except tf.python_io.OutOfRangeException:
+ break
+ return total_hits, total_gt_count
+
+
+def main():
+ worker_count = len(FLAGS.worker_hosts.split(','))
+ input_tables = FLAGS.tables.split(',')
+ if FLAGS.recall_type == 'u2i':
+ i_emb_table, gt_table = input_tables
+ g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
+ FLAGS.knn_strict)
+ else:
+ i_emb_table, gt_table = input_tables[-2], input_tables[-1]
+ g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
+ FLAGS.knn_strict)
+ hitrate_details_table, total_hitrate_table = FLAGS.outputs.split(',')
+
+ cluster = tf.train.ClusterSpec({
+ 'ps': FLAGS.ps_hosts.split(','),
+ 'worker': FLAGS.worker_hosts.split(',')
+ })
+ server = tf.train.Server(
+ cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
+ if FLAGS.job_name == 'ps':
+ server.join()
+ else:
+ g.init(task_index=FLAGS.task_index, task_count=worker_count)
+ gt_reader = tf.python_io.TableReader(
+ gt_table,
+ slice_id=FLAGS.task_index,
+ slice_count=worker_count,
+ capacity=2048)
+ details_writer = tf.python_io.TableWriter(
+ hitrate_details_table, slice_id=FLAGS.task_index)
+ print('Start compute hitrate...')
+ total_hits, total_gt_count = compute_hitrate(g, gt_reader, details_writer)
+ var_total_hitrate, var_worker_count = reduce_hitrate(
+ cluster, total_hits, total_gt_count, FLAGS.task_index)
+
+ with tf.train.MonitoredTrainingSession(
+ master=server.target, is_chief=(FLAGS.task_index == 0)) as sess:
+ outs = sess.run([var_total_hitrate, var_worker_count])
+
+ # write after all workers have completed the calculation of hitrate.
+ if outs[1] == worker_count:
+ with tf.python_io.TableWriter(total_hitrate_table) as total_writer:
+ total_writer.write([outs[0]], indices=[0])
+
+ gt_reader.close()
+ details_writer.close()
+ g.close()
+ print('Compute hitrate done.')
+
+
+if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
+ main()
diff --git a/easy_rec/python/tools/pre_check.py b/easy_rec/python/tools/pre_check.py
new file mode 100644
index 000000000..da7f1923b
--- /dev/null
+++ b/easy_rec/python/tools/pre_check.py
@@ -0,0 +1,120 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
+import logging
+import os
+import sys
+
+import tensorflow as tf
+
+from easy_rec.python.input.input import Input
+from easy_rec.python.utils import config_util
+from easy_rec.python.utils import fg_util
+from easy_rec.python.utils import io_util
+from easy_rec.python.utils.check_utils import check_env_and_input_path
+from easy_rec.python.utils.check_utils import check_sequence
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+logging.basicConfig(
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
+ level=logging.INFO)
+tf.app.flags.DEFINE_string('pipeline_config_path', None,
+ 'Path to pipeline config '
+ 'file.')
+tf.app.flags.DEFINE_multi_string(
+ 'data_input_path', None, help='data input path')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def _get_input_fn(data_config,
+ feature_configs,
+ data_path=None,
+ export_config=None):
+ """Build estimator input function.
+
+ Args:
+ data_config: dataset config
+ feature_configs: FeatureConfig
+ data_path: input_data_path
+ export_config: configuration for exporting models,
+ only used to build input_fn when exporting models
+
+ Returns:
+ subclass of Input
+ """
+ input_class_map = {y: x for x, y in data_config.InputType.items()}
+ input_cls_name = input_class_map[data_config.input_type]
+
+ input_class = Input.create_class(input_cls_name)
+ if 'TF_CONFIG' in os.environ:
+ tf_config = json.loads(os.environ['TF_CONFIG'])
+ worker_num = len(tf_config['cluster']['worker'])
+ task_index = tf_config['task']['index']
+ else:
+ worker_num = 1
+ task_index = 0
+
+ input_obj = input_class(
+ data_config,
+ feature_configs,
+ data_path,
+ task_index=task_index,
+ task_num=worker_num,
+ check_mode=True)
+ input_fn = input_obj.create_input(export_config)
+ return input_fn
+
+
+def loda_pipeline_config(pipeline_config_path):
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path, False)
+ if pipeline_config.fg_json_path:
+ fg_util.load_fg_json_to_config(pipeline_config)
+ config_util.auto_expand_share_feature_configs(pipeline_config)
+ return pipeline_config
+
+
+def run_check(pipeline_config, input_path):
+ logging.info('data_input_path: %s' % input_path)
+ check_env_and_input_path(pipeline_config, input_path)
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
+ eval_input_fn = _get_input_fn(pipeline_config.data_config, feature_configs,
+ input_path)
+ eval_spec = tf.estimator.EvalSpec(
+ name='val',
+ input_fn=eval_input_fn,
+ steps=None,
+ throttle_secs=10,
+ exporters=[])
+ input_iter = eval_spec.input_fn(
+ mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
+ with tf.Session() as sess:
+ try:
+ while (True):
+ input_feas, input_lbls = input_iter.get_next()
+ features = sess.run(input_feas)
+ check_sequence(pipeline_config, features)
+ except tf.errors.OutOfRangeError:
+ logging.info('pre-check finish...')
+
+
+def main(argv):
+ assert FLAGS.pipeline_config_path, 'pipeline_config_path should not be empty when checking!'
+ pipeline_config = loda_pipeline_config(FLAGS.pipeline_config_path)
+
+ if FLAGS.data_input_path:
+ input_path = ','.join(FLAGS.data_input_path)
+ else:
+ assert pipeline_config.train_input_path or pipeline_config.eval_input_path, \
+ 'input_path should not be empty when checking!'
+ input_path = pipeline_config.train_input_path + ',' + pipeline_config.eval_input_path
+
+ run_check(pipeline_config, input_path)
+
+
+if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
+ tf.app.run()
diff --git a/easy_rec/python/tools/predict_and_chk.py b/easy_rec/python/tools/predict_and_chk.py
index 8cc0f70f1..bc7353f76 100644
--- a/easy_rec/python/tools/predict_and_chk.py
+++ b/easy_rec/python/tools/predict_and_chk.py
@@ -3,12 +3,20 @@
import argparse
import json
import logging
+import os
import sys
import numpy as np
+import easy_rec
from easy_rec.python.inference.predictor import Predictor
+try:
+ import tensorflow as tf
+ tf.load_op_library(os.path.join(easy_rec.ops_dir, 'libembed_op.so'))
+except Exception as ex:
+ logging.warning('exception: %s' % str(ex))
+
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
@@ -23,6 +31,11 @@
'--cmp_res_path', type=str, default=None, help='compare result path')
parser.add_argument(
'--cmp_key', type=str, default='probs', help='compare key')
+ parser.add_argument(
+ '--rtp_fea_id',
+ type=int,
+ default=-1,
+ help='rtp feature column index, default to the last column')
parser.add_argument('--tol', type=float, default=1e-5, help='tolerance')
parser.add_argument(
'--label_id',
@@ -30,9 +43,15 @@
type=int,
help='the label column, which is to be excluded')
parser.add_argument(
- '--separator', type=str, default='', help='separator between features')
+ '--separator',
+ type=str,
+ default='',
+ help='separator between features, default to \\u0002')
parser.add_argument(
- '--rtp_separator', type=str, default='', help='separator')
+ '--rtp_separator',
+ type=str,
+ default='',
+ help='separator, default to \\u0001')
args = parser.parse_args()
if not args.saved_model_dir:
@@ -51,17 +70,22 @@
logging.info('separator: ' + args.separator)
predictor = Predictor(args.saved_model_dir)
+ if len(predictor.input_names) == 1:
+ assert len(
+ args.label_id
+ ) == 0, 'label_id should not be set if rtp feature format is used.'
+
with open(args.input_path, 'r') as fin:
batch_input = []
for line_str in fin:
line_str = line_str.strip()
line_tok = line_str.split(args.rtp_separator)
- feature = line_tok[-1]
+ feature = line_tok[args.rtp_fea_id]
feature = [
x for fid, x in enumerate(feature.split(args.separator))
if fid not in args.label_id
]
- if len(predictor.input_names) == 1:
+ if 'features' in predictor.input_names:
feature = args.separator.join(feature)
batch_input.append(feature)
output = predictor.predict(batch_input)
diff --git a/easy_rec/python/tools/read_kafka.py b/easy_rec/python/tools/read_kafka.py
new file mode 100644
index 000000000..57578b863
--- /dev/null
+++ b/easy_rec/python/tools/read_kafka.py
@@ -0,0 +1,55 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import logging
+import os
+import sys
+
+from kafka import KafkaConsumer
+from kafka.structs import TopicPartition
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--servers', type=str, default='localhost:9092')
+ parser.add_argument('--topic', type=str, default=None)
+ parser.add_argument('--group', type=str, default='consumer')
+ parser.add_argument('--partitions', type=str, default=None)
+ parser.add_argument('--timeout', type=float, default=float('inf'))
+ parser.add_argument('--save_dir', type=str, default=None)
+ args = parser.parse_args()
+
+ if args.topic is None:
+ logging.error('--topic is not set')
+ sys.exit(1)
+
+ servers = args.servers.split(',')
+ consumer = KafkaConsumer(
+ group_id=args.group,
+ bootstrap_servers=servers,
+ consumer_timeout_ms=args.timeout * 1000)
+
+ if args.partitions is not None:
+ partitions = [int(x) for x in args.partitions.split(',')]
+ else:
+ partitions = consumer.partitions_for_topic(args.topic)
+ logging.info('partitions: %s' % partitions)
+
+ topics = [
+ TopicPartition(topic=args.topic, partition=part_id)
+ for part_id in partitions
+ ]
+ consumer.assign(topics)
+ consumer.seek_to_beginning()
+
+ record_id = 0
+ for x in consumer:
+ logging.info('%d: key=%s\toffset=%d\ttimestamp=%d\tlen=%d' %
+ (record_id, x.key, x.offset, x.timestamp, len(x.value)))
+ if args.save_dir is not None:
+ save_path = os.path.join(args.save_dir, x.key)
+ with open(save_path, 'wb') as fout:
+ fout.write(x.value)
+ record_id += 1
diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py
index 8a36b8a5c..d86791708 100644
--- a/easy_rec/python/tools/split_model_pai.py
+++ b/easy_rec/python/tools/split_model_pai.py
@@ -2,21 +2,32 @@
import copy
import logging
import os
+import sys
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
+from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model.utils_impl import get_variables_path
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
+from easy_rec.python.utils import io_util
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+ from tensorflow.python.saved_model.path_helpers import get_variables_path
+else:
+ from tensorflow.python.saved_model.utils_impl import get_variables_path
+
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_dir', '', '')
tf.app.flags.DEFINE_string('user_model_dir', '', '')
tf.app.flags.DEFINE_string('item_model_dir', '', '')
+tf.app.flags.DEFINE_string('user_fg_json_path', '', '')
+tf.app.flags.DEFINE_string('item_fg_json_path', '', '')
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
@@ -196,7 +207,7 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
graph = ops.get_default_graph()
importer.import_graph_def(inference_graph, name='')
for name in variables_to_keep:
- variable = graph.get_tensor_by_name(name)
+ variable = _from_proto_fn(variable_protos[name.split(':')[0]])
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
saver = tf_saver.Saver()
saver.restore(sess, get_variables_path(model_dir))
@@ -234,9 +245,15 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
config_path = os.path.join(model_dir, 'assets/pipeline.config')
assert tf.gfile.Exists(config_path)
dst_path = os.path.join(part_dir, 'assets')
- dst_config_path = os.path.join(part_dir, 'assets/pipeline.config')
+ dst_config_path = os.path.join(dst_path, 'pipeline.config')
tf.gfile.MkDir(dst_path)
tf.gfile.Copy(config_path, dst_config_path)
+ if part_name == 'user' and FLAGS.user_fg_json_path:
+ dst_fg_path = os.path.join(dst_path, 'fg.json')
+ tf.gfile.Copy(FLAGS.user_fg_json_path, dst_fg_path)
+ if part_name == 'item' and FLAGS.item_fg_json_path:
+ dst_fg_path = os.path.join(dst_path, 'fg.json')
+ tf.gfile.Copy(FLAGS.item_fg_json_path, dst_fg_path)
def main(argv):
@@ -265,4 +282,5 @@ def main(argv):
if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
diff --git a/easy_rec/python/tools/split_pdn_model_pai.py b/easy_rec/python/tools/split_pdn_model_pai.py
new file mode 100644
index 000000000..78932c297
--- /dev/null
+++ b/easy_rec/python/tools/split_pdn_model_pai.py
@@ -0,0 +1,272 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import copy
+import logging
+import os
+import sys
+
+import tensorflow as tf
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model.utils_impl import get_variables_path
+from tensorflow.python.tools import saved_model_utils
+from tensorflow.python.training import saver as tf_saver
+
+from easy_rec.python.utils import io_util
+
+FLAGS = tf.app.flags.FLAGS
+tf.app.flags.DEFINE_string('model_dir', '', '')
+tf.app.flags.DEFINE_string('trigger_model_dir', '', '')
+tf.app.flags.DEFINE_string('sim_model_dir', '', '')
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+
+def search_pb(directory):
+ dir_list = []
+ for root, dirs, files in tf.gfile.Walk(directory):
+ for f in files:
+ _, ext = os.path.splitext(f)
+ if ext == '.pb':
+ dir_list.append(root)
+ if len(dir_list) == 0:
+ raise ValueError('savedmodel is not found in directory %s' % directory)
+ elif len(dir_list) > 1:
+ raise ValueError('multiple saved model found in directory %s' % directory)
+
+ return dir_list[0]
+
+
+def _node_name(name):
+ if name.startswith('^'):
+ return name[1:]
+ else:
+ return name.split(':')[0]
+
+
+def extract_sub_graph(graph_def, dest_nodes, variable_protos):
+ """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
+
+ Args:
+ graph_def: graph_pb2.GraphDef
+ dest_nodes: a list includes output node names
+
+ Returns:
+ out: the GraphDef of the sub-graph.
+ variables_to_keep: variables to be kept for saver.
+ """
+ if not isinstance(graph_def, graph_pb2.GraphDef):
+ raise TypeError('graph_def must be a graph_pb2.GraphDef proto.')
+
+ edges = {}
+ name_to_node_map = {}
+ node_seq = {}
+ seq = 0
+ nodes_to_keep = set()
+ variables_to_keep = set()
+
+ for node in graph_def.node:
+ n = _node_name(node.name)
+ name_to_node_map[n] = node
+ edges[n] = [_node_name(item) for item in node.input]
+ node_seq[n] = seq
+ seq += 1
+ for d in dest_nodes:
+ assert d in name_to_node_map, "'%s' is not in graph" % d
+
+ next_to_visit = dest_nodes[:]
+ while next_to_visit:
+ n = next_to_visit[0]
+
+ if n in variable_protos:
+ proto = variable_protos[n]
+ next_to_visit.append(_node_name(proto.initial_value_name))
+ next_to_visit.append(_node_name(proto.initializer_name))
+ next_to_visit.append(_node_name(proto.snapshot_name))
+ variables_to_keep.add(proto.variable_name)
+
+ del next_to_visit[0]
+ if n in nodes_to_keep:
+ continue
+ # make sure n is in edges
+ if n in edges:
+ nodes_to_keep.add(n)
+ next_to_visit += edges[n]
+ nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
+
+ out = graph_pb2.GraphDef()
+ for n in nodes_to_keep_list:
+ out.node.extend([copy.deepcopy(name_to_node_map[n])])
+ out.library.CopyFrom(graph_def.library)
+ out.versions.CopyFrom(graph_def.versions)
+
+ return out, variables_to_keep
+
+
+def load_meta_graph_def(model_dir):
+ """Load meta graph def in saved model.
+
+ Args:
+ model_dir: saved model directory.
+
+ Returns:
+ meta_graph_def: a MetaGraphDef.
+ variable_protos: a dict of VariableDef.
+ input_tensor_names: signature inputs in saved model.
+ output_tensor_names: signature outputs in saved model.
+ """
+ input_tensor_names = {}
+ output_tensor_names = {}
+ variable_protos = {}
+
+ meta_graph_def = saved_model_utils.get_meta_graph_def(
+ model_dir, tf.saved_model.tag_constants.SERVING)
+ signatures = meta_graph_def.signature_def
+ collections = meta_graph_def.collection_def
+
+ # parse collection_def in SavedModel
+ for key, col_def in collections.items():
+ if key in ops.GraphKeys._VARIABLE_COLLECTIONS:
+ tf.logging.info('[Collection] %s:' % key)
+ for value in col_def.bytes_list.value:
+ proto_type = ops.get_collection_proto_type(key)
+ proto = proto_type()
+ proto.ParseFromString(value)
+ tf.logging.info('%s' % proto.variable_name)
+ variable_node_name = _node_name(proto.variable_name)
+ if variable_node_name not in variable_protos:
+ variable_protos[variable_node_name] = proto
+
+ # parse signature info for SavedModel
+ for sig_name in signatures:
+ if signatures[
+ sig_name].method_name == tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
+ tf.logging.info('[Signature] inputs:')
+ for input_name in signatures[sig_name].inputs:
+ input_tensor_shape = []
+ input_tensor = signatures[sig_name].inputs[input_name]
+ for dim in input_tensor.tensor_shape.dim:
+ input_tensor_shape.append(int(dim.size))
+ tf.logging.info('"%s": %s; %s' %
+ (input_name, _TYPE_TO_STRING[input_tensor.dtype],
+ input_tensor_shape))
+ input_tensor_names[input_name] = input_tensor.name
+ tf.logging.info('[Signature] outputs:')
+ for output_name in signatures[sig_name].outputs:
+ output_tensor_shape = []
+ output_tensor = signatures[sig_name].outputs[output_name]
+ for dim in output_tensor.tensor_shape.dim:
+ output_tensor_shape.append(int(dim.size))
+ tf.logging.info('"%s": %s; %s' %
+ (output_name, _TYPE_TO_STRING[output_tensor.dtype],
+ output_tensor_shape))
+ output_tensor_names[output_name] = output_tensor.name
+
+ return meta_graph_def, variable_protos, input_tensor_names, output_tensor_names
+
+
+def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
+ output_tensor_names, part_name, part_dir):
+ """Export subpart saved model.
+
+ Args:
+ model_dir: saved model directory.
+ meta_graph_def: a MetaGraphDef.
+ variable_protos: a dict of VariableDef.
+ input_tensor_names: signature inputs in saved model.
+ output_tensor_names: signature outputs in saved model.
+ part_name: subpart model name, user or item.
+ part_dir: subpart model export directory.
+ """
+ output_tensor_names = {
+ x: output_tensor_names[x]
+ for x in output_tensor_names.keys()
+ if part_name in x
+ }
+ output_node_names = [
+ _node_name(output_tensor_names[x]) for x in output_tensor_names.keys()
+ ]
+
+ inference_graph, variables_to_keep = extract_sub_graph(
+ meta_graph_def.graph_def, output_node_names, variable_protos)
+
+ tf.reset_default_graph()
+ with tf.Session() as sess:
+ with sess.graph.as_default():
+ graph = ops.get_default_graph()
+ importer.import_graph_def(inference_graph, name='')
+ for name in variables_to_keep:
+ variable = graph.get_tensor_by_name(name)
+ graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
+ saver = tf_saver.Saver()
+ saver.restore(sess, get_variables_path(model_dir))
+
+ builder = tf.saved_model.builder.SavedModelBuilder(part_dir)
+ signature_inputs = {}
+ for input_name in input_tensor_names:
+ try:
+ tensor_info = tf.saved_model.utils.build_tensor_info(
+ graph.get_tensor_by_name(input_tensor_names[input_name]))
+ signature_inputs[input_name] = tensor_info
+ except Exception:
+ print('ignore input: %s' % input_name)
+
+ signature_outputs = {}
+ for output_name in output_tensor_names:
+ tensor_info = tf.saved_model.utils.build_tensor_info(
+ graph.get_tensor_by_name(output_tensor_names[output_name]))
+ signature_outputs[output_name] = tensor_info
+
+ prediction_signature = (
+ tf.saved_model.signature_def_utils.build_signature_def(
+ inputs=signature_inputs,
+ outputs=signature_outputs,
+ method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
+ ))
+
+ builder.add_meta_graph_and_variables(
+ sess, [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ prediction_signature,
+ })
+ builder.save()
+ config_path = os.path.join(model_dir, 'assets/pipeline.config')
+ assert tf.gfile.Exists(config_path)
+ dst_path = os.path.join(part_dir, 'assets')
+ dst_config_path = os.path.join(dst_path, 'pipeline.config')
+ tf.gfile.MkDir(dst_path)
+ tf.gfile.Copy(config_path, dst_config_path)
+
+
+def main(argv):
+ model_dir = search_pb(FLAGS.model_dir)
+ tf.logging.info('Loading meta graph...')
+ meta_graph_def, variable_protos, input_tensor_names, output_tensor_names = load_meta_graph_def(
+ model_dir)
+ tf.logging.info('Exporting trigger part model...')
+ export(
+ model_dir,
+ meta_graph_def,
+ variable_protos,
+ input_tensor_names,
+ output_tensor_names,
+ part_name='trigger_out',
+ part_dir=FLAGS.trigger_model_dir)
+ tf.logging.info('Exporting sim part model...')
+ export(
+ model_dir,
+ meta_graph_def,
+ variable_protos,
+ input_tensor_names,
+ output_tensor_names,
+ part_name='sim_out',
+ part_dir=FLAGS.sim_model_dir)
+
+
+if __name__ == '__main__':
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
+ tf.app.run()
diff --git a/easy_rec/python/tools/view_saved_model.py b/easy_rec/python/tools/view_saved_model.py
new file mode 100644
index 000000000..022bcf1aa
--- /dev/null
+++ b/easy_rec/python/tools/view_saved_model.py
@@ -0,0 +1,39 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import logging
+
+from google.protobuf import text_format
+from tensorflow.core.protobuf import saved_model_pb2
+from tensorflow.python.platform.gfile import GFile
+
+logging.basicConfig(
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
+ level=logging.INFO)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input', type=str, default=None, help='saved model path')
+ parser.add_argument(
+ '--output', type=str, default=None, help='saved model save path')
+ args = parser.parse_args()
+
+ assert args.input is not None and args.output is not None
+
+ logging.info('saved_model_path: %s' % args.input)
+
+ saved_model = saved_model_pb2.SavedModel()
+ if args.input.endswith('.pb'):
+ with GFile(args.input, 'rb') as fin:
+ saved_model.ParseFromString(fin.read())
+ else:
+ with GFile(args.input, 'r') as fin:
+ text_format.Merge(fin.read(), saved_model)
+
+ if args.output.endswith('.pbtxt'):
+ with GFile(args.output, 'w') as fout:
+ fout.write(text_format.MessageToString(saved_model, as_utf8=True))
+ else:
+ with GFile(args.output, 'wb') as fout:
+ fout.write(saved_model.SerializeToString())
diff --git a/easy_rec/python/tools/write_kafka.py b/easy_rec/python/tools/write_kafka.py
new file mode 100644
index 000000000..5dfa7dfd2
--- /dev/null
+++ b/easy_rec/python/tools/write_kafka.py
@@ -0,0 +1,65 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import logging
+import sys
+
+# from kafka import KafkaConsumer
+from kafka import KafkaAdminClient
+from kafka import KafkaProducer
+from kafka.admin import NewTopic
+
+# from kafka.structs import TopicPartition
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--servers', type=str, default='localhost:9092')
+ parser.add_argument('--topic', type=str, default=None)
+ parser.add_argument('--group', type=str, default='consumer')
+ parser.add_argument('--partitions', type=str, default=None)
+ parser.add_argument('--timeout', type=float, default=float('inf'))
+ # file to send
+ parser.add_argument('--input_path', type=str, default=None)
+ args = parser.parse_args()
+
+ if args.input_path is None:
+ logging.error('input_path is not set')
+ sys.exit(1)
+
+ if args.topic is None:
+ logging.error('topic is not set')
+ sys.exit(1)
+
+ servers = args.servers.split(',')
+
+ admin_clt = KafkaAdminClient(bootstrap_servers=servers)
+ if args.topic not in admin_clt.list_topics():
+ admin_clt.create_topics(
+ new_topics=[
+ NewTopic(
+ name=args.topic,
+ num_partitions=1,
+ replication_factor=1,
+ topic_configs={'max.message.bytes': 1024 * 1024 * 1024})
+ ],
+ validate_only=False)
+ logging.info('create increment save topic: %s' % args.topic)
+ admin_clt.close()
+
+ producer = KafkaProducer(
+ bootstrap_servers=servers,
+ request_timeout_ms=args.timeout * 1000,
+ api_version=(0, 10, 1))
+
+ i = 1
+ with open(args.input_path, 'r') as fin:
+ for line_str in fin:
+ producer.send(args.topic, line_str.encode('utf-8'))
+ i += 1
+ break
+ if i % 100 == 0:
+ logging.info('progress: %d' % i)
+ producer.close()
diff --git a/easy_rec/python/train_eval.py b/easy_rec/python/train_eval.py
index 97a22df51..bafdf0c1a 100644
--- a/easy_rec/python/train_eval.py
+++ b/easy_rec/python/train_eval.py
@@ -1,17 +1,27 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
import json
import logging
import os
import tensorflow as tf
-from tensorflow.python.lib.io import file_io
from easy_rec.python.main import _train_and_evaluate_impl
+from easy_rec.python.protos.train_pb2 import DistributionStrategy
from easy_rec.python.utils import config_util
+from easy_rec.python.utils import ds_util
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import hpo_util
+from easy_rec.python.utils.config_util import process_neg_sampler_data_path
+from easy_rec.python.utils.config_util import set_eval_input_path
+from easy_rec.python.utils.config_util import set_train_input_path
+
+if tf.__version__.startswith('1.'):
+ from tensorflow.python.platform import gfile
+else:
+ import tensorflow.io.gfile as gfile
from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
@@ -21,127 +31,171 @@
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
-tf.app.flags.DEFINE_string('pipeline_config_path', None,
- 'Path to pipeline config '
- 'file.')
-tf.app.flags.DEFINE_bool('continue_train', False,
- 'continue train using existing '
- 'model dir')
-tf.app.flags.DEFINE_string(
- 'hpo_param_path', None, help='hyperparam tuning param path')
-tf.app.flags.DEFINE_string(
- 'hpo_metric_save_path', None, help='hyperparameter save metric path')
-tf.app.flags.DEFINE_string(
- 'model_dir', None, help='will update the model_dir in pipeline_config')
-tf.app.flags.DEFINE_multi_string(
- 'train_input_path', None, help='train data input path')
-tf.app.flags.DEFINE_multi_string(
- 'eval_input_path', None, help='eval data input path')
-tf.app.flags.DEFINE_string(
- 'fine_tune_checkpoint',
- None,
- help='will update the train_config.fine_tune_checkpoint in pipeline_config')
-tf.app.flags.DEFINE_string(
- 'edit_config_json',
- None,
- help='edit pipeline config str, example: {"model_dir":"experiments/",'
- '"feature_config.feature[0].boundaries":[4,5,6,7]}')
-tf.app.flags.DEFINE_bool(
- 'ignore_finetune_ckpt_error', False,
- 'During incremental training, ignore the problem of missing fine_tune_checkpoint files'
-)
-tf.app.flags.DEFINE_string('odps_config', None, help='odps config path')
-tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
-FLAGS = tf.app.flags.FLAGS
-
-
-def main(argv):
- if FLAGS.pipeline_config_path is not None:
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--pipeline_config_path',
+ type=str,
+ default=None,
+ help='Path to pipeline config file.')
+ parser.add_argument(
+ '--continue_train',
+ action='http://github.com/store_true',
+ default=False,
+ help='continue train using existing model_dir')
+ parser.add_argument(
+ '--hpo_param_path',
+ type=str,
+ default=None,
+ help='hyperparam tuning param path')
+ parser.add_argument(
+ '--hpo_metric_save_path',
+ type=str,
+ default=None,
+ help='hyperparameter save metric path')
+ parser.add_argument(
+ '--model_dir',
+ type=str,
+ default=None,
+ help='will update the model_dir in pipeline_config')
+ parser.add_argument(
+ '--train_input_path',
+ type=str,
+ nargs='*',
+ default=None,
+ help='train data input path')
+ parser.add_argument(
+ '--eval_input_path',
+ type=str,
+ nargs='*',
+ default=None,
+ help='eval data input path')
+ parser.add_argument(
+ '--fit_on_eval',
+ action='http://github.com/store_true',
+ default=False,
+ help='Fit evaluation data after fitting and evaluating train data')
+ parser.add_argument(
+ '--fit_on_eval_steps',
+ type=int,
+ default=None,
+ help='Fit evaluation data steps')
+ parser.add_argument(
+ '--fine_tune_checkpoint',
+ type=str,
+ default=None,
+ help='will update the train_config.fine_tune_checkpoint in pipeline_config'
+ )
+ parser.add_argument(
+ '--edit_config_json',
+ type=str,
+ default=None,
+ help='edit pipeline config str, example: {"model_dir":"experiments/",'
+ '"feature_config.feature[0].boundaries":[4,5,6,7]}')
+ parser.add_argument(
+ '--ignore_finetune_ckpt_error',
+ action='http://github.com/store_true',
+ default=False,
+ help='During incremental training, ignore the problem of missing fine_tune_checkpoint files'
+ )
+ parser.add_argument(
+ '--odps_config', type=str, default=None, help='odps config path')
+ parser.add_argument(
+ '--is_on_ds', action='http://github.com/store_true', default=False, help='is on ds')
+ parser.add_argument(
+ '--check_mode',
+ action='http://github.com/store_true',
+ default=False,
+ help='is use check mode')
+ parser.add_argument(
+ '--selected_cols', type=str, default=None, help='select input columns')
+ parser.add_argument('--gpu', type=str, default=None, help='gpu id')
+ args, extra_args = parser.parse_known_args()
+
+ if args.gpu is not None:
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
+
+ edit_config_json = {}
+ if args.edit_config_json:
+ edit_config_json = json.loads(args.edit_config_json)
+
+ if extra_args is not None and len(extra_args) > 0:
+ config_util.parse_extra_config_param(extra_args, edit_config_json)
+
+ if args.pipeline_config_path is not None:
pipeline_config = config_util.get_configs_from_pipeline_file(
- FLAGS.pipeline_config_path, False)
- if FLAGS.model_dir:
- pipeline_config.model_dir = FLAGS.model_dir
+ args.pipeline_config_path, False)
+ if args.selected_cols:
+ pipeline_config.data_config.selected_cols = args.selected_cols
+ if args.model_dir:
+ pipeline_config.model_dir = args.model_dir
logging.info('update model_dir to %s' % pipeline_config.model_dir)
- if FLAGS.train_input_path:
- pipeline_config.train_input_path = ','.join(FLAGS.train_input_path)
- logging.info('update train_input_path to %s' %
- pipeline_config.train_input_path)
- if FLAGS.eval_input_path:
- pipeline_config.eval_input_path = ','.join(FLAGS.eval_input_path)
- logging.info('update eval_input_path to %s' %
- pipeline_config.eval_input_path)
- if FLAGS.fine_tune_checkpoint:
- if file_io.file_exists(FLAGS.fine_tune_checkpoint):
- pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint
- logging.info('update fine_tune_checkpoint to %s' %
- pipeline_config.train_config.fine_tune_checkpoint)
- else:
- assert FLAGS.ignore_finetune_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % FLAGS.fine_tune_checkpoint
+ if args.train_input_path:
+ set_train_input_path(pipeline_config, args.train_input_path)
+ if args.eval_input_path:
+ set_eval_input_path(pipeline_config, args.eval_input_path)
+
+ if args.fine_tune_checkpoint:
+ ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
+ args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
+
+ if ckpt_path:
+ pipeline_config.train_config.fine_tune_checkpoint = ckpt_path
if pipeline_config.fg_json_path:
fg_util.load_fg_json_to_config(pipeline_config)
- if FLAGS.odps_config:
- os.environ['ODPS_CONFIG_FILE_PATH'] = FLAGS.odps_config
+ if args.odps_config:
+ os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
- if FLAGS.is_on_ds:
+ if len(edit_config_json) > 0:
+ fine_tune_checkpoint = edit_config_json.get('train_config', {}).get(
+ 'fine_tune_checkpoint', None)
+ if fine_tune_checkpoint:
+ ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
+ args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
+ edit_config_json['train_config']['fine_tune_checkpoint'] = ckpt_path
+ config_util.edit_config(pipeline_config, edit_config_json)
+
+ process_neg_sampler_data_path(pipeline_config)
+
+ if args.is_on_ds:
+ ds_util.set_on_ds()
set_tf_config_and_get_train_worker_num_on_ds()
if pipeline_config.train_config.fine_tune_checkpoint:
- fine_tune_ckpt_path = pipeline_config.train_config.fine_tune_checkpoint
- if fine_tune_ckpt_path.endswith('/') or tf.gfile.IsDirectory(
- fine_tune_ckpt_path + '/'):
- fine_tune_ckpt_path = estimator_utils.latest_checkpoint(
- fine_tune_ckpt_path)
- logging.info(
- 'ckpt_path is model_dir, will use the latest checkpoint: %s' %
- fine_tune_ckpt_path)
-
- if fine_tune_ckpt_path.startswith('hdfs://'):
- tmpdir = os.path.dirname(fine_tune_ckpt_path.replace('hdfs://', ''))
- tmpdir = os.path.join('/tmp/experiments', tmpdir)
- logging.info('will cache fine_tune_ckpt to local dir: %s' % tmpdir)
- if tf.gfile.IsDirectory(tmpdir):
- tf.gfile.DeleteRecursively(tmpdir)
- tf.gfile.MakeDirs(tmpdir)
- for src_path in tf.gfile.Glob(fine_tune_ckpt_path + '*'):
- dst_path = os.path.join(tmpdir, os.path.basename(src_path))
- logging.info('will copy %s to local path %s' % (src_path, dst_path))
- tf.gfile.Copy(src_path, dst_path, overwrite=True)
- ckpt_filename = os.path.basename(fine_tune_ckpt_path)
- fine_tune_ckpt_path = os.path.join(tmpdir, ckpt_filename)
- pipeline_config.train_config.fine_tune_checkpoint = fine_tune_ckpt_path
- logging.info('will restore from %s' % fine_tune_ckpt_path)
-
- if FLAGS.hpo_param_path:
- with tf.gfile.GFile(FLAGS.hpo_param_path, 'r') as fin:
+ ds_util.cache_ckpt(pipeline_config)
+
+ if pipeline_config.train_config.train_distribute in [
+ DistributionStrategy.HorovodStrategy,
+ ]:
+ estimator_utils.init_hvd()
+ elif pipeline_config.train_config.train_distribute in [
+ DistributionStrategy.EmbeddingParallelStrategy,
+ DistributionStrategy.SokStrategy
+ ]:
+ estimator_utils.init_hvd()
+ estimator_utils.init_sok()
+
+ if args.hpo_param_path:
+ with gfile.GFile(args.hpo_param_path, 'r') as fin:
hpo_config = json.load(fin)
hpo_params = hpo_config['param']
config_util.edit_config(pipeline_config, hpo_params)
config_util.auto_expand_share_feature_configs(pipeline_config)
- _train_and_evaluate_impl(pipeline_config, FLAGS.continue_train)
+ _train_and_evaluate_impl(pipeline_config, args.continue_train,
+ args.check_mode)
hpo_util.save_eval_metrics(
pipeline_config.model_dir,
- metric_save_path=FLAGS.hpo_metric_save_path,
+ metric_save_path=args.hpo_metric_save_path,
has_evaluator=False)
- elif FLAGS.edit_config_json:
- config_json = json.loads(FLAGS.edit_config_json)
- fine_tune_checkpoint = config_json.get(
- 'train_config.fine_tune_checkpoint', None)
- if fine_tune_checkpoint:
- if not file_io.file_exists(fine_tune_checkpoint):
- assert FLAGS.ignore_finetune_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % fine_tune_checkpoint
- config_json.pop('train_config.fine_tune_checkpoint', None)
- logging.info('fine_tune_checkpoint(%s) is not exists. Drop it.' %
- fine_tune_checkpoint)
- config_util.edit_config(pipeline_config, config_json)
- config_util.auto_expand_share_feature_configs(pipeline_config)
- _train_and_evaluate_impl(pipeline_config, FLAGS.continue_train)
else:
config_util.auto_expand_share_feature_configs(pipeline_config)
- _train_and_evaluate_impl(pipeline_config, FLAGS.continue_train)
+ _train_and_evaluate_impl(
+ pipeline_config,
+ args.continue_train,
+ args.check_mode,
+ fit_on_eval=args.fit_on_eval,
+ fit_on_eval_steps=args.fit_on_eval_steps)
else:
raise ValueError('pipeline_config_path should not be empty when training!')
-
-
-if __name__ == '__main__':
- tf.app.run()
diff --git a/easy_rec/python/utils/__init__.py b/easy_rec/python/utils/__init__.py
index e69de29bb..09dc89476 100644
--- a/easy_rec/python/utils/__init__.py
+++ b/easy_rec/python/utils/__init__.py
@@ -0,0 +1,15 @@
+class conditional(object):
+ """Wrap another context manager and enter it only if condition is true."""
+
+ def __init__(self, condition, contextmanager):
+ self.condition = condition
+ self.contextmanager = contextmanager
+
+ def __enter__(self):
+ """Conditionally enter a context manager."""
+ if self.condition:
+ return self.contextmanager.__enter__()
+
+ def __exit__(self, *args):
+ if self.condition:
+ return self.contextmanager.__exit__(*args)
diff --git a/easy_rec/python/utils/activation.py b/easy_rec/python/utils/activation.py
new file mode 100644
index 000000000..89044f7a3
--- /dev/null
+++ b/easy_rec/python/utils/activation.py
@@ -0,0 +1,120 @@
+# -*- encoding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import numpy as np
+import six
+import tensorflow as tf
+
+from easy_rec.python.utils.load_class import load_by_path
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def dice(_x, axis=-1, epsilon=1e-9, name='dice', training=True):
+ """The Data Adaptive Activation Function in DIN.
+
+ Which can be viewed as a generalization of PReLu,
+ and can adaptively adjust the rectified point according to distribution of input data.
+
+ Arguments
+ - **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis).
+ - **epsilon** : Small float added to variance to avoid dividing by zero.
+
+ References
+ - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
+ Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining.
+ ACM, 2018: 1059-1068.] (https://arxiv.org/pdf/1706.06978.pdf)
+ """
+ alphas = tf.get_variable(
+ 'alpha_' + name,
+ _x.get_shape()[-1],
+ initializer=tf.constant_initializer(0.0),
+ dtype=tf.float32)
+ inputs_normed = tf.layers.batch_normalization(
+ inputs=_x,
+ axis=axis,
+ epsilon=epsilon,
+ center=False,
+ scale=False,
+ training=training)
+ x_p = tf.sigmoid(inputs_normed)
+ return alphas * (1.0 - x_p) * _x + x_p * _x
+
+
+def gelu(x, name='gelu'):
+ """Gaussian Error Linear Unit.
+
+ This is a smoother version of the RELU.
+ Original paper: https://arxiv.org/abs/1606.08415
+
+ Args:
+ x: float Tensor to perform activation.
+ name: name for this activation
+
+ Returns:
+ `x` with the GELU activation applied.
+ """
+ with tf.name_scope(name):
+ cdf = 0.5 * (1.0 + tf.tanh(
+ (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
+ return x * cdf
+
+
+def swish(x, name='swish'):
+ with tf.name_scope(name):
+ return x * tf.sigmoid(x)
+
+
+def get_activation(activation_string, **kwargs):
+ """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
+
+ Args:
+ activation_string: String name of the activation function.
+
+ Returns:
+ A Python function corresponding to the activation function. If
+ `activation_string` is None, empty, or "linear", this will return None.
+ If `activation_string` is not a string, it will return `activation_string`.
+
+ Raises:
+ ValueError: The `activation_string` does not correspond to a known
+ activation.
+ """
+ # We assume that anything that's not a string is already an activation
+ # function, so we just return it.
+ if not isinstance(activation_string, six.string_types):
+ return activation_string
+
+ if not activation_string:
+ return None
+
+ act = activation_string.lower()
+ if act == 'linear':
+ return None
+ elif act == 'relu':
+ return tf.nn.relu
+ elif act == 'gelu':
+ return gelu
+ elif act == 'leaky_relu':
+ return tf.nn.leaky_relu
+ elif act == 'prelu':
+ if len(kwargs) == 0:
+ return tf.nn.leaky_relu
+ return tf.keras.layers.PReLU(**kwargs)
+ elif act == 'dice':
+ return lambda x, name='dice': dice(x, name=name, **kwargs)
+ elif act == 'elu':
+ return tf.nn.elu
+ elif act == 'selu':
+ return tf.nn.selu
+ elif act == 'tanh':
+ return tf.tanh
+ elif act == 'swish':
+ if tf.__version__ < '1.13.0':
+ return swish
+ return tf.nn.swish
+ elif act == 'sigmoid':
+ return tf.nn.sigmoid
+ else:
+ return load_by_path(activation_string)
diff --git a/easy_rec/python/utils/check_utils.py b/easy_rec/python/utils/check_utils.py
new file mode 100644
index 000000000..5a7551745
--- /dev/null
+++ b/easy_rec/python/utils/check_utils.py
@@ -0,0 +1,87 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import tensorflow as tf
+
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def check_split(line, sep, requried_field_num, field_name=''):
+ assert sep, 'must have separator.' + (' field: %s.' %
+ field_name) if field_name else ''
+
+ for one_line in line:
+ field_num = len(one_line.split(sep))
+ if field_name:
+ assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, field: %s, value: %s, ' \
+ 'please check separator and data.' % \
+ (sep, field_num, requried_field_num, field_name, one_line)
+ else:
+ assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, current line is: %s, ' \
+ 'please check separator and data.' % \
+ (sep, field_num, requried_field_num, one_line)
+ assert field_num == requried_field_num, assert_info
+ return True
+
+
+def check_string_to_number(field_vals, field_name):
+ for val in field_vals:
+ try:
+ float(val)
+ except: # noqa: E722
+ assert False, 'StringToNumber ERROR: cannot convert string_to_number, field: %s, value: %s. ' \
+ 'please check data.' % (field_name, val)
+ return True
+
+
+def check_sequence(pipeline_config_path, features):
+ seq_att_groups = pipeline_config_path.model_config.seq_att_groups
+ if not seq_att_groups:
+ return
+ for seq_att_group in seq_att_groups:
+ seq_att_maps = seq_att_group.seq_att_map
+ if not seq_att_maps:
+ return
+ for seq_att_map in seq_att_maps:
+ assert len(seq_att_map.key) == len(seq_att_map.hist_seq), \
+ 'The size of hist_seq must equal to the size of key in one seq_att_map.'
+ size_list = []
+ for hist_seq in seq_att_map.hist_seq:
+ cur_seq_size = len(features[hist_seq].values)
+ size_list.append(cur_seq_size)
+ hist_seqs = ' '.join(seq_att_map.hist_seq)
+ assert len(set(size_list)) == 1, \
+ 'SequenceFeature Error: The size in [%s] should be consistent. Please check input: [%s].' % \
+ (hist_seqs, hist_seqs)
+
+
+def check_env_and_input_path(pipeline_config, input_path):
+ input_type = pipeline_config.data_config.input_type
+ input_type_name = DatasetConfig.InputType.Name(input_type)
+ ignore_input_list = [
+ DatasetConfig.InputType.TFRecordInput,
+ DatasetConfig.InputType.BatchTFRecordInput,
+ DatasetConfig.InputType.KafkaInput,
+ DatasetConfig.InputType.DataHubInput,
+ DatasetConfig.InputType.HiveInput,
+ DatasetConfig.InputType.DummyInput,
+ ]
+ if input_type in ignore_input_list:
+ return True
+ assert_info = 'Current InputType is %s, InputPath is %s. Please check InputType and InputPath.' % \
+ (input_type_name, input_path)
+ if input_type_name.startswith('Odps'):
+ # is on pai
+ for path in input_path.split(','):
+ if not path.startswith('odps://'):
+ assert False, assert_info
+ return True
+ else:
+ # local or ds
+ for path in input_path.split(','):
+ if path.startswith('odps://'):
+ assert False, assert_info
+ return True
diff --git a/easy_rec/python/utils/config_util.py b/easy_rec/python/utils/config_util.py
index bb65bd278..3c6f385e7 100644
--- a/easy_rec/python/utils/config_util.py
+++ b/easy_rec/python/utils/config_util.py
@@ -5,11 +5,15 @@
Such as Hyper parameter tuning or automatic feature expanding.
"""
+import datetime
import json
import logging
import os
import re
+import sys
+import numpy as np
+import six
import tensorflow as tf
from google.protobuf import json_format
from google.protobuf import text_format
@@ -17,11 +21,28 @@
from easy_rec.python.protos import pipeline_pb2
from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
+from easy_rec.python.utils import pai_util
+from easy_rec.python.utils.hive_utils import HiveUtils
if tf.__version__ >= '2.0':
tf = tf.compat.v1
+def search_pipeline_config(directory):
+ dir_list = []
+ for root, dirs, files in tf.gfile.Walk(directory):
+ for f in files:
+ _, ext = os.path.splitext(f)
+ if ext == '.config':
+ dir_list.append(os.path.join(root, f))
+ if len(dir_list) == 0:
+ raise ValueError('config is not found in directory %s' % directory)
+ elif len(dir_list) > 1:
+ raise ValueError('config saved model found in directory %s' % directory)
+ logging.info('use pipeline config: %s' % dir_list[0])
+ return dir_list[0]
+
+
def get_configs_from_pipeline_file(pipeline_config_path, auto_expand=True):
"""Reads config from a file containing pipeline_pb2.EasyRecConfig.
@@ -155,6 +176,19 @@ def save_pipeline_config(pipeline_config,
save_message(pipeline_config, pipeline_config_path)
+def _get_basic_types():
+ dtypes = [
+ bool, int, str, float,
+ type(u''), np.float16, np.float32, np.float64, np.char, np.byte, np.uint8,
+ np.int8, np.int16, np.uint16, np.uint32, np.int32, np.uint64, np.int64,
+ bool, str
+ ]
+ if six.PY2:
+ dtypes.append(long) # noqa: F821
+
+ return dtypes
+
+
def edit_config(pipeline_config, edit_config_json):
"""Update params specified by automl.
@@ -163,6 +197,22 @@ def edit_config(pipeline_config, edit_config_json):
edit_config_json: edit config json
"""
+ def _type_convert(proto, val, parent=None):
+ if type(val) != type(proto):
+ try:
+ if isinstance(proto, bool):
+ assert val in ['True', 'true', 'False', 'false']
+ val = val in ['True', 'true']
+ else:
+ val = type(proto)(val)
+ except ValueError as ex:
+ if parent is None:
+ raise ex
+ assert isinstance(proto, int)
+ val = getattr(parent, val)
+ assert isinstance(val, int)
+ return val
+
def _get_attr(obj, attr, only_last=False):
# only_last means we only return the last element in paths array
attr_toks = [x.strip() for x in attr.split('.') if x != '']
@@ -238,14 +288,9 @@ def _get_attr(obj, attr, only_last=False):
for tid, update_obj in enumerate(update_objs):
tmp, tmp_parent, _, _ = _get_attr(
update_obj, cond_key, only_last=True)
- if type(cond_val) != type(tmp):
- try:
- cond_val = type(tmp)(cond_val)
- except ValueError:
- # to support for enumerations like IdFeature
- assert isinstance(tmp, int)
- cond_val = getattr(tmp_parent, cond_val)
- assert isinstance(cond_val, int)
+
+ cond_val = _type_convert(tmp, cond_val, tmp_parent)
+
if op_func(tmp, cond_val):
obj_id = tid
paths.append((update_obj, update_objs, None, obj_id))
@@ -272,18 +317,15 @@ def _get_attr(obj, attr, only_last=False):
tmp_paths = _get_attr(update_obj, param_key)
# update a set of objs
for tmp_val, tmp_obj, tmp_name, tmp_id in tmp_paths:
- basic_types = [int, str, float, bool, type(u'')]
+ # list and dict are not basic types, must be handle separately
+ basic_types = _get_basic_types()
if type(tmp_val) in basic_types:
# simple type cast
- try:
- tmp_val = type(tmp_val)(param_val)
- if tmp_name is None:
- tmp_obj[tmp_id] = tmp_val
- else:
- setattr(tmp_obj, tmp_name, tmp_val)
- except ValueError:
- # for enumeration types
- text_format.Merge('%s:%s' % (tmp_name, param_val), tmp_obj)
+ tmp_val = _type_convert(tmp_val, param_val, tmp_obj)
+ if tmp_name is None:
+ tmp_obj[tmp_id] = tmp_val
+ else:
+ setattr(tmp_obj, tmp_name, tmp_val)
elif 'Scalar' in str(type(tmp_val)) and 'ClearField' in dir(tmp_obj):
tmp_obj.ClearField(tmp_name)
text_format.Parse('%s:%s' % (tmp_name, param_val), tmp_obj)
@@ -337,7 +379,11 @@ def add_boundaries_to_config(pipeline_config, tables):
for feature_config in feature_configs:
feature_name = feature_config.input_names[0]
if feature_name in feature_boundaries_info:
- feature_config.feature_type = feature_config.RawFeature
+ if feature_config.feature_type != feature_config.SequenceFeature:
+ logging.info(
+ 'feature = {0}, type = {1}, will turn to RawFeature.'.format(
+ feature_name, feature_config.feature_type))
+ feature_config.feature_type = feature_config.RawFeature
feature_config.hash_bucket_size = 0
feature_config.ClearField('boundaries')
feature_config.boundaries.extend(feature_boundaries_info[feature_name])
@@ -350,3 +396,226 @@ def get_compatible_feature_configs(pipeline_config):
else:
feature_configs = pipeline_config.feature_config.features
return feature_configs
+
+
+def parse_time(time_data):
+ """Parse time string to timestamp.
+
+ Args:
+ time_data: could be two formats: '%Y%m%d %H:%M:%S' or '%s'
+ Return:
+ timestamp: int
+ """
+ if isinstance(time_data, str) or isinstance(time_data, type(u'')):
+ if len(time_data) == 17:
+ return int(
+ datetime.datetime.strptime(time_data,
+ '%Y%m%d %H:%M:%S').strftime('%s'))
+ elif len(time_data) == 10:
+ return int(time_data)
+ else:
+ assert 'invalid time string: %s' % time_data
+ else:
+ return int(time_data)
+
+
+def search_fg_json(directory):
+ dir_list = []
+ for root, dirs, files in tf.gfile.Walk(directory):
+ for f in files:
+ _, ext = os.path.splitext(f)
+ if ext == '.json':
+ dir_list.append(os.path.join(root, f))
+ if len(dir_list) == 0:
+ return None
+ elif len(dir_list) > 1:
+ raise ValueError('fg.json found in directory %s' % directory)
+ logging.info('use fg.json: %s' % dir_list[0])
+ return dir_list[0]
+
+
+def get_input_name_from_fg_json(fg_json):
+ if not fg_json:
+ return []
+ input_names = []
+ for fea in fg_json['features']:
+ if 'feature_name' in fea:
+ if 'stub_type' in fea and fea['stub_type']:
+ continue
+ input_names.append(fea['feature_name'])
+ elif 'sequence_name' in fea:
+ sequence_name = fea['sequence_name']
+ for seq_fea in fea['features']:
+ assert 'feature_name' in seq_fea
+ if 'stub_type' in seq_fea and seq_fea['stub_type']:
+ continue
+ feature_name = seq_fea['feature_name']
+ input_names.append(sequence_name + '__' + feature_name)
+ return input_names
+
+
+def get_train_input_path(pipeline_config):
+ input_name = pipeline_config.WhichOneof('train_path')
+ return getattr(pipeline_config, input_name)
+
+
+def get_eval_input_path(pipeline_config):
+ input_name = pipeline_config.WhichOneof('eval_path')
+ return getattr(pipeline_config, input_name)
+
+
+def get_model_dir_path(pipeline_config):
+ model_dir = pipeline_config.model_dir
+ return model_dir
+
+
+def set_train_input_path(pipeline_config, train_input_path):
+ if pipeline_config.WhichOneof('train_path') == 'hive_train_input':
+ if isinstance(train_input_path, list):
+ assert len(
+ train_input_path
+ ) <= 1, 'only support one hive_train_input.table_name when hive input'
+ pipeline_config.hive_train_input.table_name = train_input_path[0]
+ else:
+ assert len(
+ train_input_path.split(',')
+ ) <= 1, 'only support one hive_train_input.table_name when hive input'
+ pipeline_config.hive_train_input.table_name = train_input_path
+ logging.info('update hive_train_input.table_name to %s' %
+ pipeline_config.hive_train_input.table_name)
+
+ elif pipeline_config.WhichOneof('train_path') == 'kafka_train_input':
+ if isinstance(train_input_path, list):
+ pipeline_config.kafka_train_input = ','.join(train_input_path)
+ else:
+ pipeline_config.kafka_train_input = train_input_path
+ elif pipeline_config.WhichOneof('train_path') == 'parquet_train_input':
+ if isinstance(train_input_path, list):
+ pipeline_config.parquet_train_input = ','.join(train_input_path)
+ else:
+ pipeline_config.parquet_train_input = train_input_path
+ else:
+ if isinstance(train_input_path, list):
+ pipeline_config.train_input_path = ','.join(train_input_path)
+ else:
+ pipeline_config.train_input_path = train_input_path
+ logging.info('update train_input_path to %s' %
+ pipeline_config.train_input_path)
+ return pipeline_config
+
+
+def set_eval_input_path(pipeline_config, eval_input_path):
+ if pipeline_config.WhichOneof('eval_path') == 'hive_eval_input':
+ if isinstance(eval_input_path, list):
+ assert len(
+ eval_input_path
+ ) <= 1, 'only support one hive_eval_input.table_name when hive input'
+ pipeline_config.hive_eval_input.table_name = eval_input_path[0]
+ else:
+ assert len(
+ eval_input_path.split(',')
+ ) <= 1, 'only support one hive_eval_input.table_name when hive input'
+ pipeline_config.hive_eval_input.table_name = eval_input_path
+ logging.info('update hive_eval_input.table_name to %s' %
+ pipeline_config.hive_eval_input.table_name)
+ elif pipeline_config.WhichOneof('eval_path') == 'parquet_eval_input':
+ if isinstance(eval_input_path, list):
+ pipeline_config.parquet_eval_input = ','.join(eval_input_path)
+ else:
+ pipeline_config.parquet_eval_input = eval_input_path
+ elif pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input':
+ if isinstance(eval_input_path, list):
+ pipeline_config.kafka_eval_input = ','.join(eval_input_path)
+ else:
+ pipeline_config.kafka_eval_input = eval_input_path
+ else:
+ if isinstance(eval_input_path, list):
+ pipeline_config.eval_input_path = ','.join(eval_input_path)
+ else:
+ pipeline_config.eval_input_path = eval_input_path
+ logging.info('update eval_input_path to %s' %
+ pipeline_config.eval_input_path)
+ return pipeline_config
+
+
+def process_data_path(data_path, hive_util):
+ if data_path.startswith('hdfs://'):
+ return data_path
+ if re.match(r'(.*)\.(.*)', data_path):
+ hdfs_path = hive_util.get_table_location(data_path)
+ assert hdfs_path, "Can't find hdfs path of %s" % data_path
+ logging.info('update %s to %s' % (data_path, hdfs_path))
+ return hdfs_path
+ return data_path
+
+
+def process_neg_sampler_data_path(pipeline_config):
+ # replace neg_sampler hive table => hdfs path
+ if pai_util.is_on_pai():
+ return
+ if not pipeline_config.data_config.HasField('sampler'):
+ return
+ # not using hive, so not need to process it
+ if pipeline_config.WhichOneof('train_path') != 'hive_train_input':
+ return
+ hive_util = HiveUtils(
+ data_config=pipeline_config.data_config,
+ hive_config=pipeline_config.hive_train_input)
+ sampler_type = pipeline_config.data_config.WhichOneof('sampler')
+ sampler_config = getattr(pipeline_config.data_config, sampler_type)
+ if hasattr(sampler_config, 'input_path'):
+ sampler_config.input_path = process_data_path(sampler_config.input_path,
+ hive_util)
+ if hasattr(sampler_config, 'user_input_path'):
+ sampler_config.user_input_path = process_data_path(
+ sampler_config.user_input_path, hive_util)
+ if hasattr(sampler_config, 'item_input_path'):
+ sampler_config.item_input_path = process_data_path(
+ sampler_config.item_input_path, hive_util)
+ if hasattr(sampler_config, 'pos_edge_input_path'):
+ sampler_config.pos_edge_input_path = process_data_path(
+ sampler_config.pos_edge_input_path, hive_util)
+ if hasattr(sampler_config, 'hard_neg_edge_input_path'):
+ sampler_config.hard_neg_edge_input_path = process_data_path(
+ sampler_config.hard_neg_edge_input_path, hive_util)
+
+
+def parse_extra_config_param(extra_args, edit_config_json):
+ arg_num = len(extra_args)
+ arg_id = 0
+ while arg_id < arg_num:
+ if extra_args[arg_id].startswith('--data_config.') or \
+ extra_args[arg_id].startswith('--train_config.') or \
+ extra_args[arg_id].startswith('--feature_config.') or \
+ extra_args[arg_id].startswith('--model_config.') or \
+ extra_args[arg_id].startswith('--export_config.') or \
+ extra_args[arg_id].startswith('--eval_config.'):
+ tmp_arg = extra_args[arg_id][2:]
+ if '=' in tmp_arg:
+ sep_pos = tmp_arg.find('=')
+ k = tmp_arg[:sep_pos]
+ v = tmp_arg[(sep_pos + 1):]
+ v = v.strip(' "\'')
+ edit_config_json[k] = v
+ arg_id += 1
+ elif arg_id + 1 < len(extra_args):
+ edit_config_json[tmp_arg] = extra_args[arg_id + 1].strip(' "\'')
+ arg_id += 2
+ else:
+ logging.error('missing value for arg: %s' % extra_args[arg_id])
+ sys.exit(1)
+ else:
+ logging.error('unknown args: %s' % extra_args[arg_id])
+ sys.exit(1)
+
+
+def process_multi_file_input_path(sampler_config_input_path):
+
+ if '*' in sampler_config_input_path:
+ input_path = ','.join(
+ file_path
+ for file_path in tf.gfile.Glob(sampler_config_input_path.split(',')))
+ else:
+ input_path = sampler_config_input_path
+
+ return input_path
diff --git a/easy_rec/python/utils/constant.py b/easy_rec/python/utils/constant.py
index 9df831a89..84366fbf5 100644
--- a/easy_rec/python/utils/constant.py
+++ b/easy_rec/python/utils/constant.py
@@ -1,4 +1,43 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+
SAMPLE_WEIGHT = 'SAMPLE_WEIGHT'
+
+DENSE_UPDATE_VARIABLES = 'DENSE_UPDATE_VARIABLES'
+
+SPARSE_UPDATE_VARIABLES = 'SPARSE_UPDATE_VARIABLES'
+ENABLE_AVX_STR_SPLIT = 'ENABLE_AVX_STR_SPLIT'
+
+# Environment variables to control whether to sort
+# feature columns by name, by default sort is not
+# enabled. The flag is set for backward compatibility.
+SORT_COL_BY_NAME = 'SORT_COL_BY_NAME'
+
+# arithmetic_optimization causes significant slow training
+# of a test case:
+# train_eval_test.TrainEvalTest.test_train_parquet
+NO_ARITHMETRIC_OPTI = 'NO_ARITHMETRIC_OPTI'
+
+# shard embedding var_name collection
+EmbeddingParallel = 'EmbeddingParallel'
+
+# environ variable to force embedding placement on cpu
+EmbeddingOnCPU = 'place_embedding_on_cpu'
+
+# clear ps counter queue at start to reuse existing ps
+ClearPsCounterQueue = 'CLEAR_PS_COUNTER_QUEUE'
+
+
+def enable_avx_str_split():
+ os.environ[ENABLE_AVX_STR_SPLIT] = '1'
+
+
+def has_avx_str_split():
+ return ENABLE_AVX_STR_SPLIT in os.environ and os.environ[
+ ENABLE_AVX_STR_SPLIT] == '1'
+
+
+def disable_avx_str_split():
+ del os.environ[ENABLE_AVX_STR_SPLIT]
diff --git a/easy_rec/python/utils/convert_rtp_fg.py b/easy_rec/python/utils/convert_rtp_fg.py
index a6e8e1199..d665fcd74 100644
--- a/easy_rec/python/utils/convert_rtp_fg.py
+++ b/easy_rec/python/utils/convert_rtp_fg.py
@@ -3,6 +3,7 @@
import json
import logging
import sys
+import traceback
import tensorflow as tf
from google.protobuf import text_format
@@ -17,6 +18,8 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
+MAX_HASH_BUCKET_SIZE = 9223372036854775807
+
def _gen_raw_config(feature, input_field, feature_config, is_multi,
curr_embed_dim):
@@ -32,7 +35,7 @@ def _gen_raw_config(feature, input_field, feature_config, is_multi,
feature_config.embedding_dim = curr_embed_dim
else:
feature_config.feature_type = feature_config.RawFeature
- input_field.default_val = feature.get('default_value', '0.0')
+ input_field.default_val = str(feature.get('default_value', '0.0'))
raw_input_dim = feature.get('value_dimension', 1)
if raw_input_dim > 1:
feature_config.raw_input_dim = raw_input_dim
@@ -42,6 +45,8 @@ def _gen_raw_config(feature, input_field, feature_config, is_multi,
if 'boundaries' in feature:
feature_config.boundaries.extend(feature['boundaries'])
feature_config.embedding_dim = curr_embed_dim
+ if 'normalizer_fn' in feature:
+ feature_config.normalizer_fn = feature['normalizer_fn']
def _set_hash_bucket(feature, feature_config, input_field):
@@ -55,6 +60,12 @@ def _set_hash_bucket(feature, feature_config, input_field):
'it is suggested to set max_partitions > 1 for large hash buckets[%s]'
% feature['feature_name'])
sys.exit(1)
+ if feature.get('filter_freq', -1) >= 0:
+ feature_config.ev_params.filter_freq = feature['filter_freq']
+ feature_config.hash_bucket_size = MAX_HASH_BUCKET_SIZE
+ if feature.get('steps_to_live', -1) >= 0:
+ feature_config.ev_params.steps_to_live = feature['steps_to_live']
+ feature_config.hash_bucket_size = MAX_HASH_BUCKET_SIZE
elif 'vocab_file' in feature:
feature_config.vocab_file = feature['vocab_file']
elif 'vocab_list' in feature:
@@ -72,7 +83,6 @@ def process_features(feature_type,
pipeline_config,
embedding_dim,
incol_separator,
- sub_value_type=None,
is_sequence=False):
feature_config = FeatureConfig()
feature_config.input_names.append(feature_name)
@@ -81,13 +91,28 @@ def process_features(feature_type,
input_field.input_name = feature_name
curr_embed_dim = feature.get('embedding_dimension',
feature.get('embedding_dim', embedding_dim))
- curr_combiner = feature.get('combiner', 'mean')
+ curr_combiner = feature.get('combiner', 'sum')
if feature.get('is_cache', False):
logging.info('will cache %s' % feature_name)
feature_config.is_cache = True
is_multi = feature.get('is_multi', False)
# is_seq = feature.get('is_seq', False)
- if feature_type == 'id_feature':
+ if is_sequence:
+ feature_config.feature_type = feature_config.SequenceFeature
+ feature_config.embedding_dim = curr_embed_dim
+ if feature_type == 'raw_feature':
+ feature_config.sub_feature_type = feature_config.RawFeature
+ input_field.default_val = feature.get('default_value', '0.0')
+ raw_input_dim = feature.get('value_dimension', 1)
+ if 'boundaries' in feature:
+ feature_config.boundaries.extend(feature['boundaries'])
+ if raw_input_dim > 1:
+ feature_config.raw_input_dim = raw_input_dim
+ else:
+ feature_config.sub_feature_type = feature_config.IdFeature
+ _set_hash_bucket(feature, feature_config, input_field)
+ feature_config.combiner = curr_combiner
+ elif feature_type == 'id_feature':
if is_multi:
feature_config.feature_type = feature_config.TagFeature
kv_separator = feature.get('kv_separator', None)
@@ -106,12 +131,9 @@ def process_features(feature_type,
_gen_raw_config(feature, input_field, feature_config, is_multi,
curr_embed_dim)
else:
- if is_multi:
- feature_config.feature_type = feature_config.TagFeature
- if feature_config.get('needWeighting', False):
- feature_config.kv_separator = ''
- else:
- feature_config.feature_type = feature_config.IdFeature
+ feature_config.feature_type = feature_config.TagFeature
+ if feature.get('needWeighting', False):
+ feature_config.kv_separator = ''
feature_config.embedding_dim = curr_embed_dim
_set_hash_bucket(feature, feature_config, input_field)
feature_config.combiner = curr_combiner
@@ -123,12 +145,9 @@ def process_features(feature_type,
if feature.get('matchType', '') == 'multihit':
is_multi = True
if need_discrete:
- if is_multi:
- feature_config.feature_type = feature_config.TagFeature
- if feature_config.get('needWeighting', False):
- feature_config.kv_separator = ''
- else:
- feature_config.feature_type = feature_config.IdFeature
+ feature_config.feature_type = feature_config.TagFeature
+ if feature.get('needWeighting', False):
+ feature_config.kv_separator = ''
feature_config.embedding_dim = curr_embed_dim
_set_hash_bucket(feature, feature_config, input_field)
feature_config.combiner = curr_combiner
@@ -154,8 +173,6 @@ def process_features(feature_type,
if 'shared_name' in feature:
feature_config.embedding_name = feature['shared_name']
# pipeline_config.feature_configs.append(feature_config)
- if is_sequence:
- feature_config.feature_type = feature_config.SequenceFeature
if pipeline_config.feature_configs:
pipeline_config.feature_configs.append(feature_config)
else:
@@ -229,9 +246,6 @@ def load_input_field_and_feature_config(rtp_fg,
for sub_feature in feature['features']:
sub_feature_type = sub_feature['feature_type']
sub_feature_name = sub_feature['feature_name']
- sub_value_type = None
- if 'value_type' in sub_feature:
- sub_value_type = sub_feature['value_type']
all_sub_feature_name = sequence_name + '_' + sub_feature_name
pipeline_config = process_features(
sub_feature_type,
@@ -240,11 +254,10 @@ def load_input_field_and_feature_config(rtp_fg,
pipeline_config,
embedding_dim,
incol_separator,
- sub_value_type,
is_sequence=True)
- except Exception as ex:
- print('Exception: %s %s' % (type(ex), str(ex)))
- print(feature)
+ except Exception:
+ logging.info('convert feature[%s] exception[%s]' %
+ (str(feature), traceback.format_exc()))
sys.exit(1)
return pipeline_config
diff --git a/easy_rec/python/utils/dag.py b/easy_rec/python/utils/dag.py
new file mode 100644
index 000000000..8d0d4b094
--- /dev/null
+++ b/easy_rec/python/utils/dag.py
@@ -0,0 +1,192 @@
+import logging
+from collections import OrderedDict
+from collections import defaultdict
+from copy import copy
+from copy import deepcopy
+
+
+class DAG(object):
+ """Directed acyclic graph implementation."""
+
+ def __init__(self):
+ """Construct a new DAG with no nodes or edges."""
+ self.reset_graph()
+
+ def add_node(self, node_name, graph=None):
+ """Add a node if it does not exist yet, or error out."""
+ if not graph:
+ graph = self.graph
+ if node_name in graph:
+ raise KeyError('node %s already exists' % node_name)
+ graph[node_name] = set()
+
+ def add_node_if_not_exists(self, node_name, graph=None):
+ try:
+ self.add_node(node_name, graph=graph)
+ except KeyError:
+ logging.info('node %s already exist' % node_name)
+
+ def delete_node(self, node_name, graph=None):
+ """Deletes this node and all edges referencing it."""
+ if not graph:
+ graph = self.graph
+ if node_name not in graph:
+ raise KeyError('node %s does not exist' % node_name)
+ graph.pop(node_name)
+
+ for node, edges in graph.items():
+ if node_name in edges:
+ edges.remove(node_name)
+
+ def delete_node_if_exists(self, node_name, graph=None):
+ try:
+ self.delete_node(node_name, graph=graph)
+ except KeyError:
+ logging.info('node %s does not exist' % node_name)
+
+ def add_edge(self, ind_node, dep_node, graph=None):
+ """Add an edge (dependency) between the specified nodes."""
+ if not graph:
+ graph = self.graph
+ if ind_node not in graph or dep_node not in graph:
+ raise KeyError('one or more nodes do not exist in graph')
+ test_graph = deepcopy(graph)
+ test_graph[ind_node].add(dep_node)
+ is_valid, message = self.validate(test_graph)
+ if is_valid:
+ graph[ind_node].add(dep_node)
+ else:
+ raise Exception('invalid DAG')
+
+ def delete_edge(self, ind_node, dep_node, graph=None):
+ """Delete an edge from the graph."""
+ if not graph:
+ graph = self.graph
+ if dep_node not in graph.get(ind_node, []):
+ raise KeyError('this edge does not exist in graph')
+ graph[ind_node].remove(dep_node)
+
+ def rename_edges(self, old_task_name, new_task_name, graph=None):
+ """Change references to a task in existing edges."""
+ if not graph:
+ graph = self.graph
+ for node, edges in graph.items():
+
+ if node == old_task_name:
+ graph[new_task_name] = copy(edges)
+ del graph[old_task_name]
+
+ else:
+ if old_task_name in edges:
+ edges.remove(old_task_name)
+ edges.add(new_task_name)
+
+ def predecessors(self, node, graph=None):
+ """Returns a list of all predecessors of the given node."""
+ if graph is None:
+ graph = self.graph
+ return [key for key in graph if node in graph[key]]
+
+ def downstream(self, node, graph=None):
+ """Returns a list of all nodes this node has edges towards."""
+ if graph is None:
+ graph = self.graph
+ if node not in graph:
+ raise KeyError('node %s is not in graph' % node)
+ return list(graph[node])
+
+ def all_downstreams(self, node, graph=None):
+ """Returns a list of all nodes ultimately downstream of the given node in the dependency graph.
+
+ in topological order.
+ """
+ if graph is None:
+ graph = self.graph
+ nodes = [node]
+ nodes_seen = set()
+ i = 0
+ while i < len(nodes):
+ downstreams = self.downstream(nodes[i], graph)
+ for downstream_node in downstreams:
+ if downstream_node not in nodes_seen:
+ nodes_seen.add(downstream_node)
+ nodes.append(downstream_node)
+ i += 1
+ return list(
+ filter(lambda node: node in nodes_seen,
+ self.topological_sort(graph=graph)))
+
+ def all_leaves(self, graph=None):
+ """Return a list of all leaves (nodes with no downstreams)."""
+ if graph is None:
+ graph = self.graph
+ return [key for key in graph if not graph[key]]
+
+ def from_dict(self, graph_dict):
+ """Reset the graph and build it from the passed dictionary.
+
+ The dictionary takes the form of {node_name: [directed edges]}
+ """
+ self.reset_graph()
+ for new_node in graph_dict.keys():
+ self.add_node(new_node)
+ for ind_node, dep_nodes in graph_dict.items():
+ if not isinstance(dep_nodes, list):
+ raise TypeError('dict values must be lists')
+ for dep_node in dep_nodes:
+ self.add_edge(ind_node, dep_node)
+
+ def reset_graph(self):
+ """Restore the graph to an empty state."""
+ self.graph = OrderedDict()
+
+ def independent_nodes(self, graph=None):
+ """Returns a list of all nodes in the graph with no dependencies."""
+ if graph is None:
+ graph = self.graph
+
+ dependent_nodes = set(
+ node for dependents in graph.values() for node in dependents)
+ return [node for node in graph.keys() if node not in dependent_nodes]
+
+ def validate(self, graph=None):
+ """Returns (Boolean, message) of whether DAG is valid."""
+ graph = graph if graph is not None else self.graph
+ if len(self.independent_nodes(graph)) == 0:
+ return False, 'no independent nodes detected'
+ try:
+ self.topological_sort(graph)
+ except ValueError:
+ return False, 'failed topological sort'
+ return True, 'valid'
+
+ def topological_sort(self, graph=None):
+ """Returns a topological ordering of the DAG.
+
+ Raises an error if this is not possible (graph is not valid).
+ """
+ if graph is None:
+ graph = self.graph
+ result = []
+ in_degree = defaultdict(lambda: 0)
+
+ for u in graph:
+ for v in graph[u]:
+ in_degree[v] += 1
+ ready = [node for node in graph if not in_degree[node]]
+
+ while ready:
+ u = ready.pop()
+ result.append(u)
+ for v in graph[u]:
+ in_degree[v] -= 1
+ if in_degree[v] == 0:
+ ready.append(v)
+
+ if len(result) == len(graph):
+ return result
+ else:
+ raise ValueError('graph is not acyclic')
+
+ def size(self):
+ return len(self.graph)
diff --git a/easy_rec/python/utils/distribution_utils.py b/easy_rec/python/utils/distribution_utils.py
index def9c08ac..22ae9ecf8 100644
--- a/easy_rec/python/utils/distribution_utils.py
+++ b/easy_rec/python/utils/distribution_utils.py
@@ -47,7 +47,7 @@ def set_tf_config_and_get_train_worker_num(
'set_tf_config_and_get_train_worker_num: distribute_strategy = %d' %
distribute_strategy)
worker_hosts = worker_hosts.split(',')
- ps_hosts = ps_hosts.split(',')
+ ps_hosts = ps_hosts.split(',') if ps_hosts else []
total_worker_num = len(worker_hosts)
train_worker_num = total_worker_num
@@ -96,7 +96,7 @@ def set_tf_config_and_get_train_worker_num(
cluster = {'chief': [worker_hosts[0]], 'worker': worker_hosts[2:]}
if distribute_strategy != DistributionStrategy.NoStrategy:
cluster['evaluator'] = [worker_hosts[1]]
- if len(ps_hosts) > 1:
+ if len(ps_hosts) > 0:
cluster['ps'] = ps_hosts
if job_name == 'ps':
os.environ['TF_CONFIG'] = json.dumps({
@@ -166,7 +166,7 @@ def set_tf_config_and_get_train_worker_num(
else:
cluster = {'chief': [worker_hosts[0]], 'worker': worker_hosts[1:]}
train_worker_num = len(worker_hosts)
- if len(ps_hosts) > 1:
+ if len(ps_hosts) > 0:
cluster['ps'] = ps_hosts
if job_name == 'ps':
os.environ['TF_CONFIG'] = json.dumps({
@@ -216,7 +216,8 @@ def set_tf_config_and_get_train_worker_num(
def set_tf_config_and_get_train_worker_num_on_ds():
- assert 'TF_CONFIG' in os.environ, "'TF_CONFIG' must in os.environ"
+ if 'TF_CONFIG' not in os.environ:
+ return
tf_config = json.loads(os.environ['TF_CONFIG'])
if 'cluster' in tf_config and 'ps' in tf_config['cluster'] and (
'evaluator' not in tf_config['cluster']):
@@ -241,3 +242,27 @@ def set_tf_config_and_get_train_worker_num_on_ds():
easyrec_tf_config['task']['type'] = tf_config['task']['type']
easyrec_tf_config['task']['index'] = tf_config['task']['index']
os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)
+
+
+def set_tf_config_and_get_distribute_eval_worker_num_on_ds():
+ assert 'TF_CONFIG' in os.environ, "'TF_CONFIG' must in os.environ"
+ tf_config = json.loads(os.environ['TF_CONFIG'])
+ if 'cluster' in tf_config and 'ps' in tf_config['cluster'] and (
+ 'evaluator' not in tf_config['cluster']):
+ easyrec_tf_config = dict()
+ easyrec_tf_config['cluster'] = {}
+ easyrec_tf_config['task'] = {}
+ easyrec_tf_config['cluster']['ps'] = tf_config['cluster']['ps']
+ easyrec_tf_config['cluster']['chief'] = [tf_config['cluster']['worker'][0]]
+ easyrec_tf_config['cluster']['worker'] = tf_config['cluster']['worker'][1:]
+
+ if tf_config['task']['type'] == 'worker' and tf_config['task']['index'] == 0:
+ easyrec_tf_config['task']['type'] = 'chief'
+ easyrec_tf_config['task']['index'] = 0
+ elif tf_config['task']['type'] == 'worker':
+ easyrec_tf_config['task']['type'] = tf_config['task']['type']
+ easyrec_tf_config['task']['index'] = tf_config['task']['index'] - 1
+ else:
+ easyrec_tf_config['task']['type'] = tf_config['task']['type']
+ easyrec_tf_config['task']['index'] = tf_config['task']['index']
+ os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)
diff --git a/easy_rec/python/utils/ds_util.py b/easy_rec/python/utils/ds_util.py
new file mode 100644
index 000000000..883e7bcee
--- /dev/null
+++ b/easy_rec/python/utils/ds_util.py
@@ -0,0 +1,65 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import logging
+import os
+import subprocess
+import traceback
+
+from tensorflow.python.platform import gfile
+
+from easy_rec.python.utils import estimator_utils
+
+
+def is_on_ds():
+ # IS_ON_PAI is set in train_eval
+ # which is the entry on DataScience platform
+ return 'IS_ON_DS' in os.environ
+
+
+def set_on_ds():
+ logging.info('set on ds environment variable: IS_ON_DS')
+ os.environ['IS_ON_DS'] = '1'
+
+
+def cache_ckpt(pipeline_config):
+ fine_tune_ckpt_path = pipeline_config.train_config.fine_tune_checkpoint
+ if not fine_tune_ckpt_path.startswith('hdfs://'):
+ # there is no need to cache if remote directories are mounted
+ return
+
+ if estimator_utils.is_ps() or estimator_utils.is_chief(
+ ) or estimator_utils.is_master():
+ tmpdir = os.path.dirname(fine_tune_ckpt_path.replace('hdfs://', ''))
+ tmpdir = os.path.join('/tmp/experiments', tmpdir)
+ logging.info('will cache fine_tune_ckpt to local dir: %s' % tmpdir)
+ if gfile.IsDirectory(tmpdir):
+ gfile.DeleteRecursively(tmpdir)
+ gfile.MakeDirs(tmpdir)
+ src_files = gfile.Glob(fine_tune_ckpt_path + '*')
+ src_files.sort()
+ data_files = [x for x in src_files if '.data-' in x]
+ meta_files = [x for x in src_files if '.data-' not in x]
+ if estimator_utils.is_ps():
+ _, _, ps_id = estimator_utils.parse_tf_config()
+ ps_id = (ps_id % len(data_files))
+ data_files = data_files[ps_id:] + data_files[:ps_id]
+ src_files = meta_files + data_files
+ else:
+ src_files = meta_files
+ for src_path in src_files:
+ _, file_name = os.path.split(src_path)
+ dst_path = os.path.join(tmpdir, os.path.basename(src_path))
+ logging.info('will copy %s to local path %s' % (src_path, dst_path))
+ try:
+ output = subprocess.check_output(
+ 'hadoop fs -get %s %s' % (src_path, dst_path), shell=True)
+ logging.info('copy succeed: %s' % output)
+ except Exception:
+ logging.warning('exception: %s' % traceback.format_exc())
+ ckpt_filename = os.path.basename(fine_tune_ckpt_path)
+ fine_tune_ckpt_path = os.path.join(tmpdir, ckpt_filename)
+ pipeline_config.train_config.fine_tune_checkpoint = fine_tune_ckpt_path
+ logging.info('will restore from %s' % fine_tune_ckpt_path)
+ else:
+ # workers do not have to create the restore graph
+ pipeline_config.train_config.ClearField('fine_tune_checkpoint')
diff --git a/easy_rec/python/utils/embedding_utils.py b/easy_rec/python/utils/embedding_utils.py
new file mode 100644
index 000000000..960513801
--- /dev/null
+++ b/easy_rec/python/utils/embedding_utils.py
@@ -0,0 +1,73 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+
+import tensorflow as tf
+from tensorflow.python.framework import ops
+
+from easy_rec.python.utils import constant
+from easy_rec.python.utils import proto_util
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def get_norm_name_to_ids():
+ """Get normalize embedding name(including kv variables) to ids.
+
+ Return:
+ normalized names to ids mapping.
+ """
+ norm_name_to_ids = {}
+ for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
+ norm_name, part_id = proto_util.get_norm_embed_name(x[0].name)
+ norm_name_to_ids[norm_name] = 1
+
+ for tid, t in enumerate(norm_name_to_ids.keys()):
+ norm_name_to_ids[t] = str(tid)
+ return norm_name_to_ids
+
+
+def get_sparse_name_to_ids():
+ """Get embedding variable(including kv variables) name to ids mapping.
+
+ Return:
+ variable names to ids mappping.
+ """
+ norm_name_to_ids = get_norm_name_to_ids()
+ name_to_ids = {}
+ for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
+ norm_name, _ = proto_util.get_norm_embed_name(x[0].name)
+ name_to_ids[x[0].name] = norm_name_to_ids[norm_name]
+ return name_to_ids
+
+
+def get_dense_name_to_ids():
+ dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
+ norm_name_to_ids = {}
+ for tid, x in enumerate(dense_train_vars):
+ norm_name_to_ids[x.op.name] = tid
+ return norm_name_to_ids
+
+
+embedding_parallel = False
+
+
+def set_embedding_parallel():
+ global embedding_parallel
+ embedding_parallel = True
+
+
+def is_embedding_parallel():
+ global embedding_parallel
+ return embedding_parallel
+
+
+def sort_col_by_name():
+ return constant.SORT_COL_BY_NAME in os.environ
+
+
+def embedding_on_cpu():
+ place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
+ return place_on_cpu
diff --git a/easy_rec/python/utils/estimator_utils.py b/easy_rec/python/utils/estimator_utils.py
index a12880113..ea15063d1 100644
--- a/easy_rec/python/utils/estimator_utils.py
+++ b/easy_rec/python/utils/estimator_utils.py
@@ -8,25 +8,64 @@
import logging
import os
import re
+import sys
import time
-from distutils.version import LooseVersion
import numpy as np
import six
import tensorflow as tf
from tensorflow.core.framework.summary_pb2 import Summary
+from tensorflow.python.client import device_lib
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import session_run_hook
from tensorflow.python.training.summary_io import SummaryWriterCache
+from easy_rec.python.ops.incr_record import get_sparse_indices
+from easy_rec.python.ops.incr_record import kv_resource_incr_gather
+from easy_rec.python.utils import constant
+from easy_rec.python.utils import embedding_utils
from easy_rec.python.utils import shape_utils
+from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer # NOQA
+
+try:
+ import horovod.tensorflow as hvd
+except Exception:
+ hvd = None
+
+try:
+ from sparse_operation_kit import experiment as sok
+except Exception:
+ sok = None
+
+try:
+ from kafka import KafkaProducer, KafkaAdminClient
+ from kafka.admin import NewTopic
+except ImportError as ex:
+ logging.warning('kafka-python is not installed: %s' % str(ex))
+
if tf.__version__ >= '2.0':
tf = tf.compat.v1
- SessionRunHook = tf.estimator.SessionRunHook
- CheckpointSaverHook = tf.estimator.CheckpointSaverHook
-else:
- SessionRunHook = tf.train.SessionRunHook
- CheckpointSaverHook = tf.train.CheckpointSaverHook
+SessionRunHook = session_run_hook.SessionRunHook
+CheckpointSaverHook = basic_session_run_hooks.CheckpointSaverHook
+
+
+def tensor_log_format_func(tensor_dict):
+ prefix = ''
+ if 'step' in tensor_dict:
+ prefix = 'global step %s: ' % tensor_dict['step']
+ stats = []
+ for k in tensor_dict:
+ if k == 'step':
+ continue
+ tensor_value = tensor_dict[k]
+ stats.append('%s = %s' % (k, tensor_value))
+ return prefix + ', '.join(stats)
class ExitBarrierHook(SessionRunHook):
@@ -111,10 +150,10 @@ def _check_flag_file(is_chief, flag_file):
logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
(is_chief, flag_file))
if is_chief:
- with tf.gfile.GFile(flag_file, 'w') as fout:
+ with gfile.GFile(flag_file, 'w') as fout:
fout.write('atexit time: %d' % int(time.time()))
else:
- while not tf.gfile.Exists(flag_file):
+ while not gfile.Exists(flag_file):
time.sleep(1)
from atexit import register
@@ -208,10 +247,10 @@ def _check_flag_file(is_chief, flag_file):
logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
(is_chief, flag_file))
if is_chief:
- with tf.gfile.GFile(flag_file, 'w') as fout:
+ with gfile.GFile(flag_file, 'w') as fout:
fout.write('atexit time: %d' % int(time.time()))
else:
- while not tf.gfile.Exists(flag_file):
+ while not gfile.Exists(flag_file):
time.sleep(1)
from atexit import register
@@ -235,7 +274,7 @@ def __init__(self, num_steps, filename, is_chief):
self._num_steps = num_steps
self._is_chief = is_chief
if self._is_chief:
- self._progress_file = tf.gfile.GFile(filename, 'w')
+ self._progress_file = gfile.GFile(filename, 'w')
self._progress_file.write('0.00\n')
self._progress_interval = 0.01 # 1%
self._last_progress_cnt = 0
@@ -276,7 +315,9 @@ def __init__(self,
checkpoint_basename='model.ckpt',
scaffold=None,
listeners=None,
- write_graph=True):
+ write_graph=True,
+ data_offset_var=None,
+ increment_save_config=None):
"""Initializes a `CheckpointSaverHook`.
Args:
@@ -290,6 +331,8 @@ def __init__(self,
Used for callbacks that run immediately before or after this hook saves
the checkpoint.
write_graph: whether to save graph.pbtxt.
+ data_offset_var: data offset variable.
+ increment_save_config: parameters for saving increment checkpoints.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
@@ -303,14 +346,124 @@ def __init__(self,
checkpoint_basename=checkpoint_basename,
scaffold=scaffold,
listeners=listeners)
+ self._cuda_profile_start = 0
+ self._cuda_profile_stop = 0
+ self._steps_per_run = 1
self._write_graph = write_graph
+ self._data_offset_var = data_offset_var
+
+ self._task_idx, self._task_num = get_task_index_and_num()
+
+ if increment_save_config is not None:
+ self._kafka_timeout_ms = os.environ.get('KAFKA_TIMEOUT', 600) * 1000
+ logging.info('KAFKA_TIMEOUT: %dms' % self._kafka_timeout_ms)
+ self._kafka_max_req_size = os.environ.get('KAFKA_MAX_REQ_SIZE',
+ 1024 * 1024 * 64)
+ logging.info('KAFKA_MAX_REQ_SIZE: %d' % self._kafka_max_req_size)
+ self._kafka_max_msg_size = os.environ.get('KAFKA_MAX_MSG_SIZE',
+ 1024 * 1024 * 1024)
+ logging.info('KAFKA_MAX_MSG_SIZE: %d' % self._kafka_max_msg_size)
+
+ self._dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
+ self._sparse_name_to_ids = embedding_utils.get_sparse_name_to_ids()
+
+ with gfile.GFile(
+ os.path.join(checkpoint_dir, constant.DENSE_UPDATE_VARIABLES),
+ 'w') as fout:
+ json.dump(self._dense_name_to_ids, fout, indent=2)
+
+ save_secs = increment_save_config.dense_save_secs
+ save_steps = increment_save_config.dense_save_steps
+ self._dense_timer = SecondOrStepTimer(
+ every_secs=save_secs if save_secs > 0 else None,
+ every_steps=save_steps if save_steps > 0 else None)
+ save_secs = increment_save_config.sparse_save_secs
+ save_steps = increment_save_config.sparse_save_steps
+ self._sparse_timer = SecondOrStepTimer(
+ every_secs=save_secs if save_secs > 0 else None,
+ every_steps=save_steps if save_steps > 0 else None)
+
+ self._dense_timer.update_last_triggered_step(0)
+ self._sparse_timer.update_last_triggered_step(0)
+
+ self._sparse_indices = []
+ self._sparse_values = []
+ sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
+ for sparse_var, indice_dtype in sparse_train_vars:
+ with ops.control_dependencies([tf.train.get_global_step()]):
+ with ops.colocate_with(sparse_var):
+ sparse_indice = get_sparse_indices(
+ var_name=sparse_var.op.name, ktype=indice_dtype)
+ # sparse_indice = sparse_indice.global_indices
+ self._sparse_indices.append(sparse_indice)
+ if 'EmbeddingVariable' in str(type(sparse_var)):
+ self._sparse_values.append(
+ kv_resource_incr_gather(
+ sparse_var._handle, sparse_indice,
+ np.zeros(sparse_var.shape.as_list(), dtype=np.float32)))
+ # sparse_var.sparse_read(sparse_indice))
+ else:
+ self._sparse_values.append(
+ array_ops.gather(sparse_var, sparse_indice))
+
+ self._kafka_producer = None
+ self._incr_save_dir = None
+ if increment_save_config.HasField('kafka'):
+ self._topic = increment_save_config.kafka.topic
+ logging.info('increment save topic: %s' % self._topic)
+
+ admin_clt = KafkaAdminClient(
+ bootstrap_servers=increment_save_config.kafka.server,
+ request_timeout_ms=self._kafka_timeout_ms,
+ api_version_auto_timeout_ms=self._kafka_timeout_ms)
+ if self._topic not in admin_clt.list_topics():
+ admin_clt.create_topics(
+ new_topics=[
+ NewTopic(
+ name=self._topic,
+ num_partitions=1,
+ replication_factor=1,
+ topic_configs={
+ 'max.message.bytes': self._kafka_max_msg_size
+ })
+ ],
+ validate_only=False)
+ logging.info('create increment save topic: %s' % self._topic)
+ admin_clt.close()
+
+ servers = increment_save_config.kafka.server.split(',')
+ self._kafka_producer = KafkaProducer(
+ bootstrap_servers=servers,
+ max_request_size=self._kafka_max_req_size,
+ api_version_auto_timeout_ms=self._kafka_timeout_ms,
+ request_timeout_ms=self._kafka_timeout_ms)
+ elif increment_save_config.HasField('fs'):
+ fs = increment_save_config.fs
+ if fs.relative:
+ self._incr_save_dir = os.path.join(checkpoint_dir, fs.incr_save_dir)
+ else:
+ self._incr_save_dir = fs.incr_save_dir
+ if not self._incr_save_dir.endswith('/'):
+ self._incr_save_dir += '/'
+ if not gfile.IsDirectory(self._incr_save_dir):
+ gfile.MakeDirs(self._incr_save_dir)
+ elif increment_save_config.HasField('datahub'):
+ raise NotImplementedError('datahub increment saving is in development.')
+ else:
+ raise ValueError(
+ 'incr_update not specified correctly, must be oneof: kafka,fs')
+
+ self._debug_save_update = increment_save_config.debug_save_update
+ else:
+ self._dense_timer = None
+ self._sparse_timer = None
def after_create_session(self, session, coord):
global_step = session.run(self._global_step_tensor)
if self._write_graph:
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
- # add variables in begin. Graph is finalized after all begin calls.
+ # add variables at begin. Graph is finalized after all begin calls.
tf.train.write_graph(tf.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, 'graph.pbtxt')
saver_def = self._get_saver().saver_def if self._get_saver() else None
@@ -319,16 +472,149 @@ def after_create_session(self, session, coord):
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
- # when tf version > 1.10.0, we use defaut training strategy, which saves ckpt
- # at first train step
- if LooseVersion(tf.__version__) >= LooseVersion('1.10.0'):
- # The checkpoint saved here is the state at step "global_step".
- self._save(session, global_step)
+
+ # save for step 0
+ self._save(session, global_step)
+
self._timer.update_last_triggered_step(global_step)
def before_run(self, run_context): # pylint: disable=unused-argument
return tf.train.SessionRunArgs(self._global_step_tensor)
+ def _send_dense(self, global_step, session):
+ dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
+ dense_train_vals = session.run(dense_train_vars)
+ logging.info('global_step=%d, increment save dense variables' % global_step)
+
+ # build msg header
+ msg_num = len(dense_train_vals)
+ msg_ids = [self._dense_name_to_ids[x.op.name] for x in dense_train_vars]
+ # 0 mean dense update message
+ msg_header = [0, msg_num, global_step]
+ for msg_id, x in zip(msg_ids, dense_train_vals):
+ msg_header.append(msg_id)
+ msg_header.append(x.size)
+
+ # build msg body
+ bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
+ for x in dense_train_vals:
+ bytes_buf += x.tobytes()
+
+ if self._kafka_producer is not None:
+ msg_key = 'dense_update_%d' % global_step
+ send_res = self._kafka_producer.send(
+ self._topic, bytes_buf, key=msg_key.encode('utf-8'))
+ logging.info('kafka send dense: %d exception: %s' %
+ (global_step, send_res.exception))
+
+ if self._incr_save_dir is not None:
+ save_path = os.path.join(self._incr_save_dir,
+ 'dense_update_%d' % global_step)
+ with gfile.GFile(save_path, 'wb') as fout:
+ fout.write(bytes_buf)
+ save_flag = save_path + '.done'
+ with gfile.GFile(save_flag, 'w') as fout:
+ fout.write('dense_update_%d' % global_step)
+
+ if self._debug_save_update and self._incr_save_dir is None:
+ base_dir, _ = os.path.split(self._save_path)
+ incr_save_dir = os.path.join(base_dir, 'incr_save/')
+ if not gfile.Exists(incr_save_dir):
+ gfile.MakeDirs(incr_save_dir)
+ save_path = os.path.join(incr_save_dir, 'dense_update_%d' % global_step)
+ with gfile.GFile(save_path, 'wb') as fout:
+ fout.write(bytes_buf)
+
+ logging.info(
+ 'global_step=%d, increment update dense variables, msg_num=%d' %
+ (global_step, msg_num))
+
+ def _send_sparse(self, global_step, session):
+ sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
+ sparse_res = session.run(self._sparse_indices + self._sparse_values)
+ msg_num = int(len(sparse_res) / 2)
+
+ sel_ids = [i for i in range(msg_num) if len(sparse_res[i]) > 0]
+ sparse_key_res = [sparse_res[i] for i in sel_ids]
+ sparse_val_res = [sparse_res[i + msg_num] for i in sel_ids]
+ sparse_train_vars = [sparse_train_vars[i][0] for i in sel_ids]
+
+ sel_embed_ids = [
+ self._sparse_name_to_ids[x.name] for x in sparse_train_vars
+ ]
+
+ msg_num = len(sel_ids)
+
+ if msg_num == 0:
+ logging.warning('there are no sparse updates, will skip this send: %d' %
+ global_step)
+ return
+
+ # build msg header
+ # 1 means sparse update messages
+ msg_header = [1, msg_num, global_step]
+ for tmp_id, tmp_key in zip(sel_embed_ids, sparse_key_res):
+ msg_header.append(tmp_id)
+ msg_header.append(len(tmp_key))
+ bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
+
+ # build msg body
+ for tmp_id, tmp_key, tmp_val, tmp_var in zip(sel_embed_ids, sparse_key_res,
+ sparse_val_res,
+ sparse_train_vars):
+ # for non kv embedding variables, add partition offset to tmp_key
+ if 'EmbeddingVariable' not in str(type(tmp_var)):
+ if tmp_var._save_slice_info is not None:
+ tmp_key += tmp_var._save_slice_info.var_offset[0]
+ bytes_buf += tmp_key.tobytes()
+ bytes_buf += tmp_val.tobytes()
+ if self._kafka_producer is not None:
+ msg_key = 'sparse_update_%d' % global_step
+ send_res = self._kafka_producer.send(
+ self._topic, bytes_buf, key=msg_key.encode('utf-8'))
+ logging.info('kafka send sparse: %d %s' %
+ (global_step, send_res.exception))
+
+ if self._incr_save_dir is not None:
+ save_path = os.path.join(self._incr_save_dir,
+ 'sparse_update_%d' % global_step)
+ with gfile.GFile(save_path, 'wb') as fout:
+ fout.write(bytes_buf)
+ save_flag = save_path + '.done'
+ with gfile.GFile(save_flag, 'w') as fout:
+ fout.write('sparse_update_%d' % global_step)
+
+ if self._debug_save_update and self._incr_save_dir is None:
+ base_dir, _ = os.path.split(self._save_path)
+ incr_save_dir = os.path.join(base_dir, 'incr_save/')
+ if not gfile.Exists(incr_save_dir):
+ gfile.MakeDirs(incr_save_dir)
+ save_path = os.path.join(incr_save_dir, 'sparse_update_%d' % global_step)
+ with gfile.GFile(save_path, 'wb') as fout:
+ fout.write(bytes_buf)
+
+ logging.info(
+ 'global_step=%d, increment update sparse variables, msg_num=%d, msg_size=%d'
+ % (global_step, msg_num, len(bytes_buf)))
+
+ def after_run(self, run_context, run_values):
+ super(CheckpointSaverHook, self).after_run(run_context, run_values)
+ stale_global_step = run_values.results
+ global_step = -1
+ if self._dense_timer is not None and self._dense_timer.should_trigger_for_step(
+ stale_global_step + self._steps_per_run):
+ global_step = run_context.session.run(self._global_step_tensor)
+ self._dense_timer.update_last_triggered_step(global_step)
+ self._send_dense(global_step, run_context.session)
+
+ if self._sparse_timer is not None and self._sparse_timer.should_trigger_for_step(
+ stale_global_step + self._steps_per_run):
+ if global_step < 0:
+ global_step = run_context.session.run(self._global_step_tensor)
+
+ self._sparse_timer.update_last_triggered_step(global_step)
+ self._send_sparse(global_step, run_context.session)
+
def _save(self, session, step):
"""Saves the latest checkpoint, returns should_stop."""
logging.info('Saving checkpoints for %d into %s.', step, self._save_path)
@@ -336,12 +622,22 @@ def _save(self, session, step):
for l in self._listeners: # noqa: E741
l.before_save(session, step)
+ if self._data_offset_var is not None:
+ save_data_offset = session.run(self._data_offset_var)
+ data_offset_json = {}
+ for x in save_data_offset:
+ if x:
+ data_offset_json.update(json.loads(x))
+ save_offset_path = os.path.join(self._checkpoint_dir,
+ 'model.ckpt-%d.offset' % step)
+ with gfile.GFile(save_offset_path, 'w') as fout:
+ json.dump(data_offset_json, fout)
+
self._get_saver().save(
session,
self._save_path,
global_step=step,
write_meta_graph=self._write_graph)
- save_dir, save_name = os.path.split(self._save_path)
self._summary_writer.add_session_log(
tf.SessionLog(
@@ -357,6 +653,18 @@ def _save(self, session, step):
should_stop = True
return should_stop
+ def end(self, session):
+ global_step = session.run(self._global_step_tensor)
+ super(CheckpointSaverHook, self).end(session)
+ if self._dense_timer is not None and \
+ global_step != self._dense_timer.last_triggered_step():
+ self._dense_timer.update_last_triggered_step(global_step)
+ self._send_dense(global_step, session)
+ if self._sparse_timer is not None and \
+ global_step != self._sparse_timer.last_triggered_step():
+ self._sparse_timer.update_last_triggered_step(global_step)
+ self._send_sparse(global_step, session)
+
class NumpyCheckpointRestoreHook(SessionRunHook):
"""Restore variable from numpy checkpoint."""
@@ -395,7 +703,7 @@ def begin(self):
vars_not_inited[var_name] = ','.join([str(s) for s in var_shape])
self._restore_op = tf.group(assign_ops)
- with tf.gfile.GFile(self._ckpt_path[:-4] + '_not_inited.txt', 'w') as f:
+ with gfile.GFile(self._ckpt_path[:-4] + '_not_inited.txt', 'w') as f:
for var_name in sorted(vars_not_inited.keys()):
f.write('%s:%s\n' % (var_name, vars_not_inited[var_name]))
assert not has_shape_unmatch, 'exist variable shape not match, restore failed'
@@ -516,7 +824,7 @@ def end(self, session):
eval_result_file = os.path.join(self._output_dir,
'online_eval_result.txt-%s' % global_step)
logging.info('Saving online eval result to file %s' % eval_result_file)
- with tf.gfile.GFile(eval_result_file, 'w') as ofile:
+ with gfile.GFile(eval_result_file, 'w') as ofile:
result_to_write = {}
for key in sorted(metric_value_dict):
# convert numpy float to python float
@@ -540,6 +848,8 @@ def parse_tf_config():
def get_task_index_and_num():
+ if hvd is not None and 'HOROVOD_RANK' in os.environ:
+ return hvd.rank(), hvd.size()
cluster, task_type, task_index = parse_tf_config()
if 'worker' not in cluster:
return 0, 1
@@ -571,6 +881,31 @@ def get_ckpt_version(ckpt_path):
return int(toks[-1])
+def get_latest_checkpoint_from_checkpoint_path(checkpoint_path,
+ ignore_ckpt_error):
+ ckpt_path = None
+ if checkpoint_path.endswith('/') or gfile.IsDirectory(checkpoint_path + '/'):
+ checkpoint_dir = checkpoint_path
+ if not checkpoint_dir.endswith('/'):
+ checkpoint_dir = checkpoint_dir + '/'
+ if gfile.Exists(checkpoint_dir):
+ ckpt_path = latest_checkpoint(checkpoint_dir)
+ if ckpt_path:
+ logging.info(
+ 'fine_tune_checkpoint is directory, will use the latest checkpoint: %s'
+ % ckpt_path)
+ else:
+ assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
+ else:
+ assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
+ elif gfile.Exists(checkpoint_path + '.index'):
+ ckpt_path = checkpoint_path
+ logging.info('update fine_tune_checkpoint to %s' % checkpoint_path)
+ else:
+ assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
+ return ckpt_path
+
+
def latest_checkpoint(model_dir):
"""Find lastest checkpoint under a directory.
@@ -580,14 +915,26 @@ def latest_checkpoint(model_dir):
Return:
model_path: xx/model.ckpt-2000
"""
- ckpt_metas = tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.meta'))
- if len(ckpt_metas) == 0:
+ try:
+ ckpt_metas = gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))
+
+ if len(ckpt_metas) == 0:
+ return None
+
+ if len(ckpt_metas) > 1:
+ ckpt_metas.sort(key=lambda x: get_ckpt_version(x))
+ ckpt_path = os.path.splitext(ckpt_metas[-1])[0]
+ return ckpt_path
+ except errors_impl.NotFoundError:
return None
- if len(ckpt_metas) > 1:
- ckpt_metas.sort(key=lambda x: get_ckpt_version(x))
- ckpt_path = os.path.splitext(ckpt_metas[-1])[0]
- return ckpt_path
+
+def get_trained_steps(model_dir):
+ ckpt_path = latest_checkpoint(model_dir)
+ if ckpt_path is not None:
+ return int(ckpt_path.split('-')[-1])
+ else:
+ return 0
def master_to_chief():
@@ -620,9 +967,70 @@ def chief_to_master():
return None
+def is_ps():
+ if 'TF_CONFIG' in os.environ:
+ tf_config = json.loads(os.environ['TF_CONFIG'])
+ if 'task' in tf_config:
+ return tf_config['task']['type'] == 'ps'
+ return False
+
+
def is_chief():
+ if has_hvd():
+ return hvd.rank() == 0
+
if 'TF_CONFIG' in os.environ:
tf_config = json.loads(os.environ['TF_CONFIG'])
if 'task' in tf_config:
return tf_config['task']['type'] in ['chief', 'master']
return True
+
+
+def is_master():
+ if 'TF_CONFIG' in os.environ:
+ tf_config = json.loads(os.environ['TF_CONFIG'])
+ if 'task' in tf_config:
+ return tf_config['task']['type'] == 'master'
+ return True
+
+
+def is_evaluator():
+ if 'TF_CONFIG' in os.environ:
+ tf_config = json.loads(os.environ['TF_CONFIG'])
+ if 'task' in tf_config:
+ return tf_config['task']['type'] == 'evaluator'
+ return False
+
+
+def has_hvd():
+ return hvd is not None and 'HOROVOD_RANK' in os.environ
+
+
+def has_sok():
+ return sok is not None and 'ENABLE_SOK' in os.environ
+
+
+def init_hvd():
+ if hvd is None:
+ logging.error(
+ 'horovod is not installed: HOROVOD_WITH_TENSORFLOW=1 pip install horovod'
+ )
+ sys.exit(1)
+
+ hvd.init()
+ os.environ['HOROVOD_RANK'] = str(hvd.rank())
+
+
+def init_sok():
+ try:
+ sok.init()
+ os.environ['ENABLE_SOK'] = '1'
+ return True
+ except Exception:
+ logging.warning('sok is not installed')
+ return False
+
+
+def get_available_gpus():
+ local_device_protos = device_lib.list_local_devices()
+ return [x.name for x in local_device_protos if x.device_type == 'GPU']
diff --git a/easy_rec/python/utils/export_big_model.py b/easy_rec/python/utils/export_big_model.py
index 248d6d021..243847a5f 100644
--- a/easy_rec/python/utils/export_big_model.py
+++ b/easy_rec/python/utils/export_big_model.py
@@ -1,12 +1,14 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
import logging
import os
import time
import numpy as np
import tensorflow as tf
+from google.protobuf import json_format
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import ops
from tensorflow.python.ops.variables import global_variables
@@ -17,12 +19,15 @@
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training.device_setter import replica_device_setter
from tensorflow.python.training.monitored_session import ChiefSessionCreator
+from tensorflow.python.training.monitored_session import Scaffold
from tensorflow.python.training.saver import export_meta_graph
import easy_rec
+from easy_rec.python.utils import constant
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import io_util
from easy_rec.python.utils import proto_util
+from easy_rec.python.utils.meta_graph_editor import EMBEDDING_INITIALIZERS
from easy_rec.python.utils.meta_graph_editor import MetaGraphEditor
if tf.__version__ >= '2.0':
@@ -32,6 +37,8 @@
ConfigProto = config_pb2.ConfigProto
GPUOptions = config_pb2.GPUOptions
+INCR_UPDATE_SIGNATURE_KEY = 'incr_update_sig'
+
def export_big_model(export_dir, pipeline_config, redis_params,
serving_input_fn, estimator, checkpoint_path, verbose):
@@ -57,7 +64,8 @@ def export_big_model(export_dir, pipeline_config, redis_params,
logging.warning('load libwrite_sparse_kv.so failed: %s' % str(ex))
sparse_kv_module = None
if not checkpoint_path:
- checkpoint_path = tf.train.latest_checkpoint(pipeline_config.model_dir)
+ checkpoint_path = estimator_utils.latest_checkpoint(
+ pipeline_config.model_dir)
logging.info('checkpoint_path = %s' % checkpoint_path)
server = None
@@ -247,7 +255,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,
with GFile(embed_name_to_id_file, 'w') as fout:
for tmp_norm_name in norm_name_to_ids:
fout.write('%s\t%s\n' % (tmp_norm_name, norm_name_to_ids[tmp_norm_name]))
- tf.add_to_collection(
+ ops.add_to_collection(
tf.GraphKeys.ASSET_FILEPATHS,
tf.constant(
embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))
@@ -255,6 +263,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,
export_dir = os.path.join(export_dir,
meta_graph_def.meta_info_def.meta_graph_version)
export_dir = io_util.fix_oss_dir(export_dir)
+ logging.info('export_dir=%s' % export_dir)
if Exists(export_dir):
logging.info('will delete old dir: %s' % export_dir)
DeleteRecursively(export_dir)
@@ -282,6 +291,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,
saver = tf.train.Saver()
with tf.Session(target=server.target if server else '') as sess:
saver.restore(sess, checkpoint_path)
+
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
@@ -295,7 +305,7 @@ def export_big_model(export_dir, pipeline_config, redis_params,
# remove temporary files
Remove(embed_name_to_id_file)
- return
+ return export_dir
def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
@@ -308,7 +318,8 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
kv_module = tf.load_op_library(write_kv_lib_path)
if not checkpoint_path:
- checkpoint_path = tf.train.latest_checkpoint(pipeline_config.model_dir)
+ checkpoint_path = estimator_utils.latest_checkpoint(
+ pipeline_config.model_dir)
logging.info('checkpoint_path = %s' % checkpoint_path)
server = None
@@ -489,6 +500,7 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
oss_timeout=oss_params.get('oss_timeout', 1500),
meta_graph_def=meta_graph_def,
norm_name_to_ids=norm_name_to_ids,
+ incr_update_params=oss_params.get('incr_update', None),
debug_dir=export_dir if verbose else '')
meta_graph_editor.edit_graph_for_oss()
tf.reset_default_graph()
@@ -500,14 +512,49 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
with GFile(embed_name_to_id_file, 'w') as fout:
for tmp_norm_name in norm_name_to_ids:
fout.write('%s\t%s\n' % (tmp_norm_name, norm_name_to_ids[tmp_norm_name]))
- tf.add_to_collection(
- tf.GraphKeys.ASSET_FILEPATHS,
+ ops.add_to_collection(
+ ops.GraphKeys.ASSET_FILEPATHS,
tf.constant(
embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))
+ if 'incr_update' in oss_params:
+ dense_train_vars_path = os.path.join(
+ os.path.dirname(checkpoint_path), constant.DENSE_UPDATE_VARIABLES)
+ ops.add_to_collection(
+ ops.GraphKeys.ASSET_FILEPATHS,
+ tf.constant(
+ dense_train_vars_path,
+ dtype=tf.string,
+ name=constant.DENSE_UPDATE_VARIABLES))
+
+ asset_file = 'incr_update.txt'
+ asset_file_path = os.path.join(export_dir, asset_file)
+ with GFile(asset_file_path, 'w') as fout:
+ incr_update = oss_params['incr_update']
+ incr_update_json = {}
+ if 'kafka' in incr_update:
+ incr_update_json['storage'] = 'kafka'
+ incr_update_json['kafka'] = json.loads(
+ json_format.MessageToJson(
+ incr_update['kafka'], preserving_proto_field_name=True))
+ elif 'datahub' in incr_update:
+ incr_update_json['storage'] = 'datahub'
+ incr_update_json['datahub'] = json.loads(
+ json_format.MessageToJson(
+ incr_update['datahub'], preserving_proto_field_name=True))
+ elif 'fs' in incr_update:
+ incr_update_json['storage'] = 'fs'
+ incr_update_json['fs'] = {'incr_save_dir': incr_update['fs'].mount_path}
+ json.dump(incr_update_json, fout, indent=2)
+
+ ops.add_to_collection(
+ ops.GraphKeys.ASSET_FILEPATHS,
+ tf.constant(asset_file_path, dtype=tf.string, name=asset_file))
+
export_dir = os.path.join(export_dir,
meta_graph_def.meta_info_def.meta_graph_version)
export_dir = io_util.fix_oss_dir(export_dir)
+ logging.info('export_dir=%s' % export_dir)
if Exists(export_dir):
logging.info('will delete old dir: %s' % export_dir)
DeleteRecursively(export_dir)
@@ -518,6 +565,7 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
tmp = graph.get_tensor_by_name(inputs[tmp_key].name)
tensor_info_inputs[tmp_key] = \
tf.saved_model.utils.build_tensor_info(tmp)
+
tensor_info_outputs = {}
for tmp_key in outputs:
tmp = graph.get_tensor_by_name(outputs[tmp_key].name)
@@ -529,23 +577,54 @@ def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
+ if 'incr_update' in oss_params:
+ incr_update_inputs = meta_graph_editor.sparse_update_inputs
+ incr_update_outputs = meta_graph_editor.sparse_update_outputs
+ incr_update_inputs.update(meta_graph_editor.dense_update_inputs)
+ incr_update_outputs.update(meta_graph_editor.dense_update_outputs)
+ tensor_info_incr_update_inputs = {}
+ tensor_info_incr_update_outputs = {}
+ for tmp_key in incr_update_inputs:
+ tmp = graph.get_tensor_by_name(incr_update_inputs[tmp_key].name)
+ tensor_info_incr_update_inputs[tmp_key] = \
+ tf.saved_model.utils.build_tensor_info(tmp)
+ for tmp_key in incr_update_outputs:
+ tmp = graph.get_tensor_by_name(incr_update_outputs[tmp_key].name)
+ tensor_info_incr_update_outputs[tmp_key] = \
+ tf.saved_model.utils.build_tensor_info(tmp)
+ incr_update_signature = (
+ tf.saved_model.signature_def_utils.build_signature_def(
+ inputs=tensor_info_incr_update_inputs,
+ outputs=tensor_info_incr_update_outputs,
+ method_name=signature_constants.PREDICT_METHOD_NAME))
+ else:
+ incr_update_signature = None
+
session_config = ConfigProto(
allow_soft_placement=True, log_device_placement=True)
saver = tf.train.Saver()
with tf.Session(target=server.target if server else '') as sess:
saver.restore(sess, checkpoint_path)
+ main_op = tf.group([
+ Scaffold.default_local_init_op(),
+ ops.get_collection(EMBEDDING_INITIALIZERS)
+ ])
+ incr_update_sig_map = {
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
+ }
+ if incr_update_signature is not None:
+ incr_update_sig_map[INCR_UPDATE_SIGNATURE_KEY] = incr_update_signature
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
- signature_def_map={
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature,
- },
+ signature_def_map=incr_update_sig_map,
assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
saver=saver,
+ main_op=main_op,
strip_default_attrs=True,
clear_devices=True)
builder.save()
# remove temporary files
Remove(embed_name_to_id_file)
- return
+ return export_dir
diff --git a/easy_rec/python/utils/expr_util.py b/easy_rec/python/utils/expr_util.py
new file mode 100644
index 000000000..a236bdfd9
--- /dev/null
+++ b/easy_rec/python/utils/expr_util.py
@@ -0,0 +1,118 @@
+from collections import deque
+
+
+def _process_multi_expr(expr):
+ expr = expr.strip()
+ size = len(expr)
+ idx = 0
+ two_expr = ['>=', '<=', '==']
+ expr_list = []
+ while (idx < size):
+ if idx + 2 <= size and expr[idx:idx + 2] in two_expr:
+ expr_list.append(expr[idx:idx + 2])
+ idx += 2
+ else:
+ expr_list.append(expr[idx])
+ idx += 1
+ return expr_list
+
+
+def _process_enum(enum, input_names, prefix=''):
+ enum = enum.strip()
+ if enum in input_names:
+ enum = "parsed_dict['%s']" % (prefix + enum)
+ return enum
+
+
+def _get_expression_list(expression, input_names, prefix=''):
+ ops = [
+ '+', '-', '*', '/', '(', ')', '>', '>=', '<', '<=', '==', '=', '&', '|'
+ ]
+ expression_list = []
+ eunm = ''
+ pre_expr = ''
+
+ for i in expression:
+ if i in ops:
+ if eunm:
+ expression_list.append(_process_enum(eunm, input_names, prefix=prefix))
+ eunm = ''
+ pre_expr += i
+ else:
+ eunm += i
+ if pre_expr:
+ expression_list.extend(_process_multi_expr(pre_expr))
+ pre_expr = ''
+ if eunm:
+ expression_list.append(_process_enum(eunm, input_names, prefix=prefix))
+ if pre_expr:
+ expression_list.extend(_process_multi_expr(pre_expr))
+
+ final_expression_list = ['']
+ ops = ['(', ')', '>=', '<=', '==', '>', '<', '&', '|']
+ for expr in expression_list:
+ if expr in ops:
+ final_expression_list.append(expr)
+ elif final_expression_list[-1] not in ops:
+ final_expression_list[-1] += expr
+ else:
+ final_expression_list.append(expr)
+ final_expression_list = [expr for expr in final_expression_list if expr]
+ return final_expression_list
+
+
+def _solve(enum, sign, stack):
+ if len(stack) == 0 or enum == '' or sign == '':
+ return enum
+ op1 = stack.pop()
+ op2 = enum
+ if sign == '>':
+ result = 'tf.greater(%s, %s)' % (op1, op2)
+ elif sign == '>=':
+ result = 'tf.greater_equal(%s, %s)' % (op1, op2)
+ elif sign == '<':
+ result = 'tf.less(%s, %s)' % (op1, op2)
+ elif sign == '<=':
+ result = 'tf.less_equal(%s, %s)' % (op1, op2)
+ elif sign == '==':
+ result = 'tf.equal(%s, %s)' % (op1, op2)
+ elif sign == '&':
+ result = '%s & %s' % (op1, op2)
+ elif sign == '|':
+ result = '%s | %s' % (op1, op2)
+ else:
+ assert False
+ return result
+
+
+def _expression_eval(expr_list):
+ ops = ['>', '>=', '<', '<=', '==', '&', '|', '(', ')']
+ stack = deque()
+ sign = ''
+ operand = ''
+ for c in expr_list:
+ if c == ' ':
+ continue
+ elif c not in ops:
+ operand = c
+ elif c == '(':
+ stack.append(sign)
+ sign = ''
+ else:
+ result = _solve(operand, sign, stack)
+ operand = ''
+ if c == ')':
+ sign = stack.pop()
+ operand = _solve(result, sign, stack)
+ sign = ''
+ else:
+ sign = c
+ stack.append(result)
+ expr_str = _solve(operand, sign, stack)
+ return expr_str
+
+
+def get_expression(expression, input_names, prefix=''):
+ expression_list = _get_expression_list(expression, input_names, prefix=prefix)
+ expression = _expression_eval(expression_list)
+ return expression
diff --git a/easy_rec/python/utils/fg_util.py b/easy_rec/python/utils/fg_util.py
index fb5694287..c394444bf 100644
--- a/easy_rec/python/utils/fg_util.py
+++ b/easy_rec/python/utils/fg_util.py
@@ -17,6 +17,11 @@ def load_fg_json_to_config(pipeline_config):
fg_json_path = pipeline_config.fg_json_path
if not fg_json_path:
return
+
+ if fg_json_path.startswith('!'):
+ # already loaded
+ return
+
label_fields = pipeline_config.data_config.label_fields
with tf.gfile.GFile(fg_json_path, 'r') as fin:
rtp_fg = json.load(fin)
@@ -26,7 +31,10 @@ def load_fg_json_to_config(pipeline_config):
pipeline_config.data_config.ClearField('input_fields')
pipeline_config.ClearField('feature_configs')
- pipeline_config.feature_config.ClearField('features')
+
+ # not clear features so that we could define extra features
+ # which is not defined in fg.json
+ # pipeline_config.feature_config.ClearField('features')
for input_config in fg_config.data_config.input_fields:
in_config = DatasetConfig.Field()
@@ -38,4 +46,8 @@ def load_fg_json_to_config(pipeline_config):
fea_config.CopyFrom(fc)
pipeline_config.feature_config.features.append(fea_config)
logging.info('data_config and feature_config has been replaced by fg_json.')
+
+ # signal that it is already loaded
+ pipeline_config.fg_json_path = '!' + pipeline_config.fg_json_path
+
return pipeline_config
diff --git a/easy_rec/python/utils/hit_rate_utils.py b/easy_rec/python/utils/hit_rate_utils.py
new file mode 100644
index 000000000..7ae313548
--- /dev/null
+++ b/easy_rec/python/utils/hit_rate_utils.py
@@ -0,0 +1,220 @@
+import logging
+
+import graphlearn as gl
+import numpy as np
+import tensorflow as tf
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def load_graph(i_emb_table, emb_dim, knn_metric, timeout, knn_strict):
+ """Load embedding tables in GL.
+
+ that used to lookup embedding and do knn search.
+ """
+ gl.set_knn_metric(knn_metric)
+ gl.set_timeout(timeout)
+ option = gl.IndexOption()
+ option.name = 'knn'
+ if knn_strict:
+ # option.index_type = "flat"
+ option.index_type = 'ivfflat'
+ option.nlist = 5
+ option.nprobe = 5
+ else:
+ option.index_type = 'ivfflat'
+ option.nlist = 5
+ option.nprobe = 2
+ g = gl.Graph().node(
+ i_emb_table,
+ node_type='i',
+ decoder=gl.Decoder(attr_types=['float'] * emb_dim, attr_delimiter=','),
+ option=option)
+ return g
+
+
+def batch_hitrate(src_ids,
+ recall_ids,
+ recall_distances,
+ gt_items,
+ num_interests,
+ mask=None):
+ """Compute hitrate of a batch of src ids.
+
+ Args:
+ src_ids: trigger id, a numpy array.
+ recall_ids: recalled ids by src_ids, a numpy array.
+ recall_distances: corresponding distances of recalled ids, a numpy array.
+ gt_items: batch of ground truth item ids list, a list of list.
+ num_interests: max number of interests.
+ mask: some models have different number of interests.
+
+ Returns:
+ hitrates: hitrate of src_ids, a list.
+ bad_cases: bad cases, a list of list.
+ bad_dsts: distances of bad cases, a list of list.
+ hits: total hit counts of a batch of src ids, a scalar.
+ gt_count: total ground truth items num of a batch of src ids, a scalar.
+ """
+ hitrates = []
+ bad_cases = []
+ bad_dists = []
+ hits = 0.0
+ gt_count = 0.0
+ for idx, src_id in enumerate(src_ids):
+ recall_id = recall_ids[idx]
+ recall_distance = recall_distances[idx]
+
+ bad_case = {}
+ gt_items_size = len(gt_items[idx])
+ hit_ids = []
+ if gt_items_size == 0: # just skip invalid record.
+ print('Id {:d} has no related items sequence, just skip.'.format(src_id))
+ continue
+ for interest_id in range(num_interests):
+ if not mask[idx, interest_id]:
+ continue
+ for k, id in enumerate(recall_id[interest_id]):
+ if id in gt_items[idx]:
+ if id not in hit_ids:
+ hit_ids.append(id)
+ else:
+ dis = recall_distance[interest_id][k]
+ if id not in bad_case:
+ bad_case[id] = dis
+ elif dis < bad_case[id]:
+ bad_case[id] = dis
+ hit_count = float(len(hit_ids))
+ hitrates.append(hit_count / gt_items_size)
+ hits += hit_count
+ gt_count += gt_items_size
+ bad_cases.append([x for x in bad_case])
+ bad_dists.append([bad_case[x] for x in bad_case])
+ return hitrates, bad_cases, bad_dists, hits, gt_count
+
+
+def reduce_hitrate(cluster, hits, count, task_index):
+ """Reduce hitrate of all workers.
+
+ Args:
+ cluster: tf cluster.
+ hits: total_hits of each worker.
+ count: total count of ground truth items of each worker.
+ task_index: worker index.
+
+ Returns:
+ var_total_hitrate: variable of total hitrate.
+ var_worker_count: variable used to mark the number of worker that
+ have completed the calculation of hitrate.
+ """
+ with tf.device(
+ tf.train.replica_device_setter(
+ worker_device='/job:worker/task:%d' % task_index, cluster=cluster)):
+ with tf.variable_scope('hitrate_var', reuse=tf.AUTO_REUSE):
+ var_worker_count = tf.get_variable(
+ 'worker_count',
+ shape=(),
+ dtype=tf.int32,
+ initializer=tf.zeros_initializer())
+ var_hits = tf.get_variable(
+ 'hits',
+ shape=(),
+ dtype=tf.float32,
+ initializer=tf.zeros_initializer())
+ var_gt_count = tf.get_variable(
+ 'gt_count',
+ shape=(),
+ dtype=tf.float32,
+ initializer=tf.zeros_initializer())
+ var_total_hitrate = tf.get_variable(
+ 'total_hitate',
+ shape=(),
+ dtype=tf.float32,
+ initializer=tf.zeros_initializer())
+
+ var_hits = tf.assign_add(var_hits, hits, use_locking=True)
+ var_gt_count = tf.assign_add(var_gt_count, count, use_locking=True)
+ var_gt_count = tf.Print(
+ var_gt_count, [var_gt_count, var_hits],
+ message='var_gt_count/var_hits')
+ var_total_hitrate = tf.assign(
+ var_total_hitrate, var_hits / var_gt_count, use_locking=True)
+ with tf.control_dependencies([var_total_hitrate]):
+ var_worker_count = tf.assign_add(var_worker_count, 1, use_locking=True)
+ return var_total_hitrate, var_worker_count
+
+
+def compute_hitrate_batch(g, gt_record, emb_dim, num_interests, top_k):
+ """Reduce hitrate of one batch.
+
+ Args:
+ g: a GL Graph instance.
+ gt_record: record list of groung truth.
+ emb_dim: embedding dim.
+ num_interests: max number of interests.
+ top_k: top_k hitrate.
+
+ Returns:
+ hits: total hit counts of a batch of src ids, a scalar.
+ gt_count: total ground truth items num of a batch of src ids, a scalar.
+ src_ids: src ids, a list.
+ recall_ids: recall ids, a list.
+ recall_distances: recall distances, a list.
+ hitrates: hitrate of a batch of src_ids, a list.
+ bad_cases: bad cases, a list of list.
+ bad_dsts: distances of bad cases, a list of list.
+ """
+
+ def _to_float_attrs(x):
+ # incase user embedding is not present
+ if x == '':
+ return np.zeros([emb_dim], dtype=np.float32)
+ embed = np.array(x.split(','), dtype=np.float32)
+ assert len(embed) == emb_dim, 'invalid embed len=%d, x=%s' % (len(embed), x)
+ return embed
+
+ def _to_multi_float_attrs(x, userid):
+ if x == '':
+ arr = [_to_float_attrs(x) for i in range(num_interests)]
+ else:
+ arr = [_to_float_attrs(sub_x) for sub_x in x.split('|')]
+ assert len(arr) == num_interests, 'invalid arr len=%d, x=%s, userid=%s' % (
+ len(arr), x, userid)
+ return arr
+
+ src_ids = np.array([src_items[0] for src_items in gt_record])
+ user_embedding = np.array([
+ _to_multi_float_attrs(src_items[2], src_items[0])
+ for src_items in gt_record
+ ])
+ user_emb_num = [src_items[3] for src_items in gt_record]
+
+ print('max(user_emb_num) = %d len(src_ids) = %d' %
+ (np.max(user_emb_num), len(src_ids)))
+
+ # a list of list.
+ gt_items = [
+ list(map(int, src_items[1].split(','))) for src_items in gt_record
+ ]
+
+ logging.info('src_nodes.float_attrs.shape=%s' % str(user_embedding.shape))
+ user_embedding = user_embedding.reshape([-1, user_embedding.shape[-1]])
+ # numpy array
+ recall_ids, recall_distances = g.search('i', user_embedding,
+ gl.KnnOption(k=top_k))
+ logging.info('recall_ids.shape=%s' % str(recall_ids.shape))
+
+ def _make_mask(lens):
+ mask = np.ones([len(lens), num_interests], dtype=np.float32)
+ for tmp_id, tmp_len in enumerate(lens):
+ mask[tmp_id, int(tmp_len):] = 0
+ return mask
+
+ mask = _make_mask(user_emb_num)
+ recall_ids = recall_ids.reshape([-1, num_interests, recall_ids.shape[-1]])
+ recall_distances = recall_distances.reshape(
+ [-1, num_interests, recall_distances.shape[-1]])
+ hitrates, bad_cases, bad_dists, hits, gt_count = batch_hitrate(
+ src_ids, recall_ids, recall_distances, gt_items, num_interests, mask)
+ return hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists
diff --git a/easy_rec/python/utils/hive_utils.py b/easy_rec/python/utils/hive_utils.py
new file mode 100644
index 000000000..250344f13
--- /dev/null
+++ b/easy_rec/python/utils/hive_utils.py
@@ -0,0 +1,183 @@
+# -*- coding: utf-8 -*-
+import logging
+
+try:
+ from pyhive import hive
+ from pyhive.exc import ProgrammingError
+except ImportError:
+ logging.warning('pyhive is not installed.')
+
+
+class TableInfo(object):
+
+ def __init__(self, tablename, selected_cols, partition_kv, limit_num):
+ self.tablename = tablename
+ self.selected_cols = selected_cols
+ self.partition_kv = partition_kv
+ self.limit_num = limit_num
+
+ def gen_sql(self):
+ part = ''
+ if self.partition_kv and len(self.partition_kv) > 0:
+ res = []
+ for k, v in self.partition_kv.items():
+ res.append('{}={}'.format(k, v))
+ part = ' '.join(res)
+ sql = """select {}
+ from {}""".format(self.selected_cols, self.tablename)
+
+ if part:
+ sql += """
+ where {}
+ """.format(part)
+ if self.limit_num is not None and self.limit_num > 0:
+ sql += ' limit {}'.format(self.limit_num)
+ return sql
+
+
+class HiveUtils(object):
+ """Common IO based interface, could run at local or on data science."""
+
+ def __init__(self,
+ data_config,
+ hive_config,
+ selected_cols='',
+ record_defaults=[],
+ task_index=0,
+ task_num=1):
+
+ self._data_config = data_config
+ self._hive_config = hive_config
+
+ self._num_epoch = data_config.num_epochs
+ self._num_epoch_record = 0
+ self._task_index = task_index
+ self._task_num = task_num
+ self._selected_cols = selected_cols
+ self._record_defaults = record_defaults
+
+ def _construct_table_info(self, table_name, limit_num):
+ # sample_table/dt=2014-11-23/name=a
+ segs = table_name.split('/')
+ table_name = segs[0].strip()
+ if len(segs) > 0:
+ partition_kv = {i.split('=')[0]: i.split('=')[1] for i in segs[1:]}
+ else:
+ partition_kv = None
+
+ table_info = TableInfo(table_name, self._selected_cols, partition_kv,
+ limit_num)
+ return table_info
+
+ def _construct_hive_connect(self):
+ conn = hive.Connection(
+ host=self._hive_config.host,
+ port=self._hive_config.port,
+ username=self._hive_config.username,
+ database=self._hive_config.database)
+ return conn
+
+ def hive_read_line(self, input_path, limit_num=None):
+ table_info = self._construct_table_info(input_path, limit_num)
+ conn = self._construct_hive_connect()
+ cursor = conn.cursor()
+ sql = table_info.gen_sql()
+ cursor.execute(sql)
+
+ while True:
+ data = cursor.fetchmany(size=1)
+ if len(data) == 0:
+ break
+ yield data
+
+ cursor.close()
+ conn.close()
+
+ def hive_read_lines(self, input_path, batch_size, limit_num=None):
+ table_info = self._construct_table_info(input_path, limit_num)
+ conn = self._construct_hive_connect()
+ cursor = conn.cursor()
+ sql = table_info.gen_sql()
+ cursor.execute(sql)
+
+ while True:
+ data = cursor.fetchmany(size=batch_size)
+ if len(data) == 0:
+ break
+ yield data
+
+ cursor.close()
+ conn.close()
+
+ def run_sql(self, sql):
+ conn = self._construct_hive_connect()
+ cursor = conn.cursor()
+ cursor.execute(sql)
+ try:
+ data = cursor.fetchall()
+ except ProgrammingError:
+ data = []
+ return data
+
+ def is_table_or_partition_exist(self,
+ table_name,
+ partition_name=None,
+ partition_val=None):
+ if partition_name and partition_val:
+ sql = 'show partitions %s partition(%s=%s)' % (table_name, partition_name,
+ partition_val)
+ try:
+ res = self.run_sql(sql)
+ if not res:
+ return False
+ else:
+ return True
+ except: # noqa: E722
+ return False
+
+ else:
+ sql = 'desc %s' % table_name
+ try:
+ self.run_sql(sql)
+ return True
+ except: # noqa: E722
+ return False
+
+ def get_table_location(self, input_path):
+ conn = self._construct_hive_connect()
+ cursor = conn.cursor()
+ partition = ''
+ if len(input_path.split('/')) == 2:
+ table_name, partition = input_path.split('/')
+ partition += '/'
+ else:
+ table_name = input_path
+ sql = 'desc formatted %s' % table_name
+ cursor.execute(sql)
+ data = cursor.fetchmany()
+ for line in data:
+ if line[0].startswith('Location'):
+ return line[1].strip() + '/' + partition
+ return None
+
+ def get_all_cols(self, input_path):
+ conn = self._construct_hive_connect()
+ cursor = conn.cursor()
+ sql = 'desc %s' % input_path.split('/')[0]
+ cursor.execute(sql)
+ data = cursor.fetchmany()
+ col_names = []
+ cols_types = []
+ pt_name = ''
+ if len(input_path.split('/')) == 2:
+ pt_name = input_path.split('/')[1].split('=')[0]
+
+ for col in data:
+ col_name = col[0].strip()
+ if col_name and (not col_name.startswith('#')) and (col_name
+ not in col_names):
+ if col_name != pt_name:
+ col_names.append(col_name)
+ cols_types.append(col[1].strip())
+
+ return col_names, cols_types
diff --git a/easy_rec/python/utils/hpo_util.py b/easy_rec/python/utils/hpo_util.py
index 753c2eb99..bebf08476 100644
--- a/easy_rec/python/utils/hpo_util.py
+++ b/easy_rec/python/utils/hpo_util.py
@@ -38,6 +38,13 @@ def get_all_eval_result(event_file_pattern):
def save_eval_metrics(model_dir, metric_save_path, has_evaluator=True):
+ """Save evaluation metrics.
+
+ Args:
+ model_dir: train model directory
+ metric_save_path: metric saving path
+ has_evaluator: evaluation is done on a separate evaluator, not on master.
+ """
def _get_eval_event_file_pattern():
eval_dir = os.path.join(model_dir, 'eval_val/')
diff --git a/easy_rec/python/utils/hvd_utils.py b/easy_rec/python/utils/hvd_utils.py
new file mode 100644
index 000000000..223283486
--- /dev/null
+++ b/easy_rec/python/utils/hvd_utils.py
@@ -0,0 +1,56 @@
+# -*- encoding: utf-8 -*-
+import logging
+
+import tensorflow as tf
+from tensorflow.python.framework import ops
+from tensorflow.python.training import session_run_hook
+
+from easy_rec.python.utils import constant
+
+# from horovod.tensorflow.compression import Compression
+try:
+ from horovod.tensorflow.functions import broadcast_variables
+except Exception:
+ pass
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class BroadcastGlobalVariablesHook(session_run_hook.SessionRunHook):
+ """SessionRunHook that will broadcast all global variables from root rank to all other processes during initialization.
+
+ This is necessary to ensure consistent initialization of all workers when
+ training is started with random weights or restored from a checkpoint.
+ """ # noqa: E501
+
+ def __init__(self, root_rank, device=''):
+ """Construct a new BroadcastGlobalVariablesHook that will broadcast all global variables from root rank to all other processes during initialization.
+
+ Args:
+ root_rank:
+ Rank that will send data, other ranks will receive data.
+ device:
+ Device to be used for broadcasting. Uses GPU by default
+ if Horovod was built with HOROVOD_GPU_OPERATIONS.
+ """ # noqa: E501
+ super(BroadcastGlobalVariablesHook, self).__init__()
+ self.root_rank = root_rank
+ self.bcast_op = None
+ self.device = device
+
+ def begin(self):
+ bcast_vars = []
+ embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
+ for x in tf.global_variables():
+ # if '/embedding' not in x.name and 'DynamicVariable' not in str(type(x)):
+ if x.name not in embed_para_vars:
+ bcast_vars.append(x)
+ logging.info('will broadcast variable: name=%s shape=%s' %
+ (x.name, x.get_shape()))
+ if not self.bcast_op or self.bcast_op.graph != tf.get_default_graph():
+ with tf.device(self.device):
+ self.bcast_op = broadcast_variables(bcast_vars, self.root_rank)
+
+ def after_create_session(self, session, coord):
+ session.run(self.bcast_op)
diff --git a/easy_rec/python/utils/input_utils.py b/easy_rec/python/utils/input_utils.py
index 8be3ba5b4..cd8a5b975 100644
--- a/easy_rec/python/utils/input_utils.py
+++ b/easy_rec/python/utils/input_utils.py
@@ -1,6 +1,7 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
+import pandas as pd
import tensorflow as tf
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
@@ -37,21 +38,24 @@ def get_type_defaults(field_type, default_val=''):
return type_defaults[field_type]
-def string_to_number(field, ftype, name=''):
+def string_to_number(field, ftype, default_value, name=''):
"""Type conversion for parsing rtp fg input format.
Args:
field: field to be converted.
ftype: field dtype set in DatasetConfig.
+ default_value: default value for this field
name: field name for
Returns: A name for the operation (optional).
"""
- tmp_field = field
+ default_vals = tf.tile(tf.constant([str(default_value)]), tf.shape(field))
+ field = tf.where(tf.greater(tf.strings.length(field), 0), field, default_vals)
+
if ftype in [DatasetConfig.INT32, DatasetConfig.INT64]:
# Int type is not supported in fg.
# If you specify INT32, INT64 in DatasetConfig, you need to perform a cast at here.
tmp_field = tf.string_to_number(
- field, tf.double, name='field_as_int_%s' % name)
+ field, tf.double, name='field_as_flt_%s' % name)
if ftype in [DatasetConfig.INT64]:
tmp_field = tf.cast(tmp_field, tf.int64)
else:
@@ -64,6 +68,41 @@ def string_to_number(field, ftype, name=''):
field, tf.float64, name='field_as_flt_%s' % name)
elif ftype in [DatasetConfig.BOOL]:
tmp_field = tf.logical_or(tf.equal(field, 'True'), tf.equal(field, 'true'))
+ elif ftype in [DatasetConfig.STRING]:
+ tmp_field = field
else:
- assert 'invalid types: %s' % str(ftype)
+ assert False, 'invalid types: %s' % str(ftype)
return tmp_field
+
+
+def np_to_tf_type(np_type):
+ _types_map = {
+ int: tf.int32,
+ np.int32: tf.int32,
+ np.int64: tf.int64,
+ str: tf.string,
+ np.float: tf.float32,
+ np.float32: tf.float32,
+ float: tf.float32,
+ np.double: tf.float64
+ }
+ if np_type in _types_map:
+ return _types_map[np_type]
+ else:
+ return tf.string
+
+
+def get_tf_type_from_parquet_file(cols, parquet_file):
+ # gfile not supported, read_parquet requires random access
+ input_data = pd.read_parquet(parquet_file, columns=cols)
+ tf_types = []
+ for col in cols:
+ obj = input_data[col][0]
+ if isinstance(obj, list):
+ data_type = type(obj[0])
+ elif isinstance(obj, np.ndarray):
+ data_type = type(obj[0])
+ else:
+ data_type = type(obj)
+ tf_types.append(np_to_tf_type(data_type))
+ return tf_types
diff --git a/easy_rec/python/utils/io_util.py b/easy_rec/python/utils/io_util.py
index c7702c318..92c9c8a1f 100644
--- a/easy_rec/python/utils/io_util.py
+++ b/easy_rec/python/utils/io_util.py
@@ -4,6 +4,7 @@
isort:skip_file
"""
+import logging
from future import standard_library
standard_library.install_aliases()
@@ -15,12 +16,15 @@
import tensorflow as tf
from six.moves import http_client
from six.moves import urllib
-
+import json
if six.PY2:
from urllib import quote
else:
from urllib.parse import quote
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
EASY_REC_RES_DIR = 'easy_rec_user_resources'
HTTP_MAX_NUM_RETRY = 5
HTTP_MAX_TIMEOUT = 600
@@ -165,3 +169,114 @@ def fix_oss_dir(path):
if path.startswith('oss://') and not path.endswith('/'):
return path + '/'
return path
+
+
+def save_data_to_json_path(json_path, data):
+ with tf.gfile.GFile(json_path, 'w') as fout:
+ fout.write(json.dumps(data))
+ assert tf.gfile.Exists(json_path), 'in_save_data_to_json_path, save_failed'
+
+
+def read_data_from_json_path(json_path):
+ if json_path and tf.gfile.Exists(json_path):
+ with tf.gfile.GFile(json_path, 'r') as fin:
+ data = json.loads(fin.read())
+ return data
+ else:
+ logging.info('json_path not exists, return None')
+ return None
+
+
+def convert_tf_flags_to_argparse(flags):
+ """Convert tf.app.flags.FLAGS to argparse.ArgumentParser.
+
+ Args:
+ flags: tf.app.flags.FLAGS
+ Returns:
+ argparse.ArgumentParser: configurate ArgumentParser object
+ """
+ import argparse
+ import ast
+ parser = argparse.ArgumentParser()
+
+ args = {}
+ for flag in flags._flags().values():
+ flag_name = flag.name
+ if flag_name in args:
+ args[flag_name][0] = True
+ continue
+ default = flag.value
+ flag_type = type(default)
+ help_str = flag.help or ''
+ args[flag_name] = [
+ False, flag_type, default, help_str,
+ flag.choices if hasattr(flag, 'choices') else None
+ ]
+
+ def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+ for flag_name, (multi, flag_type, default, help_str, choices) in args.items():
+ if flag_type == bool:
+ parser.add_argument(
+ '--' + flag_name,
+ type=str2bool,
+ nargs='?',
+ const=True,
+ default=False,
+ help=help_str)
+ elif flag_type == str:
+ if choices:
+ parser.add_argument(
+ '--' + flag_name,
+ type=str,
+ choices=choices,
+ default=default,
+ help=help_str)
+ elif multi:
+ parser.add_argument(
+ '--' + flag_name,
+ type=str,
+ action='http://github.com/append',
+ default=default,
+ help=help_str)
+ else:
+ parser.add_argument(
+ '--' + flag_name, type=str, default=default, help=help_str)
+ elif flag_type in (list, dict):
+ parser.add_argument(
+ '--' + flag_name,
+ type=lambda s: ast.literal_eval(s),
+ default=default,
+ help=help_str)
+ elif flag_type in (int, float):
+ parser.add_argument(
+ '--' + flag_name, type=flag_type, default=default, help=help_str)
+ else:
+ parser.add_argument(
+ '--' + flag_name, type=str, default=default, help=help_str)
+ return parser
+
+
+def filter_unknown_args(flags, args):
+ """Filter unknown args."""
+ known_args = [args[0]]
+ parser = convert_tf_flags_to_argparse(flags)
+ args, unknown = parser.parse_known_args(args)
+ if len(unknown) > 1:
+ logging.info('undefined arguments: %s', ', '.join(unknown[1:]))
+ for key, value in vars(args).items():
+ if value is None:
+ continue
+ if type(value) in (list, dict) and not value:
+ continue
+ known_args.append('--' + key + '=' + str(value))
+ logging.info('defined arguments: %s', ', '.join(known_args[1:]))
+ return known_args
diff --git a/easy_rec/python/utils/load_class.py b/easy_rec/python/utils/load_class.py
index 4821b743d..9ac749c76 100644
--- a/easy_rec/python/utils/load_class.py
+++ b/easy_rec/python/utils/load_class.py
@@ -7,6 +7,7 @@
import os
import pkgutil
import pydoc
+import traceback
from abc import ABCMeta
import six
@@ -36,6 +37,8 @@ def load_by_path(path):
path = path.strip()
if path == '' or path is None:
return None
+ if 'lambda' in path:
+ return eval(path)
components = path.split('.')
if components[0] == 'tf':
components[0] = 'tensorflow'
@@ -43,7 +46,7 @@ def load_by_path(path):
try:
return pydoc.locate(path)
except pydoc.ErrorDuringImport:
- logging.error('load %s failed' % path)
+ logging.error('load %s failed: %s' % (path, traceback.format_exc()))
return None
@@ -217,3 +220,30 @@ def create_class(cls, name):
return newclass
return RegisterABCMeta
+
+
+def load_keras_layer(name):
+ """Load keras layer class.
+
+ Args:
+ name: keras layer name
+
+ Return:
+ (layer_class, is_customize)
+ """
+ name = name.strip()
+ if name == '' or name is None:
+ return None
+
+ path = 'easy_rec.python.layers.keras.' + name
+ try:
+ cls = pydoc.locate(path)
+ if cls is not None:
+ return cls, True
+ path = 'tensorflow.keras.layers.' + name
+ return pydoc.locate(path), False
+ except pydoc.ErrorDuringImport:
+ print('load keras layer %s failed' % name)
+ logging.error('load keras layer %s failed: %s' %
+ (name, traceback.format_exc()))
+ return None, False
diff --git a/easy_rec/python/utils/meta_graph_editor.py b/easy_rec/python/utils/meta_graph_editor.py
index d01c6fb8e..9fc75f1fe 100644
--- a/easy_rec/python/utils/meta_graph_editor.py
+++ b/easy_rec/python/utils/meta_graph_editor.py
@@ -2,14 +2,22 @@
import logging
import os
+import numpy as np
import tensorflow as tf
from google.protobuf import text_format
+from tensorflow.python.framework import ops
from tensorflow.python.platform.gfile import GFile
+# from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model.loader_impl import SavedModelLoader
+from easy_rec.python.utils import conditional
+from easy_rec.python.utils import constant
+from easy_rec.python.utils import embedding_utils
from easy_rec.python.utils import proto_util
+EMBEDDING_INITIALIZERS = 'embedding_initializers'
+
class MetaGraphEditor:
@@ -27,6 +35,7 @@ def __init__(self,
oss_timeout=0,
meta_graph_def=None,
norm_name_to_ids=None,
+ incr_update_params=None,
debug_dir=''):
self._lookup_op = tf.load_op_library(lookup_lib_path)
self._debug_dir = debug_dir
@@ -39,7 +48,10 @@ def __init__(self,
else:
assert meta_graph_def, 'either saved_model_dir or meta_graph_def must be set'
tf.reset_default_graph()
- tf.train.import_meta_graph(meta_graph_def)
+ from tensorflow.python.framework import meta_graph
+ meta_graph.import_scoped_meta_graph_with_return_elements(
+ meta_graph_def, clear_devices=True)
+ # tf.train.import_meta_graph(meta_graph_def)
self._meta_graph_version = meta_graph_def.meta_info_def.meta_graph_version
self._signature_def = meta_graph_def.signature_def[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@@ -74,6 +86,31 @@ def __init__(self,
self._oss_sk = oss_sk
self._oss_timeout = oss_timeout
+ self._incr_update_params = incr_update_params
+
+ # increment update placeholders
+ self._embedding_update_inputs = {}
+ self._embedding_update_outputs = {}
+
+ self._dense_update_inputs = {}
+ self._dense_update_outputs = {}
+
+ @property
+ def sparse_update_inputs(self):
+ return self._embedding_update_inputs
+
+ @property
+ def sparse_update_outputs(self):
+ return self._embedding_update_outputs
+
+ @property
+ def dense_update_inputs(self):
+ return self._dense_update_inputs
+
+ @property
+ def dense_update_outputs(self):
+ return self._dense_update_outputs
+
@property
def graph_def(self):
return self._meta_graph_def.graph_def
@@ -314,9 +351,11 @@ def _get_tensor_by_name(tensor_name):
if not self._embed_name_to_ids:
embed_name_uniq = list(set(self._embed_names))
self._embed_name_to_ids = {
- t: str(tid) for tid, t in enumerate(embed_name_uniq)
+ t: tid for tid, t in enumerate(embed_name_uniq)
}
- self._embed_ids = [self._embed_name_to_ids[x] for x in self._embed_names]
+ self._embed_ids = [
+ int(self._embed_name_to_ids[x]) for x in self._embed_names
+ ]
self._is_cache_from_redis = [
proto_util.is_cache_from_redis(x, self._redis_cache_names)
@@ -332,22 +371,25 @@ def _get_tensor_by_name(tensor_name):
def add_lookup_op(self, lookup_input_indices, lookup_input_values,
lookup_input_shapes, lookup_input_weights):
logging.info('add custom lookup operation to lookup embeddings from redis')
+ self._lookup_outs = [None for i in range(len(lookup_input_values))]
for i in range(len(lookup_input_values)):
if lookup_input_values[i].dtype == tf.int32:
lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
- self._lookup_outs = self._lookup_op.kv_lookup(
- lookup_input_indices,
- lookup_input_values,
- lookup_input_shapes,
- lookup_input_weights,
- url=self._redis_url,
- password=self._redis_passwd,
- timeout=self._redis_timeout,
- combiners=self._embed_combiners,
- embedding_dims=self._embed_dims,
- embedding_names=self._embed_ids,
- cache=self._is_cache_from_redis,
- version=self._meta_graph_version)
+ for i in range(len(self._lookup_outs)):
+ i_1 = i + 1
+ self._lookup_outs[i] = self._lookup_op.kv_lookup(
+ lookup_input_indices[i:i_1],
+ lookup_input_values[i:i_1],
+ lookup_input_shapes[i:i_1],
+ lookup_input_weights[i:i_1],
+ url=self._redis_url,
+ password=self._redis_passwd,
+ timeout=self._redis_timeout,
+ combiners=self._embed_combiners[i:i_1],
+ embedding_dims=self._embed_dims[i:i_1],
+ embedding_names=self._embed_ids[i:i_1],
+ cache=self._is_cache_from_redis,
+ version=self._meta_graph_version)[0]
meta_graph_def = tf.train.export_meta_graph()
@@ -362,9 +404,32 @@ def add_lookup_op(self, lookup_input_indices, lookup_input_values,
def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values,
lookup_input_shapes, lookup_input_weights):
logging.info('add custom lookup operation to lookup embeddings from oss')
- for i in range(len(lookup_input_values)):
- if lookup_input_values[i].dtype == tf.int32:
- lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
+ place_on_cpu = os.getenv('place_embedding_on_cpu')
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
+ with conditional(place_on_cpu, ops.device('/CPU:0')):
+ for i in range(len(lookup_input_values)):
+ if lookup_input_values[i].dtype == tf.int32:
+ lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
+ # N = len(lookup_input_indices)
+ # self._lookup_outs = [ None for _ in range(N) ]
+ # for i in range(N):
+ # i_1 = i + 1
+ # self._lookup_outs[i] = self._lookup_op.oss_read_kv(
+ # lookup_input_indices[i:i_1],
+ # lookup_input_values[i:i_1],
+ # lookup_input_shapes[i:i_1],
+ # lookup_input_weights[i:i_1],
+ # osspath=self._oss_path,
+ # endpoint=self._oss_endpoint,
+ # ak=self._oss_ak,
+ # sk=self._oss_sk,
+ # timeout=self._oss_timeout,
+ # combiners=self._embed_combiners[i:i_1],
+ # embedding_dims=self._embed_dims[i:i_1],
+ # embedding_ids=self._embed_ids[i:i_1],
+ # embedding_is_kv=self._embed_is_kv[i:i_1],
+ # shared_name='embedding_lookup_res',
+ # name='embedding_lookup_fused/lookup')[0]
self._lookup_outs = self._lookup_op.oss_read_kv(
lookup_input_indices,
lookup_input_values,
@@ -377,8 +442,60 @@ def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values,
timeout=self._oss_timeout,
combiners=self._embed_combiners,
embedding_dims=self._embed_dims,
- embedding_names=self._embed_ids,
- embedding_is_kv=self._embed_is_kv)
+ embedding_ids=self._embed_ids,
+ embedding_is_kv=self._embed_is_kv,
+ shared_name='embedding_lookup_res',
+ name='embedding_lookup_fused/lookup')
+
+ N = np.max([int(x) for x in self._embed_ids]) + 1
+ uniq_embed_ids = [x for x in range(N)]
+ uniq_embed_dims = [0 for x in range(N)]
+ uniq_embed_combiners = ['mean' for x in range(N)]
+ uniq_embed_is_kvs = [0 for x in range(N)]
+ for embed_id, embed_combiner, embed_is_kv, embed_dim in zip(
+ self._embed_ids, self._embed_combiners, self._embed_is_kv,
+ self._embed_dims):
+ uniq_embed_combiners[embed_id] = embed_combiner
+ uniq_embed_is_kvs[embed_id] = embed_is_kv
+ uniq_embed_dims[embed_id] = embed_dim
+
+ lookup_init_op = self._lookup_op.oss_init(
+ osspath=self._oss_path,
+ endpoint=self._oss_endpoint,
+ ak=self._oss_ak,
+ sk=self._oss_sk,
+ combiners=uniq_embed_combiners,
+ embedding_dims=uniq_embed_dims,
+ embedding_ids=uniq_embed_ids,
+ embedding_is_kv=uniq_embed_is_kvs,
+ N=N,
+ shared_name='embedding_lookup_res',
+ name='embedding_lookup_fused/init')
+
+ ops.add_to_collection(EMBEDDING_INITIALIZERS, lookup_init_op)
+
+ if self._incr_update_params is not None:
+ # all sparse variables are updated by a single custom operation
+ message_ph = tf.placeholder(tf.int8, [None], name='incr_update/message')
+ embedding_update = self._lookup_op.embedding_update(
+ message=message_ph,
+ shared_name='embedding_lookup_res',
+ name='embedding_lookup_fused/embedding_update')
+ self._embedding_update_inputs['incr_update/sparse/message'] = message_ph
+ self._embedding_update_outputs[
+ 'incr_update/sparse/embedding_update'] = embedding_update
+
+ # dense variables are updated one by one
+ dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
+ for x in ops.get_collection(constant.DENSE_UPDATE_VARIABLES):
+ dense_var_id = dense_name_to_ids[x.op.name]
+ dense_input_name = 'incr_update/dense/%d/input' % dense_var_id
+ dense_output_name = 'incr_update/dense/%d/output' % dense_var_id
+ dense_update_input = tf.placeholder(
+ tf.float32, x.get_shape(), name=dense_input_name)
+ self._dense_update_inputs[dense_input_name] = dense_update_input
+ dense_assign_op = tf.assign(x, dense_update_input)
+ self._dense_update_outputs[dense_output_name] = dense_assign_op
meta_graph_def = tf.train.export_meta_graph()
diff --git a/easy_rec/python/utils/multi_optimizer.py b/easy_rec/python/utils/multi_optimizer.py
index 9e5cefbda..c34c4abe0 100644
--- a/easy_rec/python/utils/multi_optimizer.py
+++ b/easy_rec/python/utils/multi_optimizer.py
@@ -38,6 +38,9 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_ops.append(opt.apply_gradients(tmp, None))
return tf.group(update_ops)
+ def open_auto_record(self, flag=True):
+ super(MultiOptimizer, self).open_auto_record(flag)
+
def get_slot(self, var, name):
raise NotImplementedError('not implemented')
# for opt in self._opts:
diff --git a/easy_rec/python/utils/numpy_utils.py b/easy_rec/python/utils/numpy_utils.py
new file mode 100644
index 000000000..cfda857d2
--- /dev/null
+++ b/easy_rec/python/utils/numpy_utils.py
@@ -0,0 +1,18 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
+
+import numpy as np
+
+
+class NumpyEncoder(json.JSONEncoder):
+ """For encode numpy arrays."""
+
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ return json.JSONEncoder.default(self, obj)
diff --git a/easy_rec/python/utils/odps_util.py b/easy_rec/python/utils/odps_util.py
index 99e33bc32..47a9d9870 100644
--- a/easy_rec/python/utils/odps_util.py
+++ b/easy_rec/python/utils/odps_util.py
@@ -1,6 +1,8 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Common functions used for odps input."""
+from tensorflow.python.framework import dtypes
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
@@ -25,6 +27,18 @@ def is_type_compatiable(odps_type, input_type):
return False
+def odps_type_to_input_type(odps_type):
+ """Check that odps_type are compatiable with input_type."""
+ odps_type_map = {
+ 'bigint': DatasetConfig.INT64,
+ 'string': DatasetConfig.STRING,
+ 'double': DatasetConfig.DOUBLE
+ }
+ assert odps_type in odps_type_map, 'only support [bigint, string, double]'
+ input_type = odps_type_map[odps_type]
+ return input_type
+
+
def check_input_field_and_types(data_config):
"""Check compatibility of input in data_config.
@@ -52,3 +66,14 @@ def check_input_field_and_types(data_config):
assert is_type_compatiable(tmp_type, y), \
'feature[%s] type error: odps %s is not compatible with input_type %s' % (
x, tmp_type, DatasetConfig.FieldType.Name(y))
+
+
+def odps_type_2_tf_type(odps_type):
+ if odps_type == 'string':
+ return dtypes.string
+ elif odps_type == 'bigint':
+ return dtypes.int64
+ elif odps_type in ['double', 'float']:
+ return dtypes.float32
+ else:
+ return dtypes.string
diff --git a/easy_rec/python/utils/proto_util.py b/easy_rec/python/utils/proto_util.py
index 2f8255858..c96d41a78 100644
--- a/easy_rec/python/utils/proto_util.py
+++ b/easy_rec/python/utils/proto_util.py
@@ -51,7 +51,8 @@ def get_norm_embed_name(name, verbose=False):
# input_layer/app_category_embedding/app_category_embedding_weights/SparseReshape
# => input_layer/app_category_embedding
for i in range(0, len(name_toks) - 1):
- if name_toks[i + 1].endswith('_embedding_weights'):
+ if name_toks[i + 1].endswith('_embedding_weights') or \
+ '_embedding_weights_' in name_toks[i + 1]:
tmp_name = '/'.join(name_toks[:i + 1])
if verbose:
logging.info('norm %s to %s' % (name, tmp_name))
diff --git a/easy_rec/python/utils/shape_utils.py b/easy_rec/python/utils/shape_utils.py
index e51e1fbb4..f54521513 100644
--- a/easy_rec/python/utils/shape_utils.py
+++ b/easy_rec/python/utils/shape_utils.py
@@ -388,3 +388,45 @@ def assert_rank(tensor, expected_rank, name=None):
'For the tensor `%s` in scope `%s`, the actual rank '
'`%d` (shape = %s) is not equal to the expected rank `%s`' %
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
+
+
+def truncate_sequence(seq_emb, seq_len, limited_len):
+
+ def truncate(seq_embed, seq_length):
+ seq_embed = tf.slice(seq_embed, [0, 0, 0],
+ [shape[0], limited_len, shape[2]])
+ seq_length = tf.where(
+ tf.greater(seq_length, limited_len),
+ tf.ones_like(seq_length) * limited_len, seq_length)
+ return seq_embed, seq_length
+
+ def keep(seq_embed, seq_length):
+ return seq_embed, seq_length
+
+ shape = get_shape_list(seq_emb)
+ max_seq_len = shape[1]
+
+ return tf.cond(max_seq_len > limited_len, lambda: truncate(seq_emb, seq_len),
+ lambda: keep(seq_emb, seq_len))
+
+
+def pad_or_truncate_sequence(seq_emb, seq_len, fixed_len):
+ padding_length = fixed_len - tf.shape(seq_emb)[1]
+
+ def padding():
+ paddings = tf.stack([[0, 0], [0, padding_length], [0, 0]])
+ padded = tf.pad(seq_emb, paddings)
+ return padded, seq_len
+
+ def truncate():
+ sliced = tf.slice(seq_emb, [0, 0, 0], [-1, fixed_len, -1])
+ length = tf.where(seq_len < fixed_len, seq_len,
+ tf.ones_like(seq_len) *
+ fixed_len) if seq_len is not None else None
+ return sliced, length
+
+ def keep():
+ return seq_emb, seq_len
+
+ return tf.cond(padding_length > 0, padding,
+ lambda: tf.cond(padding_length < 0, truncate, keep))
diff --git a/easy_rec/python/utils/test_utils.py b/easy_rec/python/utils/test_utils.py
index 9d719507d..b71249ac8 100644
--- a/easy_rec/python/utils/test_utils.py
+++ b/easy_rec/python/utils/test_utils.py
@@ -16,6 +16,7 @@
import string
import subprocess
import time
+import six
from multiprocessing import Process
from subprocess import getstatusoutput
from tensorflow.python.platform import gfile
@@ -23,9 +24,14 @@
from easy_rec.python.protos.train_pb2 import DistributionStrategy
from easy_rec.python.utils import config_util
from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
+from easy_rec.python.utils.io_util import read_data_from_json_path
+from easy_rec.python.utils import constant
TEST_DIR = './tmp/easy_rec_test'
+# parallel run of tests could take more time
+TEST_TIME_OUT = int(os.environ.get('TEST_TIME_OUT', 1800))
+
def get_hdfs_tmp_dir(test_dir):
"""Create a randomly of directory in HDFS."""
@@ -37,17 +43,34 @@ def get_hdfs_tmp_dir(test_dir):
return test_rand_dir
+def proc_wait(proc, timeout=1200):
+ t0 = time.time()
+ while proc.poll() is None and time.time() - t0 < timeout:
+ time.sleep(1)
+ if proc.poll() is None:
+ logging.warning('proc[pid=%d] timeout[%d], will kill the proc' %
+ (proc.pid, timeout))
+ proc.terminate()
+ while proc.poll() is None:
+ time.sleep(1)
+
+
def get_tmp_dir():
- tmp_name = ''.join(
- [random.choice(string.ascii_letters + string.digits) for i in range(8)])
- if os.environ.get('TEST_DIR', '') != '':
- global TEST_DIR
- TEST_DIR = os.environ['TEST_DIR']
- dir_name = os.path.join(TEST_DIR, tmp_name)
- if os.path.exists(dir_name):
- shutil.rmtree(dir_name)
- os.makedirs(dir_name)
- return dir_name
+ max_retry = 5
+ while max_retry > 0:
+ tmp_name = ''.join([
+ random.choice(string.ascii_letters + string.digits) for i in range(12)
+ ])
+ if os.environ.get('TEST_DIR', '') != '':
+ global TEST_DIR
+ TEST_DIR = os.environ['TEST_DIR']
+ dir_name = os.path.join(TEST_DIR, tmp_name)
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
+ return dir_name
+ else:
+ max_retry -= 1
+ raise RuntimeError('Failed to get_tmp_dir: max_retry=%d' % max_retry)
def clear_all_tmp_dirs():
@@ -72,12 +95,17 @@ def get_available_gpus():
return gpus
-def run_cmd(cmd_str, log_file):
+def run_cmd(cmd_str, log_file, env=None):
"""Run a shell cmd."""
- logging.info('%s > %s 2>&1 ' % (cmd_str, log_file))
+ cmd_str = cmd_str.replace('\r', ' ').replace('\n', ' ')
+ logging.info('RUNCMD: %s > %s 2>&1 ' % (cmd_str, log_file))
with open(log_file, 'w') as lfile:
- return subprocess.Popen(
- cmd_str.split(), stdout=lfile, stderr=subprocess.STDOUT)
+ proc = subprocess.Popen(
+ cmd_str, stdout=lfile, stderr=subprocess.STDOUT, shell=True, env=env)
+ if six.PY2:
+ # for debug purpose
+ proc.args = cmd_str
+ return proc
def RunAsSubprocess(f):
@@ -132,7 +160,10 @@ def _replace_data_for_test(data_path):
return data_path
-def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
+def _load_config_for_test(pipeline_config_path,
+ test_dir,
+ total_steps=50,
+ num_epochs=0):
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path)
train_config = pipeline_config.train_config
@@ -141,17 +172,25 @@ def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
train_config.num_steps = total_steps
# change model_dir
- pipeline_config.model_dir = test_dir + '/train'
+ pipeline_config.model_dir = os.path.join(test_dir, 'train')
logging.info('test_model_dir %s' % pipeline_config.model_dir)
eval_config.num_examples = max(10, data_config.batch_size)
- data_config.num_epochs = 0
+ data_config.num_epochs = num_epochs
+ return pipeline_config
+
+
+def _load_config_for_distribute_eval(pipeline_config_path, test_dir):
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ pipeline_config_path)
+ pipeline_config.model_dir = test_dir
+ logging.info('test_model_dir %s' % pipeline_config.model_dir)
return pipeline_config
def test_datahub_train_eval(pipeline_config_path,
+ odps_oss_config,
test_dir,
process_pipeline_func=None,
- hyperparam_str='',
total_steps=50,
post_check_func=None):
gpus = get_available_gpus()
@@ -174,17 +213,32 @@ def test_datahub_train_eval(pipeline_config_path,
pipeline_config.train_config.train_distribute = 0
pipeline_config.train_config.num_gpus_per_worker = 1
pipeline_config.train_config.sync_replicas = False
+
+ pipeline_config.datahub_train_input.akId = odps_oss_config.dh_id
+ pipeline_config.datahub_train_input.akSecret = odps_oss_config.dh_key
+ pipeline_config.datahub_train_input.region = odps_oss_config.dh_endpoint
+ pipeline_config.datahub_train_input.project = odps_oss_config.dh_project
+ pipeline_config.datahub_train_input.topic = odps_oss_config.dh_topic
+
+ pipeline_config.datahub_eval_input.akId = odps_oss_config.dh_id
+ pipeline_config.datahub_eval_input.akSecret = odps_oss_config.dh_key
+ pipeline_config.datahub_eval_input.region = odps_oss_config.dh_endpoint
+ pipeline_config.datahub_eval_input.project = odps_oss_config.dh_project
+ pipeline_config.datahub_eval_input.topic = odps_oss_config.dh_topic
+
if process_pipeline_func is not None:
assert callable(process_pipeline_func)
pipeline_config = process_pipeline_func(pipeline_config)
config_util.save_pipeline_config(pipeline_config, test_dir)
test_pipeline_config_path = os.path.join(test_dir, 'pipeline.config')
- train_cmd = 'python3 -m easy_rec.python.train_eval --pipeline_config_path %s %s' % (
- test_pipeline_config_path, hyperparam_str)
+ train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s' % \
+ test_pipeline_config_path
proc = run_cmd(train_cmd, '%s/log_%s.txt' % (test_dir, 'master'))
- proc.wait()
+ proc_wait(proc, timeout=TEST_TIME_OUT)
if proc.returncode != 0:
- logging.error('train %s failed' % test_pipeline_config_path)
+ logging.warning(
+ 'train %s failed[pid=%d][code=%d][args=%s]' %
+ (test_pipeline_config_path, proc.pid, proc.returncode, proc.args))
return False
if post_check_func:
return post_check_func(pipeline_config)
@@ -202,7 +256,11 @@ def test_single_train_eval(pipeline_config_path,
process_pipeline_func=None,
hyperparam_str='',
total_steps=50,
- post_check_func=None):
+ post_check_func=None,
+ check_mode=False,
+ fine_tune_checkpoint=None,
+ extra_cmd_args=None,
+ timeout=-1):
gpus = get_available_gpus()
if len(gpus) > 0:
set_gpu_id(gpus[0])
@@ -228,10 +286,18 @@ def test_single_train_eval(pipeline_config_path,
pipeline_config = process_pipeline_func(pipeline_config)
config_util.save_pipeline_config(pipeline_config, test_dir)
test_pipeline_config_path = os.path.join(test_dir, 'pipeline.config')
- train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s %s' % (
- test_pipeline_config_path, hyperparam_str)
+ train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path=' + test_pipeline_config_path
+ if hyperparam_str:
+ train_cmd += ' --edit_config_json=\'%s\'' % hyperparam_str
+ if fine_tune_checkpoint:
+ train_cmd += ' --fine_tune_checkpoint %s' % fine_tune_checkpoint
+ if check_mode:
+ train_cmd += ' --check_mode'
+ if extra_cmd_args:
+ train_cmd += ' '
+ train_cmd += extra_cmd_args
proc = run_cmd(train_cmd, '%s/log_%s.txt' % (test_dir, 'master'))
- proc.wait()
+ proc_wait(proc, timeout=TEST_TIME_OUT if timeout < 0 else timeout)
if proc.returncode != 0:
logging.error('train %s failed' % test_pipeline_config_path)
return False
@@ -240,6 +306,72 @@ def test_single_train_eval(pipeline_config_path,
return True
+def test_single_pre_check(pipeline_config_path, test_dir):
+ gpus = get_available_gpus()
+ if len(gpus) > 0:
+ set_gpu_id(gpus[0])
+ else:
+ set_gpu_id(None)
+
+ if not isinstance(pipeline_config_path, EasyRecConfig):
+ logging.info('testing pipeline config %s' % pipeline_config_path)
+ if 'TF_CONFIG' in os.environ:
+ del os.environ['TF_CONFIG']
+
+ if isinstance(pipeline_config_path, EasyRecConfig):
+ pipeline_config = pipeline_config_path
+ else:
+ pipeline_config = _load_config_for_test(pipeline_config_path, test_dir)
+
+ pipeline_config.train_config.train_distribute = 0
+ pipeline_config.train_config.num_gpus_per_worker = 1
+ pipeline_config.train_config.sync_replicas = False
+
+ config_util.save_pipeline_config(pipeline_config, test_dir)
+ test_pipeline_config_path = os.path.join(test_dir, 'pipeline.config')
+ train_cmd = 'python -m easy_rec.python.tools.pre_check --pipeline_config_path %s ' % (
+ test_pipeline_config_path)
+
+ proc = run_cmd(train_cmd, '%s/log_%s.txt' % (test_dir, 'master'))
+ proc_wait(proc, timeout=TEST_TIME_OUT)
+ if proc.returncode != 0:
+ logging.error('train %s failed' % test_pipeline_config_path)
+ return False
+ return True
+
+
+def test_single_predict(test_dir, input_path, output_path, saved_model_dir):
+ gpus = get_available_gpus()
+ if len(gpus) > 0:
+ set_gpu_id(gpus[0])
+ else:
+ set_gpu_id(None)
+
+ predict_cmd = 'python -m easy_rec.python.predict --input_path %s --output_path %s --saved_model_dir %s' % (
+ input_path, output_path, saved_model_dir)
+
+ proc = run_cmd(predict_cmd, '%s/log_%s.txt' % (test_dir, 'master'))
+ proc_wait(proc, timeout=TEST_TIME_OUT)
+ if proc.returncode != 0:
+ logging.error('predict failed')
+ return False
+ return True
+
+
+def test_feature_selection(pipeline_config):
+ model_dir = pipeline_config.model_dir
+ pipeline_config_path = os.path.join(model_dir, 'pipeline.config')
+ output_dir = os.path.join(model_dir, 'feature_selection')
+ cmd = 'python -m easy_rec.python.tools.feature_selection --config_path %s ' \
+ '--output_dir %s --topk 5 --visualize true' % (pipeline_config_path, output_dir)
+ proc = run_cmd(cmd, os.path.join(model_dir, 'log_feature_selection.txt'))
+ proc_wait(proc, timeout=TEST_TIME_OUT)
+ if proc.returncode != 0:
+ logging.error('feature selection %s failed' % pipeline_config_path)
+ return False
+ return True
+
+
def yaml_replace(train_yaml_path,
pipline_config_path,
test_pipeline_config_path,
@@ -292,7 +424,7 @@ def test_hdfs_train_eval(pipeline_config_path,
logging.info('test_pipeline_config_path is %s' % test_pipeline_config_path)
train_cmd = 'el_submit -yaml %s' % train_yaml_path
proc = subprocess.Popen(train_cmd.split(), stderr=subprocess.STDOUT)
- proc.wait()
+ proc_wait(proc, timeout=TEST_TIME_OUT)
if proc.returncode != 0:
logging.error('train %s failed' % test_pipeline_config_path)
logging.error('train_yaml %s failed' % train_yaml_path)
@@ -324,7 +456,7 @@ def test_hdfs_eval(pipeline_config_path,
logging.info('test_pipeline_config_path is %s' % test_pipeline_config_path)
eval_cmd = 'el_submit -yaml %s' % eval_yaml_path
proc = subprocess.Popen(eval_cmd.split(), stderr=subprocess.STDOUT)
- proc.wait()
+ proc_wait(proc, timeout=TEST_TIME_OUT)
if proc.returncode != 0:
logging.error('eval %s failed' % test_pipeline_config_path)
logging.error('eval_yaml %s failed' % eval_yaml_path)
@@ -358,7 +490,7 @@ def test_hdfs_export(pipeline_config_path,
logging.info('test_pipeline_config_path is %s' % test_pipeline_config_path)
eval_cmd = 'el_submit -yaml %s' % export_yaml_path
proc = subprocess.Popen(eval_cmd.split(), stderr=subprocess.STDOUT)
- proc.wait()
+ proc_wait(proc, timeout=TEST_TIME_OUT)
if proc.returncode != 0:
logging.error('export %s failed' % test_pipeline_config_path)
logging.error('export_yaml %s failed' % export_yaml_path)
@@ -375,7 +507,7 @@ def _ports_in_use(ports):
return stat == 0
-def _get_ports(num_worker):
+def get_ports_base(num_worker):
port_base = int(os.environ.get('PORT_BASE', 10000))
num_try = 10
for i in range(num_try):
@@ -385,25 +517,48 @@ def _get_ports(num_worker):
logging.info('ports %s in use, retry...' % ports)
-def _ps_worker_train(pipeline_config_path, test_dir, num_worker):
+def _get_ports(num_worker):
+ # port queue to deals with port conflicts when multiple
+ # test cases run in parallel
+ if 'ports' in os.environ:
+ ports = os.environ['ports']
+ port_arr = [int(x) for x in ports.split(',')]
+ assert len(port_arr) >= num_worker, 'not enough ports: %s, required: %d'\
+ % (ports, num_worker)
+ return port_arr[:num_worker]
+ else:
+ return get_ports_base(num_worker)
+
+
+def _ps_worker_train(pipeline_config_path,
+ test_dir,
+ num_worker,
+ num_evaluator=0,
+ fit_on_eval=False,
+ fit_on_eval_steps=None):
gpus = get_available_gpus()
# not enough gpus, run on cpu only
if len(gpus) < num_worker:
gpus = [None] * num_worker
ports = _get_ports(num_worker + 1)
- tf_config = {
- 'cluster': {
- 'master': ['localhost:%d' % ports[0]],
- 'worker': ['localhost:%d' % ports[i] for i in range(1, num_worker)],
- 'ps': ['localhost:%d' % ports[-1]]
- }
+ chief_or_master = 'master' if num_evaluator == 0 else 'chief'
+ cluster = {
+ chief_or_master: ['localhost:%d' % ports[0]],
+ 'worker': ['localhost:%d' % ports[i] for i in range(1, num_worker)],
+ 'ps': ['localhost:%d' % ports[-1]]
}
+ tf_config = {'cluster': cluster}
procs = {}
- tf_config['task'] = {'type': 'master', 'index': 0}
+ tf_config['task'] = {'type': chief_or_master, 'index': 0}
os.environ['TF_CONFIG'] = json.dumps(tf_config)
set_gpu_id(gpus[0])
train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s' % pipeline_config_path
- procs['master'] = run_cmd(train_cmd, '%s/log_%s.txt' % (test_dir, 'master'))
+ if fit_on_eval:
+ train_cmd += ' --fit_on_eval'
+ if fit_on_eval_steps is not None:
+ train_cmd += ' --fit_on_eval_steps ' + str(int(fit_on_eval_steps))
+ procs[chief_or_master] = run_cmd(
+ train_cmd, '%s/log_%s.txt' % (test_dir, chief_or_master))
tf_config['task'] = {'type': 'ps', 'index': 0}
os.environ['TF_CONFIG'] = json.dumps(tf_config)
set_gpu_id('')
@@ -416,6 +571,63 @@ def _ps_worker_train(pipeline_config_path, test_dir, num_worker):
worker_name = 'worker_%d' % idx
procs[worker_name] = run_cmd(train_cmd,
'%s/log_%s.txt' % (test_dir, worker_name))
+ if num_evaluator > 0:
+ tf_config['task'] = {'type': 'evaluator', 'index': 0}
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
+ set_gpu_id('')
+ procs['evaluator'] = run_cmd(train_cmd,
+ '%s/log_%s.txt' % (test_dir, 'evaluator'))
+
+ return procs
+
+
+def _ps_worker_distribute_eval(pipeline_config_path,
+ checkpoint_path,
+ test_dir,
+ num_worker,
+ num_evaluator=0):
+ gpus = get_available_gpus()
+ # not enough gpus, run on cpu only
+ if len(gpus) < num_worker:
+ gpus = [None] * num_worker
+ ports = _get_ports(num_worker + 1)
+ chief_or_master = 'master' if num_evaluator == 0 else 'chief'
+ cluster = {
+ chief_or_master: ['localhost:%d' % ports[0]],
+ 'worker': ['localhost:%d' % ports[i] for i in range(1, num_worker)],
+ 'ps': ['localhost:%d' % ports[-1]]
+ }
+ tf_config = {'cluster': cluster}
+ procs = {}
+ tf_config['task'] = {'type': chief_or_master, 'index': 0}
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
+ os.environ[constant.SORT_COL_BY_NAME] = '1'
+ set_gpu_id(gpus[0])
+ train_cmd = 'python -m easy_rec.python.eval --pipeline_config_path {} --checkpoint_path {} \
+ --distribute_eval True --eval_result_path distribute_eval_result.txt'.format(
+ pipeline_config_path, checkpoint_path)
+ procs[chief_or_master] = run_cmd(
+ train_cmd, '%s/distribute_eval_log_%s.txt' % (test_dir, chief_or_master))
+ tf_config['task'] = {'type': 'ps', 'index': 0}
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
+ set_gpu_id('')
+ procs['ps'] = run_cmd(train_cmd,
+ '%s/distribute_eval_log_%s.txt' % (test_dir, 'ps'))
+
+ for idx in range(num_worker - 1):
+ tf_config['task'] = {'type': 'worker', 'index': idx}
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
+ set_gpu_id(gpus[idx + 1])
+ worker_name = 'worker_%d' % idx
+ procs[worker_name] = run_cmd(
+ train_cmd, '%s/distribute_eval_log_%s.txt' % (test_dir, worker_name))
+ if num_evaluator > 0:
+ tf_config['task'] = {'type': 'evaluator', 'index': 0}
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
+ set_gpu_id('')
+ procs['evaluator'] = run_cmd(
+ train_cmd, '%s/distribute_eval_log_%s.txt' % (test_dir, 'evaluator'))
+
return procs
@@ -442,10 +654,46 @@ def _multi_worker_mirror_train(pipeline_config_path, test_dir, num_worker):
return procs
-def test_distributed_train_eval(pipeline_config_path, test_dir, total_steps=50):
+def _multi_worker_hvd_train(pipeline_config_path, test_dir, num_worker):
+ gpus = get_available_gpus()
+ # not enough gpus, run on cpu only
+ if len(gpus) < num_worker:
+ gpus = ''
+ else:
+ gpus = ','.join(gpus)
+ set_gpu_id(gpus)
+ ports = _get_ports(num_worker)
+ hosts = ','.join(['localhost:%d' % ports[i] for i in range(num_worker)])
+ train_cmd = 'horovodrun -np %d --hosts %s python -m easy_rec.python.train_eval --pipeline_config_path %s' % (
+ num_worker, hosts, pipeline_config_path)
+ proc = run_cmd(train_cmd, '%s/log_hvd.txt' % test_dir)
+ proc_wait(proc, timeout=1200)
+ return proc.returncode == 0
+
+
+def test_distributed_train_eval(pipeline_config_path,
+ test_dir,
+ total_steps=50,
+ num_evaluator=0,
+ edit_config_json=None,
+ use_hvd=False,
+ fit_on_eval=False,
+ num_epoch=0):
logging.info('testing pipeline config %s' % pipeline_config_path)
pipeline_config = _load_config_for_test(pipeline_config_path, test_dir,
- total_steps)
+ total_steps, num_epoch)
+ if edit_config_json is not None:
+ config_util.edit_config(pipeline_config, edit_config_json)
+
+ if use_hvd:
+ pipeline_config.train_config.sync_replicas = False
+ if pipeline_config.train_config.train_distribute not in [
+ DistributionStrategy.EmbeddingParallelStrategy,
+ DistributionStrategy.SokStrategy
+ ]:
+ pipeline_config.train_config.train_distribute =\
+ DistributionStrategy.HorovodStrategy
+
train_config = pipeline_config.train_config
config_util.save_pipeline_config(pipeline_config, test_dir)
test_pipeline_config_path = os.path.join(test_dir, 'pipeline.config')
@@ -453,9 +701,17 @@ def test_distributed_train_eval(pipeline_config_path, test_dir, total_steps=50):
task_failed = None
procs = None
try:
+ if use_hvd:
+ return _multi_worker_hvd_train(test_pipeline_config_path, test_dir, 2)
if train_config.train_distribute == DistributionStrategy.NoStrategy:
num_worker = 2
- procs = _ps_worker_train(test_pipeline_config_path, test_dir, num_worker)
+ procs = _ps_worker_train(
+ test_pipeline_config_path,
+ test_dir,
+ num_worker,
+ num_evaluator,
+ fit_on_eval,
+ fit_on_eval_steps=int(total_steps // 2))
elif train_config.train_distribute == DistributionStrategy.MultiWorkerMirroredStrategy:
num_worker = 2
procs = _multi_worker_mirror_train(test_pipeline_config_path, test_dir,
@@ -509,3 +765,102 @@ def test_distributed_train_eval(pipeline_config_path, test_dir, total_steps=50):
logging.error('train %s failed' % pipeline_config_path)
return task_failed is None
+
+
+def test_distribute_eval_test(cur_eval_path, test_dir):
+ single_work_eval_path = os.path.join(cur_eval_path, 'eval_result.txt')
+ distribute_eval_path = os.path.join(test_dir, 'distribute_eval_result.txt')
+ if not os.path.exists(distribute_eval_path):
+ return False
+ single_data = read_data_from_json_path(single_work_eval_path)
+ distribute_data = read_data_from_json_path(distribute_eval_path)
+ single_ret = {
+ k: single_data[k]
+ for k in single_data.keys()
+ if 'loss' not in k and 'step' not in k
+ }
+ distribute_ret = {
+ k: distribute_data[k] for k in distribute_data.keys() if 'loss' not in k
+ }
+ difference_num = 0.00001
+ for k in single_ret.keys():
+ if (abs(single_ret[k] - distribute_ret[k]) > difference_num):
+ logging.error(
+ 'distribute_eval difference[%.8f] large than threshold[%.8f]' %
+ (abs(single_ret[k] - distribute_ret[k]), difference_num))
+ return False
+ return True
+
+
+def test_distributed_eval(pipeline_config_path,
+ checkpoint_path,
+ test_dir,
+ total_steps=50,
+ num_evaluator=0):
+ logging.info('testing pipeline config %s' % pipeline_config_path)
+ pipeline_config = _load_config_for_distribute_eval(pipeline_config_path,
+ test_dir)
+ train_config = pipeline_config.train_config
+ config_util.save_pipeline_config(pipeline_config, test_dir)
+ test_pipeline_config_path = os.path.join(test_dir, 'pipeline.config')
+
+ task_failed = None
+ procs = None
+ is_equal = False
+ try:
+ if train_config.train_distribute == DistributionStrategy.NoStrategy:
+ num_worker = 2
+ procs = _ps_worker_distribute_eval(test_pipeline_config_path,
+ checkpoint_path, test_dir, num_worker,
+ num_evaluator)
+ else:
+ raise NotImplementedError
+
+ # print proc info
+ assert len(procs) > 0, 'processes are empty'
+ for k, proc in procs.items():
+ logging.info('%s pid: %d' % (k, proc.pid))
+ task_finish_cnt = 0
+ task_has_finished = {k: False for k in procs.keys()}
+ while True:
+ for k, proc in procs.items():
+ if proc.poll() is None:
+ if task_failed is not None:
+ logging.error('task %s failed, %s quit' % (task_failed, k))
+ proc.terminate()
+ if k != 'ps':
+ task_has_finished[k] = True
+ task_finish_cnt += 1
+ logging.info('task_finish_cnt %d' % task_finish_cnt)
+ else:
+ if not task_has_finished[k]:
+ # process quit by itself
+ if k != 'ps':
+ task_finish_cnt += 1
+ task_has_finished[k] = True
+ logging.info('task_finish_cnt %d' % task_finish_cnt)
+ if proc.returncode != 0:
+ logging.error('%s failed' % k)
+ task_failed = k
+ else:
+ logging.info('%s run successfuly' % k)
+ if task_finish_cnt >= num_worker:
+ break
+ time.sleep(1)
+
+ is_equal = test_distribute_eval_test(checkpoint_path, test_dir)
+
+ except Exception as e:
+ logging.error('Exception: ' + str(e))
+ raise e
+ finally:
+ if procs is not None:
+ for k, proc in procs.items():
+ if proc.poll() is None:
+ logging.info('terminate %s' % k)
+ proc.terminate()
+ if task_failed is not None:
+ logging.error('eval %s failed[%s]' % (pipeline_config_path, task_failed))
+
+ eval_success = (task_failed is None) and is_equal
+ return eval_success
diff --git a/easy_rec/python/utils/tf_utils.py b/easy_rec/python/utils/tf_utils.py
new file mode 100644
index 000000000..24f47a94a
--- /dev/null
+++ b/easy_rec/python/utils/tf_utils.py
@@ -0,0 +1,56 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Common functions used for odps input."""
+import json
+import os
+
+import tensorflow as tf
+
+from easy_rec.python.protos.dataset_pb2 import DatasetConfig
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+def get_ps_num_from_tf_config():
+ tf_config = os.environ.get('TF_CONFIG')
+ if tf_config:
+ tf_config_json = json.loads(tf_config)
+ cluster = tf_config_json.get('cluster', {})
+ ps_hosts = cluster.get('ps', [])
+ return len(ps_hosts)
+ return 0
+
+
+def get_tf_type(field_type):
+ type_map = {
+ DatasetConfig.INT32: tf.int32,
+ DatasetConfig.INT64: tf.int64,
+ DatasetConfig.STRING: tf.string,
+ DatasetConfig.BOOL: tf.bool,
+ DatasetConfig.FLOAT: tf.float32,
+ DatasetConfig.DOUBLE: tf.double
+ }
+ assert field_type in type_map, 'invalid type: %s' % field_type
+ return type_map[field_type]
+
+
+def get_col_type(tf_type):
+ type_map = {
+ tf.int32: 'BIGINT',
+ tf.int64: 'BIGINT',
+ tf.string: 'STRING',
+ tf.float32: 'FLOAT',
+ tf.double: 'DOUBLE',
+ tf.bool: 'BOOLEAN'
+ }
+ assert tf_type in type_map, 'invalid type: %s' % tf_type
+ return type_map[tf_type]
+
+
+def add_elements_to_collection(elements, name):
+ collection = tf.get_collection_ref(name)
+ collection_set = set(collection)
+ for element in elements:
+ if element not in collection_set:
+ collection.append(element)
diff --git a/easy_rec/version.py b/easy_rec/version.py
index 4dbed0c5a..fa13b61b9 100644
--- a/easy_rec/version.py
+++ b/easy_rec/version.py
@@ -1,3 +1,4 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
-__version__ = '0.3.1'
+
+__version__ = '0.8.6'
diff --git a/examples/configs/autoint_on_movielens.config b/examples/configs/autoint_on_movielens.config
new file mode 100644
index 000000000..cbf43729f
--- /dev/null
+++ b/examples/configs/autoint_on_movielens.config
@@ -0,0 +1,161 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/autoint_on_movieslen_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 1000
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_class: 'AutoInt'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ autoint {
+ multi_head_num: 2
+ multi_head_size: 32
+ interacting_layer_num: 3
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
+export_config {
+ multi_placeholder: true
+}
diff --git a/examples/configs/contrastive_learning_on_movielens.config b/examples/configs/contrastive_learning_on_movielens.config
new file mode 100644
index 000000000..11e45d317
--- /dev/null
+++ b/examples/configs/contrastive_learning_on_movielens.config
@@ -0,0 +1,260 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/contrastive_on_movieslen"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_name: "multi tower"
+ model_class: "RankModel"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'movie_id'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'user_tower'
+ inputs {
+ feature_group_name: 'user'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ packages {
+ name: 'item_tower'
+ blocks {
+ name: 'item'
+ inputs {
+ feature_group_name: 'item'
+ }
+ input_layer {
+ dropout_rate: 0.1
+ }
+ }
+ blocks {
+ name: 'item_encoder'
+ inputs {
+ block_name: 'item'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'contrastive_learning'
+ inputs {
+ package_name: 'item_tower'
+ }
+ inputs {
+ package_name: 'item_tower'
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: 'AuxiliaryLoss'
+ st_params {
+ fields {
+ key: 'loss_type'
+ value: { string_value: 'info_nce' }
+ }
+ fields {
+ key: 'loss_weight'
+ value: { number_value: 0.1 }
+ }
+ fields {
+ key: 'temperature'
+ value: { number_value: 0.2 }
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'top_mlp'
+ inputs {
+ block_name: 'contrastive_learning'
+ ignore_input: true
+ }
+ inputs {
+ block_name: 'user_tower'
+ }
+ inputs {
+ package_name: 'item_tower'
+ reset_input {
+ }
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 64]
+ }
+ }
+ }
+ concat_blocks: 'top_mlp'
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/dcn_backbone_on_movielens.config b/examples/configs/dcn_backbone_on_movielens.config
new file mode 100644
index 000000000..7be038dbf
--- /dev/null
+++ b/examples/configs/dcn_backbone_on_movielens.config
@@ -0,0 +1,201 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/dcn_backbone_on_movieslen"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: false
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_name: 'DCN v2'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "deep"
+ inputs {
+ feature_group_name: 'all'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ blocks {
+ name: "dcn"
+ inputs {
+ feature_group_name: 'all'
+ input_fn: 'lambda x: [x, x]'
+ }
+ recurrent {
+ num_steps: 3
+ fixed_input_index: 0
+ keras_layer {
+ class_name: 'Cross'
+ }
+ }
+ }
+ concat_blocks: ['deep', 'dcn']
+ top_mlp {
+ hidden_units: [64, 32, 16]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/dcn_on_movielens.config b/examples/configs/dcn_on_movielens.config
new file mode 100644
index 000000000..09110c81d
--- /dev/null
+++ b/examples/configs/dcn_on_movielens.config
@@ -0,0 +1,182 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/dcn_on_movieslen_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_class: 'DCN'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ dcn {
+ deep_tower {
+ input: "all"
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ cross_tower {
+ input: "all"
+ cross_num: 5
+ }
+ final_dnn {
+ hidden_units: [64, 32, 16]
+ }
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/deepfm_backbone_on_criteo.config b/examples/configs/deepfm_backbone_on_criteo.config
new file mode 100644
index 000000000..25fc5cfc6
--- /dev/null
+++ b/examples/configs/deepfm_backbone_on_criteo.config
@@ -0,0 +1,643 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/deepfm_backbone_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ embedding_dim: 16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_name: 'DeepFM'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "deep_features"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide_features"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:WIDE
+ }
+ backbone {
+ blocks {
+ name: 'wide_features'
+ inputs {
+ feature_group_name: 'wide_features'
+ }
+ input_layer {
+ wide_output_dim: 1
+ }
+ }
+ blocks {
+ name: 'wide_logit'
+ inputs {
+ block_name: 'wide_features'
+ }
+ lambda {
+ expression: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
+ }
+ }
+ blocks {
+ name: 'deep_features'
+ inputs {
+ feature_group_name: 'deep_features'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'fm'
+ inputs {
+ block_name: 'deep_features'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ st_params {
+ fields {
+ key: 'use_variant'
+ value { bool_value: true }
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'deep'
+ inputs {
+ block_name: 'deep_features'
+ input_slice: '[0]'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ concat_blocks: ['wide_logit', 'fm', 'deep']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/deepfm_backbone_on_criteo_with_autodis.config b/examples/configs/deepfm_backbone_on_criteo_with_autodis.config
new file mode 100644
index 000000000..e0c6ccb43
--- /dev/null
+++ b/examples/configs/deepfm_backbone_on_criteo_with_autodis.config
@@ -0,0 +1,759 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/deepfm_autodis_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ feature_name: "D1"
+ input_names: "F1"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ feature_name: "D2"
+ input_names: "F2"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ feature_name: "D3"
+ input_names: "F3"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ feature_name: "D4"
+ input_names: "F4"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ feature_name: "D5"
+ input_names: "F5"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ feature_name: "D6"
+ input_names: "F6"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ feature_name: "D7"
+ input_names: "F7"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ feature_name: "D8"
+ input_names: "F8"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ feature_name: "D9"
+ input_names: "F9"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ feature_name: "D10"
+ input_names: "F10"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ feature_name: "D11"
+ input_names: "F11"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ feature_name: "D12"
+ input_names: "F12"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ feature_name: "D13"
+ input_names: "F13"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+}
+model_config: {
+ model_name: 'DeepFM with AutoDis'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "numerical_features"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "categorical_features"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide_features"
+ feature_names: "D1"
+ feature_names: "D2"
+ feature_names: "D3"
+ feature_names: "D4"
+ feature_names: "D5"
+ feature_names: "D6"
+ feature_names: "D7"
+ feature_names: "D8"
+ feature_names: "D9"
+ feature_names: "D10"
+ feature_names: "D11"
+ feature_names: "D12"
+ feature_names: "D13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:WIDE
+ }
+ backbone {
+ blocks {
+ name: 'wide_features'
+ inputs {
+ feature_group_name: 'wide_features'
+ }
+ input_layer {
+ wide_output_dim: 1
+ }
+ }
+ blocks {
+ name: 'wide_logit'
+ inputs {
+ block_name: 'wide_features'
+ }
+ lambda {
+ expression: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
+ }
+ }
+ blocks {
+ name: 'num_emb'
+ inputs {
+ feature_group_name: 'numerical_features'
+ }
+ keras_layer {
+ class_name: 'AutoDisEmbedding'
+ auto_dis_embedding {
+ embedding_dim: 16
+ num_bins: 20
+ temperature: 0.815
+ output_tensor_list: true
+ }
+ }
+ }
+ blocks {
+ name: 'categorical_features'
+ inputs {
+ feature_group_name: 'categorical_features'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'fm'
+ inputs {
+ block_name: 'categorical_features'
+ input_slice: '[1]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ fm {
+ use_variant: true
+ }
+ }
+ }
+ blocks {
+ name: 'deep'
+ inputs {
+ block_name: 'categorical_features'
+ input_slice: '[0]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[0]'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ # no wide_logit may have better performance
+ concat_blocks: ['wide_logit', 'fm', 'deep']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/deepfm_backbone_on_criteo_with_periodic.config b/examples/configs/deepfm_backbone_on_criteo_with_periodic.config
new file mode 100644
index 000000000..06753ad2c
--- /dev/null
+++ b/examples/configs/deepfm_backbone_on_criteo_with_periodic.config
@@ -0,0 +1,757 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/deepfm_periodic_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ feature_name: "D1"
+ input_names: "F1"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ feature_name: "D2"
+ input_names: "F2"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ feature_name: "D3"
+ input_names: "F3"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ feature_name: "D4"
+ input_names: "F4"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ feature_name: "D5"
+ input_names: "F5"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ feature_name: "D6"
+ input_names: "F6"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ feature_name: "D7"
+ input_names: "F7"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ feature_name: "D8"
+ input_names: "F8"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ feature_name: "D9"
+ input_names: "F9"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ feature_name: "D10"
+ input_names: "F10"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ feature_name: "D11"
+ input_names: "F11"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ feature_name: "D12"
+ input_names: "F12"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ feature_name: "D13"
+ input_names: "F13"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+}
+model_config: {
+ model_name: 'DeepFM with Periodic'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "numerical_features"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "categorical_features"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide_features"
+ feature_names: "D1"
+ feature_names: "D2"
+ feature_names: "D3"
+ feature_names: "D4"
+ feature_names: "D5"
+ feature_names: "D6"
+ feature_names: "D7"
+ feature_names: "D8"
+ feature_names: "D9"
+ feature_names: "D10"
+ feature_names: "D11"
+ feature_names: "D12"
+ feature_names: "D13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:WIDE
+ }
+ backbone {
+ blocks {
+ name: 'wide_features'
+ inputs {
+ feature_group_name: 'wide_features'
+ }
+ input_layer {
+ wide_output_dim: 1
+ }
+ }
+ blocks {
+ name: 'wide_logit'
+ inputs {
+ block_name: 'wide_features'
+ }
+ lambda {
+ expression: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
+ }
+ }
+ blocks {
+ name: 'num_emb'
+ inputs {
+ feature_group_name: 'numerical_features'
+ }
+ keras_layer {
+ class_name: 'PeriodicEmbedding'
+ periodic_embedding {
+ embedding_dim: 16
+ sigma: 0.005
+ output_tensor_list: true
+ }
+ }
+ }
+ blocks {
+ name: 'categorical_features'
+ inputs {
+ feature_group_name: 'categorical_features'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'fm'
+ inputs {
+ block_name: 'categorical_features'
+ input_slice: '[1]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ fm {
+ use_variant: true
+ }
+ }
+ }
+ blocks {
+ name: 'deep'
+ inputs {
+ block_name: 'categorical_features'
+ input_slice: '[0]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[0]'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ concat_blocks: ['wide_logit', 'fm', 'deep']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/deepfm_backbone_on_movielens.config b/examples/configs/deepfm_backbone_on_movielens.config
new file mode 100644
index 000000000..5e6ea9b8d
--- /dev/null
+++ b/examples/configs/deepfm_backbone_on_movielens.config
@@ -0,0 +1,243 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/deepfm_backbone_movieslen"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [8, 4, 4]
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_name: 'DeepFM'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'wide'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: WIDE
+ }
+ feature_groups: {
+ group_name: 'features'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ feature_names: 'title'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'wide'
+ inputs {
+ feature_group_name: 'wide'
+ }
+ input_layer {
+ wide_output_dim: 1
+ }
+ }
+ blocks {
+ name: 'features'
+ inputs {
+ feature_group_name: 'features'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'fm'
+ inputs {
+ block_name: 'features'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ }
+ }
+ blocks {
+ name: 'deep'
+ inputs {
+ block_name: 'features'
+ input_slice: '[0]'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64, 1]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+ blocks {
+ name: 'add'
+ inputs {
+ block_name: 'wide'
+ input_fn: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
+ }
+ inputs {
+ block_name: 'fm'
+ }
+ inputs {
+ block_name: 'deep'
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: 'Add'
+ }
+ }
+ concat_blocks: 'add'
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/deepfm_on_criteo.config b/examples/configs/deepfm_on_criteo.config
new file mode 100644
index 000000000..fc8537f0d
--- /dev/null
+++ b/examples/configs/deepfm_on_criteo.config
@@ -0,0 +1,589 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/deepfm_criteo_ckpt"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 1000
+ sync_replicas: True
+ num_steps: 20000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ embedding_dim:16
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_class: 'DeepFM'
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ final_dnn {
+ hidden_units: [256, 128, 64]
+ }
+ wide_regularization: 1e-4
+ dense_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/deepfm_on_movielens.config b/examples/configs/deepfm_on_movielens.config
new file mode 100644
index 000000000..a49a1988c
--- /dev/null
+++ b/examples/configs/deepfm_on_movielens.config
@@ -0,0 +1,184 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/deepfm_movieslen_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [8, 4, 4]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_class: 'DeepFM'
+ feature_groups: {
+ group_name: 'wide'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: WIDE
+ }
+ feature_groups: {
+ group_name: 'deep'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ feature_names: 'title'
+ wide_deep: DEEP
+ }
+ deepfm {
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/dlrm_backbone_on_criteo.config b/examples/configs/dlrm_backbone_on_criteo.config
new file mode 100644
index 000000000..bb7b2a92f
--- /dev/null
+++ b/examples/configs/dlrm_backbone_on_criteo.config
@@ -0,0 +1,578 @@
+# align with raw dlrm model
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/dlrm_backbone_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_name: 'DLRM'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "dense"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "sparse"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'bottom_mlp'
+ inputs {
+ feature_group_name: 'dense'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [64, 32, 16]
+ }
+ }
+ }
+ blocks {
+ name: 'sparse'
+ inputs {
+ feature_group_name: 'sparse'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'dot'
+ inputs {
+ block_name: 'bottom_mlp'
+ input_fn: 'lambda x: [x]'
+ }
+ inputs {
+ block_name: 'sparse'
+ input_fn: 'lambda x: x[1]'
+ }
+ keras_layer {
+ class_name: 'DotInteraction'
+ }
+ }
+ blocks {
+ name: 'sparse_2d'
+ inputs {
+ block_name: 'sparse'
+ input_fn: 'lambda x: x[0]'
+ }
+ }
+ concat_blocks: ['sparse_2d', 'dot']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/dlrm_on_criteo.config b/examples/configs/dlrm_on_criteo.config
new file mode 100644
index 000000000..e6c45d574
--- /dev/null
+++ b/examples/configs/dlrm_on_criteo.config
@@ -0,0 +1,534 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/dlrm_criteo_ckpt"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_class: 'DLRM'
+ feature_groups: {
+ group_name: "dense"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "sparse"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ dlrm {
+ bot_dnn {
+ hidden_units: [64, 32, 16]
+ }
+ top_dnn {
+ hidden_units: [256, 128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/dlrm_on_criteo_with_autodis.config b/examples/configs/dlrm_on_criteo_with_autodis.config
new file mode 100644
index 000000000..53de6a279
--- /dev/null
+++ b/examples/configs/dlrm_on_criteo_with_autodis.config
@@ -0,0 +1,587 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/dlrm_autodis_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_name: 'DLRM with autodis'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "dense"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "sparse"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'num_emb'
+ inputs {
+ feature_group_name: 'dense'
+ }
+ keras_layer {
+ class_name: 'AutoDisEmbedding'
+ auto_dis_embedding {
+ embedding_dim: 16
+ num_bins: 40
+ temperature: 0.815
+ output_tensor_list: true
+ }
+ }
+ }
+ blocks {
+ name: 'sparse'
+ inputs {
+ feature_group_name: 'sparse'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'dot'
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[1]'
+ }
+ inputs {
+ block_name: 'sparse'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'DotInteraction'
+ }
+ }
+ blocks {
+ name: 'sparse_2d'
+ inputs {
+ block_name: 'sparse'
+ input_slice: '[0]'
+ }
+ }
+ blocks {
+ name: 'num_emb_2d'
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[0]'
+ }
+ }
+ concat_blocks: ['num_emb_2d', 'dot', 'sparse_2d']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/dlrm_on_criteo_with_narydis.config b/examples/configs/dlrm_on_criteo_with_narydis.config
new file mode 100644
index 000000000..121fd98c2
--- /dev/null
+++ b/examples/configs/dlrm_on_criteo_with_narydis.config
@@ -0,0 +1,587 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/dlrm_narydis_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_name: 'DLRM with autodis'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "dense"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "sparse"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'num_emb'
+ inputs {
+ feature_group_name: 'dense'
+ }
+ keras_layer {
+ class_name: 'NaryDisEmbedding'
+ nary_dis_embedding {
+ embedding_dim: 8
+ carries: [2, 9]
+ multiplier: 1e6
+ output_tensor_list: true
+ }
+ }
+ }
+ blocks {
+ name: 'sparse'
+ inputs {
+ feature_group_name: 'sparse'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'dot'
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[1]'
+ }
+ inputs {
+ block_name: 'sparse'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'DotInteraction'
+ }
+ }
+ blocks {
+ name: 'sparse_2d'
+ inputs {
+ block_name: 'sparse'
+ input_slice: '[0]'
+ }
+ }
+ blocks {
+ name: 'num_emb_2d'
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[0]'
+ }
+ }
+ concat_blocks: ['num_emb_2d', 'dot', 'sparse_2d']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/dlrm_on_criteo_with_periodic.config b/examples/configs/dlrm_on_criteo_with_periodic.config
new file mode 100644
index 000000000..36c120e95
--- /dev/null
+++ b/examples/configs/dlrm_on_criteo_with_periodic.config
@@ -0,0 +1,595 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/dlrm_periodic_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_name: 'dlrm with periodic'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "dense"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "sparse"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'num_emb'
+ inputs {
+ feature_group_name: 'dense'
+ }
+ keras_layer {
+ class_name: 'PeriodicEmbedding'
+ st_params {
+ fields {
+ key: "output_tensor_list"
+ value { bool_value: true }
+ }
+ fields {
+ key: "embedding_dim"
+ value { number_value: 16 }
+ }
+ fields {
+ key: "sigma"
+ value { number_value: 0.005 }
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'sparse'
+ inputs {
+ feature_group_name: 'sparse'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'dot'
+ inputs {
+ block_name: 'num_emb'
+ input_slice: '[1]'
+ }
+ inputs {
+ block_name: 'sparse'
+ input_fn: 'lambda x: x[1]'
+ }
+ keras_layer {
+ class_name: 'DotInteraction'
+ }
+ }
+ blocks {
+ name: 'sparse_2d'
+ inputs {
+ block_name: 'sparse'
+ input_slice: '[0]'
+ }
+ }
+ blocks {
+ name: 'num_emb_2d'
+ inputs {
+ block_name: 'num_emb'
+ input_fn: 'lambda x: x[0]'
+ }
+ }
+ concat_blocks: ['num_emb_2d', 'dot', 'sparse_2d']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/dlrm_senet_on_criteo.config b/examples/configs/dlrm_senet_on_criteo.config
new file mode 100644
index 000000000..961f41e73
--- /dev/null
+++ b/examples/configs/dlrm_senet_on_criteo.config
@@ -0,0 +1,583 @@
+# align with raw dlrm model
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/dlrm_senet_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_name: 'DLRM'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "dense"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "sparse"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'bottom_mlp'
+ inputs {
+ feature_group_name: 'dense'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [64, 32, 16]
+ }
+ }
+ }
+ blocks {
+ name: 'sparse'
+ inputs {
+ feature_group_name: 'sparse'
+ }
+ input_layer {
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: 'senet'
+ inputs {
+ block_name: 'sparse'
+ }
+ keras_layer {
+ class_name: 'SENet'
+ senet {
+ reduction_ratio: 4
+ }
+ }
+ }
+ blocks {
+ name: 'dot'
+ inputs {
+ block_name: 'bottom_mlp'
+ input_fn: 'lambda x: [x]'
+ }
+ inputs {
+ block_name: 'senet'
+ input_fn: 'lambda x: tf.unstack(tf.reshape(x, [-1, 26, 16]), axis=1)'
+ }
+ keras_layer {
+ class_name: 'DotInteraction'
+ }
+ }
+ concat_blocks: ['senet', 'dot']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/dlrm_standard_on_criteo.config b/examples/configs/dlrm_standard_on_criteo.config
new file mode 100644
index 000000000..720560693
--- /dev/null
+++ b/examples/configs/dlrm_standard_on_criteo.config
@@ -0,0 +1,568 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/dlrm_standard_criteo"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 20000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_name: 'Stardard DLRM'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "dense"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "sparse"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'bottom_mlp'
+ inputs {
+ feature_group_name: 'dense'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [64, 32, 16]
+ }
+ }
+ }
+ blocks {
+ name: 'sparse'
+ inputs {
+ feature_group_name: 'sparse'
+ }
+ input_layer {
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: 'dot'
+ inputs {
+ block_name: 'bottom_mlp'
+ }
+ inputs {
+ block_name: 'sparse'
+ }
+ keras_layer {
+ class_name: 'DotInteraction'
+ }
+ }
+ concat_blocks: ['bottom_mlp', 'dot']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/dssm_on_books.config b/examples/configs/dssm_on_books.config
new file mode 100644
index 000000000..eebcdb295
--- /dev/null
+++ b/examples/configs/dssm_on_books.config
@@ -0,0 +1,114 @@
+train_input_path: "examples/data/amazon_books_data/amazon_train_data"
+eval_input_path: "examples/data/amazon_books_data/amazon_test_data"
+model_dir: "examples/ckpt/dssm_book_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ num_steps: 20000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name:'book_id_seq'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'book_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'label'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: "\t"
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500000
+ }
+ features: {
+ input_names: 'book_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 400000
+ }
+ features: {
+ input_names: 'book_id_seq'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 400000
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ allow_key_search: true
+ seq_att_map: {
+ key: "book_id"
+ hist_seq: "book_id_seq"
+ }
+ }
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ item_tower {
+ id: "book_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
diff --git a/examples/configs/dssm_on_books_negative_sample.config b/examples/configs/dssm_on_books_negative_sample.config
new file mode 100644
index 000000000..8e3fe87f1
--- /dev/null
+++ b/examples/configs/dssm_on_books_negative_sample.config
@@ -0,0 +1,129 @@
+train_input_path: "examples/data/amazon_books_data/amazon_train_data"
+eval_input_path: "examples/data/amazon_books_data/amazon_test_data"
+model_dir: "examples/ckpt/dssm_book_negative_sample_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ num_steps: 20000
+}
+
+eval_config {
+ metrics_set {
+ recall_at_topk { topk: 1 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 10 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 100 }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name:'book_id_seq'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'book_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'label'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: "\t"
+
+ negative_sampler {
+ input_path: 'examples/data/amazon_books_data/negative_book_data'
+ num_sample: 1024
+ num_eval_sample: 1024
+ attr_fields: 'book_id'
+ item_id_field: 'book_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500000
+ }
+ features: {
+ input_names: 'book_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 400000
+ }
+ features: {
+ input_names: 'book_id_seq'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 400000
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ allow_key_search: true
+ need_key_feature:true
+ seq_att_map: {
+ key: "book_id"
+ hist_seq: "book_id_seq"
+ }
+ }
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ item_tower {
+ id: "book_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-6
+}
diff --git a/examples/configs/dssm_senet_on_taobao.config b/examples/configs/dssm_senet_on_taobao.config
new file mode 100644
index 000000000..7b8e0da1c
--- /dev/null
+++ b/examples/configs/dssm_senet_on_taobao.config
@@ -0,0 +1,283 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_senet_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: false
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_class: "DSSM_SENet"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ dssm_senet {
+ user_tower {
+ id: "user_id"
+ senet {
+ num_squeeze_group : 2
+ reduction_ratio: 4
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ senet {
+ num_squeeze_group : 2
+ reduction_ratio: 4
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
diff --git a/examples/configs/fibinet_on_movielens.config b/examples/configs/fibinet_on_movielens.config
new file mode 100644
index 000000000..b4ecaf613
--- /dev/null
+++ b/examples/configs/fibinet_on_movielens.config
@@ -0,0 +1,202 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/fibinet_on_movieslen_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: False
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_name: 'FiBiNet'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "all"
+ inputs {
+ feature_group_name: "all"
+ }
+ input_layer {
+ do_batch_norm: true
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: "fibinet"
+ inputs {
+ block_name: "all"
+ }
+ keras_layer {
+ class_name: 'FiBiNet'
+ fibinet {
+ senet {
+ reduction_ratio: 4
+ }
+ bilinear {
+ type: 'each'
+ num_output_units: 512
+ }
+ mlp {
+ hidden_units: [512, 256]
+ }
+ }
+ }
+ }
+ concat_blocks: ['fibinet']
+ }
+ model_params {
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/fm_on_criteo.config b/examples/configs/fm_on_criteo.config
new file mode 100644
index 000000000..a93cda194
--- /dev/null
+++ b/examples/configs/fm_on_criteo.config
@@ -0,0 +1,580 @@
+train_input_path: "examples/data/criteo/criteo_train_data"
+eval_input_path: "examples/data/criteo/criteo_test_data"
+model_dir: "examples/ckpt/fm_criteo_ckpt"
+
+train_config {
+ log_step_count_steps: 500
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: True
+ num_steps: 20000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: STRING
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: STRING
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_class: 'FM'
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:WIDE
+ }
+ fm {
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/examples/configs/masknet_on_movielens.config b/examples/configs/masknet_on_movielens.config
new file mode 100644
index 000000000..04205ddd5
--- /dev/null
+++ b/examples/configs/masknet_on_movielens.config
@@ -0,0 +1,198 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/masknet_on_movieslen_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_name: 'MaskNet'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "mask_net"
+ inputs {
+ feature_group_name: "all"
+ }
+ keras_layer {
+ class_name: 'MaskNet'
+ masknet {
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mlp {
+ hidden_units: [512, 256]
+ }
+ }
+ }
+ }
+ concat_blocks: ['mask_net']
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/mind_on_books.config b/examples/configs/mind_on_books.config
new file mode 100644
index 000000000..d19eb5d96
--- /dev/null
+++ b/examples/configs/mind_on_books.config
@@ -0,0 +1,113 @@
+train_input_path: "examples/data/amazon_books_data/amazon_train_data"
+eval_input_path: "examples/data/amazon_books_data/amazon_test_data"
+model_dir: "examples/ckpt/mind_book_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ num_steps: 20000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name:'book_id_seq'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'book_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'label'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: "\t"
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500000
+ }
+ features: {
+ input_names: 'book_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 400000
+ }
+ features: {
+ input_names: 'book_id_seq'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 400000
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "MIND"
+ feature_groups: {
+ group_name: 'hist'
+ feature_names: 'book_id_seq'
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ mind {
+ user_dnn {
+ hidden_units: [128, 64, 32]
+ }
+ item_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ concat_dnn {
+ hidden_units: [64, 32]
+ }
+
+ capsule_config {
+ max_k: 3
+ max_seq_len: 50
+ high_dim: 64
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
diff --git a/examples/configs/mind_on_books_negative_sample.config b/examples/configs/mind_on_books_negative_sample.config
new file mode 100644
index 000000000..6058e6a2f
--- /dev/null
+++ b/examples/configs/mind_on_books_negative_sample.config
@@ -0,0 +1,130 @@
+train_input_path: "examples/data/amazon_books_data/amazon_train_data"
+eval_input_path: "examples/data/amazon_books_data/amazon_test_data"
+model_dir: "examples/ckpt/mind_book_negative_sample_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ num_steps: 20000
+}
+
+eval_config {
+ metrics_set {
+ recall_at_topk { topk: 1 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 10 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 100 }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name:'book_id_seq'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'book_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'label'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: "\t"
+
+ negative_sampler {
+ input_path: 'examples/data/amazon_books_data/negative_book_data'
+ num_sample: 1024
+ num_eval_sample: 1024
+ attr_fields: 'book_id'
+ item_id_field: 'book_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500000
+ }
+ features: {
+ input_names: 'book_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 400000
+ }
+ features: {
+ input_names: 'book_id_seq'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 400000
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "MIND"
+ feature_groups: {
+ group_name: 'hist'
+ feature_names: 'book_id_seq'
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ mind {
+ user_dnn {
+ hidden_units: [128, 64, 32]
+ }
+ item_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ concat_dnn {
+ hidden_units: [64, 32]
+ }
+
+ capsule_config {
+ max_k: 3
+ max_seq_len: 50
+ high_dim: 64
+ }
+ item_id: "book_id"
+ l2_regularization: 1e-6
+ ignore_in_batch_neg_sam: true
+ }
+ embedding_regularization: 5e-5
+ loss_type: SOFTMAX_CROSS_ENTROPY
+}
diff --git a/examples/configs/mlp_on_movielens.config b/examples/configs/mlp_on_movielens.config
new file mode 100644
index 000000000..814a09e00
--- /dev/null
+++ b/examples/configs/mlp_on_movielens.config
@@ -0,0 +1,241 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/mlp_movieslen"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_class: "RankModel"
+ feature_groups: {
+ group_name: 'features'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ feature_names: 'title'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'mlp'
+ inputs {
+ feature_group_name: 'features'
+ }
+ layers {
+ keras_layer {
+ class_name: 'Dense'
+ st_params {
+ fields {
+ key: 'units'
+ value: { number_value: 256 }
+ }
+ fields {
+ key: 'activation'
+ value: { string_value: 'relu' }
+ }
+ }
+ }
+ }
+ layers {
+ keras_layer {
+ class_name: 'Dropout'
+ st_params {
+ fields {
+ key: 'rate'
+ value: { number_value: 0.5 }
+ }
+ }
+ }
+ }
+ layers {
+ keras_layer {
+ class_name: 'Dense'
+ st_params {
+ fields {
+ key: 'units'
+ value: { number_value: 256 }
+ }
+ fields {
+ key: 'activation'
+ value: { string_value: 'relu' }
+ }
+ }
+ }
+ }
+ layers {
+ keras_layer {
+ class_name: 'Dropout'
+ st_params {
+ fields {
+ key: 'rate'
+ value: { number_value: 0.5 }
+ }
+ }
+ }
+ }
+ layers {
+ keras_layer {
+ class_name: 'Dense'
+ st_params {
+ fields {
+ key: 'units'
+ value: { number_value: 1 }
+ }
+ }
+ }
+ }
+ }
+ concat_blocks: 'mlp'
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/multi_tower_on_movielens.config b/examples/configs/multi_tower_on_movielens.config
new file mode 100644
index 000000000..aa6c27b87
--- /dev/null
+++ b/examples/configs/multi_tower_on_movielens.config
@@ -0,0 +1,214 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/multi_tower_movieslen"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_name: "multi tower"
+ model_class: "RankModel"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'movie_id'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'user_tower'
+ inputs {
+ feature_group_name: 'user'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ blocks {
+ name: 'item_tower'
+ inputs {
+ feature_group_name: 'item'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ blocks {
+ name: 'top_mlp'
+ inputs {
+ block_name: 'user_tower'
+ }
+ inputs {
+ block_name: 'item_tower'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 64]
+ }
+ }
+ }
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/readme.md b/examples/configs/readme.md
new file mode 100644
index 000000000..6df4d84a1
--- /dev/null
+++ b/examples/configs/readme.md
@@ -0,0 +1,126 @@
+# Config配置文件说明
+
+## 输入
+
+在我们的demo实验中,采用local环境的csv格式的文件。
+
+```
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/new_autoint_on_movieslen_ckpt"
+```
+
+其中,`train_input_path`是训练集路径,`test_input_path`是测试集路径,`model_dir`是指定模型保存的路径。
+
+## 训练配置
+
+train_config用于配置一些训练时常用的参数,详细见[train.md](../../docs/source/train.md)。
+
+```
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+```
+
+## 评估配置
+
+eval_config用于配置训练过程中的评估指标(如AUC),详细见 [eval.md](../../docs/source/eval.md)。
+
+```
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+```
+
+## 数据配置
+
+data_config用于配置输入文件中各特征列的数据类型,详细见 [data.md](../../docs/source/feature/data.md)。
+
+```
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+}
+```
+
+## 特征配置
+
+feature_config用于配置特征字段。
+
+```
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+}
+```
+
+## 模型配置
+
+model_config用于配置模型类型以及模型网络具体参数信息等。
+
+```
+model_config: {
+ model_class: 'DeepFM'
+ feature_groups: {
+ group_name: 'wide'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ wide_deep: WIDE
+ }
+ feature_groups: {
+ group_name: 'deep'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ wide_deep: DEEP
+ }
+ deepfm {
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
+```
+
+## 导出配置
+
+export_config用于配置导出模型时的参数,详细见 [export.md](../../docs/source/export.md)。
diff --git a/examples/configs/wide_and_deep_backbone_on_movielens.config b/examples/configs/wide_and_deep_backbone_on_movielens.config
new file mode 100644
index 000000000..93957485c
--- /dev/null
+++ b/examples/configs/wide_and_deep_backbone_on_movielens.config
@@ -0,0 +1,220 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/wide_and_deep_movieslen"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_class: "RankModel"
+ feature_groups: {
+ group_name: 'wide'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: WIDE
+ }
+ feature_groups: {
+ group_name: 'deep'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'wide'
+ inputs {
+ feature_group_name: 'wide'
+ }
+ input_layer {
+ wide_output_dim: 1
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: 'deep_logit'
+ inputs {
+ feature_group_name: 'deep'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 256, 256, 1]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+ blocks {
+ name: 'final_logit'
+ inputs {
+ block_name: 'wide'
+ input_fn: 'lambda x: tf.add_n(x)'
+ }
+ inputs {
+ block_name: 'deep_logit'
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: 'Add'
+ }
+ }
+ concat_blocks: 'final_logit'
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/configs/wide_and_deep_on_movielens.config b/examples/configs/wide_and_deep_on_movielens.config
new file mode 100644
index 000000000..94d00cade
--- /dev/null
+++ b/examples/configs/wide_and_deep_on_movielens.config
@@ -0,0 +1,188 @@
+train_input_path: "examples/data/movielens_1m/movies_train_data"
+eval_input_path: "examples/data/movielens_1m/movies_test_data"
+model_dir: "examples/ckpt/wide_and_deep_movieslen_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'job_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'year'
+ input_type: INT32
+ }
+
+ label_fields: 'label'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+ separator: '\t'
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'job_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'year'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+}
+model_config: {
+ model_class: "WideAndDeep"
+ feature_groups: {
+ group_name: 'wide'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: WIDE
+ }
+ feature_groups: {
+ group_name: 'deep'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ wide_and_deep {
+ wide_output_dim: 16
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+
+ final_dnn {
+ hidden_units: [64, 32, 16]
+ }
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/examples/data/amazon_books_data/download_and_process.sh b/examples/data/amazon_books_data/download_and_process.sh
new file mode 100644
index 000000000..be7a8ef59
--- /dev/null
+++ b/examples/data/amazon_books_data/download_and_process.sh
@@ -0,0 +1,11 @@
+#! /bin/bash
+if [ "$(uname)" == "Darwin" ]; then
+ curl -O https://ali-rec-sln.oss-cn-hangzhou.aliyuncs.com/resources/AmazonBooksData.tar.gz
+elif [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then
+ wget -c https://ali-rec-sln.oss-cn-hangzhou.aliyuncs.com/resources/AmazonBooksData.tar.gz
+elif [ "$(expr substr $(uname -s) 1 10)" == "MINGW32_NT" ]; then
+ curl -O https://ali-rec-sln.oss-cn-hangzhou.aliyuncs.com/resources/AmazonBooksData.tar.gz
+fi
+tar -zxvf AmazonBooksData.tar.gz
+mv AmazonBooksData.tar.gz AmazonBooksData/
+python process_amazon.py
diff --git a/examples/data/amazon_books_data/process_amazon.py b/examples/data/amazon_books_data/process_amazon.py
new file mode 100644
index 000000000..ba1fa7f4c
--- /dev/null
+++ b/examples/data/amazon_books_data/process_amazon.py
@@ -0,0 +1,106 @@
+import random
+
+import pandas as pd
+
+print('Start reading data...')
+title = ['UserID', 'BookID', 'Time']
+print('Reading train data...')
+train = pd.read_table(
+ 'AmazonBooksData/book_train.txt',
+ sep=',',
+ header=None,
+ names=title,
+ engine='python',
+ encoding='ISO-8859-1')
+print('Reading test data...')
+test = pd.read_table(
+ 'AmazonBooksData/book_test.txt',
+ sep=',',
+ header=None,
+ names=title,
+ engine='python',
+ encoding='ISO-8859-1')
+
+print('Start processing train data...')
+train_set = []
+for userID, hist in train.groupby('UserID'):
+ pos_list = hist['BookID'].tolist()
+
+ # generate negative samples randomly
+ def gen_neg():
+ neg = pos_list[0]
+ while neg in pos_list:
+ # 1~367982 is the range of book id
+ neg = random.randint(1, 367982)
+ return neg
+
+ neg_list_1 = [gen_neg() for i in range(len(pos_list))]
+ neg_list_2 = [gen_neg() for i in range(len(pos_list))]
+ neg_list_3 = [gen_neg() for i in range(len(pos_list))]
+ neg_list_4 = [gen_neg() for i in range(len(pos_list))]
+
+ for i in range(1, len(pos_list)):
+ # set the max sequence length to 50
+ hist = pos_list[:i][-50:]
+ hist_str = '|'.join(map(str, hist))
+ if i != len(pos_list):
+ # for each positive sample, random generate 4 negative samples
+ train_set.append((userID, hist_str, pos_list[i], 1))
+ train_set.append((userID, hist_str, neg_list_1[i], 0))
+ train_set.append((userID, hist_str, neg_list_2[i], 0))
+ train_set.append((userID, hist_str, neg_list_3[i], 0))
+ train_set.append((userID, hist_str, neg_list_4[i], 0))
+
+random.shuffle(train_set)
+
+print('Start processing test data...')
+test_set = []
+for userID, hist in test.groupby('UserID'):
+ pos_list = hist['BookID'].tolist()
+
+ # generate negative samples randomly
+ def gen_neg():
+ neg = pos_list[0]
+ while neg in pos_list:
+ # 1~367982 is the range of book id
+ neg = random.randint(1, 367982)
+ return neg
+
+ neg_list_1 = [gen_neg() for i in range(len(pos_list))]
+ neg_list_2 = [gen_neg() for i in range(len(pos_list))]
+ neg_list_3 = [gen_neg() for i in range(len(pos_list))]
+ neg_list_4 = [gen_neg() for i in range(len(pos_list))]
+ for i in range(1, len(pos_list)):
+ # set the max sequence length to 50
+ hist = pos_list[:i][-50:]
+ hist_str = '|'.join(map(str, hist))
+ if i != len(pos_list):
+ # for each positive sample, random generate 4 negative samples
+ test_set.append((userID, hist_str, pos_list[i], 1))
+ test_set.append((userID, hist_str, neg_list_1[i], 0))
+ test_set.append((userID, hist_str, neg_list_2[i], 0))
+ test_set.append((userID, hist_str, neg_list_3[i], 0))
+ test_set.append((userID, hist_str, neg_list_4[i], 0))
+random.shuffle(test_set)
+
+train_set_df = pd.DataFrame(train_set)
+test_set_df = pd.DataFrame(test_set)
+
+print('Start writing amazon_train_data...')
+train_set_df.to_csv(
+ r'amazon_train_data', index=False, sep='\t', mode='a', header=False)
+print('Start writing amazon_test_data...')
+test_set_df.to_csv(
+ r'amazon_test_data', index=False, sep='\t', mode='a', header=False)
+
+print('Negative Sampling')
+train_book = train[['BookID']].drop_duplicates()
+test_book = test[['BookID']].drop_duplicates()
+negative_book = pd.concat([train_book, test_book]).drop_duplicates()
+df_ones = pd.DataFrame(
+ 1, index=negative_book.index, columns=negative_book.columns)
+negative_book_data = pd.concat([negative_book, df_ones, negative_book], axis=1)
+new_header = ['id:int64', 'weight:float', 'feature:string']
+negative_book_data.to_csv(
+ r'negative_book_data', index=False, sep='\t', mode='a', header=new_header)
+print('Done.')
diff --git a/examples/data/amazon_books_data/readme.md b/examples/data/amazon_books_data/readme.md
new file mode 100644
index 000000000..273364f0d
--- /dev/null
+++ b/examples/data/amazon_books_data/readme.md
@@ -0,0 +1,92 @@
+# Amazon Books
+
+这是来自亚马逊的大量产品评论抓取数据集。该数据集包含来自约2000万用户的8283万条独立评论。
+
+- 基础描述:
+
+ ```
+ Ratings: 82.83 million
+ Users: 20.98 million
+ Items: 9.35 million
+ Timespan: May 1996 - July 2014
+ Metadata
+ reviews and ratings
+ item-to-item relationships (e.g. "people who bought X also bought Y")
+ timestamps
+ helpfulness votes
+ product image (and CNN features)
+ price
+ category
+ salesRank
+ ```
+
+- 下载:
+ 原始数据集:
+
+ http://jmcauley.ucsd.edu/data/amazon/index.html
+ https://tianchi.aliyun.com/dataset/dataDetail?dataId=649&userId=1
+
+ ComiRec处理后数据集:
+
+ Tsinghua Cloud: https://cloud.tsinghua.edu.cn/f/e5c4211255bc40cba828/?dl=1
+ Dropbox: https://www.dropbox.com/s/m41kahhhx0a5z0u/data.tar.gz?dl=1
+
+# 数据预处理
+
+我们基于[ComiRec](https://github.com/THUDM/ComiRec/tree/master)提供的AmazonBooks数据集进行进一步处理,使其适配EasyRec的召回模型样本格式。
+
+详细处理细节见 [process_amazon.py](process_amazon.py)
+
+也可跳过预处理,直接通过链接下载处理后的数据集: [amazon_train_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/amazon_books/amazon_train_data)、[amazon_test_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/amazon_books/amazon_test_data)、[negative_book_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/amazon_books/negative_book_data)。
+
+- 序列特征构造:
+
+ 为丰富样本特征,充分利用EasyRec处理序列特征的能力,我们对数据集进一步处理。ComiRec数据集中每一行代表一次交互,包含三个字段\,\,\。通过对用户进行分组,得到多条序列特征,用'|'分隔。为提高训练效率,我们设定序列特征的最大长度为50。
+
+ 例如,
+
+ ```
+ user_id book_id time_stamp
+ 0 a 0
+ 0 b 1
+ 0 c 2
+ 0 d 3
+
+ ---- process ----
+
+ user_id book_id_seq book_id label
+ 0 a b 1
+ 0 a|b c 1
+ 0 a|b|c d 1
+
+ ```
+
+- 负采样:
+
+ 原始数据集只包含正样本,为丰富样本,我们进行了随机负采样。对每一条样本,随机负采样4条没有出现在点击序列中的item。
+
+ ```
+ user_id book_id_seq book_id
+ 0 a b
+ 0 a|b c
+ 0 a|b|c d
+
+ ---- nagetive sampling ----
+
+ user_id book_id_seq book_id label
+ 0 a b 1
+ 0 a h 0
+ 0 a i 0
+ 0 a j 0
+ 0 a k 0
+ 0 a|b c 1
+ 0 a|b l 0
+ 0 a|b m 0
+ 0 a|b j 0
+ 0 a|b o 0
+ 0 a|b|c d 1
+ 0 a|b|c m 0
+ 0 a|b|c j 0
+ 0 a|b|c r 0
+ 0 a|b|c h 0
+ ```
diff --git a/examples/data/criteo/download_and_process.sh b/examples/data/criteo/download_and_process.sh
new file mode 100644
index 000000000..30061a862
--- /dev/null
+++ b/examples/data/criteo/download_and_process.sh
@@ -0,0 +1,12 @@
+#! /bin/bash
+if [ "$(uname)" == "Darwin" ]; then
+ curl -O https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/criteo_kaggle/kaggle-display-advertising-challenge-dataset.tar.gz
+elif [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then
+ wget -c https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/criteo_kaggle/kaggle-display-advertising-challenge-dataset.tar.gz
+elif [ "$(expr substr $(uname -s) 1 10)" == "MINGW32_NT" ]; then
+ curl -O https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/criteo_kaggle/kaggle-display-advertising-challenge-dataset.tar.gz
+fi
+mkdir criteo_kaggle_display
+tar -zxvf kaggle-display-advertising-challenge-dataset.tar.gz -C criteo_kaggle_display
+mv kaggle-display-advertising-challenge-dataset.tar.gz criteo_kaggle_display
+python process_criteo_kaggle.py
diff --git a/examples/data/criteo/process_criteo_kaggle.py b/examples/data/criteo/process_criteo_kaggle.py
new file mode 100644
index 000000000..5b9cb4f34
--- /dev/null
+++ b/examples/data/criteo/process_criteo_kaggle.py
@@ -0,0 +1,19 @@
+import pandas as pd
+
+category_features = ['F' + str(i) for i in range(1, 27)]
+dense_features = ['I' + str(i) for i in range(1, 14)]
+target_columns = ['label']
+columns = target_columns + dense_features + category_features
+
+data_train = pd.read_csv(
+ 'criteo_kaggle_display/train.txt', sep='\t', names=columns)
+
+samples_num = data_train.shape[0]
+print('samples_num:', samples_num, round(samples_num * 0.9))
+
+train_num = int(round(samples_num * 0.9))
+data_train[:train_num].to_csv(
+ r'criteo_train_data', index=False, sep='\t', mode='a', header=False)
+data_train[train_num:].to_csv(
+ r'criteo_test_data', index=False, sep='\t', mode='a', header=False)
+print('Done.')
diff --git a/examples/data/criteo/readme.md b/examples/data/criteo/readme.md
new file mode 100644
index 000000000..899a046f5
--- /dev/null
+++ b/examples/data/criteo/readme.md
@@ -0,0 +1,39 @@
+# Criteo Research Kaggle Display Advertising Challenge Dataset
+
+- 任务:CTR预估/排序
+
+- 简介:
+
+ 该数据集由 Criteo 提供,包含数百万展示广告的特征值和点击反馈。 其目的是对点击率 (CTR) 预估的算法进行基准测试。该数据集包括4500万用户的点击记录。有13个连续特征和26个分类特征。
+
+- 下载:
+
+ [官网](https://ailab.criteo.com/ressources/)下载地址:https://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz
+
+ 天池下载地址:https://tianchi.aliyun.com/dataset/144733
+
+- 详细描述:
+
+ 该数据集包含 2 个文件`train.txt` `test.txt`,对应数据的训练和测试部分。
+
+ 训练数据集`train.txt`包含7天内Criteo的一部分流量。每行对应Criteo投放的一个展示广告,第一列表示该广告是否被点击。正面(点击)和负面(未点击)的例子都被二次采样(但以不同的速率)以减少数据集的大小。
+
+ 有13个采用整数值的特征(主要是计数特征)和26个分类特征。 出于匿名目的,分类特征的值已散列到32位。 这些功能的语义未公开。某些特征可能有缺失值。行按时间顺序排列。
+
+ 测试集`test.txt`的计算方式与训练集相同,但它对应于训练期后一天的事件。 第一列(标签)已被删除。
+
+- 格式:
+
+ 数据列之间使用制表符作为分隔符: \ ... \ \ ... \
+
+ 当缺少一个值时,该字段只是空的。 测试集中没有标签字段。
+
+# 数据预处理
+
+参考[DeepFM论文](https://arxiv.org/abs/1703.04247)的方式。Criteo数据集包括4500万用户的点击记录。有13个连续特征和26个分类特征。将训练数据集随机分成两部分:90%用于训练,其余10%用于测试。
+
+详细处理细节见 [process_criteo_kaggle.py](process_criteo_kaggle.py)
+
+也可跳过预处理,直接通过链接下载处理后的数据集: [criteo_train_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/criteo_kaggle/criteo_train_data)、[criteo_test_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/criteo_kaggle/criteo_test_data)。
+
+注:由于测试集中没有标签,无法评估,故在我们的demo实验中没有使用。
diff --git a/examples/data/movielens_1m/download_and_process.sh b/examples/data/movielens_1m/download_and_process.sh
new file mode 100644
index 000000000..b7f158b4a
--- /dev/null
+++ b/examples/data/movielens_1m/download_and_process.sh
@@ -0,0 +1,11 @@
+#! /bin/bash
+if [ "$(uname)" == "Darwin" ]; then
+ curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip
+elif [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then
+ wget -c http://files.grouplens.org/datasets/movielens/ml-1m.zip
+elif [ "$(expr substr $(uname -s) 1 10)" == "MINGW32_NT" ]; then
+ curl -O http://files.grouplens.org/datasets/movielens/ml-1m.zip
+fi
+unzip ml-1m.zip
+mv ml-1m.zip ml-1m/
+python process_ml_1m.py
diff --git a/examples/data/movielens_1m/process_ml_1m.py b/examples/data/movielens_1m/process_ml_1m.py
new file mode 100644
index 000000000..8fca6292a
--- /dev/null
+++ b/examples/data/movielens_1m/process_ml_1m.py
@@ -0,0 +1,93 @@
+import re
+
+import pandas as pd
+from sklearn.utils import shuffle
+
+
+def process_data():
+ """Load Dataset from File."""
+ print('Start processing movielens-1m dataset.')
+ # Read user data
+ print('----User Data----')
+ users_title = ['UserID', 'Gender', 'Age', 'JobID', 'ZipCode']
+ users = pd.read_table(
+ 'ml-1m/users.dat',
+ sep='::',
+ header=None,
+ names=users_title,
+ engine='python',
+ encoding='ISO-8859-1')
+ users = users.filter(regex='UserID|Gender|Age|JobID|ZipCode')
+ # process the gender and age of user
+ gender_map = {'F': 0, 'M': 1}
+ users['Gender'] = users['Gender'].map(gender_map)
+
+ age_map = {val: ii for ii, val in enumerate(set(users['Age']))}
+ users['Age'] = users['Age'].map(age_map)
+
+ # read movie data
+ print('----Movie Data----')
+ movies_title = ['MovieID', 'Title', 'Genres']
+ movies = pd.read_table(
+ 'ml-1m/movies.dat',
+ sep='::',
+ header=None,
+ names=movies_title,
+ engine='python',
+ encoding='ISO-8859-1')
+
+ # split the title and year in Feature:'Title'
+ pattern = re.compile(r'^(.*)\((\d+)\)$')
+
+ title_map = {
+ val: pattern.match(val).group(1)
+ for ii, val in enumerate(set(movies['Title']))
+ }
+ year_map = {
+ val: pattern.match(val).group(2)
+ for ii, val in enumerate(set(movies['Title']))
+ }
+ movies['Year'] = movies['Title'].map(year_map)
+ movies['Title'] = movies['Title'].map(title_map)
+
+ # read rating data
+ print('----Rating Data----')
+ ratings_title = ['UserID', 'MovieID', 'ratings', 'timestamps']
+ ratings = pd.read_table(
+ 'ml-1m/ratings.dat',
+ sep='::',
+ header=None,
+ names=ratings_title,
+ engine='python',
+ encoding='ISO-8859-1')
+ ratings = ratings.filter(regex='UserID|MovieID|ratings')
+ # ratings of 4 and 5 are viewed as positive samples [label:1]
+ # ratings of 0, 1 and 2 are viewed as negative samples [label:0]
+ # discard samples of rating = 3
+ label_map = {1: 0, 2: 0, 3: 2, 4: 1, 5: 1}
+ ratings['label'] = ratings['ratings'].map(label_map)
+
+ # concat users, movies and ratings
+ data = pd.merge(pd.merge(ratings, users), movies)
+
+ # let field 'label' to postion 1
+ new_order = ['label'] + [col for col in data.columns if col != 'label']
+ data = data.reindex(columns=new_order)
+ # shuffle samples
+ data = shuffle(data)
+ print('Process Done.')
+ return ratings, users, movies, data
+
+
+ratings, users, movies, data = process_data()
+data_new = data[data['label'] < 2]
+print(data.count())
+print(data_new.count())
+
+# split train set and test set, and write to file
+print('Start writing to file.')
+data_new[:665110].to_csv(
+ r'movies_train_data', index=False, sep='\t', mode='a', header=False)
+data_new[665110:].to_csv(
+ r'movies_test_data', index=False, sep='\t', mode='a', header=False)
+print('Done.')
diff --git a/examples/data/movielens_1m/readme.md b/examples/data/movielens_1m/readme.md
new file mode 100644
index 000000000..e4c56290c
--- /dev/null
+++ b/examples/data/movielens_1m/readme.md
@@ -0,0 +1,123 @@
+# MovieLens-1M
+
+任务:CTR预估/排序
+
+数据集下载:http://files.grouplens.org/datasets/movielens/ml-1m.zip
+
+MovieLens 1M 数据集,包含6000个用户在近4000部电影上的1亿条评论。
+
+数据集分为三个文件:用户数据users.dat,电影数据movies.dat和评分数据ratings.dat。
+
+## 用户数据
+
+分别有用户ID、性别、年龄、职业ID和邮编等字段。
+
+数据中的格式:`UserID::Gender::Age::Occupation::Zip-code`
+
+可以看出UserID、Gender、Age和Occupation都是类别字段,其中邮编字段是我们不使用的。
+
+- 性别用“M”表示男性,“F”表示女性
+
+- 年龄来自以下范围:
+
+ ```
+ 1: "Under 18"
+ 18: "18-24"
+ 25: "25-34"
+ 35: "35-44"
+ 45: "45-49"
+ 50: "50-55"
+ 56: "56+"
+ ```
+
+- 职业包含以下几种:
+
+ ```
+ 0: "other" or not specified
+ 1: "academic/educator"
+ 2: "artist"
+ 3: "clerical/admin"
+ 4: "college/grad student"
+ 5: "customer service"
+ 6: "doctor/health care"
+ 7: "executive/managerial"
+ 8: "farmer"
+ 9: "homemaker"
+ 10: "K-12 student"
+ 11: "lawyer"
+ 12: "programmer"
+ 13: "retired"
+ 14: "sales/marketing"
+ 15: "scientist"
+ 16: "self-employed"
+ 17: "technician/engineer"
+ 18: "tradesman/craftsman"
+ 19: "unemployed"
+ 20: "writer"
+ ```
+
+## 电影数据
+
+分别有电影ID、电影标题和电影风格等字段。
+
+数据中的格式:`MovieID::Title::Genres`
+
+MovieID是类别字段,Title是文本,Genres也是类别字段
+
+- 标题与 IMDB 提供的标题相同(包括发行年份)
+
+- 电影风格类型有以下几种:
+
+ ```
+ Action
+ Adventure
+ Animation
+ Children's
+ Comedy
+ Crime
+ Documentary
+ Drama
+ Fantasy
+ Film-Noir
+ Horror
+ Musical
+ Mystery
+ Romance
+ Sci-Fi
+ Thriller
+ War
+ Western
+ ```
+
+## 评分数据
+
+分别有用户ID、电影ID、评分和时间戳等字段。
+
+数据中的格式:`UserID::MovieID::Rating::Timestamp`
+
+评分字段Rating就是我们要学习的label,时间戳字段我们不使用。
+
+- UserIDs 范围在 1 到 6040 之间
+
+- MovieIDs 范围在 1 到 3952 之间
+
+- 评级采用 5 星制(仅限全星评级)
+
+- 时间戳以 time(2) 返回的纪元以来的秒数表示
+
+- 每个用户至少有 20 个评分
+
+# 数据预处理
+
+我们参考了[AutoInt论文](https://dl.acm.org/doi/pdf/10.1145/3357384.3357925)中的处理方法,将评分小于 3 的样本视为负样本,因为低分表示用户不喜欢这部电影;将评分大于 3 的样本视为正样本;最后删除中性样本,即评分等于 3。
+
+详细处理细节见 [process_ml_1m.py](process_ml_1m.py)
+
+也可跳过预处理,直接通过链接下载处理后的数据集: [movies_train_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/movielens_1m/movies_train_data)、[movies_test_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/movielens_1m/movies_test_data)。
+
+- label:将评分大于3的作为正样本(label=1),将评分小于3的作为负样本(label=0),作为点击率预估任务的目标。
+- UserID、Occupation和MovieID不用变。
+- Gender字段:将‘F’和‘M’变换成0和1。
+- Age字段:把年龄离散化为0-6之间的数字(共7个数字)。
+- Genres字段:无需处理,直接转换为EasyRec的TagFeature
+- Title字段:将标题和年份拆开为两个特征,其中标题为SequenceFeature,年份为IDFeature。
diff --git a/examples/match_model/dssm.md b/examples/match_model/dssm.md
new file mode 100644
index 000000000..0fa909e94
--- /dev/null
+++ b/examples/match_model/dssm.md
@@ -0,0 +1,79 @@
+# DSSM
+
+### 简介
+
+双塔召回模型,分为user塔和item塔。
+注:使用时需指定user id和item id。
+
+
+### 参考论文
+
+[DSSM.pdf](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/cikm2013_DSSM_fullversion.pdf)
+
+### 配置说明
+
+```protobuf
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ allow_key_search: true
+ seq_att_map: {
+ key: "book_id"
+ hist_seq: "book_id_seq"
+ }
+ }
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ item_tower {
+ id: "book_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
+```
+
+- model_class: 'DSSM', 不需要修改
+- feature_groups: 需要两个feature_group: user和item, **group name不能变**
+- sequence_features: 配置序列特征
+- dssm: dssm相关的参数,必须配置user_tower和item_tower
+- user_tower/item_tower:
+ - dnn: deep part的参数配置
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+ - id: 指定user_id/item_id列
+- simi_func: 向量相似度函数,包括\[COSINE, INNER_PRODUCT, EUCLID\],默认COSINE,建议使用INNER_PRODUCT
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+
+支持的metric_set包括:
+
+- auc
+- mean_absolute_error
+- accuracy
+
+### 示例Config
+
+[dssm_on_books.config](../configs/dssm_on_books.config)
+
+### 效果评估
+
+[效果评估](https://easyrec.oss-cn-beijing.aliyuncs.com/docs/recall_eval.pdf)
diff --git a/examples/match_model/dssm_negative_sample.md b/examples/match_model/dssm_negative_sample.md
new file mode 100644
index 000000000..531081634
--- /dev/null
+++ b/examples/match_model/dssm_negative_sample.md
@@ -0,0 +1,149 @@
+# DSSM负采样版
+
+### 简介
+
+双塔召回模型,支持训练时负采样。
+
+
+
+当物品池很大上百万甚至是上亿的时候,双塔召回模型常常需要在物品池中针对每个正样本采样一千甚至一万的负样本才能达到比较好的召回效果,
+意味着正负样本比例达到了1: 1k,甚至是1: 1w, 要支持这个正负样本比例的训练,如果用离线构造样本的方式会导致离线存储和离线计算的压力都激增。
+该版本的DSSM支持运行时进行负采样,会以图存储的方式将物品的特征存储在Parameter Server节点上,并且Mini-Batch内的共享同一批负样本的计算,
+使得离线存储和离线计算的压力都大大降低。
+
+注:训练样本一般只需准备点击(正样本)的样本即可
+
+### 参考论文
+
+[DSSM.pdf](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/cikm2013_DSSM_fullversion.pdf)
+
+### 配置说明
+
+```protobuf
+eval_config {
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config: {
+ ...
+ negative_sampler {
+ input_path: 'examples/data/book_data/negative_book_data'
+ num_sample: 1024
+ num_eval_sample: 1024
+ attr_fields: 'book_id'
+ item_id_field: 'book_id'
+ }
+}
+
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ item_tower {
+ id: "book_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ simi_func: INNER_PRODUCT
+ scale_simi: false
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+```
+
+- eval_config: 评估配置,目前只支持recall_at_topk
+- data_config: 数据配置,其中需要配置负采样Sampler,负采样Sampler的配置详见[负采样配置](./%E8%B4%9F%E9%87%87%E6%A0%B7%E9%85%8D%E7%BD%AE)
+- model_class: 'DSSM', 不需要修改
+- feature_groups: 需要两个feature_group: user和item, **group name不能变**
+- dssm: dssm相关的参数,必须配置user_tower和item_tower
+- user_tower/item_tower:
+ - dnn: deep part的参数配置
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+ - id: 指定user_id/item_id列
+- simi_func: 向量相似度函数,包括\[COSINE, INNER_PRODUCT, EUCLID\],默认COSINE,建议使用INNER_PRODUCT
+- scale_simi: 是否自动缩放相似度便于loss计算,建议设置成false
+- loss_type: 目前只支持SOFTMAX_CROSS_ENTROPY
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+
+注意,DSSM负采样版目前仅支持recall_at_topk做评估指标。
+
+#### 负采样配置
+
+目前支持四种负采样Sampler:
+
+- negative_sampler:加权随机负采样,会排除Mini-Batch内的Item Id
+ - input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - num_sample: 训练worker的负采样数
+ - num_eval_sampler: 评估worker的负采样数
+ - attr_fields: Item特征名,顺序与Item的attr中特征的拼接顺序保持一致
+ - item_id_field: item_id列名
+- negative_sampler_v2:加权随机负采样,会跟排除Mini-Batch内的User有边的Item Id
+ - user_input_path: User表, Schema为: id:int64 | weight:float
+ - item_input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - pos_edge_input_path: Positive边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - user_id_field: user_id列名
+ - 其余同negative_sampler
+- hard_negative_sampler:加权随机负采样,会排除Mini-Batch内的Item Id,同时HardNegative边表中(一般为曝光未点击)进行负采样作为HardNegative
+ - user_input_path: User表, Schema为: id:int64 | weight:float
+ - item_input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - hard_neg_edge_input_path: HardNegative边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - num_hard_sample: hard negative的最大采样数目
+ - user_id_field: user_id列名
+ - 其余同negative_sampler
+- hard_negative_sampler_v2:加权随机负采样,会跟排除Mini-Batch内的User有边的Item Id,同时HardNegative边表中(一般为曝光未点击)进行负采样作为HardNegative
+ - user_input_path: User表, Schema为: id:int64 | weight:float
+ - item_input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - pos_edge_input_path: Positive边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - hard_neg_edge_input_path: HardNegative边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - num_hard_sample: hard negative的最大采样数目
+ - user_id_field: user_id列名
+ - 其余同negative_sampler
+ 一般用negative_sampler即可。
+
+### 示例Config
+
+[dssm_on_books_negative_sample.config](../configs/dssm_on_books_negative_sample.config)
+
+[DSSM_NegSampler.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/dssm_neg_sampler_on_taobao.config)
+
+[DSSM_NegSamplerV2.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/dssm_neg_sampler_v2_on_taobao.config)
+
+[DSSM_HardNegSampler.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/dssm_hard_neg_sampler_on_taobao.config)
+
+[DSSM_HardNegSamplerV2.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/dssm_hard_neg_sampler_v2_on_taobao.config)
+
+### 效果评估
+
+[效果评估](https://easyrec.oss-cn-beijing.aliyuncs.com/docs/recall_eval.pdf)
diff --git a/examples/match_model/mind.md b/examples/match_model/mind.md
new file mode 100644
index 000000000..b66c29419
--- /dev/null
+++ b/examples/match_model/mind.md
@@ -0,0 +1,186 @@
+# MIND
+
+### 简介
+
+mind召回模型, 在dssm的基础上加入了兴趣聚类功能,支持多兴趣召回,能够显著的提升召回层的效果.
+
+
+### 参考论文
+
+[MIND.pdf](https://arxiv.org/pdf/1904.08030.pdf)
+
+### 配置说明
+
+```protobuf
+ ...
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500000
+ }
+ features: {
+ input_names: 'book_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 400000
+ }
+ features: {
+ input_names: 'book_id_seq'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 400000
+ embedding_dim: 16
+ }
+
+model_config:{
+ model_class: "MIND"
+ feature_groups: {
+ group_name: 'hist'
+ feature_names: 'book_id_seq'
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ mind {
+ user_dnn {
+ hidden_units: [128, 64, 32]
+ }
+ item_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ concat_dnn {
+ hidden_units: [64, 32]
+ }
+
+ capsule_config {
+ max_k: 3
+ max_seq_len: 50
+ high_dim: 64
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
+```
+
+- model_class: 'MIND', 不需要修改
+- feature_groups: 需要三个feature_group: hist, user和item, **group name不能变**
+- mind: mind相关的参数,必须配置user_dnn和item_dnn
+- user_dnn: user侧的dnn参数
+ - dnn:
+ - hidden_units: dnn每一层的channel数
+ - use_bn: 是否使用batch_norm, 默认是true
+- item_dnn: item侧的dnn参数, 配置同user_dnn
+ - note: item侧不能用batch_norm
+- pre_capsule_dnn: 进入capsule之前的dnn的配置
+ - 可选, 配置同user_dnn和item_dnn
+- concat_dnn: hist seq 和 user feature融合后的dnn
+- capsule_config: 胶囊(动态路由)的配置
+ - max_k: 胶囊(兴趣)的个数
+ - max_seq_len: hist seq的最大长度
+ - high_dim: 兴趣向量的维度
+ - num_iters: 动态路由(兴趣聚类)的轮数
+ - routing_logits_scale: 放大routing logits, >0时生效;
+ - 一些场景显示设置为20时,兴趣向量比较分散, 即相似度比较低(0.8左右)
+ - routing_logits_stddev: routing_logits初始化的标准差
+ - squash_pow: 对squash加的power, 防止squash之后的向量值变得太小
+- simi_pow: 对相似度做的倍数, 放大interests之间的差异
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+- user_seq_combine:
+ - CONCAT: 多个seq之间采取concat的方式融合
+ - SUM: 多个seq之间采取sum的方式融合, default是SUM
+
+### 示例Config
+
+[mind_on_books.config](../configs/mind_on_books.config)
+
+### 效果评估
+
+离线的效果评估主要看在测试集上的hitrate. 可以参考文档[效果评估](https://easyrec.oss-cn-beijing.aliyuncs.com/docs/recall_eval.pdf)
+
+#### 评估sql
+
+```sql
+pai -name tensorflow1120_cpu_ext
+ -Dscript='oss://easyrec/deploy/easy_rec/python/tools/hit_rate_pai.py'
+ -Dbuckets='oss://easyrec/'
+ -Darn='acs:ram::xxx:role/aliyunodpspaidefaultrole'
+ -DossHost='oss-cn-beijing-internal.aliyuncs.com'
+ -Dtables='odps://pai_rec/tables/mind_item_embedding/dt=${ymd},odps://pai_rec/tables/mind_user_seq_and_embedding/dt=${eval_ymd}'
+ -Doutputs='odps://pai_rec/tables/mind_hitrate_details/dt=${ymd}/name=mind_top200,odps://pai_rec/tables/mind_total_hitrate/dt=${ymd}/name=mind_top200'
+ -Dcluster='{
+ \"ps\" : {
+ \"count\" : 1,
+ \"cpu\" : 800,
+ \"memory\" : 20000
+ },
+ \"worker\" : {
+ \"count\" : 16,
+ \"cpu\" : 800,
+ \"memory\" : 20000
+ }
+ }'
+ -DuserDefinedParameters='--recall_type=u2i --top_k=200 --emb_dim=32 --knn_metric=1 --knn_strict=False --batch_size=1024 --num_interests=3';
+```
+
+- mind_user_seq_and_embedding:
+ - user_id: string
+ - item_ids: string, ","分割
+ - user_emb: string, 多个向量之间用"|"分割, 向量内部用","分割
+ - user_emb_num: bigint, user兴趣向量的最大个数
+ - 说明: 不限制列名的定义,但是限制列的顺序: 0:user_id, 1:item_ids, 2:user_emb, 3:user_emb_num
+ - Local需要修改easy_rec/python/tools/hitrate.py
+- mind_item_embedding:
+ - item_id: bigint
+ - item_emb: string, item embedding, 向量内部用","分割
+ - 说明: 不限制列名的定义,但是限制列的顺序: 0:item_id, 1:item_emb
+ - Local可以按照下面的格式准备item embedding数据:
+ ```text
+ id:int64 feature:string
+ 63133 0.125,0.286,0.913,0.893
+ ```
+- num_interests: 最大的兴趣向量数
+- knn_strict: 是否使用精确的knn计算, 会导致计算量增加
+- knn_metric: 定义距离计算方式
+ - 0: L2 distance
+ - 1: Inner Product similarity
+- emb_dim: user / item表征向量的维度
+- top_k: knn检索取top_k计算hitrate
+- recall_type:
+ - u2i: user to item retrieval
+
+#### 评估结果
+
+输出下面两张表
+
+- mind_hitrate_details:
+
+ - 输出每一个user的hitrate = user_hits / user_recalls
+ - 格式如下:
+
+ ```text
+ id : bigint
+ topk_ids : string
+ topk_dists : string
+ hitrate : double
+ bad_ids : string
+ bad_dists : string
+ ```
+
+- mind_total_hitrate:
+
+ - 输出平均hitrate = SUM(user_hits) / SUM(user_recalls)
+ - 格式如下:
+
+ ```text
+ hitrate : double
+ ```
diff --git a/examples/match_model/mind_negative_sample.md b/examples/match_model/mind_negative_sample.md
new file mode 100644
index 000000000..653a78b0d
--- /dev/null
+++ b/examples/match_model/mind_negative_sample.md
@@ -0,0 +1,154 @@
+# MIND负采样版
+
+### 简介
+
+mind召回模型, 在dssm的基础上加入了兴趣聚类功能,支持多兴趣召回,能够显著的提升召回层的效果,支持训练时负采样。
+
+
+### 参考论文
+
+[MIND.pdf](https://arxiv.org/pdf/1904.08030.pdf)
+
+### 配置说明
+
+```protobuf
+ ...
+eval_config {
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config: {
+ ...
+ negative_sampler {
+ input_path: 'examples/data/book_data/negative_book_data'
+ num_sample: 1024
+ num_eval_sample: 1024
+ attr_fields: 'book_id'
+ item_id_field: 'book_id'
+ }
+}
+model_config:{
+ model_class: "MIND"
+ feature_groups: {
+ group_name: 'hist'
+ feature_names: 'book_id_seq'
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'book_id'
+ wide_deep:DEEP
+ }
+ mind {
+ user_dnn {
+ hidden_units: [128, 64, 32]
+ }
+ item_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ concat_dnn {
+ hidden_units: [64, 32]
+ }
+
+ capsule_config {
+ max_k: 3
+ max_seq_len: 50
+ high_dim: 64
+ }
+ item_id: "book_id"
+ l2_regularization: 1e-6
+ ignore_in_batch_neg_sam: true
+ }
+ embedding_regularization: 5e-5
+ loss_type: SOFTMAX_CROSS_ENTROPY
+}
+```
+
+- model_class: 'MIND', 不需要修改
+- feature_groups: 需要三个feature_group: hist, user和item, **group name不能变**
+- mind: mind相关的参数,必须配置user_dnn和item_dnn
+- user_dnn: user侧的dnn参数
+ - dnn:
+ - hidden_units: dnn每一层的channel数
+ - use_bn: 是否使用batch_norm, 默认是true
+- item_dnn: item侧的dnn参数, 配置同user_dnn
+ - note: item侧不能用batch_norm
+- pre_capsule_dnn: 进入capsule之前的dnn的配置
+ - 可选, 配置同user_dnn和item_dnn
+- concat_dnn: hist seq 和 user feature融合后的dnn
+- capsule_config: 胶囊(动态路由)的配置
+ - max_k: 胶囊(兴趣)的个数
+ - max_seq_len: hist seq的最大长度
+ - high_dim: 兴趣向量的维度
+ - num_iters: 动态路由(兴趣聚类)的轮数
+ - routing_logits_scale: 放大routing logits, >0时生效;
+ - 一些场景显示设置为20时,兴趣向量比较分散, 即相似度比较低(0.8左右)
+ - routing_logits_stddev: routing_logits初始化的标准差
+ - squash_pow: 对squash加的power, 防止squash之后的向量值变得太小
+- simi_pow: 对相似度做的倍数, 放大interests之间的差异
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+- user_seq_combine:
+ - CONCAT: 多个seq之间采取concat的方式融合
+ - SUM: 多个seq之间采取sum的方式融合, default是SUM
+
+注意,DSSM负采样版目前仅支持recall_at_topk做评估指标。
+
+#### 负采样配置
+
+目前支持四种负采样Sampler:
+
+- negative_sampler:加权随机负采样,会排除Mini-Batch内的Item Id
+ - input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - num_sample: 训练worker的负采样数
+ - num_eval_sampler: 评估worker的负采样数
+ - attr_fields: Item特征名,顺序与Item的attr中特征的拼接顺序保持一致
+ - item_id_field: item_id列名
+- negative_sampler_v2:加权随机负采样,会跟排除Mini-Batch内的User有边的Item Id
+ - user_input_path: User表, Schema为: id:int64 | weight:float
+ - item_input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - pos_edge_input_path: Positive边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - user_id_field: user_id列名
+ - 其余同negative_sampler
+- hard_negative_sampler:加权随机负采样,会排除Mini-Batch内的Item Id,同时HardNegative边表中(一般为曝光未点击)进行负采样作为HardNegative
+ - user_input_path: User表, Schema为: id:int64 | weight:float
+ - item_input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - hard_neg_edge_input_path: HardNegative边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - num_hard_sample: hard negative的最大采样数目
+ - user_id_field: user_id列名
+ - 其余同negative_sampler
+- hard_negative_sampler_v2:加权随机负采样,会跟排除Mini-Batch内的User有边的Item Id,同时HardNegative边表中(一般为曝光未点击)进行负采样作为HardNegative
+ - user_input_path: User表, Schema为: id:int64 | weight:float
+ - item_input_path: 负采样Item表, Schema为: id:int64 | weight:float | attrs:string,其中attr为":"分隔符拼接的Item特征
+ - pos_edge_input_path: Positive边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - hard_neg_edge_input_path: HardNegative边表, Schema为: userid:int64 | itemid:int64 | weight:float
+ - num_hard_sample: hard negative的最大采样数目
+ - user_id_field: user_id列名
+ - 其余同negative_sampler
+ 一般用negative_sampler即可。
+
+### 示例Config
+
+[mind_on_books_negative_sample.config](../configs/mind_on_books_negative_sample.config)
+
+### 效果评估
+
+可参考MIND的评估方式 [mind.md](mind.md)。
diff --git a/examples/match_model/readme.md b/examples/match_model/readme.md
new file mode 100644
index 000000000..a1e4d6b2e
--- /dev/null
+++ b/examples/match_model/readme.md
@@ -0,0 +1,37 @@
+# Introduction
+
+在召回任务的模型实验中,我们提供了一个公开数据集(Amazon Books)的模型demo。
+
+# Amazon Books 数据集
+
+在此数据集中, 提供了2个模型及其负采样版的demo示例 [DSSM](dssm.md) / [DSSM-Negative-Sample](dssm_negative_sample.md) / [MIND](mind.md) / [MIND-Negative-Sample](mind_negative_sample.md)。更多模型可参考[models](../../docs/source/models/)。
+
+- DSSM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/dssm_on_books.config `
+
+- DSSM with Negative Sample
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/dssm_on_books_negative_sample.config `
+
+- MIND
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/mind_on_books.config `
+
+- MIND with Negative Sample
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/mind_on_books_negative_sample.config `
+
+### Results
+
+| DataSet | Model | Epoch | AUC |
+| ------------ | ----- | ----- | ------ |
+| Amazon-Books | DSSM | 2 | 0.8173 |
+| Amazon-Books | MIND | 2 | 0.7511 |
+
+
+
+注:评估召回模型及负采样版的效果建议参考HitRate指标,具体评估方法见[HitRate评估](https://easyrec.oss-cn-beijing.aliyuncs.com/docs/recall_eval.pdf)
diff --git a/examples/rank_model/DeepFM.md b/examples/rank_model/DeepFM.md
new file mode 100644
index 000000000..c5321d0e8
--- /dev/null
+++ b/examples/rank_model/DeepFM.md
@@ -0,0 +1,66 @@
+# DeepFM
+
+### 简介
+
+DeepFM是在WideAndDeep基础上加入了FM模块的改进模型。FM模块和DNN模块共享相同的特征,即相同的Embedding。
+
+
+
+### 参考论文
+
+[DeepFM](https://arxiv.org/abs/1703.04247)
+
+### 配置说明
+
+```protobuf
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ ...
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ ...
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
+```
+
+- model_class: 'DeepFM', 不需要修改
+
+- feature_groups:
+
+ 需要两个feature_group: wide group和deep group, **group name不能变**
+
+- deepfm: deepfm相关的参数
+
+- dnn: deep part的参数配置
+
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+
+- wide_output_dim: wide部分输出的大小
+
+- final_dnn: 整合wide part, fm part, deep part的参数输入, 可以选择是否使用
+
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+
+### 示例Config
+
+[deepfm_on_movielens.config](../configs/deepfm_on_movielens.config)
+[deepfm_on_criteo.config](../configs/deepfm_on_criteo.config)
diff --git a/examples/rank_model/autoint.md b/examples/rank_model/autoint.md
new file mode 100644
index 000000000..ed257d679
--- /dev/null
+++ b/examples/rank_model/autoint.md
@@ -0,0 +1,55 @@
+# AutoInt
+
+### 简介
+
+Automatic Feature Interaction Learning via Self-Attentive Neural Networks(AutoInt)通过将特征都映射到相同的低维空间中,然后利用带有残差连接的 Multi-head Self-Attention 机制显示构造高阶特征,对低维空间中的特征交互进行显式建模,有效提升了CTR预估的准确率。
+注意:AutoInt 模型要求所有输入特征的 embedding_dim 保持一致。
+
+
+
+### 参考论文
+
+[AutoInt](https://dl.acm.org/doi/pdf/10.1145/3357384.3357925)
+
+### 配置说明
+
+```protobuf
+model_config: {
+ model_class: 'AutoInt'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ autoint {
+ multi_head_num: 2
+ multi_head_size: 32
+ interacting_layer_num: 3
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
+```
+
+- model_class: 'AutoInt', 不需要修改
+
+- feature_groups: 配置一个名为'all'的feature_group。
+
+- autoint: autoint相关的参数
+
+ - model_dim: 与特征的embedding_dim保持一致
+ - multi_head_size: Multi-head Self-attention 中的 head size,默认为1
+ - interacting_layer_num: 交叉层的层数,建议设在1到5之间,默认为1
+ - l2_regularization: L2正则,防止 overfit
+
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+
+### 示例Config
+
+[autoint_on_movielens.config](../configs/autoint_on_movielens.config)
diff --git a/examples/rank_model/dcn.md b/examples/rank_model/dcn.md
new file mode 100644
index 000000000..e268ed640
--- /dev/null
+++ b/examples/rank_model/dcn.md
@@ -0,0 +1,73 @@
+# DCN
+
+### 简介
+
+Deep&Cross Network(DCN)是在DNN模型的基础上,引入了一种新型的交叉网络,该网络在学习某些特征交叉时效率更高。特别是,DCN显式地在每一层应用特征交叉,不需要人工特征工程,并且只增加了很小的额外复杂性。
+
+
+
+### 参考论文
+
+[DCN](https://arxiv.org/abs/1708.05123)
+
+### 配置说明
+
+```protobuf
+model_config: {
+ model_class: 'DCN'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'job_id'
+ feature_names: 'age'
+ feature_names: 'gender'
+ feature_names: 'year'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ dcn {
+ deep_tower {
+ input: "all"
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ cross_tower {
+ input: "all"
+ cross_num: 5
+ }
+ final_dnn {
+ hidden_units: [64, 32, 16]
+ }
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
+```
+
+- model_class: 'DCN', 不需要修改
+
+- feature_groups: 配置一个名为'all'的feature_group。
+
+- dcn: dcn相关的参数
+
+- deep_tower
+
+ - dnn: deep part的参数配置
+
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+
+- cross_tower
+
+ - cross_num: 交叉层层数,默认为3
+
+- final_dnn: 整合wide part, fm part, deep part的参数输入, 可以选择是否使用
+
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+
+### 示例Config
+
+[dcn_on_movielens.config](../configs/dcn_on_movielens.config)
diff --git a/examples/rank_model/fm.md b/examples/rank_model/fm.md
new file mode 100644
index 000000000..c59e3bff2
--- /dev/null
+++ b/examples/rank_model/fm.md
@@ -0,0 +1,51 @@
+# FM
+
+### 简介
+
+FM模型的主要应用场景是点击率预估,目的是在数据高维稀疏的情况下,解决特征的组合问题。
+
+
+
+### 参考论文
+
+[FM](https://www.csie.ntu.edu.tw/%7Eb97053/paper/Rendle2010FM.pdf)
+
+### 配置说明
+
+```protobuf
+model_config: {
+ model_class: 'FM'
+ feature_groups: {
+ group_name: 'wide'
+ feature_names: 'F1'
+ feature_names: 'F2'
+ ...
+ wide_deep:WIDE
+ }
+ feature_groups: {
+ group_name: 'deep'
+ feature_names: 'F1'
+ feature_names: 'F2'
+ ...
+ wide_deep: DEEP
+ }
+
+ fm {
+ }
+ embedding_regularization: 1e-5
+}
+```
+
+- model_class: 'FM', 不需要修改
+
+- feature_groups:
+
+需要一个feature_group: wide group **group name不能变**
+
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+
+- input_type: 如果在提交到pai-tf集群上面运行,读取max compute 表作为输入数据,data_config:input_type要设置为OdpsInputV2。
+
+### 示列Config
+
+[fm_on_criteo.config](../configs/fm_on_criteo.config)
diff --git a/examples/rank_model/readme.md b/examples/rank_model/readme.md
new file mode 100644
index 000000000..f6a2ba791
--- /dev/null
+++ b/examples/rank_model/readme.md
@@ -0,0 +1,57 @@
+# Introduction
+
+在排序任务的模型实验中,我们提供了两个公开数据集(MovieLens-1M, Criteo Research Kaggle)的模型demo。
+
+# MovieLens-1M 数据集
+
+在MovieLens-1M 数据集中, 我们提供了4个模型上的demo示例。更多模型可参考[models](../../docs/source/models/)
+
+[Wide&Deep](wide_and_deep.md) / [DeepFM](deepfm.md) / [DCN](dcn.md) / [AutoInt](din.md)
+
+- Wide & Deep
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/wide_and_deep_on_movieslen.config `
+
+- DeepFM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/deepfm_on_movieslen.config `
+
+- DCN
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/dcn_on_movieslen.config `
+
+- AutoInt
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/autoint_on_movieslen.config `
+
+### Results
+
+| DataSet | Model | AUC |
+| ------------ | --------- | ------ |
+| MovieLens-1M | Wide&Deep | 0.8558 |
+| MovieLens-1M | DeepFM | 0.8688 |
+| MovieLens-1M | DCN | 0.8576 |
+| MovieLens-1M | AutoInt | 0.8513 |
+| MovieLens-1M | MaskNet | 0.8872 |
+| MovieLens-1M | FibiNet | 0.8879 |
+
+# Criteo Research Kaggle 数据集
+
+在 `Criteo Research Kaggle` 数据集中, 我们提供了2个模型上的demo示例。
+
+[FM](fm.md) / [DeepFM](deepfm.md)
+
+- FM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/fm_on_criteo.config`
+
+- DeepFM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/deepfm_on_criteo.config`
+
+### Results
+
+| DataSet | Model | AUC |
+| --------------- | ------ | ------ |
+| Criteo-Research | FM | 0.7577 |
+| Criteo-Research | DeepFM | 0.7967 |
diff --git a/examples/rank_model/wide_and_deep.md b/examples/rank_model/wide_and_deep.md
new file mode 100644
index 000000000..2ba30d3df
--- /dev/null
+++ b/examples/rank_model/wide_and_deep.md
@@ -0,0 +1,72 @@
+# WideAndDeep
+
+### 简介
+
+WideAndDeep包含Wide和Deep两部分,Wide部分负责记忆,Deep部分负责泛化。Wide部分可以做显式的特征交叉,Deep部分可以实现隐式自动的特征交叉。
+
+
+
+### 参考论文
+
+[WideAndDeep](https://arxiv.org/abs/1606.07792)
+
+### 配置说明
+
+```protobuf
+model_config:{
+ model_class: "WideAndDeep"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "user_id"
+ feature_names: "movie_id"
+ ...
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "user_id"
+ feature_names: "movie_id"
+ ...
+ wide_deep:WIDE
+ }
+
+ wide_and_deep {
+ wide_output_dim: 16
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+
+ final_dnn {
+ hidden_units: [64, 32, 16]
+ }
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
+```
+
+- model_class: 'WideAndDeep', 不需要修改
+
+- feature_groups:
+
+ 需要两个feature_group: wide group和deep group, **group name不能变**
+
+- wide_and_deep: wide_and_deep 相关的参数
+
+- dnn: deep part的参数配置
+
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+
+- wide_output_dim: wide部分输出的大小
+
+- final_dnn: 整合wide part, deep part的参数输入, 可以选择是否使用
+
+ - hidden_units: dnn每一层的channel数目,即神经元的数目
+
+- embedding_regularization: 对embedding部分加regularization,防止overfit
+
+- input_type: 如果在提交到pai-tf集群上面运行,读取 MaxCompute 表作为输入数据,data_config:input_type要设置为OdpsInputV2。
+
+### 示例Config
+
+[wide_and_deep_on_movielens.config](../configs/wide_and_deep_on_movieslen.config)
diff --git a/examples/readme.md b/examples/readme.md
new file mode 100644
index 000000000..dc2122d2e
--- /dev/null
+++ b/examples/readme.md
@@ -0,0 +1,303 @@
+# 介绍
+
+我们准备了一系列的Demo帮助用户快速体验EasyRec的功能,降低使用EasyRec的门槛。
+
+这些Demo包含了在公开数据集上针对不同模型做的多种实验,涵盖了推荐系统中的召回任务和排序任务,主要包括数据集下载、预处理、模型配置、训练及评估等过程。
+
+# 安装EasyRec
+
+我们提供了`本地Anaconda安装`和`Docker镜像启动`两种方式。
+
+## 本地Anaconda安装
+
+Demo实验中使用的环境为 `python=3.6.8` + `tensorflow=1.12.0`
+
+```bash
+conda create -n py36_tf12 python=3.6.8
+conda activate py36_tf12
+pip install tensorflow==1.12.0
+```
+
+```bash
+git clone https://github.com/alibaba/EasyRec.git
+cd EasyRec
+bash scripts/init.sh
+python setup.py install
+
+```
+
+## Docker镜像启动
+
+### 方法一:拉取已上传的镜像(推荐)
+
+```bash
+git clone https://github.com/alibaba/EasyRec.git
+cd EasyRec
+
+-- Docker环境可选
+(1) `python=3.6.9` + `tensorflow=1.15.5`
+docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.8.5
+docker run -td --network host -v /local_path/EasyRec:/docker_path/EasyRec mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-0.8.5
+docker exec -it bash
+
+
+(2) `python=3.8.10` + `tensorflow=2.12.0`
+docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-0.8.5
+docker run -td --network host -v /local_path/EasyRec:/docker_path/EasyRec mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-0.8.5
+
+docker exec -it bash
+```
+
+### 方法二:自行构建Docker镜像
+
+```bash
+git clone https://github.com/alibaba/EasyRec.git
+cd EasyRec
+
+-- Docker环境可选
+(1) `python=3.6.9` + `tensorflow=1.15.5`
+bash scripts/build_docker_tf115.sh
+sudo docker run -td --network host -v /local_path:/docker_path mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-
+
+(2) `python=3.8.10` + `tensorflow=2.12.0`
+bash scripts/build_docker_tf212.sh
+sudo docker run -td --network host -v /local_path:/docker_path mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-
+
+sudo docker exec -it bash
+```
+
+注:\需匹配当前EasyRec版本。
+
+# 准备数据集
+
+在`data/xxx/download_and_process.sh`文件中提供了数据集的下载、解压、数据预处理等步骤,执行完成后会在目录下得到`xxx_train_data`和`xxx_test_data`两个文件。
+
+下面分别是三种常用数据集的下载和预处理:
+
+- MovieLens-1M (详细见:[data/movielens_1m/](data/movielens_1m/)。 也可跳过预处理,直接通过链接下载处理后的数据集: [movies_train_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/movielens_1m/movies_train_data)、[movies_test_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/movielens_1m/movies_test_data))
+
+ ```bash
+ cd examples/data/movielens_1m
+ sh download_and_process.sh
+ ```
+
+- Criteo-Research-Kaggle (详细见:[data/criteo/](data/criteo/)。也可跳过预处理,直接通过链接下载处理后的数据集: [criteo_train_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/criteo_kaggle/criteo_train_data)、[criteo_test_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/criteo_kaggle/criteo_test_data))
+
+ ```bash
+ cd examples/data/criteo
+ sh download_and_process.sh
+ ```
+
+- Amazon Books (详细见:[data/amazon_books_data/](data/amazon_books_data/)。也可跳过预处理,直接通过链接直接下载处理后的数据集: [amazon_train_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/amazon_books/amazon_train_data)、[amazon_test_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/amazon_books/amazon_test_data)、[negative_book_data](https://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/amazon_books/negative_book_data))
+
+ ```bash
+ cd examples/data/amazon_books_data
+ sh download_and_process.sh
+ ```
+
+
+
+# 示例Config
+
+EasyRec的模型训练和评估都是基于config配置文件的,配置文件采用prototxt格式。在大多数任务中,我们只需要创建config文件就能满足相应的应用。
+
+我们提供了用于demo实验的完整示例config文件,详细见: [configs](configs/)。
+
+**排序任务**
+
+- [wide_and_deep_on_movielens.config](configs/wide_and_deep_on_movielens.config)
+
+- [deepfm_on_movielens.config](configs/deepfm_on_movielens.config)
+
+- [deepfm_backbone_on_movielens.config](configs/deepfm_backbone_on_movielens.config)
+
+- [dcn_on_movielens.config](configs/dcn_on_movielens.config)
+
+- [autoint_on_movielens.config](configs/autoint_on_movielens.config)
+
+- [masknet_on_movielens.config](configs/masknet_on_movielens.config)
+
+- [fibinet_on_movielens.config](configs/fibinet_on_movielens.config)
+
+- [fm_on_criteo.config](configs/fm_on_criteo.config)
+
+- [deepfm_on_criteo.config](configs/deepfm_on_criteo.config)
+
+- [deepfm_backbone_on_criteo.config](configs/deepfm_backbone_on_criteo.config)
+
+**召回任务**
+
+- [dssm_on_books.config](configs/dssm_on_books.config)
+- [mind_on_books.config](configs/mind_on_books.config)
+- [dssm_on_books_negative_sample.config](configs/dssm_on_books_negative_sample.config)
+- [mind_on_books_negative_sample.config](configs/mind_on_books_negative_sample.config)
+
+# 训练及评估
+
+通过指定对应的pipeline_config_path文件即可启动命令训练模型并评估。更多模型可参考[models](../../docs/source/models/)。
+
+### 排序任务 + MovieLens-1M 数据集
+
+在此数据集中, 提供了4个模型上的demo示例([Wide&Deep](rank_model/wide_and_deep.md) / [DeepFM](rank_model/deepfm.md) / [DCN](rank_model/dcn.md) / [AutoInt](rank_model/din.md))。
+
+- Wide & Deep
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/wide_and_deep_on_movielens.config `
+
+- DeepFM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/deepfm_on_movielens.config `
+
+- DCN
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/dcn_on_movielens.config `
+
+- AutoInt
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/autoint_on_movielens.config `
+
+### 排序任务 + Criteo Research Kaggle 数据集
+
+在此数据集中, 提供了2个模型上的demo示例([FM](rank_model/fm.md) / [DeepFM](rank_model/deepfm.md))。
+
+- FM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/fm_on_criteo.config`
+
+- DeepFM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/deepfm_on_criteo.config`
+
+### 召回任务 + Amazon Books 数据集
+
+在此数据集中, 提供了2个模型及其负采样版的demo示例 [DSSM](match_model/dssm.md) / [MIND](match_model/mind.md) / [DSSM-Negative-Sample](match_model/dssm_negative_sample.md) / [MIND-Negative-Sample](match_model/mind_negative_sample.md) 。
+
+- DSSM
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/dssm_on_books.config `
+
+- MIND
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/mind_on_books.config `
+
+- DSSM with Negative Sample
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/dssm_on_books_negative_sample.config `
+
+- MIND with Negative Sample
+
+ `python -m easy_rec.python.train_eval --pipeline_config_path examples/configs/mind_on_books_negative_sample.config `
+
+#### GPU单机单卡:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.train_eval --pipeline_config_path *.config
+```
+
+- --pipeline_config_path: 训练用的配置文件
+- --continue_train: 是否继续训
+
+#### GPU PS训练
+
+- ps跑在CPU上
+- master跑在GPU:0上
+- worker跑在GPU:1上
+- Note: 本地只支持ps, master, worker模式,不支持ps, chief, worker, evaluator模式
+
+```bash
+wget https://easyrec.oss-cn-beijing.aliyuncs.com/scripts/train_2gpu.sh
+sh train_2gpu.sh *.config
+```
+
+
+
+
+
+# 评估及导出
+
+通过修改pipeline_config_path文件即可评估及导出对应的模型。
+
+- 模型评估
+
+ `python -m easy_rec.python.eval --pipeline_config_path examples/configs/deepfm_on_criteo.config`
+
+- 模型导出
+
+ `python -m easy_rec.python.export --pipeline_config_path examples/configs/deepfm_on_criteo.config --export_dir examples/ckpt/export/deepfm_on_criteo`
+
+# 评估结果
+
+在公开数据集上的demo实验以及评估结果如下,仅供参考。
+
+### 排序模型
+
+- MovieLens-1M
+
+ | Model | Epoch | AUC |
+ | -------------------- | ----- | ------ |
+ | MLP | 1 | 0.8616 |
+ | Wide&Deep | 1 | 0.8558 |
+ | Wide&Deep(Backbone) | 1 | 0.8854 |
+ | MultiTower(Backbone) | 1 | 0.8814 |
+ | DeepFM | 1 | 0.8867 |
+ | DeepFM(Backbone) | 1 | 0.8872 |
+ | DCN | 1 | 0.8576 |
+ | DCN_v2 | 1 | 0.8770 |
+ | AutoInt | 1 | 0.8513 |
+ | MaskNet | 1 | 0.8872 |
+ | FibiNet | 1 | 0.8893 |
+
+ 备注:`MovieLens-1M` 数据集较小,评估指标方差较大,以上结果仅供参考。
+
+- Criteo-Research
+
+ | Model | Epoch | AUC |
+ | ----------------- | ----- | ------- |
+ | FM | 1 | 0.7577 |
+ | DeepFM | 1 | 0.7970 |
+ | DeepFM (backbone) | 1 | 0.7970 |
+ | DeepFM (periodic) | 1 | 0.7979 |
+ | DeepFM (autodis) | 1 | 0.7982 |
+ | DLRM | 1 | 0.79785 |
+ | DLRM (backbone) | 1 | 0.7983 |
+ | DLRM (senet) | 1 | 0.7995 |
+ | DLRM (standard) | 1 | 0.7949 |
+ | DLRM (autodis) | 1 | 0.7989 |
+ | DLRM (periodic) | 1 | 0.7998 |
+
+### 召回模型
+
+- Amazon Books Data
+
+ | Model | Epoch | AUC |
+ | ----- | ----- | ------ |
+ | DSSM | 2 | 0.8173 |
+ | MIND | 2 | 0.7511 |
+
+
+
+注:评估召回模型及负采样版的效果建议参考HitRate指标,具体评估方法见[HitRate评估](https://easyrec.oss-cn-beijing.aliyuncs.com/docs/recall_eval.pdf)。
diff --git a/git-lfs/git_lfs.py b/git-lfs/git_lfs.py
new file mode 100644
index 000000000..4eefc73e8
--- /dev/null
+++ b/git-lfs/git_lfs.py
@@ -0,0 +1,534 @@
+# -*- encoding:utf-8 -*-
+# two config files are used: .git_bin_path .git_bin_url
+import hashlib
+import json
+import logging
+import os
+import re
+import subprocess
+import sys
+import traceback
+
+blank_split = re.compile('[\t ]')
+
+logging.basicConfig(
+ format='[%(levelname)s] %(asctime)s %(filename)s[%(lineno)d] : %(message)s',
+ level=logging.INFO)
+
+try:
+ import oss2
+except ImportError:
+ logging.error(
+ 'please install python_oss from https://github.com/aliyun/aliyun-oss-python-sdk.git'
+ )
+ sys.exit(1)
+
+git_bin_path = '.git_bin_path'
+git_bin_url_path = '.git_bin_url'
+# temporary storage path
+git_oss_cache_dir = '.git_oss_cache'
+
+
+# get project name by using git remote -v
+def get_proj_name():
+ proj_name = subprocess.check_output(['git', 'remote', '-v'])
+ proj_name = proj_name.decode('utf-8')
+ proj_name = proj_name.split('\n')[0]
+ proj_name = blank_split.split(proj_name)[1]
+ proj_name = proj_name.split('/')[-1]
+ proj_name = proj_name.replace('.git', '')
+ return proj_name
+
+
+# load .git_bin_url
+# local_path md5 remote_path
+def load_git_url():
+ git_bin_url_map = {}
+ try:
+ with open(git_bin_url_path) as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ line_json = json.loads(line_str)
+ git_bin_url_map[line_json['leaf_path']] = (line_json['sig'],
+ line_json['remote_path'])
+ except Exception as ex:
+ logging.warning('exception: %s' % str(ex))
+ pass
+ return git_bin_url_map
+
+
+def save_git_url(http://github.com/git_bin_url_map):
+ with open(git_bin_url_path, 'w') as fout:
+ keys = list(git_bin_url_map.keys())
+ keys.sort()
+ for key in keys:
+ val = git_bin_url_map[key]
+ tmp_str = '{"leaf_path": "%s", "sig": "%s", "remote_path": "%s"}' % (
+ key, val[0], val[1])
+ fout.write('%s\n' % tmp_str)
+
+
+def path2name(path):
+ name = path.replace('//', '/')
+ name = path.replace('/', '_')
+ if name[-1] == '_':
+ return name[:-1]
+ elif name == '.':
+ return 'curr_dir'
+ else:
+ return name
+
+
+def get_file_arr(path):
+ archive_files = []
+ if os.path.isdir(path):
+ for one_file in os.listdir(path):
+ one_path = path + '/' + one_file
+ if not os.path.isdir(one_path):
+ archive_files.append(one_path)
+ return archive_files
+ else: # just a file
+ archive_files.append(path)
+ return archive_files
+
+
+def load_git_bin():
+ file_arr = {}
+ if not os.path.exists(git_bin_path):
+ return file_arr
+
+ with open(git_bin_path, 'r') as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ try:
+ line_json = json.loads(line_str)
+ file_arr[line_json['leaf_name']] = line_json['leaf_file']
+ except Exception as ex:
+ logging.warning('%s is corrupted : %s' %
+ (git_bin_path, traceback.format_exc(ex)))
+ return file_arr
+
+
+def save_git_bin(git_arr):
+ leaf_paths = list(git_arr.keys())
+ leaf_paths.sort()
+ with open(git_bin_path, 'w') as fout:
+ for leaf_path in leaf_paths:
+ leaf_files = git_arr[leaf_path]
+ leaf_files.sort()
+ # make sure that leaf_name is in front of leaf_file
+ tmp_str = '{"leaf_name": "%s", "leaf_file": %s}' % (
+ leaf_path, json.dumps(leaf_files))
+ fout.write('%s\n' % tmp_str)
+
+
+def recheck_git_bin():
+ file_arr = load_git_bin()
+ update = False
+ del_arr = []
+ for leaf_path in file_arr:
+ leaf_files = file_arr[leaf_path]
+ good_leaf_files = [x for x in leaf_files if os.path.exists(x)]
+ if not os.path.exists(leaf_path):
+ del_arr.append(leaf_path)
+ update = True
+ elif len(good_leaf_files) != len(leaf_files):
+ file_arr[leaf_path] = good_leaf_files
+ update = True
+ for leaf_path in del_arr:
+ del file_arr[leaf_path]
+ if update:
+ save_git_bin(file_arr)
+ return file_arr
+
+
+# check whether a folder changes by check md5 of the tar file of the folder
+# note -z option is not used, because the file has random effects
+# the md5files are saved in .git_bin_url
+def get_local_sig(leaf_files):
+ if len(leaf_files) == 0:
+ logging.warning('no leaf files')
+ return None
+ leaf_files = sorted(leaf_files)
+ m = hashlib.md5()
+ block_size = 1024 * 1024 * 8
+ for one_file in leaf_files:
+ with open(one_file, 'rb') as fin:
+ for chunk in iter(lambda: fin.read(block_size), b''):
+ m.update(chunk)
+ return m.hexdigest()
+
+
+def list_leafs(curr_path):
+ bottom_dir = []
+ if os.path.isdir(curr_path):
+ for root, dirs, files in os.walk(curr_path, topdown=True):
+ if len(dirs) == 0 or len(files) > 0:
+ if root[-1] == '/':
+ root = root[:-1]
+ file_arr = get_file_arr(root)
+ bottom_dir.append((root, file_arr))
+ else: # a single file
+ curr_dir = os.path.dirname(curr_path)
+ if curr_dir == '':
+ curr_dir = '.'
+ bottom_dir.append((curr_dir, [curr_path]))
+ return bottom_dir
+
+
+# check whether lst0 and lst1 contain the same string elements
+def lst_eq(lst0, lst1):
+ if len(lst0) != len(lst1):
+ return False
+ for x in lst1:
+ if x not in lst0:
+ return False
+ return True
+
+
+def merge_lst(lst0, lst1):
+ for a in lst1:
+ if a not in lst0:
+ lst0.append(a)
+ return lst0
+
+
+def has_conflict(leaf_path, leaf_files):
+ if not os.path.exists(leaf_path):
+ return False
+ for leaf_file in leaf_files:
+ if os.path.exists(leaf_file):
+ return True
+ return False
+
+
+def get_yes_no(msg):
+ while True:
+ logging.info(msg)
+ tmp_op = sys.stdin.readline()
+ tmp_op = tmp_op.strip()
+ if len(tmp_op) == 0:
+ break
+ elif tmp_op[0] == 'Y' or tmp_op[0] == 'y':
+ update = True
+ break
+ elif tmp_op[0] == 'N' or tmp_op[0] == 'n':
+ update = False
+ break
+ return update
+
+
+if __name__ == '__main__':
+ if len(sys.argv) < 2:
+ logging.error(
+ 'usage: python git_lfs.py [pull] [push] [add filename] [resolve_conflict]'
+ )
+ sys.exit(1)
+ home_directory = os.path.expanduser('~')
+ with open('.git_oss_config_pub', 'r') as fin:
+ git_oss_data_dir = None
+ host = None
+ bucket_name = None
+ git_oss_private_path = None
+ enable_accelerate = 0
+ accl_endpoint = None
+ for line_str in fin:
+ line_str = line_str.strip()
+ if len(line_str) == 0:
+ continue
+ if line_str.startswith('#'):
+ continue
+ line_str = line_str.replace('~/', home_directory + '/')
+ line_str = line_str.replace('${TMPDIR}/',
+ os.environ.get('TMPDIR', '/tmp/'))
+ line_str = line_str.replace('${PROJECT_NAME}', get_proj_name())
+ line_tok = [x.strip() for x in line_str.split('=') if x != '']
+ if line_tok[0] == 'host':
+ host = line_tok[1]
+ elif line_tok[0] == 'git_oss_data_dir':
+ git_oss_data_dir = line_tok[1].strip('/')
+ elif line_tok[0] == 'bucket_name':
+ bucket_name = line_tok[1]
+ elif line_tok[0] == 'git_oss_private_config':
+ git_oss_private_path = line_tok[1]
+ if git_oss_private_path.startswith('~/'):
+ git_oss_private_path = os.path.join(home_directory,
+ git_oss_private_path[2:])
+ elif line_tok[0] == 'git_oss_cache_dir':
+ git_oss_cache_dir = line_tok[1]
+ elif line_tok[0] == 'accl_endpoint':
+ accl_endpoint = line_tok[1]
+
+ logging.info('git_oss_data_dir=%s, host=%s, bucket_name=%s' %
+ (git_oss_data_dir, host, bucket_name))
+
+ logging.info('git_oss_cache_dir: %s' % git_oss_cache_dir)
+
+ if not os.path.exists(git_oss_cache_dir):
+ os.makedirs(git_oss_cache_dir)
+
+ logging.info('git_oss_private_config=%s' % git_oss_private_path)
+ if git_oss_private_path is not None and os.path.exists(git_oss_private_path):
+ # load oss configs
+ with open(git_oss_private_path, 'r') as fin:
+ for line_str in fin:
+ line_str = line_str.strip()
+ line_tok = [x.strip() for x in line_str.split('=') if x != '']
+ if line_tok[0] in ['accessid', 'accessKeyID']:
+ accessid = line_tok[1]
+ elif line_tok[0] in ['accesskey', 'accessKeySecret']:
+ accesskey = line_tok[1]
+ oss_auth = oss2.Auth(accessid, accesskey)
+ oss_bucket = oss2.Bucket(oss_auth, host, bucket_name)
+ else:
+ logging.info('git_oss_private_path[%s] is not found, read-only mode' %
+ git_oss_private_path)
+ # pull only mode
+ oss_auth = None
+ oss_bucket = None
+
+ if sys.argv[1] == 'push':
+ updated = False
+ git_bin_arr = recheck_git_bin()
+ git_bin_url = load_git_url()
+ for leaf_path in git_bin_arr:
+ leaf_files = git_bin_arr[leaf_path]
+ # empty directory will not be push to oss
+ if len(leaf_files) == 0:
+ continue
+ file_name = path2name(leaf_path)
+ new_sig = get_local_sig(leaf_files)
+ if new_sig is None:
+ continue
+ if leaf_path in git_bin_url and git_bin_url[leaf_path][0] == new_sig:
+ continue
+ # build tar file and push to oss
+ file_name_with_sig = file_name + '_' + new_sig
+ tar_out_path = '%s/%s.tar.gz' % (git_oss_cache_dir, file_name_with_sig)
+ subprocess.check_output(['tar', '-czf', tar_out_path] + leaf_files)
+ save_path = '%s/%s' % (git_oss_data_dir, file_name_with_sig)
+ oss_bucket.put_object_from_file(save_path, tar_out_path)
+ oss_bucket.put_object_acl(save_path, oss2.OBJECT_ACL_PUBLIC_READ)
+ git_bin_url[leaf_path] = (new_sig, save_path)
+ logging.info('pushed %s' % leaf_path)
+ updated = True
+ for leaf_path in list(git_bin_url.keys()):
+ if leaf_path not in git_bin_arr:
+ del git_bin_url[leaf_path]
+ logging.info('dropped %s' % leaf_path)
+ updated = True
+ if updated:
+ save_git_url(http://github.com/git_bin_url)
+ logging.info('push succeed.')
+ else:
+ logging.warning('nothing to push')
+ subprocess.check_output(['git', 'add', git_bin_url_path])
+ elif sys.argv[1] == 'pull':
+ # pull images from remote
+ any_update = False
+ git_bin_arr = load_git_bin()
+ git_bin_url = load_git_url()
+ for leaf_path in git_bin_arr:
+ leaf_files = git_bin_arr[leaf_path]
+ if len(leaf_files) == 0:
+ if os.path.isfile(leaf_path):
+ logging.error('conflicts: %s is a file, but was a dir' % leaf_path)
+ elif not os.path.isdir(leaf_path):
+ os.makedirs(leaf_path)
+ continue
+ # newly add files
+ if leaf_path not in git_bin_url:
+ continue
+ file_name = path2name(leaf_path)
+ all_file_exist = True
+ for tmp in leaf_files:
+ if not os.path.exists(tmp):
+ all_file_exist = False
+ remote_sig = git_bin_url[leaf_path][0]
+ if all_file_exist:
+ local_sig = get_local_sig(leaf_files)
+ if local_sig == remote_sig:
+ continue
+ else:
+ local_sig = ''
+
+ update = False
+ if len(sys.argv) > 2 and (sys.argv[2] == '-f' or
+ sys.argv[2] == '--force'):
+ update = True
+ else:
+ if has_conflict(leaf_path, leaf_files):
+ update = get_yes_no(
+ 'update %s using remote file[remote_sig=%s local_sig=%s]?[N/Y]' %
+ (leaf_path, remote_sig, local_sig))
+ else:
+ update = True
+ if not update:
+ continue
+ # pull from remote oss
+ remote_path = git_bin_url[leaf_path][1]
+ _, file_name_with_sig = os.path.split(remote_path)
+ tar_tmp_path = '%s/%s.tar.gz' % (git_oss_cache_dir, file_name_with_sig)
+ max_retry = 5
+ while max_retry > 0:
+ try:
+ if not os.path.exists(tar_tmp_path):
+ in_cache = False
+ if oss_bucket:
+ oss_bucket.get_object_to_file(remote_path, tar_tmp_path)
+ else:
+ url = 'http://%s.%s/%s' % (bucket_name, host, remote_path)
+ # subprocess.check_output(['wget', url, '-O', tar_tmp_path])
+ if sys.platform.startswith('linux'):
+ subprocess.check_output(['wget', url, '-O', tar_tmp_path])
+ elif sys.platform.startswith('darwin'):
+ subprocess.check_output(['curl', url, '--output', tar_tmp_path])
+ elif sys.platform.startswith('win'):
+ subprocess.check_output(['curl', url, '--output', tar_tmp_path])
+ else:
+ in_cache = True
+ logging.info('%s is in cache' % file_name_with_sig)
+ subprocess.check_output(['tar', '-zxf', tar_tmp_path])
+ local_sig = get_local_sig(leaf_files)
+ if local_sig == remote_sig:
+ break
+ if in_cache:
+ logging.warning('cache invalid, will download from remote')
+ os.remove(tar_tmp_path)
+ continue
+ logging.warning('download failed, local_sig(%s) != remote_sig(%s)' %
+ (local_sig, remote_sig))
+ except subprocess.CalledProcessError as ex:
+ logging.error('exception: %s' % str(ex))
+ except oss2.exceptions.RequestError as ex:
+ logging.error('exception: %s' % str(ex))
+
+ os.remove(tar_tmp_path)
+ if accl_endpoint is not None and host != accl_endpoint:
+ logging.info('will try accelerate endpoint: %s' % accl_endpoint)
+ host = accl_endpoint
+ if oss_auth:
+ oss_bucket = oss2.Bucket(oss_auth, host, bucket_name)
+ max_retry -= 1
+
+ logging.info('%s updated' % leaf_path)
+ any_update = True
+ if not any_update:
+ logging.info('nothing to be updated')
+ elif sys.argv[1] == 'add':
+ add_path = sys.argv[2]
+ if not os.path.exists(add_path):
+ raise ValueError('add path %s does not exist' % add_path)
+ bin_file_map = {}
+ try:
+ bin_file_map = load_git_bin()
+ except Exception as ex:
+ logging.warning('load_git_bin exception: %s' % traceback.format_exc(ex))
+ pass
+ leaf_dirs = list_leafs(add_path)
+ any_new = False
+ for leaf_path, leaf_files in leaf_dirs:
+ for leaf_file in leaf_files:
+ tmp_out = subprocess.check_output(['git', 'ls-files', leaf_file])
+ if len(tmp_out.strip()) > 0:
+ subprocess.check_output(['git', 'rm', '--cached', leaf_file])
+ if leaf_path not in bin_file_map:
+ bin_file_map[leaf_path] = leaf_files
+ any_new = True
+ else: # check whether the files are the same
+ old_leaf_files = bin_file_map[leaf_path]
+ if not lst_eq(old_leaf_files, leaf_files):
+ bin_file_map[leaf_path] = merge_lst(old_leaf_files, leaf_files)
+ any_new = True
+ if any_new:
+ # write back to .git_bin_path
+ save_git_bin(bin_file_map)
+ logging.info('added %s' % add_path)
+ else:
+ logging.info('already add %s' % add_path)
+ subprocess.check_output(['git', 'add', '.git_bin_path'])
+ elif sys.argv[1] == 'remove':
+ del_path = sys.argv[2]
+ try:
+ bin_file_map = load_git_bin()
+ except Exception as ex:
+ logging.warning('load_git_bin exception: %s' % traceback.format_exc(ex))
+ pass
+ leaf_dirs = list_leafs(del_path)
+ any_update = False
+ for leaf_path, leaf_files in leaf_dirs:
+ if leaf_path in bin_file_map:
+ for leaf_file in leaf_files:
+ if leaf_file in bin_file_map[leaf_path]:
+ tmp_id = bin_file_map[leaf_path].index(leaf_file)
+ del bin_file_map[leaf_path][tmp_id]
+ any_update = True
+ if len(bin_file_map[leaf_path]) == 0:
+ del bin_file_map[leaf_path]
+ if any_update:
+ save_git_bin(bin_file_map)
+ logging.info('remove %s' % del_path)
+ elif sys.argv[1] == 'resolve_conflict':
+ git_objs = {}
+ with open(git_bin_path, 'r') as fin:
+ merge_start = 0
+ for line_str in fin:
+ if line_str.startswith('<<<<<<<'):
+ merge_start = 1
+ elif line_str.startswith('======='):
+ merge_start = 2
+ elif line_str.startswith('>>>>>>>'):
+ merge_start = 0
+ elif merge_start == 0:
+ tmp_obj = json.loads(line_str)
+ leaf_name = tmp_obj['leaf_name']
+ leaf_file = tmp_obj['leaf_file']
+ git_objs[leaf_name] = leaf_file
+ elif merge_start == 1:
+ tmp_obj = json.loads(line_str)
+ leaf_name = tmp_obj['leaf_name']
+ leaf_file = tmp_obj['leaf_file']
+ git_objs[leaf_name] = leaf_file
+ elif merge_start == 2:
+ tmp_obj = json.loads(line_str)
+ leaf_name = tmp_obj['leaf_name']
+ leaf_file = tmp_obj['leaf_file']
+ if leaf_name in git_objs:
+ union = git_objs[leaf_name]
+ for tmp in leaf_file:
+ if tmp not in union:
+ union.append(tmp)
+ logging.info('add %s to %s' % (tmp, leaf_name))
+ git_objs[leaf_name] = union
+ else:
+ git_objs[leaf_name] = leaf_file
+ else:
+ logging.warning('invalid state: merge_start = %d, line_str = %s' %
+ (merge_start, line_str))
+ save_git_bin(git_objs)
+
+ git_bin_url_map = {}
+ with open(git_bin_url_path, 'r') as fin:
+ merge_start = 0
+ for line_str in fin:
+ if line_str.startswith('<<<<<<<'):
+ merge_start = 1
+ elif line_str.startswith('======='):
+ merge_start = 2
+ elif line_str.startswith('>>>>>>>'):
+ merge_start = 0
+ elif merge_start in [0, 1, 2]:
+ line_json = json.loads(line_str)
+ if line_json['leaf_path'] in git_objs:
+ git_bin_url_map[line_json['leaf_path']] = (line_json['sig'],
+ line_json['remote_path'])
+ else:
+ logging.warning('invalid state: merge_start = %d, line_str = %s' %
+ (merge_start, line_str))
+ save_git_url(http://github.com/git_bin_url_map)
+ logging.info('all conflicts fixed.')
+ else:
+ logging.warning('invalid cmd: %s' % sys.argv[1])
+ logging.warning(
+ 'choices are: %s' %
+ ','.join(['push', 'pull', 'add', 'remove', 'resolve_conflict']))
diff --git a/pai_jobs/deploy.sh b/pai_jobs/deploy.sh
old mode 100644
new mode 100755
index e1f10f6f5..77b1065b6
--- a/pai_jobs/deploy.sh
+++ b/pai_jobs/deploy.sh
@@ -92,7 +92,18 @@ fi
cp easy_rec/__init__.py easy_rec/__init__.py.bak
sed -i -e "s/\[VERSION\]/$VERSION/g" easy_rec/__init__.py
find -L easy_rec -name "*.pyc" | xargs rm -rf
-tar -cvzhf $RES_PATH easy_rec run.py
+echo "tensorflow-probability==0.5.0" > requirements.txt
+
+if [ ! -d "datahub" ]
+then
+ wget http://easyrec.oss-cn-beijing.aliyuncs.com/third_party/pydatahub.tar.gz
+ if [ $? -ne 0 ]
+ then
+ echo "datahub download failed."
+ fi
+ tar -zvxf pydatahub.tar.gz
+fi
+tar -cvzhf $RES_PATH easy_rec run.py requirements.txt
mv easy_rec/__init__.py.bak easy_rec/__init__.py
# 2 means generate only
diff --git a/pai_jobs/deploy_ext.sh b/pai_jobs/deploy_ext.sh
old mode 100644
new mode 100755
index cadc4962a..ff17f1760
--- a/pai_jobs/deploy_ext.sh
+++ b/pai_jobs/deploy_ext.sh
@@ -22,7 +22,9 @@ ODPSCMD=odpscmd
mode=0
odps_config=""
-while getopts 'V:C:OGc:' OPT; do
+is_tf15=0
+is_py3=0
+while getopts 'V:C:OGc:D:P' OPT; do
case $OPT in
V)
VERSION="$OPTARG";;
@@ -34,12 +36,18 @@ while getopts 'V:C:OGc:' OPT; do
mode=1;;
G)
mode=2;;
+ D)
+ is_tf15=1;;
+ P)
+ is_py3=1;;
?)
echo "Usage: `basename $0` -V VERSION [-C odpscmd_path] [-c odps_config_path] [-O]"
echo " -O: only update easy_rec resource file"
echo " -G: generate resource file and xflow, but not deploy"
echo " -c: odps_config file path"
echo " -C: odpscmd file path, default to: odpscmd, so in default odpscmd must be in PATH"
+ echo " -D: use tf1.15 or deeprec"
+ echo " -P: use tf1.12_py3"
echo " -V: algorithm version, chars must be in [0-9A-Za-z_-], default: version info in easy_rec/version.py"
exit 1
esac
@@ -85,20 +93,84 @@ cd $curr_dir
RES_PATH=easy_rec_ext_${VERSION}_res.tar.gz
-if [ ! -e easy_rec ]
+if [ -e easy_rec ]
then
- ln -s $root_dir/easy_rec ./
+ rm -rf easy_rec
fi
-cp easy_rec/__init__.py easy_rec/__init__.py.bak
+cp -R $root_dir/easy_rec ./easy_rec
sed -i -e "s/\[VERSION\]/$VERSION/g" easy_rec/__init__.py
find -L easy_rec -name "*.pyc" | xargs rm -rf
-tar -cvzhf $RES_PATH easy_rec run.py
-mv easy_rec/__init__.py.bak easy_rec/__init__.py
+
+if [ ! -d "datahub" ]
+then
+ if [ ! -e "pydatahub.tar.gz" ]
+ then
+ wget http://easyrec.oss-cn-beijing.aliyuncs.com/third_party/pydatahub.tar.gz
+ if [ $? -ne 0 ]
+ then
+ echo "datahub download failed."
+ fi
+ fi
+ tar -zvxf pydatahub.tar.gz
+ rm -rf pydatahub.tar.gz
+fi
+
+if [ ! -d "kafka" ]
+then
+ if [ ! -e "kafka.tar.gz" ]
+ then
+ wget http://easyrec.oss-cn-beijing.aliyuncs.com/third_party/kafka.tar.gz
+ if [ $? -ne 0 ]
+ then
+ echo "kafka download failed."
+ fi
+ fi
+ tar -zvxf kafka.tar.gz
+ rm -rf kafka.tar.gz
+fi
+
+if [ ! -d "faiss" ]
+then
+ if [ ! -e "faiss.tar.gz" ]
+ then
+ wget http://easyrec.oss-cn-beijing.aliyuncs.com/third_party/faiss.tar.gz
+ if [ $? -ne 0 ]
+ then
+ echo "faiss download failed."
+ fi
+ fi
+ tar -zvxf faiss.tar.gz
+ rm -rf faiss.tar.gz
+fi
+
+if [ -d "tensorflow_probability" ]
+then
+ rm -rf tensorflow_probability
+fi
+if [ $is_tf15 -gt 0 ]; then
+ tfp_version='0.8.0'
+else
+ tfp_version='0.5.0'
+fi
+if [ ! -e "tensorflow_probability" ]
+then
+ wget http://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/probability-${tfp_version}.tar.gz
+ if [ $? -ne 0 ]
+ then
+ echo "tensorflow_probability download failed."
+ fi
+fi
+tar -xzvf probability-${tfp_version}.tar.gz --strip-components=1 probability-${tfp_version}/tensorflow_probability
+rm -rf tensorflow_probability/examples
+rm -rf tensorflow_probability/g3doc
+rm -rf probability-${tfp_version}.tar.gz
+
+tar -cvzhf $RES_PATH easy_rec datahub lz4 cprotobuf kafka faiss tensorflow_probability run.py
# 2 means generate only
if [ $mode -ne 2 ]
then
- ${ODPSCMD} --config=$odps_config -e "add file $RES_PATH -f;"
+ ${ODPSCMD} --config=$odps_config -e "add archive $RES_PATH -f;"
if [ $? -ne 0 ]
then
echo "add $RES_PATH failed"
@@ -117,8 +189,25 @@ fi
cd easy_rec_flow_ex
sed -i -e "s/parameter name=\"version\" use=\"optional\" default=\"[0-9A-Za-z_-]\+\"/parameter name=\"version\" use=\"optional\" default=\"$VERSION\"/g" easy_rec_ext.xml
+
+if [ $is_tf15 -gt 0 ]
+then
+ echo "will deploy DeepRec(TF1.15) version"
+ sed -i -e "s/name=\"easy_rec_ext\"/name=\"easy_rec_ext15\"/g" easy_rec_ext.xml
+ sed -i -e "s/tensorflow1120_ext/tensorflow1150_ext/g" easy_rec_ext.xml
+fi
+
+if [ $is_py3 -gt 0 ]
+then
+ echo "will deploy TF1.12_py3 version"
+ sed -i -e "s/name=\"easy_rec_ext\"/name=\"easy_rec_py3_ext\"/g" easy_rec_ext.xml
+ sed -i -e "s/tensorflow1120_ext/tensorflow1120_py3_ext/g" easy_rec_ext.xml
+fi
+
tar -cvzf easy_rec_flow_ex.tar.gz easy_rec_ext.lua easy_rec_ext.xml
+git checkout ./easy_rec_ext.xml
+
if [ $mode -ne 2 ]
then
cd ../xflow-deploy
diff --git a/pai_jobs/easy_rec_flow/easy_rec.lua b/pai_jobs/easy_rec_flow/easy_rec.lua
index 87a165c85..0dacd2b5e 100644
--- a/pai_jobs/easy_rec_flow/easy_rec.lua
+++ b/pai_jobs/easy_rec_flow/easy_rec.lua
@@ -79,7 +79,7 @@ function check_run_mode(cluster, gpuRequired)
end
end
-function getHyperParams(config, cmd, checkpoint_path,
+function getHyperParams(config, cmd, checkpoint_path, fine_tune_checkpoint,
eval_result_path, export_dir, gpuRequired,
cpuRequired, memRequired, cluster, continue_train,
distribute_strategy, with_evaluator, eval_method,
@@ -128,7 +128,7 @@ function getHyperParams(config, cmd, checkpoint_path,
checkTable(output_table)
if extra_params ~= nil and extra_params ~= '' then
- hyperParameters = hyperParameters .. extra_params
+ hyperParameters = hyperParameters .. " " .. extra_params
end
return hyperParameters, cluster, tables, output_table
end
@@ -161,7 +161,7 @@ function getHyperParams(config, cmd, checkpoint_path,
hyperParameters = hyperParameters .. ' --knn_compress_dim=' .. knn_compress_dim
end
if extra_params ~= nil and extra_params ~= '' then
- hyperParameters = hyperParameters .. extra_params
+ hyperParameters = hyperParameters .. " " .. extra_params
end
return hyperParameters, cluster, tables, output_table
end
@@ -188,7 +188,7 @@ function getHyperParams(config, cmd, checkpoint_path,
if eval_tables ~= "" and eval_tables ~= nil then
hyperParameters = hyperParameters .. " --eval_tables " .. eval_tables
end
- elseif cmd == 'export' then
+ elseif cmd == 'export' or cmd == 'export_checkpoint' then
hyperParameters = hyperParameters .. " --checkpoint_path=" .. checkpoint_path
hyperParameters = hyperParameters .. " --export_dir=" .. export_dir
elseif cmd == 'train' then
@@ -199,6 +199,9 @@ function getHyperParams(config, cmd, checkpoint_path,
if with_evaluator ~= "" and tonumber(with_evaluator) ~= 0 then
hyperParameters = hyperParameters .. " --with_evaluator"
end
+ if fine_tune_checkpoint ~= nil and fine_tune_checkpoint ~= '' then
+ hyperParameters = hyperParameters .. " --fine_tune_checkpoint=" .. fine_tune_checkpoint
+ end
if eval_method ~= 'none' and eval_method ~= 'separate' and eval_method ~= 'master' then
error('invalid eval_method ' .. eval_method)
end
@@ -255,7 +258,7 @@ function getHyperParams(config, cmd, checkpoint_path,
num_gpus_per_worker)
if extra_params ~= nil and extra_params ~= '' then
- hyperParameters = hyperParameters .. extra_params
+ hyperParameters = hyperParameters .. " " .. extra_params
end
return hyperParameters, cluster, tables, output_table
@@ -349,12 +352,18 @@ function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols,
-- all_cols, all_col_types, selected_cols, reserved_cols,
-- create_table_sql, add_partition_sql, tables parameter to runTF
if cmd ~= 'train' and cmd ~= 'evaluate' and cmd ~= 'predict' and cmd ~= 'export'
+ and cmd ~= 'export_checkpoint'
and cmd ~= 'evaluate' and cmd ~= 'custom' and cmd ~= 'vector_retrieve' then
error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom, vector_retrieve')
end
-- for export
- if cmd == 'export' or cmd == 'custom' then
+ if cmd == 'export' or cmd == 'custom' or cmd == 'export_checkpoint' then
+ return "", "", "", "", "select 1;", "select 1;", tables
+ end
+
+ -- for online train or train with oss input
+ if cmd == 'train' and (tables == nil or tables == '') and (trainTables == nil or trainTables == '') then
return "", "", "", "", "select 1;", "select 1;", tables
end
diff --git a/pai_jobs/easy_rec_flow/easy_rec.xml b/pai_jobs/easy_rec_flow/easy_rec.xml
index 0dc2dea31..d9dc17ddc 100644
--- a/pai_jobs/easy_rec_flow/easy_rec.xml
+++ b/pai_jobs/easy_rec_flow/easy_rec.xml
@@ -25,6 +25,7 @@
+
@@ -63,7 +64,8 @@
-
+
+
@@ -142,6 +144,7 @@
+
@@ -222,7 +225,8 @@
-
+
+
tensorflow1120
algo_public
diff --git a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua
index eea9ec32e..38d76d85c 100644
--- a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua
+++ b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua
@@ -101,7 +101,7 @@ function check_run_mode(cluster, gpuRequired)
end
end
-function getHyperParams(config, cmd, checkpoint_path,
+function getHyperParams(config, cmd, checkpoint_path, fine_tune_checkpoint,
eval_result_path, export_dir, gpuRequired,
cpuRequired, memRequired, cluster, continue_train,
distribute_strategy, with_evaluator, eval_method,
@@ -150,7 +150,7 @@ function getHyperParams(config, cmd, checkpoint_path,
checkTable(output_table)
if extra_params ~= nil and extra_params ~= '' then
- hyperParameters = hyperParameters .. extra_params
+ hyperParameters = hyperParameters .. " " .. extra_params
end
return hyperParameters, cluster, tables, output_table
end
@@ -183,7 +183,7 @@ function getHyperParams(config, cmd, checkpoint_path,
hyperParameters = hyperParameters .. ' --knn_compress_dim=' .. knn_compress_dim
end
if extra_params ~= nil and extra_params ~= '' then
- hyperParameters = hyperParameters .. extra_params
+ hyperParameters = hyperParameters .. " " .. extra_params
end
return hyperParameters, cluster, tables, output_table
end
@@ -210,7 +210,7 @@ function getHyperParams(config, cmd, checkpoint_path,
if eval_tables ~= "" and eval_tables ~= nil then
hyperParameters = hyperParameters .. " --eval_tables " .. eval_tables
end
- elseif cmd == 'export' then
+ elseif cmd == 'export' or cmd == 'export_checkpoint' then
hyperParameters = hyperParameters .. " --checkpoint_path=" .. checkpoint_path
hyperParameters = hyperParameters .. " --export_dir=" .. export_dir
elseif cmd == 'train' then
@@ -221,6 +221,9 @@ function getHyperParams(config, cmd, checkpoint_path,
if with_evaluator ~= "" and tonumber(with_evaluator) ~= 0 then
hyperParameters = hyperParameters .. " --with_evaluator"
end
+ if fine_tune_checkpoint ~= nil and fine_tune_checkpoint ~= '' then
+ hyperParameters = hyperParameters .. " --fine_tune_checkpoint=" .. fine_tune_checkpoint
+ end
if eval_method ~= 'none' and eval_method ~= 'separate' and eval_method ~= 'master' then
error('invalid eval_method ' .. eval_method)
end
@@ -277,7 +280,7 @@ function getHyperParams(config, cmd, checkpoint_path,
num_gpus_per_worker)
if extra_params ~= nil and extra_params ~= '' then
- hyperParameters = hyperParameters .. extra_params
+ hyperParameters = hyperParameters .. " " .. extra_params
end
return hyperParameters, cluster, tables, output_table
@@ -371,12 +374,18 @@ function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols,
-- all_cols, all_col_types, selected_cols, reserved_cols,
-- create_table_sql, add_partition_sql, tables parameter to runTF
if cmd ~= 'train' and cmd ~= 'evaluate' and cmd ~= 'predict' and cmd ~= 'export'
+ and cmd ~= 'export_checkpoint'
and cmd ~= 'evaluate' and cmd ~= 'custom' and cmd ~= 'vector_retrieve' then
error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom, vector_retrieve')
end
-- for export
- if cmd == 'export' or cmd == 'custom' then
+ if cmd == 'export' or cmd == 'custom' or cmd == 'export_checkpoint' then
+ return "", "", "", "", "select 1;", "select 1;", tables
+ end
+
+ -- for online train or train with oss input
+ if cmd == 'train' and (tables == nil or tables == '') and (trainTables == nil or trainTables == '') then
return "", "", "", "", "select 1;", "select 1;", tables
end
diff --git a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml
index f03cb855f..10ae05d0d 100644
--- a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml
+++ b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml
@@ -27,6 +27,7 @@
+
@@ -66,6 +67,7 @@
+
@@ -144,6 +146,7 @@
+
@@ -227,6 +230,7 @@
+
tensorflow1120_ext
algo_public
diff --git a/pai_jobs/run.py b/pai_jobs/run.py
index 163eecc72..309ec4e7a 100644
--- a/pai_jobs/run.py
+++ b/pai_jobs/run.py
@@ -2,25 +2,31 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import print_function
-import json
import logging
# use few threads to avoid oss error
import os
+import time
import tensorflow as tf
+import yaml
+from tensorflow.python.platform import gfile
import easy_rec
-from easy_rec.python.inference.predictor import Predictor
+from easy_rec.python.inference.odps_predictor import ODPSPredictor
from easy_rec.python.inference.vector_retrieve import VectorRetrieve
+from easy_rec.python.tools.pre_check import run_check
from easy_rec.python.utils import config_util
+from easy_rec.python.utils import constant
+from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import hpo_util
from easy_rec.python.utils import pai_util
from easy_rec.python.utils.distribution_utils import DistributionStrategyMap
from easy_rec.python.utils.distribution_utils import set_distribution_config
-from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num # NOQA
+os.environ['IS_ON_PAI'] = '1'
+from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num # NOQA
os.environ['OENV_MultiWriteThreadsNum'] = '4'
os.environ['OENV_MultiCopyThreadsNum'] = '4'
@@ -64,6 +70,8 @@
tf.app.flags.DEFINE_string('eval_tables', '', 'tables used for evaluation')
tf.app.flags.DEFINE_string('boundary_table', '', 'tables used for boundary')
tf.app.flags.DEFINE_string('sampler_table', '', 'tables used for sampler')
+tf.app.flags.DEFINE_string('fine_tune_checkpoint', None,
+ 'finetune checkpoint path')
tf.app.flags.DEFINE_string('query_table', '',
'table used for retrieve vector neighbours')
tf.app.flags.DEFINE_string('doc_table', '',
@@ -101,6 +109,11 @@
# flags used for export
tf.app.flags.DEFINE_string('export_dir', '',
'directory where model should be exported to')
+tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists')
+tf.app.flags.DEFINE_string('export_done_file', '',
+ 'a flag file to signal that export model is done')
+tf.app.flags.DEFINE_integer('max_wait_ckpt_ts', 0,
+ 'max wait time in seconds for checkpoints')
tf.app.flags.DEFINE_boolean('continue_train', True,
'use the same model to continue train or not')
@@ -155,14 +168,22 @@
tf.app.flags.DEFINE_string('oss_embedding_version', '', 'oss embedding version')
tf.app.flags.DEFINE_bool('verbose', False, 'print more debug information')
+tf.app.flags.DEFINE_bool('place_embedding_on_cpu', False,
+ 'whether to place embedding variables on cpu')
# for automl hyper parameter tuning
tf.app.flags.DEFINE_string('model_dir', None, 'model directory')
+tf.app.flags.DEFINE_bool('clear_model', False,
+ 'remove model directory if exists')
tf.app.flags.DEFINE_string('hpo_param_path', None,
'hyperparameter tuning param path')
tf.app.flags.DEFINE_string('hpo_metric_save_path', None,
'hyperparameter save metric path')
tf.app.flags.DEFINE_string('asset_files', None, 'extra files to add to export')
+tf.app.flags.DEFINE_bool('check_mode', False, 'is use check mode')
+tf.app.flags.DEFINE_string('fg_json_path', None, '')
+tf.app.flags.DEFINE_bool('enable_avx_str_split', False,
+ 'enable avx str split to speedup')
FLAGS = tf.app.flags.FLAGS
@@ -193,8 +214,38 @@ def set_selected_cols(pipeline_config, selected_cols, all_cols, all_col_types):
pipeline_config.data_config.selected_col_types)
+def _wait_ckpt(ckpt_path, max_wait_ts):
+ logging.info('will wait %s seconds for checkpoint' % max_wait_ts)
+ start_ts = time.time()
+ if '/model.ckpt-' not in ckpt_path:
+ while time.time() - start_ts < max_wait_ts:
+ tmp_ckpt = estimator_utils.latest_checkpoint(ckpt_path)
+ if tmp_ckpt is None:
+ logging.info('wait for checkpoint in directory[%s]' % ckpt_path)
+ time.sleep(30)
+ else:
+ logging.info('find checkpoint[%s] in directory[%s]' %
+ (tmp_ckpt, ckpt_path))
+ break
+ else:
+ while time.time() - start_ts < max_wait_ts:
+ if not gfile.Exists(ckpt_path + '.index'):
+ logging.info('wait for checkpoint[%s]' % ckpt_path)
+ time.sleep(30)
+ else:
+ logging.info('find checkpoint[%s]' % ckpt_path)
+ break
+
+
def main(argv):
pai_util.set_on_pai()
+ if FLAGS.enable_avx_str_split:
+ constant.enable_avx_str_split()
+ logging.info('will enable avx str split: %s' %
+ constant.is_avx_str_split_enabled())
+
+ if FLAGS.distribute_eval:
+ os.environ['distribute_eval'] = 'True'
# load lookup op
try:
@@ -215,9 +266,14 @@ def main(argv):
len(FLAGS.worker_hosts.split(',')))
pipeline_config = config_util.get_configs_from_pipeline_file(config, False)
+ # should be in front of edit_config_json step
+ # otherwise data_config and feature_config are not ready
+ if pipeline_config.fg_json_path:
+ fg_util.load_fg_json_to_config(pipeline_config)
+
if FLAGS.edit_config_json:
print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json)
- config_json = json.loads(FLAGS.edit_config_json)
+ config_json = yaml.safe_load(FLAGS.edit_config_json)
config_util.edit_config(pipeline_config, config_json)
if FLAGS.model_dir:
@@ -227,10 +283,29 @@ def main(argv):
assert pipeline_config.model_dir.startswith(
'oss://'), 'invalid model_dir format: %s' % pipeline_config.model_dir
+ if FLAGS.asset_files:
+ pipeline_config.export_config.asset_files.extend(
+ FLAGS.asset_files.split(','))
+
+ if FLAGS.config:
+ if not pipeline_config.model_dir.endswith('/'):
+ pipeline_config.model_dir += '/'
+
+ if FLAGS.clear_model:
+ if gfile.IsDirectory(
+ pipeline_config.model_dir) and estimator_utils.is_chief():
+ gfile.DeleteRecursively(pipeline_config.model_dir)
+
+ if FLAGS.max_wait_ckpt_ts > 0:
+ if FLAGS.checkpoint_path:
+ _wait_ckpt(FLAGS.checkpoint_path, FLAGS.max_wait_ckpt_ts)
+ else:
+ _wait_ckpt(pipeline_config.model_dir, FLAGS.max_wait_ckpt_ts)
+
if FLAGS.cmd == 'train':
assert FLAGS.config, 'config should not be empty when training!'
- if not FLAGS.train_tables:
+ if not FLAGS.train_tables and FLAGS.tables:
tables = FLAGS.tables.split(',')
assert len(
tables
@@ -239,19 +314,23 @@ def main(argv):
if FLAGS.train_tables:
pipeline_config.train_input_path = FLAGS.train_tables
- else:
+ elif FLAGS.tables:
pipeline_config.train_input_path = FLAGS.tables.split(',')[0]
if FLAGS.eval_tables:
pipeline_config.eval_input_path = FLAGS.eval_tables
- else:
+ elif FLAGS.tables:
pipeline_config.eval_input_path = FLAGS.tables.split(',')[1]
print('[run.py] train_tables: %s' % pipeline_config.train_input_path)
print('[run.py] eval_tables: %s' % pipeline_config.eval_input_path)
- if pipeline_config.fg_json_path:
- fg_util.load_fg_json_to_config(pipeline_config)
+ if FLAGS.fine_tune_checkpoint:
+ pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint
+
+ if pipeline_config.train_config.HasField('fine_tune_checkpoint'):
+ pipeline_config.train_config.fine_tune_checkpoint = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
+ pipeline_config.train_config.fine_tune_checkpoint, False)
if FLAGS.boundary_table:
logging.info('Load boundary_table: %s' % FLAGS.boundary_table)
@@ -261,17 +340,21 @@ def main(argv):
if FLAGS.sampler_table:
pipeline_config.data_config.negative_sampler.input_path = FLAGS.sampler_table
- # parse selected_cols
- set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
- FLAGS.all_col_types)
+ if FLAGS.train_tables or FLAGS.tables:
+ # parse selected_cols
+ set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
+ FLAGS.all_col_types)
+ else:
+ pipeline_config.data_config.selected_cols = ''
+ pipeline_config.data_config.selected_col_types = ''
distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy]
# update params specified by automl if hpo_param_path is specified
if FLAGS.hpo_param_path:
logging.info('hpo_param_path = %s' % FLAGS.hpo_param_path)
- with tf.gfile.GFile(FLAGS.hpo_param_path, 'r') as fin:
- hpo_config = json.load(fin)
+ with gfile.GFile(FLAGS.hpo_param_path, 'r') as fin:
+ hpo_config = yaml.safe_load(fin)
hpo_params = hpo_config['param']
config_util.edit_config(pipeline_config, hpo_params)
config_util.auto_expand_share_feature_configs(pipeline_config)
@@ -281,8 +364,11 @@ def main(argv):
assert FLAGS.eval_method in [
'none', 'master', 'separate'
], 'invalid evalaute_method: %s' % FLAGS.eval_method
+
+ # with_evaluator is depreciated, keeped for compatibility
if FLAGS.with_evaluator:
FLAGS.eval_method = 'separate'
+
num_worker = set_tf_config_and_get_train_worker_num(
FLAGS.ps_hosts,
FLAGS.worker_hosts,
@@ -292,25 +378,29 @@ def main(argv):
eval_method=FLAGS.eval_method)
set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
distribute_strategy)
+ logging.info('run.py check_mode: %s .' % FLAGS.check_mode)
train_and_evaluate_impl(
- pipeline_config, continue_train=FLAGS.continue_train)
+ pipeline_config,
+ continue_train=FLAGS.continue_train,
+ check_mode=FLAGS.check_mode)
if FLAGS.hpo_metric_save_path:
hpo_util.save_eval_metrics(
pipeline_config.model_dir,
metric_save_path=FLAGS.hpo_metric_save_path,
- has_evaluator=FLAGS.with_evaluator)
+ has_evaluator=(FLAGS.eval_method == 'separate'))
+
elif FLAGS.cmd == 'evaluate':
check_param('config')
# TODO: support multi-worker evaluation
if not FLAGS.distribute_eval:
assert len(
- FLAGS.worker_hosts.split(',')) == 1, 'evaluate only need 1 woker'
+ FLAGS.worker_hosts.split(',')) == 1, 'evaluate only need 1 worker'
config_util.auto_expand_share_feature_configs(pipeline_config)
if FLAGS.eval_tables:
pipeline_config.eval_input_path = FLAGS.eval_tables
- else:
+ elif FLAGS.tables:
pipeline_config.eval_input_path = FLAGS.tables.split(',')[0]
distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy]
@@ -323,19 +413,35 @@ def main(argv):
set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
distribute_strategy)
- # parse selected_cols
- set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
- FLAGS.all_col_types)
+ if FLAGS.eval_tables or FLAGS.tables:
+ # parse selected_cols
+ set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
+ FLAGS.all_col_types)
+ else:
+ pipeline_config.data_config.selected_cols = ''
+ pipeline_config.data_config.selected_col_types = ''
+
if FLAGS.distribute_eval:
+ os.environ['distribute_eval'] = 'True'
+ logging.info('will_use_distribute_eval')
+ distribute_eval = os.environ.get('distribute_eval')
+ logging.info('distribute_eval = {}'.format(distribute_eval))
easy_rec.distribute_evaluate(pipeline_config, FLAGS.checkpoint_path, None,
FLAGS.eval_result_path)
else:
+ os.environ['distribute_eval'] = 'False'
+ logging.info('will_use_eval')
+ distribute_eval = os.environ.get('distribute_eval')
+ logging.info('distribute_eval = {}'.format(distribute_eval))
easy_rec.evaluate(pipeline_config, FLAGS.checkpoint_path, None,
FLAGS.eval_result_path)
elif FLAGS.cmd == 'export':
check_param('export_dir')
check_param('config')
-
+ if FLAGS.place_embedding_on_cpu:
+ os.environ['place_embedding_on_cpu'] = 'True'
+ else:
+ os.environ['place_embedding_on_cpu'] = 'False'
redis_params = {}
if FLAGS.redis_url:
redis_params['redis_url'] = FLAGS.redis_url
@@ -382,10 +488,23 @@ def main(argv):
assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker'
config_util.auto_expand_share_feature_configs(pipeline_config)
+ export_dir = FLAGS.export_dir
+ if not export_dir.endswith('/'):
+ export_dir = export_dir + '/'
+ if FLAGS.clear_export:
+ if gfile.IsDirectory(export_dir):
+ gfile.DeleteRecursively(export_dir)
+
extra_params = redis_params
extra_params.update(oss_params)
- easy_rec.export(FLAGS.export_dir, pipeline_config, FLAGS.checkpoint_path,
- FLAGS.asset_files, FLAGS.verbose, **extra_params)
+ export_out_dir = easy_rec.export(export_dir, pipeline_config,
+ FLAGS.checkpoint_path, FLAGS.asset_files,
+ FLAGS.verbose, **extra_params)
+ if FLAGS.export_done_file:
+ flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
+ logging.info('create export done file: %s' % flag_file)
+ with gfile.GFile(flag_file, 'w') as fout:
+ fout.write('ExportDone')
elif FLAGS.cmd == 'predict':
check_param('tables')
check_param('saved_model_dir')
@@ -397,7 +516,12 @@ def main(argv):
profiling_file = FLAGS.profiling_file if FLAGS.task_index == 0 else None
if profiling_file is not None:
print('profiling_file = %s ' % profiling_file)
- predictor = Predictor(FLAGS.saved_model_dir, profiling_file=profiling_file)
+ predictor = ODPSPredictor(
+ FLAGS.saved_model_dir,
+ fg_json_path=FLAGS.fg_json_path,
+ profiling_file=profiling_file,
+ all_cols=FLAGS.all_cols,
+ all_col_types=FLAGS.all_col_types)
input_table, output_table = FLAGS.tables, FLAGS.outputs
logging.info('input_table = %s, output_table = %s' %
(input_table, output_table))
@@ -405,14 +529,28 @@ def main(argv):
predictor.predict_impl(
input_table,
output_table,
- all_cols=FLAGS.all_cols,
- all_col_types=FLAGS.all_col_types,
- selected_cols=FLAGS.selected_cols,
reserved_cols=FLAGS.reserved_cols,
output_cols=FLAGS.output_cols,
batch_size=FLAGS.batch_size,
slice_id=FLAGS.task_index,
slice_num=worker_num)
+ elif FLAGS.cmd == 'export_checkpoint':
+ check_param('export_dir')
+ check_param('config')
+ set_tf_config_and_get_train_worker_num(
+ FLAGS.ps_hosts,
+ FLAGS.worker_hosts,
+ FLAGS.task_index,
+ FLAGS.job_name,
+ eval_method='none')
+ assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker'
+ config_util.auto_expand_share_feature_configs(pipeline_config)
+ easy_rec.export_checkpoint(
+ pipeline_config,
+ export_path=FLAGS.export_dir + '/model',
+ checkpoint_path=FLAGS.checkpoint_path,
+ asset_files=FLAGS.asset_files,
+ verbose=FLAGS.verbose)
elif FLAGS.cmd == 'vector_retrieve':
check_param('knn_distance')
assert FLAGS.knn_feature_dims is not None, '`knn_feature_dims` should not be None'
@@ -442,9 +580,12 @@ def main(argv):
m=FLAGS.knn_compress_dim)
worker_hosts = FLAGS.worker_hosts.split(',')
knn(FLAGS.knn_num_neighbours, FLAGS.task_index, len(worker_hosts))
+ elif FLAGS.cmd == 'check':
+ run_check(pipeline_config, FLAGS.tables)
else:
raise ValueError(
- 'cmd should be one of train/evaluate/export/predict/vector_retrieve')
+ 'cmd should be one of train/evaluate/export/predict/export_checkpoint/vector_retrieve'
+ )
if __name__ == '__main__':
diff --git a/requirements.txt b/requirements.txt
index ec4ca05eb..c6e294bac 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,2 +1 @@
-r requirements/runtime.txt
--r requirements/tests.txt
diff --git a/requirements/docs.txt b/requirements/docs.txt
index cddfc09c6..9e81da2c6 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -1,4 +1,8 @@
+alabaster>=0.7,<0.8,!=0.7.5
+commonmark==0.8.1
+http://easyrec.oss-cn-beijing.aliyuncs.com/tools/Markdown-3.4.1-py3-none-any.whl
recommonmark==0.6.0
-sphinx==3.3.1
-sphinx_markdown_tables==0.0.15
-sphinx_rtd_theme==0.5.0
+sphinx==5.1.1
+sphinx_markdown_tables==0.0.17
+sphinx_rtd_theme
+tensorflow-probability==0.11.0
diff --git a/requirements/extra.txt b/requirements/extra.txt
new file mode 100644
index 000000000..ac085bb2e
--- /dev/null
+++ b/requirements/extra.txt
@@ -0,0 +1,15 @@
+# extra packages for some features
+
+# cprotobuf is required by pydatahub
+cprotobuf==0.1.9
+
+# for data input from datahub
+pydatahub
+# for data input from hive tables
+pyhive
+
+# for datahub test
+pyodps
+sasl==0.3.1
+thrift
+thrift_sasl
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index d3bc757fc..402e16b6e 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,10 +1,13 @@
-cprotobuf==0.1.9
+eas_prediction == 0.24; python_version < '3.0'
+eas_prediction; python_version >= '3.0'
future
matplotlib
+numpy <= 1.23
+oss2
pandas
psutil
+pyarrow
+pyodps
+PyYAML
scikit-learn
xlrd >= 0.9.0
-# pydatahub
-# cprotobuf is required by pydatahub
-# cprotobuf==0.1.9
diff --git a/requirements/tests.txt b/requirements/tests.txt
index 8567fc6c0..bceb7b1c5 100644
--- a/requirements/tests.txt
+++ b/requirements/tests.txt
@@ -1,5 +1,4 @@
configparser
-configparser
docformatter
flake8
flake8-docstrings
@@ -7,6 +6,4 @@ isort==4.3.21
mdformat;python_version>"3"
mdformat-tables;python_version>"3"
pre-commit
-pyodps
-PyYAML
yapf
diff --git a/samples/demo_script/process_lbl.py b/samples/demo_script/process_lbl.py
new file mode 100644
index 000000000..2be68a0e7
--- /dev/null
+++ b/samples/demo_script/process_lbl.py
@@ -0,0 +1,6 @@
+import numpy as np
+
+
+def remap_lbl(labels):
+ res = np.where(labels < 5, 0, 1)
+ return res
diff --git a/samples/dh_script/configs/deepfm.config b/samples/dh_script/configs/deepfm.config
index 9902eb2ef..29ed5e8cc 100644
--- a/samples/dh_script/configs/deepfm.config
+++ b/samples/dh_script/configs/deepfm.config
@@ -14,7 +14,7 @@ train_config {
}
use_moving_average: false
}
- log_step_count_steps: 200L
+ log_step_count_steps: 200
sync_replicas: true
}
@@ -27,21 +27,18 @@ eval_config {
datahub_train_input{
akId:"{DH_ID}"
akSecret:"{DH_KEY}"
- region:"{DH_REG}"
+ endpoint:"{DH_REG}"
project:"{DH_PRO}"
topic:"{DH_TOPIC}"
- shard_num:3
- life_cycle:7
+ offset_time: '20220627 09:22:00'
}
datahub_eval_input{
akId:"{DH_ID}"
akSecret:"{DH_KEY}"
- region:"{DH_REG}"
+ endpoint:"{DH_REG}"
project:"{DH_PRO}"
topic:"{DH_TOPIC}"
- shard_num:3
- life_cycle:7
}
data_config {
diff --git a/samples/dh_script/configs/dh.config b/samples/dh_script/configs/dh.config
deleted file mode 100644
index 1aa936bab..000000000
--- a/samples/dh_script/configs/dh.config
+++ /dev/null
@@ -1,6 +0,0 @@
-[datahub]
-access_id=
-access_key=
-endpoint=https://dh-cn-beijing.aliyuncs.com
-topic_name=pf_test
-project=tmf_easy
diff --git a/samples/dh_script/deep_fm/create_external_deepfm_table.sql b/samples/dh_script/deep_fm/create_external_deepfm_table.sql
deleted file mode 100644
index 031af15dd..000000000
--- a/samples/dh_script/deep_fm/create_external_deepfm_table.sql
+++ /dev/null
@@ -1,65 +0,0 @@
-drop TABLE IF EXISTS external_deepfm_train_{TIME_STAMP} ;
-create EXTERNAL table external_deepfm_train_{TIME_STAMP}(
- label BIGINT
- ,`hour` string
- ,c1 STRING
- ,banner_pos STRING
- ,site_id STRING
- ,site_domain STRING
- ,site_category STRING
- ,app_id STRING
- ,app_domain STRING
- ,app_category STRING
- ,device_id STRING
- ,device_ip STRING
- ,device_model STRING
- ,device_type STRING
- ,device_conn_type STRING
- ,c14 STRING
- ,c15 STRING
- ,c16 STRING
- ,c17 STRING
- ,c18 STRING
- ,c19 STRING
- ,c20 STRING
- ,c21 STRING
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_ENDPOINT_INTERNAL}/{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/train/'
-;
-
-drop TABLE IF EXISTS external_deepfm_test_{TIME_STAMP};
-create EXTERNAL table external_deepfm_test_{TIME_STAMP}(
- label BIGINT
- ,`hour` string
- ,c1 STRING
- ,banner_pos STRING
- ,site_id STRING
- ,site_domain STRING
- ,site_category STRING
- ,app_id STRING
- ,app_domain STRING
- ,app_category STRING
- ,device_id STRING
- ,device_ip STRING
- ,device_model STRING
- ,device_type STRING
- ,device_conn_type STRING
- ,c14 STRING
- ,c15 STRING
- ,c16 STRING
- ,c17 STRING
- ,c18 STRING
- ,c19 STRING
- ,c20 STRING
- ,c21 STRING
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_ENDPOINT_INTERNAL}/{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/test/'
-;
diff --git a/samples/dh_script/deep_fm/create_inner_deepfm_table.sql b/samples/dh_script/deep_fm/create_inner_deepfm_table.sql
index bf70f0ac1..1b7955013 100644
--- a/samples/dh_script/deep_fm/create_inner_deepfm_table.sql
+++ b/samples/dh_script/deep_fm/create_inner_deepfm_table.sql
@@ -26,11 +26,7 @@ create table deepfm_train_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE deepfm_train_{TIME_STAMP}
-select * from external_deepfm_train_{TIME_STAMP} ;
-
-desc deepfm_train_{TIME_STAMP};
-desc external_deepfm_train_{TIME_STAMP};
+tunnel upload {TEST_DATA_DIR}/train_{TIME_STAMP} deepfm_train_{TIME_STAMP};
drop TABLE IF EXISTS deepfm_test_{TIME_STAMP};
create table deepfm_test_{TIME_STAMP}(
@@ -60,7 +56,4 @@ create table deepfm_test_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE deepfm_test_{TIME_STAMP}
-select * from external_deepfm_test_{TIME_STAMP};
-desc deepfm_test_{TIME_STAMP};
-desc external_deepfm_test_{TIME_STAMP};
+tunnel upload {TEST_DATA_DIR}/test_{TIME_STAMP} deepfm_test_{TIME_STAMP};
diff --git a/samples/emr_script/deep_fm/create_external_deepfm_table.sql b/samples/emr_script/deep_fm/create_external_deepfm_table.sql
deleted file mode 100644
index 031af15dd..000000000
--- a/samples/emr_script/deep_fm/create_external_deepfm_table.sql
+++ /dev/null
@@ -1,65 +0,0 @@
-drop TABLE IF EXISTS external_deepfm_train_{TIME_STAMP} ;
-create EXTERNAL table external_deepfm_train_{TIME_STAMP}(
- label BIGINT
- ,`hour` string
- ,c1 STRING
- ,banner_pos STRING
- ,site_id STRING
- ,site_domain STRING
- ,site_category STRING
- ,app_id STRING
- ,app_domain STRING
- ,app_category STRING
- ,device_id STRING
- ,device_ip STRING
- ,device_model STRING
- ,device_type STRING
- ,device_conn_type STRING
- ,c14 STRING
- ,c15 STRING
- ,c16 STRING
- ,c17 STRING
- ,c18 STRING
- ,c19 STRING
- ,c20 STRING
- ,c21 STRING
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_ENDPOINT_INTERNAL}/{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/train/'
-;
-
-drop TABLE IF EXISTS external_deepfm_test_{TIME_STAMP};
-create EXTERNAL table external_deepfm_test_{TIME_STAMP}(
- label BIGINT
- ,`hour` string
- ,c1 STRING
- ,banner_pos STRING
- ,site_id STRING
- ,site_domain STRING
- ,site_category STRING
- ,app_id STRING
- ,app_domain STRING
- ,app_category STRING
- ,device_id STRING
- ,device_ip STRING
- ,device_model STRING
- ,device_type STRING
- ,device_conn_type STRING
- ,c14 STRING
- ,c15 STRING
- ,c16 STRING
- ,c17 STRING
- ,c18 STRING
- ,c19 STRING
- ,c20 STRING
- ,c21 STRING
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_ENDPOINT_INTERNAL}/{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/test/'
-;
diff --git a/samples/emr_script/mmoe/mmoe_census_income.config b/samples/emr_script/mmoe/mmoe_census_income.config
new file mode 100644
index 000000000..ad3e7644e
--- /dev/null
+++ b/samples/emr_script/mmoe/mmoe_census_income.config
@@ -0,0 +1,567 @@
+hive_train_input {
+ host: "192.168.0.1"
+ username: "admin"
+ table_name: "census_income_train_simple"
+ limit_num: 500
+ fetch_size: 1024
+}
+
+hive_eval_input {
+ host: "192.168.0.1"
+ username: "admin"
+ table_name: "census_income_train_simple"
+ limit_num: 500
+ fetch_size: 1024
+}
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1290
+ decay_factor: 0.5
+ min_learning_rate: 1e-06
+ }
+ }
+ }
+ }
+ num_steps: 25
+ sync_replicas: true
+ log_step_count_steps: 10
+ save_checkpoints_steps: 25
+}
+
+eval_config {
+ metrics_set {
+ auc {}
+ }
+}
+
+model_config {
+ model_class: "MMoE"
+ mmoe {
+ experts {
+ expert_name: "expert_1"
+ dnn {
+ hidden_units: [128, 64, 32, 16]
+ dropout_ratio: [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ experts {
+ expert_name: "expert_2"
+ dnn {
+ hidden_units: [128, 64, 32, 16]
+ dropout_ratio: [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ experts {
+ expert_name: "expert_3"
+ dnn {
+ hidden_units: [128, 64, 32, 16]
+ dropout_ratio: [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ experts {
+ expert_name: "expert_4"
+ dnn {
+ hidden_units: [128, 64, 32, 16]
+ dropout_ratio: [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ task_towers {
+ tower_name: "task1"
+ label_name: "label_1"
+ metrics_set {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 192, 128, 64]
+ dropout_ratio: [0.1, 0.1, 0.1, 0.1]
+ }
+ loss_type: CLASSIFICATION
+ num_class: 1
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "task2"
+ label_name: "label_2"
+ dnn {
+ hidden_units: [256, 192, 128, 64]
+ dropout_ratio: [0.1, 0.1, 0.1, 0.1]
+ }
+ loss_type: CLASSIFICATION
+ num_class: 1
+ weight: 1.0
+ metrics_set {
+ auc {}
+ }
+ }
+ l2_regularization: 1e-06
+ }
+ embedding_regularization: 5e-05
+ feature_groups {
+ group_name: "all"
+ feature_names:"age"
+ feature_names:"detailed_household_and_family_stat"
+ feature_names:"detailed_household_summary_in_household"
+ feature_names:"migration_code_change_in_msa"
+ feature_names:"migration_code_change_in_reg"
+ feature_names:"migration_code_move_within_reg"
+ feature_names:"live_in_this_house_1_year_ago"
+ feature_names:"migration_prev_res_in_sunbelt"
+ feature_names:"num_persons_worked_for_employer"
+ feature_names:"citizenship"
+ feature_names:"mace"
+ feature_names:"hispanic_origin"
+ feature_names:"sex"
+ feature_names:"region_of_previous_residence"
+ feature_names:"instance_weight"
+ feature_names:"family_members_under_18"
+ feature_names:"country_of_birth_father"
+ feature_names:"country_of_birth_mother"
+ feature_names:"country_of_birth_self"
+ feature_names:"year"
+ feature_names:"class_of_worker"
+ feature_names:"industry_code"
+ feature_names:"occupation_code"
+ feature_names:"education"
+ feature_names:"major_industry"
+ feature_names:"major_occupation"
+ feature_names:"wage_per_hour"
+ feature_names:"enrolled_in_edu_inst_last_wk"
+ feature_names:"member_of_a_labor_union"
+ feature_names:"reason_for_unemployment"
+ feature_names:"full_or_part_time_employment_stat"
+ feature_names:"capital_gains"
+ feature_names:"capital_losses"
+ feature_names:"divdends_from_stocks"
+ feature_names:"tax_filer_status"
+ feature_names:"state_of_previous_residence"
+ feature_names:"own_business_or_self_employed"
+ feature_names:"fill_inc_questionnaire_for_veteran_s_admin"
+ feature_names:"veterans_benefits"
+ feature_names:"weeks_worked_in_year"
+ wide_deep: DEEP
+ }
+}
+
+data_config {
+ batch_size: 10
+ label_fields: "label_1"
+ label_fields: "label_2"
+ num_epochs: 1
+ prefetch_size: 4
+ input_type: HiveInput
+ input_fields {
+ input_name:'label_1'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label_2'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "class_of_worker"
+ }
+ input_fields {
+ input_name: "industry_code"
+ }
+ input_fields {
+ input_name: "occupation_code"
+ }
+ input_fields {
+ input_name: "education"
+ }
+ input_fields {
+ input_name: "wage_per_hour"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "enrolled_in_edu_inst_last_wk"
+ }
+ input_fields {
+ input_name: "major_industry"
+ }
+ input_fields {
+ input_name: "major_occupation"
+ }
+ input_fields {
+ input_name: "mace"
+ }
+ input_fields {
+ input_name: "hispanic_origin"
+ }
+ input_fields {
+ input_name: "sex"
+ }
+ input_fields {
+ input_name: "member_of_a_labor_union"
+ }
+ input_fields {
+ input_name: "reason_for_unemployment"
+ }
+ input_fields {
+ input_name: "full_or_part_time_employment_stat"
+ }
+ input_fields {
+ input_name: "capital_gains"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "capital_losses"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "divdends_from_stocks"
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: "tax_filer_status"
+ }
+ input_fields {
+ input_name: "region_of_previous_residence"
+ }
+ input_fields {
+ input_name: "state_of_previous_residence"
+ }
+ input_fields {
+ input_name: "detailed_household_and_family_stat"
+ }
+ input_fields {
+ input_name: "detailed_household_summary_in_household"
+ }
+ input_fields {
+ input_name: "instance_weight"
+ }
+ input_fields {
+ input_name: "migration_code_change_in_msa"
+ }
+ input_fields {
+ input_name: "migration_code_change_in_reg"
+ }
+ input_fields {
+ input_name: "migration_code_move_within_reg"
+ }
+ input_fields {
+ input_name: "live_in_this_house_1_year_ago"
+ }
+ input_fields {
+ input_name: "migration_prev_res_in_sunbelt"
+ }
+ input_fields {
+ input_name: "num_persons_worked_for_employer"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "family_members_under_18"
+ }
+ input_fields {
+ input_name: "country_of_birth_father"
+ }
+ input_fields {
+ input_name: "country_of_birth_mother"
+ }
+ input_fields {
+ input_name: "country_of_birth_self"
+ }
+ input_fields {
+ input_name: "citizenship"
+ }
+ input_fields {
+ input_name: "own_business_or_self_employed"
+ }
+ input_fields {
+ input_name: "fill_inc_questionnaire_for_veteran_s_admin"
+ }
+ input_fields {
+ input_name: "veterans_benefits"
+ }
+ input_fields {
+ input_name: "weeks_worked_in_year"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "year"
+ }
+}
+
+feature_configs {
+ input_names: "age"
+ feature_type: RawFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+}
+feature_configs {
+ input_names: "class_of_worker"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "industry_code"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "occupation_code"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "education"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "wage_per_hour"
+ feature_type: RawFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+}
+feature_configs {
+ input_names: "enrolled_in_edu_inst_last_wk"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "major_industry"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "major_occupation"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "mace"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "hispanic_origin"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "sex"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "member_of_a_labor_union"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "reason_for_unemployment"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "full_or_part_time_employment_stat"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "capital_gains"
+ feature_type: RawFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+}
+feature_configs {
+ input_names: "capital_losses"
+ feature_type: RawFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+}
+feature_configs {
+ input_names: "divdends_from_stocks"
+ feature_type: RawFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+}
+feature_configs {
+ input_names: "tax_filer_status"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "region_of_previous_residence"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "state_of_previous_residence"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "detailed_household_and_family_stat"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "detailed_household_summary_in_household"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "instance_weight"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "migration_code_change_in_msa"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "migration_code_change_in_reg"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "migration_code_move_within_reg"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "live_in_this_house_1_year_ago"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "migration_prev_res_in_sunbelt"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "num_persons_worked_for_employer"
+ feature_type: RawFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+}
+feature_configs {
+ input_names: "family_members_under_18"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "country_of_birth_father"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "country_of_birth_mother"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "country_of_birth_self"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "citizenship"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "own_business_or_self_employed"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "fill_inc_questionnaire_for_veteran_s_admin"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "veterans_benefits"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
+feature_configs {
+ input_names: "weeks_worked_in_year"
+ feature_type: RawFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+}
+feature_configs {
+ input_names: "year"
+ feature_type: IdFeature
+ embedding_dim: 9
+ hash_bucket_size: 400
+ embedding_name: "feature"
+}
diff --git a/samples/hpo/hpo_param_v13.json b/samples/hpo/hpo_param_v13.json
new file mode 100755
index 000000000..9acf7acec
--- /dev/null
+++ b/samples/hpo/hpo_param_v13.json
@@ -0,0 +1,5 @@
+{
+ "param": {
+ "export_config.multi_placeholder": "false"
+ }
+}
diff --git a/samples/hpo/hpo_param_v14.json b/samples/hpo/hpo_param_v14.json
new file mode 100755
index 000000000..b4af1c916
--- /dev/null
+++ b/samples/hpo/hpo_param_v14.json
@@ -0,0 +1,5 @@
+{
+ "param": {
+ "feature_config.features[input_names[0]=hour].feature_type": "RawFeature"
+ }
+}
diff --git a/samples/hpo/pipeline.config b/samples/hpo/pipeline.config
new file mode 100644
index 000000000..e5b35660b
--- /dev/null
+++ b/samples/hpo/pipeline.config
@@ -0,0 +1,18 @@
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.1
+ min_learning_rate: 1e-7
+ }
+ }
+ }
+ }
+ sync_replicas: true
+ save_summary_steps: 20
+ log_step_count_steps: 20
+}
diff --git a/samples/hpo/pipeline_finetune.config b/samples/hpo/pipeline_finetune.config
new file mode 100644
index 000000000..82ef1a9df
--- /dev/null
+++ b/samples/hpo/pipeline_finetune.config
@@ -0,0 +1,18 @@
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 1e-06
+ decay_steps: 1000
+ decay_factor: 0.1
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ sync_replicas: true
+ save_summary_steps: 20
+ log_step_count_steps: 20
+}
diff --git a/samples/hpo/search_space.json b/samples/hpo/search_space.json
new file mode 100644
index 000000000..7bf06cfc4
--- /dev/null
+++ b/samples/hpo/search_space.json
@@ -0,0 +1,69 @@
+{
+ "model_config.embedding_regularization":{"_type":"uniform","_value":[0.000001, 0.0001]},
+ "model_config.dbmtl.l2_regularization":{"_type":"uniform","_value":[0.000001, 0.0001]},
+ "model_config.dbmtl.bottom_dnn.hidden_units[0]":{"_type":"choice","_value":[2048, 1024,512,256]},
+ "model_config.dbmtl.bottom_dnn.hidden_units[1]":{"_type":"choice","_value":[1024,512,256,128]},
+ "model_config.dbmtl.task_towers[0].dnn.hidden_units[0]":{"_type":"choice","_value":[1024,512,256,128]},
+ "model_config.dbmtl.task_towers[0].dnn.hidden_units[1]":{"_type":"choice","_value":[1024,512,256,128,64]},
+ "model_config.dbmtl.task_towers[0].dnn.hidden_units[2]":{"_type":"choice","_value":[1024,512,256,128,64,32]},
+ "model_config.dbmtl.task_towers[0].dnn.hidden_units[3]":{"_type":"choice","_value":[1024,512,256,128,64,32,16]},
+ "model_config.dbmtl.task_towers[0].relation_dnn.hidden_units[0]":{"_type":"choice","_value":[128,64,32,16,8]},
+ "model_config.dbmtl.task_towers[0].weight":{"_type":"choice","_value":[1,2]},
+ "model_config.dbmtl.task_towers[1].dnn.hidden_units[0]":{"_type":"choice","_value":[1024,512,256,128]},
+ "model_config.dbmtl.task_towers[1].dnn.hidden_units[1]":{"_type":"choice","_value":[1024,512,256,128,64]},
+ "model_config.dbmtl.task_towers[1].dnn.hidden_units[2]":{"_type":"choice","_value":[1024,512,256,128,64,32]},
+ "model_config.dbmtl.task_towers[1].dnn.hidden_units[3]":{"_type":"choice","_value":[1024,512,256,128,64,32,16]},
+ "model_config.dbmtl.task_towers[1].relation_dnn.hidden_units[0]":{"_type":"choice","_value":[128,64,32,16,8]},
+ "model_config.dbmtl.task_towers[1].weight":{"_type":"choice","_value":[0.5,1,2]},
+ "model_config.dbmtl.task_towers[2].dnn.hidden_units[0]":{"_type":"choice","_value":[1024,512,256,128]},
+ "model_config.dbmtl.task_towers[2].dnn.hidden_units[1]":{"_type":"choice","_value":[1024,512,256,128,64]},
+ "model_config.dbmtl.task_towers[2].dnn.hidden_units[2]":{"_type":"choice","_value":[1024,512,256,128,64,32]},
+ "model_config.dbmtl.task_towers[2].dnn.hidden_units[3]":{"_type":"choice","_value":[1024,512,256,128,64,32,16]},
+ "model_config.dbmtl.task_towers[2].relation_dnn.hidden_units[0]":{"_type":"choice","_value":[128,64,32,16,8]},
+ "model_config.dbmtl.task_towers[2].weight":{"_type":"choice","_value":[0.5,1]},
+ "model_config.dbmtl.task_towers[3].dnn.hidden_units[0]":{"_type":"choice","_value":[1024,512,256,128]},
+ "model_config.dbmtl.task_towers[3].dnn.hidden_units[1]":{"_type":"choice","_value":[1024,512,256,128,64]},
+ "model_config.dbmtl.task_towers[3].dnn.hidden_units[2]":{"_type":"choice","_value":[1024,512,256,128,64,32]},
+ "model_config.dbmtl.task_towers[3].dnn.hidden_units[3]":{"_type":"choice","_value":[1024,512,256,128,64,32,16]},
+ "model_config.dbmtl.task_towers[3].relation_dnn.hidden_units[0]":{"_type":"choice","_value":[128,64,32,16,8]},
+ "model_config.dbmtl.task_towers[3].weight":{"_type":"choice","_value":[0.5,1]},
+ "data_config.batch_size":{"_type":"choice","_value":[2048,4096]},
+ "feature_configs[:].embedding_dim": {"_type": "randint", "_value": [4,16]},
+"feature_configs[0].hash_bucket_size": {"_type": "randint", "_value": [1260935, 2521870]},
+"feature_configs[1].hash_bucket_size": {"_type": "randint", "_value": [6847640, 13695280]},
+"feature_configs[2].hash_bucket_size": {"_type": "randint", "_value": [120, 240]},
+"feature_configs[3].hash_bucket_size": {"_type": "randint", "_value": [35, 70]},
+"feature_configs[4].hash_bucket_size": {"_type": "randint", "_value": [20855, 41710]},
+"feature_configs[5].hash_bucket_size": {"_type": "randint", "_value": [4385, 8770]},
+"feature_configs[8].hash_bucket_size": {"_type": "randint", "_value": [10, 20]},
+"feature_configs[9].hash_bucket_size": {"_type": "randint", "_value": [695, 1390]},
+"feature_configs[12].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[13].hash_bucket_size": {"_type": "randint", "_value": [255, 510]},
+"feature_configs[14].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[15].hash_bucket_size": {"_type": "randint", "_value": [29515, 59030]},
+"feature_configs[16].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[17].hash_bucket_size": {"_type": "randint", "_value": [65, 130]},
+"feature_configs[18].hash_bucket_size": {"_type": "randint", "_value": [180, 360]},
+"feature_configs[34].hash_bucket_size": {"_type": "randint", "_value": [2831005, 5662010]},
+"feature_configs[35].hash_bucket_size": {"_type": "randint", "_value": [35770, 71540]},
+"feature_configs[36].hash_bucket_size": {"_type": "randint", "_value": [264395, 528790]},
+"feature_configs[37].hash_bucket_size": {"_type": "randint", "_value": [860, 1720]},
+"feature_configs[614].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[615].hash_bucket_size": {"_type": "randint", "_value": [309265, 618530]},
+"feature_configs[616].hash_bucket_size": {"_type": "randint", "_value": [250, 500]},
+"feature_configs[617].hash_bucket_size": {"_type": "randint", "_value": [30, 60]},
+"feature_configs[618].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[619].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[620].hash_bucket_size": {"_type": "randint", "_value": [40, 80]},
+"feature_configs[628].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[629].hash_bucket_size": {"_type": "randint", "_value": [270, 540]},
+"feature_configs[630].hash_bucket_size": {"_type": "randint", "_value": [3875, 7750]},
+"feature_configs[631].hash_bucket_size": {"_type": "randint", "_value": [255, 510]},
+"feature_configs[632].hash_bucket_size": {"_type": "randint", "_value": [65, 130]},
+"feature_configs[633].hash_bucket_size": {"_type": "randint", "_value": [15, 30]},
+"feature_configs[634].hash_bucket_size": {"_type": "randint", "_value": [175, 350]},
+"feature_configs[650].hash_bucket_size": {"_type": "randint", "_value": [665, 1330]},
+"model_config.dbmtl.task_towers[2].relation_tower_names":{"_type":"choice","_value":[["is_valid_play"],["is_valid_play","ln_play_time"]]},
+"model_config.dbmtl.task_towers[3].relation_tower_names":{"_type":"choice","_value":[["is_valid_play"],["is_valid_play","ln_play_time"],["is_valid_play","is_like"],["is_valid_play","ln_play_time","is_like"]]},
+"train_config.optimizer_config[0].adam_optimizer.learning_rate.exponential_decay_learning_rate.initial_learning_rate":{"_type":"choice","_value":[1e-3,1e-4]}
+}
diff --git a/samples/model_config/aitm_on_taobao.config b/samples/model_config/aitm_on_taobao.config
new file mode 100644
index 000000000..9131a41a7
--- /dev/null
+++ b/samples/model_config/aitm_on_taobao.config
@@ -0,0 +1,334 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/aitm_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 500
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_config: {
+ features {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "tag_category_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "tag_brand_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config {
+ model_name: "AITM"
+ model_class: "MultiTaskModel"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "share_bottom"
+ inputs {
+ feature_group_name: "all"
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [512, 256]
+ }
+ }
+ }
+ blocks {
+ name: "ctr_tower"
+ inputs {
+ block_name: "share_bottom"
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: 128
+ }
+ }
+ }
+ blocks {
+ name: "cvr_tower"
+ inputs {
+ block_name: "share_bottom"
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: 128
+ }
+ }
+ }
+ blocks {
+ name: "cvr_aitm"
+ inputs {
+ block_name: "cvr_tower"
+ }
+ inputs {
+ block_name: "ctr_tower"
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: "AITMTower"
+ aitm {
+ transfer_mlp {
+ hidden_units: 128
+ }
+ }
+ }
+ }
+ output_blocks: ["ctr_tower", "cvr_aitm"]
+ }
+ model_params {
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: 64
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ losses {
+ loss_type: CLASSIFICATION
+ }
+ losses {
+ loss_type: ORDER_CALIBRATE_LOSS
+ }
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: 64
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/autoint_on_sequence_feature_taobao.config b/samples/model_config/autoint_on_sequence_feature_taobao.config
index 384e30eeb..82ab8c6c8 100644
--- a/samples/model_config/autoint_on_sequence_feature_taobao.config
+++ b/samples/model_config/autoint_on_sequence_feature_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: True
- num_steps: 2500
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/autoint_on_taobao.config b/samples/model_config/autoint_on_taobao.config
index b7addad9b..d7fcae52b 100644
--- a/samples/model_config/autoint_on_taobao.config
+++ b/samples/model_config/autoint_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: True
- num_steps: 2500
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/bst_backbone_on_taobao.config b/samples/model_config/bst_backbone_on_taobao.config
new file mode 100644
index 000000000..b801f87ef
--- /dev/null
+++ b/samples/model_config/bst_backbone_on_taobao.config
@@ -0,0 +1,317 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/bst_backbone_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_name: 'BST'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'normal'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'sequence'
+ feature_names: "cate_id"
+ feature_names: "brand"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'deep'
+ inputs {
+ feature_group_name: 'normal'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ blocks {
+ name: 'seq_input'
+ inputs {
+ feature_group_name: 'sequence'
+ }
+ input_layer {
+ output_seq_and_normal_feature: true
+ }
+ }
+ blocks {
+ name: 'BST'
+ inputs {
+ block_name: 'seq_input'
+ }
+ keras_layer {
+ class_name: 'BST'
+ bst {
+ hidden_size: 128
+ num_attention_heads: 2
+ num_hidden_layers: 2
+ intermediate_size: 128
+ hidden_act: 'gelu'
+ max_position_embeddings: 50
+ hidden_dropout_prob: 0.1
+ attention_probs_dropout_prob: 0
+ }
+ }
+ }
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 0
+ }
+ embedding_regularization: 0
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/bst_on_taobao.config b/samples/model_config/bst_on_taobao.config
index 0ff2fac81..28f06fb97 100644
--- a/samples/model_config/bst_on_taobao.config
+++ b/samples/model_config/bst_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: True
- num_steps: 2500
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/cdn_on_taobao.config b/samples/model_config/cdn_on_taobao.config
new file mode 100644
index 000000000..d73b44cee
--- /dev/null
+++ b/samples/model_config/cdn_on_taobao.config
@@ -0,0 +1,333 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/cdn_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_name: 'Cross Decoupling Network'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'memorize'
+ feature_names: 'user_id'
+ feature_names: 'adgroup_id'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'frequency'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "mem_expert"
+ inputs {
+ feature_group_name: "memorize"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: [512, 256]
+ }
+ }
+ }
+ blocks {
+ name: "gen_experts"
+ inputs {
+ feature_group_name: "general"
+ input_fn: "lambda x: [x, x]"
+ }
+ repeat {
+ num_repeat: 3
+ keras_layer {
+ class_name: "MaskBlock"
+ mask_block {
+ output_size: 256
+ aggregation_size: 1024
+ }
+ }
+ }
+ }
+ blocks {
+ name: "gate_weight"
+ inputs {
+ feature_group_name: "frequency"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: 4
+ use_final_bn: false
+ final_activation: "softmax"
+ }
+ }
+ }
+ blocks {
+ name: "gate"
+ inputs {
+ block_name: "gate_weight"
+ input_fn: "lambda x: [x]"
+ }
+ inputs {
+ block_name: "mem_expert"
+ input_fn: "lambda x: [x]"
+ }
+ inputs {
+ block_name: "gen_experts"
+ }
+ keras_layer {
+ class_name: "Gate"
+ }
+ }
+ top_mlp {
+ hidden_units: [128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/samples/model_config/cl4srec_on_taobao.config b/samples/model_config/cl4srec_on_taobao.config
new file mode 100644
index 000000000..7887b9ece
--- /dev/null
+++ b/samples/model_config/cl4srec_on_taobao.config
@@ -0,0 +1,375 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/cl4srec_on_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ embedding_name: 'cate_id'
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ embedding_name: 'brand'
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ embedding_name: 'cate_id'
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ embedding_name: 'brand'
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_name: 'CL4SRec'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'campaign_id'
+ feature_names: 'cate_id'
+ feature_names: 'brand'
+ feature_names: 'customer'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'user_seq'
+ feature_names: "tag_brand_list"
+ feature_names: "tag_category_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'user_seq'
+ inputs {
+ feature_group_name: 'user_seq'
+ }
+ input_layer {
+ output_seq_and_normal_feature: true
+ }
+ }
+ packages {
+ name: 'seq_augment'
+ blocks {
+ name: 'augment'
+ inputs {
+ block_name: 'user_seq'
+ }
+ keras_layer {
+ class_name: 'SeqAugment'
+ seq_aug {
+ mask_rate: 0.6
+ crop_rate: 0.2
+ reorder_rate: 0.6
+ }
+ }
+ }
+ }
+ packages {
+ name: 'seq_encoder'
+ blocks {
+ name: 'BST'
+ inputs {
+ use_package_input: true
+ }
+ keras_layer {
+ class_name: 'BST'
+ bst {
+ hidden_size: 128
+ num_attention_heads: 2
+ num_hidden_layers: 2
+ intermediate_size: 128
+ hidden_act: 'gelu'
+ max_position_embeddings: 50
+ hidden_dropout_prob: 0.1
+ attention_probs_dropout_prob: 0
+ output_all_token_embeddings: false
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'contrastive'
+ inputs {
+ package_name: 'seq_encoder'
+ package_input: 'seq_augment'
+ }
+ inputs {
+ package_name: 'seq_encoder'
+ package_input: 'seq_augment'
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: 'AuxiliaryLoss'
+ st_params {
+ fields {
+ key: 'loss_type'
+ value: { string_value: 'nce_loss' }
+ }
+ fields {
+ key: 'loss_weight'
+ value: { number_value: 0.1 }
+ }
+ fields {
+ key: 'temperature'
+ value: { number_value: 0.15 }
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'main'
+ inputs {
+ package_name: 'seq_encoder'
+ package_input: 'user_seq'
+ }
+ inputs {
+ feature_group_name: 'user'
+ }
+ inputs {
+ feature_group_name: 'item'
+ }
+ }
+ concat_blocks: 'main'
+ top_mlp {
+ hidden_units: [256, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 0
+ }
+ embedding_regularization: 0
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/cl4srec_on_taobao_with_custom_op.config b/samples/model_config/cl4srec_on_taobao_with_custom_op.config
new file mode 100644
index 000000000..18ee40dde
--- /dev/null
+++ b/samples/model_config/cl4srec_on_taobao_with_custom_op.config
@@ -0,0 +1,375 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/cl4srec_on_taobao_custom_op_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 1024
+ num_epochs: 1
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ embedding_name: 'cate_id'
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ embedding_name: 'brand'
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ embedding_name: 'cate_id'
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ embedding_name: 'brand'
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_name: 'CL4SRec'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'campaign_id'
+ feature_names: 'cate_id'
+ feature_names: 'brand'
+ feature_names: 'customer'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'user_seq'
+ feature_names: "tag_brand_list"
+ feature_names: "tag_category_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'user_seq'
+ inputs {
+ feature_group_name: 'user_seq'
+ }
+ input_layer {
+ output_seq_and_normal_feature: true
+ }
+ }
+ packages {
+ name: 'seq_augment'
+ blocks {
+ name: 'augment'
+ inputs {
+ block_name: 'user_seq'
+ }
+ keras_layer {
+ class_name: 'SeqAugmentOps'
+ seq_aug {
+ mask_rate: 0.6
+ crop_rate: 0.2
+ reorder_rate: 0.6
+ }
+ }
+ }
+ }
+ packages {
+ name: 'seq_encoder'
+ blocks {
+ name: 'BST'
+ inputs {
+ use_package_input: true
+ }
+ keras_layer {
+ class_name: 'BST'
+ bst {
+ hidden_size: 128
+ num_attention_heads: 2
+ num_hidden_layers: 2
+ intermediate_size: 128
+ hidden_act: 'gelu'
+ max_position_embeddings: 50
+ hidden_dropout_prob: 0.1
+ attention_probs_dropout_prob: 0
+ output_all_token_embeddings: false
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'contrastive'
+ inputs {
+ package_name: 'seq_encoder'
+ package_input: 'seq_augment'
+ }
+ inputs {
+ package_name: 'seq_encoder'
+ package_input: 'seq_augment'
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: 'AuxiliaryLoss'
+ st_params {
+ fields {
+ key: 'loss_type'
+ value: { string_value: 'nce_loss' }
+ }
+ fields {
+ key: 'loss_weight'
+ value: { number_value: 0.1 }
+ }
+ fields {
+ key: 'temperature'
+ value: { number_value: 0.15 }
+ }
+ }
+ }
+ }
+ blocks {
+ name: 'main'
+ inputs {
+ package_name: 'seq_encoder'
+ package_input: 'user_seq'
+ }
+ inputs {
+ feature_group_name: 'user'
+ }
+ inputs {
+ feature_group_name: 'item'
+ }
+ }
+ concat_blocks: 'main'
+ top_mlp {
+ hidden_units: [256, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 0
+ }
+ embedding_regularization: 0
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/cmbf_on_movielens.config b/samples/model_config/cmbf_on_movielens.config
new file mode 100644
index 000000000..3a433a928
--- /dev/null
+++ b/samples/model_config/cmbf_on_movielens.config
@@ -0,0 +1,238 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/cmbf_movielens_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'CMBF'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ cmbf {
+ config {
+ multi_head_num: 2
+ text_multi_head_num: 4
+ image_multi_head_num: 3
+ image_head_size: 16
+ text_head_size: 8
+ image_feature_dim: 64
+ image_self_attention_layer_num: 0
+ text_self_attention_layer_num: 2
+ cross_modal_layer_num: 3
+ image_cross_head_size: 16
+ text_cross_head_size: 16
+ max_position_embeddings: 16
+ use_token_type: true
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/cmbf_on_movielens_has_other_feature.config b/samples/model_config/cmbf_on_movielens_has_other_feature.config
new file mode 100644
index 000000000..f7d6b9ccd
--- /dev/null
+++ b/samples/model_config/cmbf_on_movielens_has_other_feature.config
@@ -0,0 +1,245 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/cmbf_movielens_other_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'CMBF'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'other'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ cmbf {
+ config {
+ multi_head_num: 2
+ text_multi_head_num: 4
+ image_multi_head_num: 2
+ image_head_size: 8
+ text_head_size: 8
+ image_feature_dim: 64
+ image_self_attention_layer_num: 2
+ text_self_attention_layer_num: 2
+ cross_modal_layer_num: 3
+ image_cross_head_size: 16
+ text_cross_head_size: 16
+ max_position_embeddings: 16
+ use_token_type: true
+ other_feature_dnn: {
+ hidden_units: 64
+ }
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/cmbf_on_movielens_only_image_feature.config b/samples/model_config/cmbf_on_movielens_only_image_feature.config
new file mode 100644
index 000000000..06bbb4586
--- /dev/null
+++ b/samples/model_config/cmbf_on_movielens_only_image_feature.config
@@ -0,0 +1,144 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/cmbf_movielens_only_img_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'CMBF'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ cmbf {
+ config {
+ multi_head_num: 2
+ image_multi_head_num: 2
+ image_head_size: 8
+ text_head_size: 8
+ image_feature_dim: 64
+ image_self_attention_layer_num: 2
+ text_self_attention_layer_num: 2
+ cross_modal_layer_num: 3
+ image_cross_head_size: 8
+ text_cross_head_size: 16
+ max_position_embeddings: 16
+ use_token_type: true
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/cmbf_on_movielens_only_text_feature.config b/samples/model_config/cmbf_on_movielens_only_text_feature.config
new file mode 100644
index 000000000..2111a42ef
--- /dev/null
+++ b/samples/model_config/cmbf_on_movielens_only_text_feature.config
@@ -0,0 +1,226 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/cmbf_movielens_only_txt_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_class: 'CMBF'
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ cmbf {
+ config {
+ multi_head_num: 2
+ text_multi_head_num: 2
+ image_head_size: 8
+ text_head_size: 16
+ image_feature_dim: 64
+ image_self_attention_layer_num: 2
+ text_self_attention_layer_num: 2
+ cross_modal_layer_num: 3
+ image_cross_head_size: 8
+ text_cross_head_size: 16
+ max_position_embeddings: 16
+ use_token_type: true
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/cmbf_with_multi_loss.config b/samples/model_config/cmbf_with_multi_loss.config
new file mode 100644
index 000000000..feaa88955
--- /dev/null
+++ b/samples/model_config/cmbf_with_multi_loss.config
@@ -0,0 +1,247 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/cmbf_movielens_multi_loss_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'CMBF'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ cmbf {
+ config {
+ multi_head_num: 2
+ image_head_size: 8
+ text_head_size: 8
+ image_feature_dim: 64
+ image_self_attention_layer_num: 2
+ text_self_attention_layer_num: 2
+ cross_modal_layer_num: 3
+ image_cross_head_size: 8
+ text_cross_head_size: 16
+ max_position_embeddings: 16
+ use_token_type: true
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+ losses {
+ loss_type: F1_REWEIGHTED_LOSS
+ weight: 1.0
+ f1_reweighted_loss {
+ f1_beta_square: 0.5625
+ }
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/custom_early_stop_on_taobao.config b/samples/model_config/custom_early_stop_on_taobao.config
index 2e4cccba2..8dab099ff 100644
--- a/samples/model_config/custom_early_stop_on_taobao.config
+++ b/samples/model_config/custom_early_stop_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 50
sync_replicas: True
- num_steps: 500
+ num_steps: 200
}
eval_config {
diff --git a/samples/model_config/dat_on_taobao.config b/samples/model_config/dat_on_taobao.config
new file mode 100644
index 000000000..c113df2ed
--- /dev/null
+++ b/samples/model_config/dat_on_taobao.config
@@ -0,0 +1,320 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dat_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 200
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 4000
+ sync_replicas: false
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ recall_at_topk {
+ topk: 50
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 512
+ prefetch_size: 32
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 2048
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_class: "DAT"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: 'user_id_augment'
+ feature_names: 'user_id'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: 'item_id_augment'
+ feature_names: 'adgroup_id'
+ wide_deep:DEEP
+ }
+
+ dat {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [ 128, 32]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [ 128, 32]
+ }
+ }
+ simi_func: COSINE
+ temperature: 0.01
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+ loss_type: SOFTMAX_CROSS_ENTROPY
+}
+
+export_config {
+}
diff --git a/samples/model_config/dbmtl_backbone_on_taobao.config b/samples/model_config/dbmtl_backbone_on_taobao.config
new file mode 100644
index 000000000..876823bc9
--- /dev/null
+++ b/samples/model_config/dbmtl_backbone_on_taobao.config
@@ -0,0 +1,316 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_backbone_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 100
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_config: {
+ features {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "tag_category_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "tag_brand_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config {
+ model_name: "DBMTL"
+ model_class: "MultiTaskModel"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "mask_net"
+ inputs {
+ feature_group_name: "all"
+ }
+ keras_layer {
+ class_name: 'MaskNet'
+ masknet {
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mlp {
+ hidden_units: [512, 256]
+ }
+ }
+ }
+ }
+ }
+ model_params {
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_cmbf_on_movielens.config b/samples/model_config/dbmtl_cmbf_on_movielens.config
new file mode 100644
index 000000000..73e9ff028
--- /dev/null
+++ b/samples/model_config/dbmtl_cmbf_on_movielens.config
@@ -0,0 +1,272 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/dbmtl_cmbf_movielens_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ label_fields: 'rating'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'DBMTL'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ dbmtl {
+ bottom_cmbf {
+ multi_head_num: 2
+ image_multi_head_num: 2
+ text_multi_head_num: 2
+ image_feature_patch_num: 8
+ image_head_size: 32
+ text_head_size: 8
+ image_self_attention_layer_num: 2
+ text_self_attention_layer_num: 2
+ cross_modal_layer_num: 3
+ image_cross_head_size: 32
+ text_cross_head_size: 8
+ max_position_embeddings: 16
+ use_token_type: true
+ }
+ task_towers {
+ tower_name: "classify"
+ label_name: "label"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "rating"
+ label_name: "rating"
+ loss_type: L2_LOSS
+ metrics_set: {
+ mean_squared_error {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ relation_tower_names: ["classify"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/dbmtl_mmoe_on_taobao.config b/samples/model_config/dbmtl_mmoe_on_taobao.config
index f5e03eeb9..de817d1be 100644
--- a/samples/model_config/dbmtl_mmoe_on_taobao.config
+++ b/samples/model_config/dbmtl_mmoe_on_taobao.config
@@ -16,7 +16,7 @@ train_config {
}
use_moving_average: false
}
- num_steps: 5000
+ num_steps: 100
sync_replicas: true
save_checkpoints_steps: 100
log_step_count_steps: 100
diff --git a/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform.config b/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform.config
new file mode 100644
index 000000000..de7a57424
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform.config
@@ -0,0 +1,346 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 100
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ allow_key_transform:true
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform_dnn.config b/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform_dnn.config
new file mode 100644
index 000000000..6beaef9fc
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform_dnn.config
@@ -0,0 +1,357 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 1000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ allow_key_transform:true
+ transform_dnn:true
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ sequence_features: {
+ group_name: "seq_fea_2"
+ tf_summary: false
+ allow_key_transform:true
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform_multi_seq.config b/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform_multi_seq.config
new file mode 100644
index 000000000..d4b2785e1
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform_multi_seq.config
@@ -0,0 +1,356 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 1000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ allow_key_transform:true
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ sequence_features: {
+ group_name: "seq_fea_2"
+ tf_summary: false
+ allow_key_transform:true
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_numeric_boundary_need_key_feature_taobao.config b/samples/model_config/dbmtl_on_multi_numeric_boundary_need_key_feature_taobao.config
new file mode 100644
index 000000000..6584391eb
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_boundary_need_key_feature_taobao.config
@@ -0,0 +1,347 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ need_key_feature:false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_numeric_boundary_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_multi_numeric_boundary_sequence_feature_taobao.config
new file mode 100644
index 000000000..26a9a615b
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_boundary_sequence_feature_taobao.config
@@ -0,0 +1,346 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_numeric_hash_bucket_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_multi_numeric_hash_bucket_sequence_feature_taobao.config
new file mode 100644
index 000000000..90a26ce93
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_hash_bucket_sequence_feature_taobao.config
@@ -0,0 +1,302 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_numeric_num_buckets_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_multi_numeric_num_buckets_sequence_feature_taobao.config
new file mode 100644
index 000000000..1d9c6852f
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_num_buckets_sequence_feature_taobao.config
@@ -0,0 +1,308 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ num_buckets: 15
+ max_val: 100000
+ min_val: 0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ num_buckets: 15
+ max_val: 100000
+ min_val: 0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_numeric_raw_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_multi_numeric_raw_sequence_feature_taobao.config
new file mode 100644
index 000000000..c2d005f8b
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_numeric_raw_sequence_feature_taobao.config
@@ -0,0 +1,304 @@
+train_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+eval_input_path: "data/test/tb_data/taobao_multi_seq_train_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ sub_feature_type: RawFeature
+ sequence_length:50
+ embedding_dim: 16
+ raw_input_dim: 4
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ sub_feature_type: RawFeature
+ sequence_length:50
+ embedding_dim: 16
+ raw_input_dim: 4
+ separator: "|"
+ seq_multi_sep: ";"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_multi_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_multi_sequence_feature_taobao.config
new file mode 100644
index 000000000..33a213fc8
--- /dev/null
+++ b/samples/model_config/dbmtl_on_multi_sequence_feature_taobao.config
@@ -0,0 +1,311 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea_1"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ }
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ }
+ }
+
+ sequence_features: {
+ group_name: "seq_fea_2"
+ tf_summary: false
+ seq_att_map: {
+ key: "cate_id"
+ hist_seq: "tag_category_list"
+ }
+ seq_att_map: {
+ key: "cate_id"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_aux_hist_seq_taobao.config b/samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_aux_hist_seq_taobao.config
new file mode 100644
index 000000000..ea279e2eb
--- /dev/null
+++ b/samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_aux_hist_seq_taobao.config
@@ -0,0 +1,343 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ aux_hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_taobao.config
new file mode 100644
index 000000000..5303a7af4
--- /dev/null
+++ b/samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_taobao.config
@@ -0,0 +1,344 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ boundaries: 15.0
+ boundaries: 20.0
+ boundaries: 21.0
+ boundaries: 23.0
+ boundaries: 30.0
+ boundaries: 32.0
+ boundaries: 40.0
+ boundaries: 47.0
+ boundaries: 66.0
+ boundaries: 70.0
+ boundaries: 77.0
+ boundaries: 87.0
+ boundaries: 99.0
+ boundaries: 120.0
+ boundaries: 148.0
+ boundaries: 188.0
+ boundaries: 199.0
+ boundaries: 235.0
+ boundaries: 301.0
+ boundaries: 443.0
+ boundaries: 597.0
+ boundaries: 1314.0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_numeric_hash_bucket_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_numeric_hash_bucket_sequence_feature_taobao.config
new file mode 100644
index 000000000..83e682c0b
--- /dev/null
+++ b/samples/model_config/dbmtl_on_numeric_hash_bucket_sequence_feature_taobao.config
@@ -0,0 +1,300 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_numeric_num_buckets_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_numeric_num_buckets_sequence_feature_taobao.config
new file mode 100644
index 000000000..d5d2f304f
--- /dev/null
+++ b/samples/model_config/dbmtl_on_numeric_num_buckets_sequence_feature_taobao.config
@@ -0,0 +1,306 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ num_buckets: 15
+ max_val: 100000
+ min_val: 0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ num_buckets: 15
+ max_val: 100000
+ min_val: 0
+ sub_feature_type: RawFeature
+ sequence_length: 300
+ separator: "|"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_numeric_raw_sequence_feature_taobao.config b/samples/model_config/dbmtl_on_numeric_raw_sequence_feature_taobao.config
new file mode 100644
index 000000000..6b2a3559d
--- /dev/null
+++ b/samples/model_config/dbmtl_on_numeric_raw_sequence_feature_taobao.config
@@ -0,0 +1,300 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 1
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ sub_feature_type: RawFeature
+ sequence_length:50
+ embedding_dim: 16
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ sub_feature_type: RawFeature
+ sequence_length:50
+ embedding_dim: 16
+ separator: "|"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_on_taobao_with_multi_loss.config b/samples/model_config/dbmtl_on_taobao_with_multi_loss.config
new file mode 100644
index 000000000..d04564b02
--- /dev/null
+++ b/samples/model_config/dbmtl_on_taobao_with_multi_loss.config
@@ -0,0 +1,312 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_config: {
+ features {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "tag_category_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "tag_brand_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ activation: "dice"
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ losses {
+ loss_type: F1_REWEIGHTED_LOSS
+ weight: 1.0
+ f1_reweighted_loss {
+ f1_beta_square: 1.0
+ }
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ activation: "dice"
+ }
+ relation_dnn {
+ hidden_units: [32]
+ activation: "dice"
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ num_class: 2
+ losses {
+ loss_type: JRC_LOSS
+ jrc_loss {
+ session_name: "user_id"
+ alpha: 0.5
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ activation: "dice"
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ activation: "dice"
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dbmtl_uniter_on_movielens.config b/samples/model_config/dbmtl_uniter_on_movielens.config
new file mode 100644
index 000000000..ce620d67f
--- /dev/null
+++ b/samples/model_config/dbmtl_uniter_on_movielens.config
@@ -0,0 +1,275 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/dbmtl_cmbf_movielens_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ label_fields: 'rating'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'DBMTL'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'other'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ dbmtl {
+ bottom_uniter {
+ hidden_size: 512
+ num_attention_heads: 4
+ num_hidden_layers: 2
+ intermediate_size: 512
+ hidden_act: 'swish'
+ max_position_embeddings: 16
+ hidden_dropout_prob: 0.1
+ attention_probs_dropout_prob: 0
+ other_feature_dnn: {
+ hidden_units: 256
+ hidden_units: 128
+ }
+ }
+ task_towers {
+ tower_name: "classify"
+ label_name: "label"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "rating"
+ label_name: "rating"
+ loss_type: L2_LOSS
+ metrics_set: {
+ mean_squared_error {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64]
+ }
+ relation_tower_names: ["classify"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/dbmtl_variational_dropout_on_sequence_feature_taobao.config b/samples/model_config/dbmtl_variational_dropout_on_sequence_feature_taobao.config
new file mode 100644
index 000000000..df848bd1a
--- /dev/null
+++ b/samples/model_config/dbmtl_variational_dropout_on_sequence_feature_taobao.config
@@ -0,0 +1,300 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dbmtl_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_configs {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+}
+feature_configs {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config {
+ model_class: "DBMTL"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ tf_summary: false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ dbmtl {
+ bottom_dnn {
+ hidden_units: [1024, 512, 256]
+ }
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ relation_tower_names: ["ctr"]
+ relation_dnn {
+ hidden_units: [32]
+ }
+ weight: 1.0
+ }
+ l2_regularization: 1e-6
+ }
+ variational_dropout {
+ regularization_lambda:0.01
+ embedding_wise_variational_dropout:true
+ }
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dcn_backbone_on_taobao.config b/samples/model_config/dcn_backbone_on_taobao.config
new file mode 100644
index 000000000..86f4e2462
--- /dev/null
+++ b/samples/model_config/dcn_backbone_on_taobao.config
@@ -0,0 +1,291 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dcn_backbone_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "deep"
+ inputs {
+ feature_group_name: "all"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ blocks {
+ name: "cross"
+ inputs {
+ feature_group_name: "all"
+ input_fn: "lambda x: [x, x]"
+ }
+ recurrent {
+ num_steps: 5
+ fixed_input_index: 0
+ keras_layer {
+ class_name: "Cross"
+ }
+ }
+ }
+ concat_blocks: ['deep', 'cross']
+ top_mlp {
+ hidden_units: [64, 32, 16]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/samples/model_config/dead_line_stop.config b/samples/model_config/dead_line_stop.config
new file mode 100644
index 000000000..c3017dd4b
--- /dev/null
+++ b/samples/model_config/dead_line_stop.config
@@ -0,0 +1,300 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_demo"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 400
+ sync_replicas: false
+ log_step_count_steps: 200
+ # stop after dead_line time
+ dead_line: '20220516 23:59:59'
+ save_checkpoints_steps: 10
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ separator: ""
+ }
+}
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+}
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_avazu_kafka.config b/samples/model_config/deepfm_combo_avazu_kafka.config
new file mode 100644
index 000000000..6e5c1fd5d
--- /dev/null
+++ b/samples/model_config/deepfm_combo_avazu_kafka.config
@@ -0,0 +1,387 @@
+# data/test/dwd_avazu_ctr_deepmodel_10w.csv
+
+kafka_train_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_train'
+ offset_info: '{"0": 5, "1": 10}'
+}
+
+kafka_eval_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_test'
+ offset_info: '{"0":20, "1":30}'
+}
+
+model_dir: "experiments/dwd_avazu_out_test_combo_kafka"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ num_examples: 10240
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: KafkaInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_avazu_kafka_chief_redundant.config b/samples/model_config/deepfm_combo_avazu_kafka_chief_redundant.config
new file mode 100644
index 000000000..74712f406
--- /dev/null
+++ b/samples/model_config/deepfm_combo_avazu_kafka_chief_redundant.config
@@ -0,0 +1,388 @@
+# data/test/dwd_avazu_ctr_deepmodel_10w.csv
+
+kafka_train_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_train'
+ offset_info: '{"0": 5, "1": 10}'
+}
+
+kafka_eval_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_test'
+ offset_info: '{"0":20, "1":30}'
+}
+
+model_dir: "experiments/dwd_avazu_out_test_combo_kafka_chief_redundant"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: false
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ num_examples: 10240
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ chief_redundant: true
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: KafkaInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_avazu_kafka_time_offset.config b/samples/model_config/deepfm_combo_avazu_kafka_time_offset.config
new file mode 100644
index 000000000..0dfa19629
--- /dev/null
+++ b/samples/model_config/deepfm_combo_avazu_kafka_time_offset.config
@@ -0,0 +1,389 @@
+# data/test/dwd_avazu_ctr_deepmodel_10w.csv
+
+kafka_train_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_train'
+ # timestamp in seconds
+ offset_time: '1650765731'
+}
+
+kafka_eval_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_test'
+ # timestamp in seconds
+ offset_time: '1650765731'
+}
+
+model_dir: "experiments/dwd_avazu_out_time_offset"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ num_examples: 10240
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: KafkaInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_avazu_kafka_time_offset2.config b/samples/model_config/deepfm_combo_avazu_kafka_time_offset2.config
new file mode 100644
index 000000000..656a7b967
--- /dev/null
+++ b/samples/model_config/deepfm_combo_avazu_kafka_time_offset2.config
@@ -0,0 +1,389 @@
+# data/test/dwd_avazu_ctr_deepmodel_10w.csv
+
+kafka_train_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_train'
+ # offset by date time
+ offset_time: '20220517 20:00:00'
+}
+
+kafka_eval_input {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_op_test_topic'
+ group: 'kafka_test'
+ # offset by date time
+ offset_time: '20220517 20:00:00'
+}
+
+model_dir: "experiments/dwd_avazu_out_time_offset"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ num_examples: 10240
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: KafkaInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_on_avazu_ctr.config b/samples/model_config/deepfm_combo_on_avazu_ctr.config
index 4d637c62c..a2a7fa765 100644
--- a/samples/model_config/deepfm_combo_on_avazu_ctr.config
+++ b/samples/model_config/deepfm_combo_on_avazu_ctr.config
@@ -291,7 +291,7 @@ feature_config: {
input_names: ["site_id", "app_id"]
feature_name: "site_id_app_id"
feature_type: ComboFeature
- hash_bucket_size: 1000,
+ hash_bucket_size: 1000
embedding_dim: 16
}
@@ -364,7 +364,7 @@ model_config:{
}
l2_regularization: 1e-5
}
- embedding_regularization: 1e-7
+ # embedding_regularization: 1e-7
}
export_config {
diff --git a/samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config b/samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config
new file mode 100644
index 000000000..a4a920137
--- /dev/null
+++ b/samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config
@@ -0,0 +1,383 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "experiments/dwd_avazu_out_test_combo_embedding_adagrad"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config {
+ adagrad_optimizer {
+ learning_rate {
+ constant_learning_rate {
+ learning_rate: 0.05
+ }
+ }
+ initial_accumulator_value: 1.0
+ }
+ }
+
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ # embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_on_avazu_eval_online_gauc_ctr.config b/samples/model_config/deepfm_combo_on_avazu_eval_online_gauc_ctr.config
new file mode 100644
index 000000000..70014d209
--- /dev/null
+++ b/samples/model_config/deepfm_combo_on_avazu_eval_online_gauc_ctr.config
@@ -0,0 +1,373 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "experiments/dwd_avazu_out_test_combo_eval_online"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ gauc {uid_field: "device_id"}
+ }
+ eval_online: true
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_on_avazu_feature_name.config b/samples/model_config/deepfm_combo_on_avazu_feature_name.config
new file mode 100644
index 000000000..65aff377d
--- /dev/null
+++ b/samples/model_config/deepfm_combo_on_avazu_feature_name.config
@@ -0,0 +1,396 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "experiments/dwd_avazu_out_test_combo_feature_name"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ feature_name: "banner_pos_v2"
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,3,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "c19"]
+ feature_name: "site_id_c19"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["c21", "c19"]
+ feature_name: "c19_c21"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "banner_pos_v2"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ feature_names: "site_id_c19"
+ feature_names: "c19_c21"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ # embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_v2_on_avazu_ctr.config b/samples/model_config/deepfm_combo_v2_on_avazu_ctr.config
new file mode 100644
index 000000000..6eaeed9d0
--- /dev/null
+++ b/samples/model_config/deepfm_combo_v2_on_avazu_ctr.config
@@ -0,0 +1,372 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "experiments/dwd_avazu_out_test_combo_v2"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000
+ combo_join_sep: "X"
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ # embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_v3_on_avazu_ctr.config b/samples/model_config/deepfm_combo_v3_on_avazu_ctr.config
new file mode 100644
index 000000000..7b3b21cf3
--- /dev/null
+++ b/samples/model_config/deepfm_combo_v3_on_avazu_ctr.config
@@ -0,0 +1,383 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "experiments/dwd_avazu_out_test_combo_v3"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000
+ combo_join_sep: "X"
+ combo_input_seps: [",", ""]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["hour", "app_id"]
+ feature_name: "hour_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000
+ combo_join_sep: "X"
+ combo_input_seps: ["", ","]
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ feature_names: "hour_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ # embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_combo_variational_dropout_on_avazu_ctr.config b/samples/model_config/deepfm_combo_variational_dropout_on_avazu_ctr.config
new file mode 100644
index 000000000..a21c72052
--- /dev/null
+++ b/samples/model_config/deepfm_combo_variational_dropout_on_avazu_ctr.config
@@ -0,0 +1,376 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "experiments/dwd_avazu_out_test_combo"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "site_id_app_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "site_id_app_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ variational_dropout {
+ regularization_lambda:0.01
+ embedding_wise_variational_dropout:true
+ }
+ embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_distribute_eval_combo_on_avazu_ctr.config b/samples/model_config/deepfm_distribute_eval_combo_on_avazu_ctr.config
new file mode 100644
index 000000000..eaf1f6e3f
--- /dev/null
+++ b/samples/model_config/deepfm_distribute_eval_combo_on_avazu_ctr.config
@@ -0,0 +1,385 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "data/test/distribute_eval_test/dwd_distribute_eval_avazu_out_test_combo"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ gauc {
+ uid_field:'device_ip'
+ }
+ }
+ metrics_set: {
+ session_auc {
+ session_id_field: 'device_ip'
+ }
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ num_buckets: 24
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: ["site_id", "app_id"]
+ feature_name: "app_id_X_site_id"
+ feature_type: ComboFeature
+ hash_bucket_size: 1000,
+ embedding_dim: 16
+ }
+
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "app_id_X_site_id"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-7
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/deepfm_distribute_eval_multi_cls_on_avazu_ctr.config b/samples/model_config/deepfm_distribute_eval_multi_cls_on_avazu_ctr.config
new file mode 100644
index 000000000..e5ab92f19
--- /dev/null
+++ b/samples/model_config/deepfm_distribute_eval_multi_cls_on_avazu_ctr.config
@@ -0,0 +1,364 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ accuracy {}
+ }
+ metrics_set: {
+ recall_at_topk { topk: 2}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: IdFeature
+ hash_bucket_size: 12
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "DeepFM"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+
+ deepfm {
+ wide_output_dim: 16
+
+ dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ final_dnn {
+ hidden_units: [128, 64]
+ }
+ l2_regularization: 1e-5
+ }
+
+ num_class: 5
+ embedding_regularization: 1e-7
+}
diff --git a/samples/model_config/deepfm_multi_cls_on_avazu_ctr.config b/samples/model_config/deepfm_multi_cls_on_avazu_ctr.config
index a755413bb..84e2fb9d4 100644
--- a/samples/model_config/deepfm_multi_cls_on_avazu_ctr.config
+++ b/samples/model_config/deepfm_multi_cls_on_avazu_ctr.config
@@ -26,7 +26,9 @@ train_config {
eval_config {
metrics_set: {
- accuracy {}
+ accuracy {}
+ }
+ metrics_set: {
recall_at_topk { topk: 2}
}
}
diff --git a/samples/model_config/deepfm_on_criteo_with_autodis.config b/samples/model_config/deepfm_on_criteo_with_autodis.config
new file mode 100755
index 000000000..41975c090
--- /dev/null
+++ b/samples/model_config/deepfm_on_criteo_with_autodis.config
@@ -0,0 +1,786 @@
+train_input_path: "data/test/criteo_sample.tfrecord"
+eval_input_path: "data/test/criteo_sample.tfrecord"
+
+model_dir: "experiments/deepfm_with_autodis"
+
+train_config {
+ log_step_count_steps: 20
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: INT64
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 8096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: TFRecordInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ feature_name: "D1"
+ input_names: "F1"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ feature_name: "D2"
+ input_names: "F2"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ feature_name: "D3"
+ input_names: "F3"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ feature_name: "D4"
+ input_names: "F4"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ feature_name: "D5"
+ input_names: "F5"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ feature_name: "D6"
+ input_names: "F6"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ feature_name: "D7"
+ input_names: "F7"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ feature_name: "D8"
+ input_names: "F8"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ feature_name: "D9"
+ input_names: "F9"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ feature_name: "D10"
+ input_names: "F10"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ feature_name: "D11"
+ input_names: "F11"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ feature_name: "D12"
+ input_names: "F12"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ feature_name: "D13"
+ input_names: "F13"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+}
+model_config:{
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "numerical_features"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "categorical_features"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide_features"
+ feature_names: "D1"
+ feature_names: "D2"
+ feature_names: "D3"
+ feature_names: "D4"
+ feature_names: "D5"
+ feature_names: "D6"
+ feature_names: "D7"
+ feature_names: "D8"
+ feature_names: "D9"
+ feature_names: "D10"
+ feature_names: "D11"
+ feature_names: "D12"
+ feature_names: "D13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:WIDE
+ }
+ backbone {
+ blocks {
+ name: 'wide_features'
+ inputs {
+ feature_group_name: 'wide_features'
+ }
+ input_layer {
+ wide_output_dim: 1
+ }
+ }
+ blocks {
+ name: 'wide_logit'
+ inputs {
+ block_name: 'wide_features'
+ }
+ lambda {
+ expression: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
+ }
+ }
+ blocks {
+ name: 'num_emb'
+ inputs {
+ feature_group_name: 'numerical_features'
+ }
+ keras_layer {
+ class_name: 'AutoDisEmbedding'
+ auto_dis_embedding {
+ embedding_dim: 10
+ num_bins: 20
+ temperature: 0.815
+ output_tensor_list: true
+ }
+ }
+ }
+ blocks {
+ name: 'categorical_features'
+ inputs {
+ feature_group_name: 'categorical_features'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'fm'
+ inputs {
+ block_name: 'categorical_features'
+ input_fn: 'lambda x: x[1]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_fn: 'lambda x: x[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ fm {
+ use_variant: true
+ }
+ }
+ }
+ blocks {
+ name: 'deep'
+ inputs {
+ block_name: 'categorical_features'
+ input_fn: 'lambda x: x[0]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_fn: 'lambda x: x[0]'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ concat_blocks: ['wide_logit', 'fm', 'deep']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/samples/model_config/deepfm_on_criteo_with_periodic.config b/samples/model_config/deepfm_on_criteo_with_periodic.config
new file mode 100755
index 000000000..081fbf2cf
--- /dev/null
+++ b/samples/model_config/deepfm_on_criteo_with_periodic.config
@@ -0,0 +1,785 @@
+train_input_path: "data/test/criteo_sample.tfrecord"
+eval_input_path: "data/test/criteo_sample.tfrecord"
+
+model_dir: "experiments/deepfm_with_periodic"
+
+train_config {
+ log_step_count_steps: 20
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: "\t"
+ input_fields: {
+ input_name: "label"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F1"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F2"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F3"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F4"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F5"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F6"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F7"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F8"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F9"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F10"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F11"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F12"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "F13"
+ input_type: FLOAT
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "C1"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C2"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C3"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C4"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C5"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C6"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C7"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C8"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C9"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C10"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C11"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C12"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C13"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C14"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C15"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C16"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C17"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C18"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C19"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C20"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C21"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C22"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C23"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C24"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C25"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "C26"
+ input_type: INT64
+ default_val:""
+ }
+ label_fields: "label"
+
+ batch_size: 8096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: TFRecordInput
+}
+
+feature_config: {
+ features: {
+ input_names: "F1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "F2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "F3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "F4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "F5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "F6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "F7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "F8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "F9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "F10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "F11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "F12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "F13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "C1"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C2"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C3"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C4"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C5"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C6"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C7"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C8"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C9"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C10"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C11"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C12"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C13"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C14"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C15"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C16"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C17"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C18"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C19"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C20"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C21"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C22"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C23"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C24"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C25"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ input_names: "C26"
+ hash_bucket_size: 1000000
+ feature_type: IdFeature
+ embedding_dim: 10
+ embedding_name: "vocab_embed"
+ }
+ features: {
+ feature_name: "D1"
+ input_names: "F1"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ feature_name: "D2"
+ input_names: "F2"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ feature_name: "D3"
+ input_names: "F3"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ feature_name: "D4"
+ input_names: "F4"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ feature_name: "D5"
+ input_names: "F5"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ feature_name: "D6"
+ input_names: "F6"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ feature_name: "D7"
+ input_names: "F7"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ feature_name: "D8"
+ input_names: "F8"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ feature_name: "D9"
+ input_names: "F9"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ feature_name: "D10"
+ input_names: "F10"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ feature_name: "D11"
+ input_names: "F11"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ feature_name: "D12"
+ input_names: "F12"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ feature_name: "D13"
+ input_names: "F13"
+ embedding_dim:10
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+}
+model_config:{
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: "numerical_features"
+ feature_names: "F1"
+ feature_names: "F2"
+ feature_names: "F3"
+ feature_names: "F4"
+ feature_names: "F5"
+ feature_names: "F6"
+ feature_names: "F7"
+ feature_names: "F8"
+ feature_names: "F9"
+ feature_names: "F10"
+ feature_names: "F11"
+ feature_names: "F12"
+ feature_names: "F13"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "categorical_features"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide_features"
+ feature_names: "D1"
+ feature_names: "D2"
+ feature_names: "D3"
+ feature_names: "D4"
+ feature_names: "D5"
+ feature_names: "D6"
+ feature_names: "D7"
+ feature_names: "D8"
+ feature_names: "D9"
+ feature_names: "D10"
+ feature_names: "D11"
+ feature_names: "D12"
+ feature_names: "D13"
+ feature_names: "C1"
+ feature_names: "C2"
+ feature_names: "C3"
+ feature_names: "C4"
+ feature_names: "C5"
+ feature_names: "C6"
+ feature_names: "C7"
+ feature_names: "C8"
+ feature_names: "C9"
+ feature_names: "C10"
+ feature_names: "C11"
+ feature_names: "C12"
+ feature_names: "C13"
+ feature_names: "C14"
+ feature_names: "C15"
+ feature_names: "C16"
+ feature_names: "C17"
+ feature_names: "C18"
+ feature_names: "C19"
+ feature_names: "C20"
+ feature_names: "C21"
+ feature_names: "C22"
+ feature_names: "C23"
+ feature_names: "C24"
+ feature_names: "C25"
+ feature_names: "C26"
+ wide_deep:WIDE
+ }
+ backbone {
+ blocks {
+ name: 'wide_features'
+ inputs {
+ feature_group_name: 'wide_features'
+ }
+ input_layer {
+ wide_output_dim: 1
+ }
+ }
+ blocks {
+ name: 'wide_logit'
+ inputs {
+ block_name: 'wide_features'
+ }
+ lambda {
+ expression: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
+ }
+ }
+ blocks {
+ name: 'num_emb'
+ inputs {
+ feature_group_name: 'numerical_features'
+ }
+ keras_layer {
+ class_name: 'PeriodicEmbedding'
+ periodic_embedding {
+ embedding_dim: 10
+ sigma: 0.005
+ output_tensor_list: true
+ }
+ }
+ }
+ blocks {
+ name: 'categorical_features'
+ inputs {
+ feature_group_name: 'categorical_features'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'fm'
+ inputs {
+ block_name: 'categorical_features'
+ input_fn: 'lambda x: x[1]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_fn: 'lambda x: x[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ fm {
+ use_variant: true
+ }
+ }
+ }
+ blocks {
+ name: 'deep'
+ inputs {
+ block_name: 'categorical_features'
+ input_fn: 'lambda x: x[0]'
+ }
+ inputs {
+ block_name: 'num_emb'
+ input_fn: 'lambda x: x[0]'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ concat_blocks: ['wide_logit', 'fm', 'deep']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-5
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/samples/model_config/din_backbone_on_taobao.config b/samples/model_config/din_backbone_on_taobao.config
new file mode 100644
index 000000000..7cb48ac56
--- /dev/null
+++ b/samples/model_config/din_backbone_on_taobao.config
@@ -0,0 +1,315 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/din_backbone_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ max_seq_len: 50
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_name: 'DIN'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'normal'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'sequence'
+ feature_names: "cate_id"
+ feature_names: "brand"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'deep'
+ inputs {
+ feature_group_name: 'normal'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ blocks {
+ name: 'seq_input'
+ inputs {
+ feature_group_name: 'sequence'
+ }
+ input_layer {
+ output_seq_and_normal_feature: true
+ }
+ }
+ blocks {
+ name: 'DIN'
+ inputs {
+ block_name: 'seq_input'
+ }
+ keras_layer {
+ class_name: 'DIN'
+ din {
+ attention_dnn {
+ hidden_units: 32
+ hidden_units: 1
+ activation: "dice"
+ }
+ need_target_feature: true
+ }
+ }
+ }
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 0
+ }
+ embedding_regularization: 0
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/din_on_gzip_data.config b/samples/model_config/din_on_gzip_data.config
new file mode 100644
index 000000000..f3bb5e924
--- /dev/null
+++ b/samples/model_config/din_on_gzip_data.config
@@ -0,0 +1,298 @@
+train_input_path: "data/test/tb_data/taobao_test_data_compress.gz"
+eval_input_path: "data/test/tb_data/taobao_test_data_compress.gz"
+model_dir: "experiments/din_on_gzip_data"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTowerDIN'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ seq_att_groups: {
+ group_name: "din"
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ }
+ seq_att_map: {
+ key: "cate_id"
+ hist_seq: "tag_category_list"
+ }
+ }
+
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ din_towers {
+ input: "din"
+ dnn {
+ hidden_units: [128, 64, 32, 1]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 5e-7
+ }
+ embedding_regularization: 5e-5
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/din_on_taobao_latest_export.config b/samples/model_config/din_on_taobao_latest_export.config
new file mode 100644
index 000000000..928431f92
--- /dev/null
+++ b/samples/model_config/din_on_taobao_latest_export.config
@@ -0,0 +1,301 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/din_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 10
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTowerDIN'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ seq_att_groups: {
+ group_name: "din"
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ }
+ seq_att_map: {
+ key: "cate_id"
+ hist_seq: "tag_category_list"
+ }
+ }
+
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ din_towers {
+ input: "din"
+ dnn {
+ hidden_units: [128, 64, 32, 1]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 5e-7
+ }
+ embedding_regularization: 5e-5
+}
+
+export_config {
+ multi_placeholder: false
+ exporter_type: "latest"
+ exports_to_keep: 10
+ asset_files: "samples/rtp_fg/fg.json"
+}
diff --git a/samples/model_config/din_varitional_dropout_on_taobao.config b/samples/model_config/din_varitional_dropout_on_taobao.config
new file mode 100644
index 000000000..4406fa68e
--- /dev/null
+++ b/samples/model_config/din_varitional_dropout_on_taobao.config
@@ -0,0 +1,302 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/din_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTowerDIN'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ seq_att_groups: {
+ group_name: "din"
+ seq_att_map: {
+ key: "brand"
+ hist_seq: "tag_brand_list"
+ }
+ seq_att_map: {
+ key: "cate_id"
+ hist_seq: "tag_category_list"
+ }
+ }
+
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ din_towers {
+ input: "din"
+ dnn {
+ hidden_units: [128, 64, 32, 1]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 5e-7
+ }
+ variational_dropout {
+ regularization_lambda:0.01
+ embedding_wise_variational_dropout:true
+ }
+ embedding_regularization: 5e-5
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/dlrm_backbone_on_taobao.config b/samples/model_config/dlrm_backbone_on_taobao.config
new file mode 100644
index 000000000..a66f1d190
--- /dev/null
+++ b/samples/model_config/dlrm_backbone_on_taobao.config
@@ -0,0 +1,299 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dlrm_backbone_taobao_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ log_step_count_steps: 10
+ sync_replicas: true
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: DOUBLE
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: RawFeature
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: RawFeature
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: RawFeature
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: RawFeature
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: RawFeature
+ }
+}
+model_config {
+ model_class: 'RankModel'
+
+ feature_groups {
+ group_name: 'dense'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'new_user_class_level'
+ feature_names: 'price'
+
+ wide_deep: DEEP
+ }
+
+ feature_groups {
+ group_name: 'sparse'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'occupation'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'bottom_mlp'
+ inputs {
+ feature_group_name: 'dense'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [64, 32, 16]
+ }
+ }
+ }
+ blocks {
+ name: 'sparse'
+ inputs {
+ feature_group_name: 'sparse'
+ }
+ input_layer {
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: 'dot'
+ inputs {
+ block_name: 'bottom_mlp'
+ input_fn: 'lambda x: [x]'
+ }
+ inputs {
+ block_name: 'sparse'
+ }
+ keras_layer {
+ class_name: 'DotInteraction'
+ }
+ }
+ concat_blocks: ['bottom_mlp', 'dot']
+ top_mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ model_params {
+ }
+ embedding_regularization: 1e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dlrm_on_criteo.config b/samples/model_config/dlrm_on_criteo.config
new file mode 100644
index 000000000..30973b29d
--- /dev/null
+++ b/samples/model_config/dlrm_on_criteo.config
@@ -0,0 +1,571 @@
+binary_train_input {
+ category_path: 'criteo_data/category.bin'
+ dense_path: 'criteo_data/dense.bin'
+ label_path: 'criteo_data/label.bin'
+}
+binary_eval_input {
+ category_path: 'criteo_data/category.bin'
+ dense_path: 'criteo_data/dense.bin'
+ label_path: 'criteo_data/label.bin'
+}
+model_dir: "experiments/dlrm_criteo"
+
+train_config {
+ optimizer_config: {
+ momentum_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 1e-4
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 10000
+ log_step_count_steps: 10
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+
+ input_fields: {
+ input_name: "f1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f2"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f3"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f4"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f5"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f6"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f7"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f8"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f9"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f10"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f11"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f12"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f13"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c2"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c3"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c4"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c5"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c6"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c7"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c8"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c9"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c10"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c11"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c12"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c13"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c22"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c23"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c24"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c25"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c26"
+ input_type: INT64
+ default_val:""
+ }
+
+
+ label_fields: 'label'
+ batch_size: 4096
+ num_epochs: 1000
+ prefetch_size: 32
+ input_type: CriteoInput
+}
+
+feature_config: {
+ features: {
+ input_names: "f1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "f2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "f3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "f4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "f5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "f6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "f7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "f8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "f9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "f10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "f11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "f12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "f13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "c1"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c2"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c3"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c4"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c5"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c6"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c7"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c8"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c9"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c10"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3853295
+ }
+ features: {
+ input_names: "c11"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c12"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c13"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c19"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c20"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3997977
+ }
+ features: {
+ input_names: "c21"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 2564129
+ }
+ features: {
+ input_names: "c22"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3966498
+ }
+ features: {
+ input_names: "c23"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c24"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c25"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c26"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000000
+ embedding_name: "embedding"
+ }
+}
+model_config {
+ model_class: 'DLRM'
+
+ feature_groups {
+ group_name: 'dense'
+ feature_names: "f1"
+ feature_names: "f2"
+ feature_names: "f3"
+ feature_names: "f4"
+ feature_names: "f5"
+ feature_names: "f6"
+ feature_names: "f7"
+ feature_names: "f8"
+ feature_names: "f9"
+ feature_names: "f10"
+ feature_names: "f11"
+ feature_names: "f12"
+ feature_names: "f13"
+
+ wide_deep: DEEP
+ }
+
+ feature_groups {
+ group_name: 'sparse'
+ feature_names: "c1"
+ feature_names: "c2"
+ feature_names: "c3"
+ feature_names: "c4"
+ feature_names: "c5"
+ feature_names: "c6"
+ feature_names: "c7"
+ feature_names: "c8"
+ feature_names: "c9"
+ feature_names: "c10"
+ feature_names: "c11"
+ feature_names: "c12"
+ feature_names: "c13"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "c22"
+ feature_names: "c23"
+ feature_names: "c24"
+ feature_names: "c25"
+ feature_names: "c26"
+
+ wide_deep: DEEP
+ }
+
+ dlrm {
+ bot_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ top_dnn {
+ hidden_units: [256, 128, 128, 64]
+ }
+ }
+
+ embedding_regularization: 1e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dlrm_on_criteo_parquet.config b/samples/model_config/dlrm_on_criteo_parquet.config
new file mode 100644
index 000000000..3641cffbe
--- /dev/null
+++ b/samples/model_config/dlrm_on_criteo_parquet.config
@@ -0,0 +1,564 @@
+parquet_train_input: "data/test/criteo_parquet/*.parquet,data/test/criteo_parquet/*.parquet"
+parquet_eval_input: "data/test/criteo_parquet/*.parquet"
+
+model_dir: "experiments/dlrm_criteo_parquet/"
+
+train_config {
+ optimizer_config: {
+ lazy_adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 1e-4
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 10000
+ log_step_count_steps: 10
+ sync_replicas: false
+ # is_profiling: true
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'is_click'
+ input_type: INT32
+ }
+
+ input_fields: {
+ input_name: "f1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f2"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f3"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f4"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f5"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f6"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f7"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f8"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f9"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f10"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f11"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f12"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f13"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c2"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c3"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c4"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c5"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c6"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c7"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c8"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c9"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c10"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c11"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c12"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c13"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c22"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c23"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c24"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c25"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c26"
+ input_type: INT64
+ default_val:""
+ }
+
+ label_fields: 'is_click'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: ParquetInputV2
+}
+
+feature_config: {
+ features: {
+ input_names: "f1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "f2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "f3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "f4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "f5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "f6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "f7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "f8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "f9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "f10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "f11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "f12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "f13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "c1"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c2"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c3"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c4"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c5"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c6"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c7"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c8"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c9"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c10"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3853295
+ }
+ features: {
+ input_names: "c11"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c12"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c13"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c19"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c20"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3997977
+ }
+ features: {
+ input_names: "c21"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 2564129
+ }
+ features: {
+ input_names: "c22"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3966498
+ }
+ features: {
+ input_names: "c23"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c24"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c25"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c26"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 100000
+ embedding_name: "embedding"
+ }
+}
+model_config {
+ model_class: 'DLRM'
+
+ feature_groups {
+ group_name: 'dense'
+ feature_names: "f1"
+ feature_names: "f2"
+ feature_names: "f3"
+ feature_names: "f4"
+ feature_names: "f5"
+ feature_names: "f6"
+ feature_names: "f7"
+ feature_names: "f8"
+ feature_names: "f9"
+ feature_names: "f10"
+ feature_names: "f11"
+ feature_names: "f12"
+ feature_names: "f13"
+
+ wide_deep: DEEP
+ }
+
+ feature_groups {
+ group_name: 'sparse'
+ feature_names: "c1"
+ feature_names: "c2"
+ feature_names: "c3"
+ feature_names: "c4"
+ feature_names: "c5"
+ feature_names: "c6"
+ feature_names: "c7"
+ feature_names: "c8"
+ feature_names: "c9"
+ feature_names: "c10"
+ feature_names: "c11"
+ feature_names: "c12"
+ feature_names: "c13"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "c22"
+ feature_names: "c23"
+ feature_names: "c24"
+ feature_names: "c25"
+ feature_names: "c26"
+
+ wide_deep: DEEP
+ }
+
+ dlrm {
+ bot_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ top_dnn {
+ hidden_units: [256, 128, 128, 64]
+ }
+ }
+
+ embedding_regularization: 0 #1e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dlrm_on_criteo_parquet_ep.config b/samples/model_config/dlrm_on_criteo_parquet_ep.config
new file mode 100644
index 000000000..7ec394ed4
--- /dev/null
+++ b/samples/model_config/dlrm_on_criteo_parquet_ep.config
@@ -0,0 +1,569 @@
+parquet_train_input: "data/test/criteo_parquet/*.parquet"
+parquet_eval_input: "data/test/criteo_parquet/*.parquet"
+
+model_dir: "experiments/dlrm_criteo_parquet_ep/"
+
+train_config {
+ optimizer_config: {
+ lazy_adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 1e-4
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 10000
+ log_step_count_steps: 10
+ sync_replicas: False
+ train_distribute: EmbeddingParallelStrategy
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'is_click'
+ input_type: INT32
+ }
+
+ input_fields: {
+ input_name: "f1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f2"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f3"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f4"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f5"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f6"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f7"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f8"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f9"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f10"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f11"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f12"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f13"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c2"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c3"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c4"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c5"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c6"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c7"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c8"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c9"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c10"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c11"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c12"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c13"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c22"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c23"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c24"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c25"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c26"
+ input_type: INT64
+ default_val:""
+ }
+
+
+ label_fields: 'is_click'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: ParquetInput
+}
+
+feature_config: {
+ features: {
+ input_names: "f1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "f2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "f3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "f4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "f5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "f6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "f7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "f8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "f9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "f10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "f11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "f12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "f13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "c1"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c2"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c3"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c4"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c5"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c6"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c7"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c8"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c9"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c10"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c11"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c12"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c13"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c19"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c20"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c21"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c22"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c23"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c24"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c25"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c26"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+}
+model_config {
+ model_class: 'DLRM'
+
+ feature_groups {
+ group_name: 'dense'
+ feature_names: "f1"
+ feature_names: "f2"
+ feature_names: "f3"
+ feature_names: "f4"
+ feature_names: "f5"
+ feature_names: "f6"
+ feature_names: "f7"
+ feature_names: "f8"
+ feature_names: "f9"
+ feature_names: "f10"
+ feature_names: "f11"
+ feature_names: "f12"
+ feature_names: "f13"
+
+ wide_deep: DEEP
+ }
+
+ feature_groups {
+ group_name: 'sparse'
+ feature_names: "c1"
+ feature_names: "c2"
+ feature_names: "c3"
+ feature_names: "c4"
+ feature_names: "c5"
+ feature_names: "c6"
+ feature_names: "c7"
+ feature_names: "c8"
+ feature_names: "c9"
+ feature_names: "c10"
+ feature_names: "c11"
+ feature_names: "c12"
+ feature_names: "c13"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "c22"
+ feature_names: "c23"
+ feature_names: "c24"
+ feature_names: "c25"
+ feature_names: "c26"
+
+ wide_deep: DEEP
+ }
+
+ dlrm {
+ bot_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ top_dnn {
+ hidden_units: [256, 128, 128, 64]
+ }
+ }
+
+ embedding_regularization: 1e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dlrm_on_criteo_parquet_ep_v2.config b/samples/model_config/dlrm_on_criteo_parquet_ep_v2.config
new file mode 100644
index 000000000..1ba861528
--- /dev/null
+++ b/samples/model_config/dlrm_on_criteo_parquet_ep_v2.config
@@ -0,0 +1,569 @@
+parquet_train_input: "data/test/criteo_parquet/*.parquet"
+parquet_eval_input: "data/test/criteo_parquet/*.parquet"
+
+model_dir: "experiments/dlrm_criteo_parquet_ep_v2/"
+
+train_config {
+ optimizer_config: {
+ lazy_adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 1e-4
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 10000
+ log_step_count_steps: 10
+ sync_replicas: False
+ train_distribute: EmbeddingParallelStrategy
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'is_click'
+ input_type: INT32
+ }
+
+ input_fields: {
+ input_name: "f1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f2"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f3"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f4"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f5"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f6"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f7"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f8"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f9"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f10"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f11"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f12"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f13"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c2"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c3"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c4"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c5"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c6"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c7"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c8"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c9"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c10"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c11"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c12"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c13"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c22"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c23"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c24"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c25"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c26"
+ input_type: INT64
+ default_val:""
+ }
+
+
+ label_fields: 'is_click'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: ParquetInputV2
+}
+
+feature_config: {
+ features: {
+ input_names: "f1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "f2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "f3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "f4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "f5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "f6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "f7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "f8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "f9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "f10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "f11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "f12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "f13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "c1"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c2"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c3"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c4"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c5"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c6"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c7"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c8"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c9"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c10"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c11"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c12"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c13"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c19"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c20"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c21"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c22"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c23"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c24"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c25"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c26"
+ feature_type: IdFeature
+ embedding_dim: 32
+ num_buckets: 10000000
+ embedding_name: "embedding"
+ }
+}
+model_config {
+ model_class: 'DLRM'
+
+ feature_groups {
+ group_name: 'dense'
+ feature_names: "f1"
+ feature_names: "f2"
+ feature_names: "f3"
+ feature_names: "f4"
+ feature_names: "f5"
+ feature_names: "f6"
+ feature_names: "f7"
+ feature_names: "f8"
+ feature_names: "f9"
+ feature_names: "f10"
+ feature_names: "f11"
+ feature_names: "f12"
+ feature_names: "f13"
+
+ wide_deep: DEEP
+ }
+
+ feature_groups {
+ group_name: 'sparse'
+ feature_names: "c1"
+ feature_names: "c2"
+ feature_names: "c3"
+ feature_names: "c4"
+ feature_names: "c5"
+ feature_names: "c6"
+ feature_names: "c7"
+ feature_names: "c8"
+ feature_names: "c9"
+ feature_names: "c10"
+ feature_names: "c11"
+ feature_names: "c12"
+ feature_names: "c13"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "c22"
+ feature_names: "c23"
+ feature_names: "c24"
+ feature_names: "c25"
+ feature_names: "c26"
+
+ wide_deep: DEEP
+ }
+
+ dlrm {
+ bot_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ top_dnn {
+ hidden_units: [256, 128, 128, 64]
+ }
+ }
+
+ embedding_regularization: 1e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dlrm_on_taobao.config b/samples/model_config/dlrm_on_taobao.config
index f94c78005..ac8880d14 100644
--- a/samples/model_config/dlrm_on_taobao.config
+++ b/samples/model_config/dlrm_on_taobao.config
@@ -18,7 +18,7 @@ train_config {
}
save_checkpoints_steps: 100
log_step_count_steps: 10
- sync_replicas: True
+ sync_replicas: true
num_steps: 2500
}
diff --git a/samples/model_config/deepfm_on_sequence_feature_taobao.config b/samples/model_config/dropoutnet_distribute_eval_on_taobao.config
similarity index 80%
rename from samples/model_config/deepfm_on_sequence_feature_taobao.config
rename to samples/model_config/dropoutnet_distribute_eval_on_taobao.config
index 059e33d7b..83682f5e3 100644
--- a/samples/model_config/deepfm_on_sequence_feature_taobao.config
+++ b/samples/model_config/dropoutnet_distribute_eval_on_taobao.config
@@ -1,6 +1,6 @@
train_input_path: "data/test/tb_data/taobao_train_data"
eval_input_path: "data/test/tb_data/taobao_test_data"
-model_dir: "experiments/deepfm_on_taobao_ckpt"
+model_dir: "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt"
train_config {
log_step_count_steps: 100
@@ -18,8 +18,8 @@ train_config {
use_moving_average: false
}
save_checkpoints_steps: 100
- sync_replicas: True
- num_steps: 2500
+ sync_replicas: false
+ num_steps: 1000
}
eval_config {
@@ -209,14 +209,14 @@ feature_configs : {
}
feature_configs : {
input_names: 'tag_category_list'
- feature_type: SequenceFeature
+ feature_type: TagFeature
separator: '|'
- hash_bucket_size: 10000
+ hash_bucket_size: 100000
embedding_dim: 16
}
feature_configs : {
input_names: 'tag_brand_list'
- feature_type: SequenceFeature
+ feature_type: TagFeature
separator: '|'
hash_bucket_size: 100000
embedding_dim: 16
@@ -227,64 +227,76 @@ feature_configs : {
embedding_dim: 16
num_buckets: 50
}
-
model_config: {
- model_class: 'DeepFM'
+ model_class: "DropoutNet"
feature_groups: {
- group_name: 'wide'
+ group_name: 'user_content'
feature_names: 'user_id'
feature_names: 'cms_segid'
feature_names: 'cms_group_id'
feature_names: 'age_level'
- feature_names: 'pvalue_level'
- feature_names: 'shopping_level'
- feature_names: 'occupation'
- feature_names: 'new_user_class_level'
- feature_names: 'adgroup_id'
- feature_names: 'cate_id'
- feature_names: 'campaign_id'
- feature_names: 'customer'
- feature_names: 'brand'
- feature_names: 'price'
- feature_names: 'pid'
- wide_deep: WIDE
+ wide_deep:DEEP
}
feature_groups: {
- group_name: 'deep'
- feature_names: 'user_id'
- feature_names: 'cms_segid'
- feature_names: 'cms_group_id'
- feature_names: 'age_level'
+ group_name: 'user_preference'
feature_names: 'pvalue_level'
feature_names: 'shopping_level'
feature_names: 'occupation'
feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item_content"
feature_names: 'adgroup_id'
feature_names: 'cate_id'
feature_names: 'campaign_id'
feature_names: 'customer'
feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item_preference"
feature_names: 'price'
feature_names: 'pid'
- wide_deep: DEEP
- sequence_features: {
- group_name: "seq_fea"
- tf_summary: false
- seq_att_map: {
- key: "brand"
- key: "cate_id"
- hist_seq: "tag_brand_list"
- hist_seq: "tag_category_list"
- }
+ wide_deep:DEEP
+ }
+ losses {
+ loss_type: CLASSIFICATION
+ weight: 1.0
}
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
}
- deepfm {
- dnn {
- hidden_units: [256, 256, 256]
+ dropoutnet {
+ user_content {
+ hidden_units: [256]
+ }
+ item_content {
+ hidden_units: [256]
+ }
+ user_preference {
+ hidden_units: [512]
+ }
+ item_preference {
+ hidden_units: [512]
+ }
+ user_tower {
+ hidden_units: [128, 64]
+ use_bn: false
+ }
+ item_tower {
+ hidden_units: [128, 64]
+ use_bn: false
}
- l2_regularization: 1e-4
+ user_dropout_rate: 0.5
+ item_dropout_rate: 0.5
+ l2_regularization: 1e-06
}
- embedding_regularization: 1e-5
+ # sample_weight_field: "weight"
+ embedding_regularization: 5e-5
}
export_config {
diff --git a/samples/model_config/dssm_distribute_eval_listwise_hard_neg_sampler_on_taobao.config b/samples/model_config/dssm_distribute_eval_listwise_hard_neg_sampler_on_taobao.config
new file mode 100644
index 000000000..66fedda32
--- /dev/null
+++ b/samples/model_config/dssm_distribute_eval_listwise_hard_neg_sampler_on_taobao.config
@@ -0,0 +1,309 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "data/test/distribute_eval_test/dssm_distribute_eval_listwise_hard_neg_sampler_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 10
+ save_checkpoints_steps: 100
+ sync_replicas: false
+ num_steps: 200
+}
+
+eval_config {
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+
+ hard_negative_sampler {
+ user_input_path: 'data/test/tb_data/taobao_user_profile_gl'
+ item_input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ hard_neg_edge_input_path: 'data/test/tb_data/taobao_noclk_edge_gl'
+ num_sample: 1024
+ num_hard_sample: 20
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ user_id_field: 'user_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ simi_func: INNER_PRODUCT
+ scale_simi: false
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_distribute_eval_pointwise_classification_on_taobao.config b/samples/model_config/dssm_distribute_eval_pointwise_classification_on_taobao.config
new file mode 100644
index 000000000..69ff6fcd4
--- /dev/null
+++ b/samples/model_config/dssm_distribute_eval_pointwise_classification_on_taobao.config
@@ -0,0 +1,279 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "data/test/distribute_eval_test/dssm_distribute_eval_pointwise_classification_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: false
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_distribute_eval_reg_on_taobao.config b/samples/model_config/dssm_distribute_eval_reg_on_taobao.config
new file mode 100644
index 000000000..093d7c4a5
--- /dev/null
+++ b/samples/model_config/dssm_distribute_eval_reg_on_taobao.config
@@ -0,0 +1,282 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "data/test/distribute_eval_test/dssm_distribute_eval_reg_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: false
+ num_steps: 1000
+}
+
+eval_config {
+ metrics_set: {
+ mean_absolute_error {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ simi_func: INNER_PRODUCT
+ }
+ embedding_regularization: 5e-5
+ loss_type: L2_LOSS
+
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_hard_neg_sampler_on_taobao.config b/samples/model_config/dssm_hard_neg_sampler_on_taobao.config
index d7b3f056f..4bf4cab03 100644
--- a/samples/model_config/dssm_hard_neg_sampler_on_taobao.config
+++ b/samples/model_config/dssm_hard_neg_sampler_on_taobao.config
@@ -17,6 +17,7 @@ train_config {
}
use_moving_average: false
}
+ log_step_count_steps: 10
save_checkpoints_steps: 100
sync_replicas: false
num_steps: 2500
@@ -133,7 +134,7 @@ data_config {
item_input_path: 'data/test/tb_data/taobao_ad_feature_gl'
hard_neg_edge_input_path: 'data/test/tb_data/taobao_noclk_edge_gl'
num_sample: 1024
- num_hard_sample: 200
+ num_hard_sample: 20
num_eval_sample: 2048
attr_fields: 'adgroup_id'
attr_fields: 'cate_id'
diff --git a/samples/model_config/dssm_hard_neg_sampler_regular_on_taobao.config b/samples/model_config/dssm_hard_neg_sampler_regular_on_taobao.config
new file mode 100644
index 000000000..5eb030824
--- /dev/null
+++ b/samples/model_config/dssm_hard_neg_sampler_regular_on_taobao.config
@@ -0,0 +1,309 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_hard_neg_sampler_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 10
+ save_checkpoints_steps: 100
+ sync_replicas: false
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+
+ hard_negative_sampler {
+ user_input_path: 'data/test/tb_data/hard_negative_sampler_user/taobao_user_profile_gl.csv'
+ item_input_path: 'data/test/tb_data/hard_negative_sampler_item/taobao_ad_feature_gl.csv'
+ hard_neg_edge_input_path: 'data/test/tb_data/hard_negative_sampler_edge/taobao_noclk_edge_gl.csv'
+ num_sample: 1024
+ num_hard_sample: 20
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ user_id_field: 'user_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ simi_func: INNER_PRODUCT
+ scale_simi: false
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config b/samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config
index 1db820391..63dee4e30 100644
--- a/samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config
+++ b/samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config
@@ -134,7 +134,7 @@ data_config {
pos_edge_input_path: 'data/test/tb_data/taobao_clk_edge_gl'
hard_neg_edge_input_path: 'data/test/tb_data/taobao_noclk_edge_gl'
num_sample: 1024
- num_hard_sample: 200
+ num_hard_sample: 10
num_eval_sample: 2048
attr_fields: 'adgroup_id'
attr_fields: 'cate_id'
diff --git a/samples/model_config/dssm_neg_sampler_need_key_feature.config b/samples/model_config/dssm_neg_sampler_need_key_feature.config
new file mode 100644
index 000000000..ea7cc8342
--- /dev/null
+++ b/samples/model_config/dssm_neg_sampler_need_key_feature.config
@@ -0,0 +1,302 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_neg_sampler_sequence_feature"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 1000
+ sync_replicas: false
+ save_checkpoints_steps: 100
+ log_step_count_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ recall_at_topk { topk: 3 }
+ }
+ # currently not supported
+ # metrics_set: {
+ # gauc {
+ # uid_field: "user_id"
+ # }
+ # }
+}
+
+data_config {
+ batch_size: 1024
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ num_epochs: 5
+ prefetch_size: 4
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 256
+ num_eval_sample: 4096
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_configs : {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs : {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs : {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep:DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ allow_key_search: true
+ need_key_feature:false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dssm_neg_sampler_on_taobao.config b/samples/model_config/dssm_neg_sampler_on_taobao.config
index 6d4b9f9c4..ba5f59887 100644
--- a/samples/model_config/dssm_neg_sampler_on_taobao.config
+++ b/samples/model_config/dssm_neg_sampler_on_taobao.config
@@ -293,7 +293,7 @@ model_config:{
}
}
simi_func: INNER_PRODUCT
- scale_simi: false
+ scale_simi: true
l2_regularization: 1e-6
}
loss_type: SOFTMAX_CROSS_ENTROPY
diff --git a/samples/model_config/dssm_neg_sampler_sequence_feature.config b/samples/model_config/dssm_neg_sampler_sequence_feature.config
new file mode 100644
index 000000000..5ce0af37f
--- /dev/null
+++ b/samples/model_config/dssm_neg_sampler_sequence_feature.config
@@ -0,0 +1,302 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_neg_sampler_sequence_feature"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 1000
+ sync_replicas: false
+ save_checkpoints_steps: 100
+ log_step_count_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ recall_at_topk { topk: 3 }
+ }
+ # currently not supported
+ # metrics_set: {
+ # gauc {
+ # uid_field: "user_id"
+ # }
+ # }
+}
+
+data_config {
+ batch_size: 1024
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ num_epochs: 5
+ prefetch_size: 4
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 256
+ num_eval_sample: 4096
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_configs : {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs : {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs : {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep:DEEP
+ sequence_features: {
+ group_name: "seq_fea"
+ allow_key_search: true
+ need_key_feature:true
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/dssm_neg_sampler_with_sample_weight.config b/samples/model_config/dssm_neg_sampler_with_sample_weight.config
new file mode 100644
index 000000000..725c40294
--- /dev/null
+++ b/samples/model_config/dssm_neg_sampler_with_sample_weight.config
@@ -0,0 +1,298 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_neg_sampler_with_sample_weight_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: false
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: FLOAT
+ }
+ sample_weight: 'price'
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 1024
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+}
+model_config:{
+ model_class: "DSSM"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ dssm {
+ user_tower {
+ id: "user_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ simi_func: INNER_PRODUCT
+ scale_simi: true
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_on_sequence_feature_taobao.config b/samples/model_config/dssm_on_sequence_feature_taobao.config
index b2c223745..554226e87 100644
--- a/samples/model_config/dssm_on_sequence_feature_taobao.config
+++ b/samples/model_config/dssm_on_sequence_feature_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: false
- num_steps: 2500
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/dssm_on_taobao.config b/samples/model_config/dssm_on_taobao.config
index 8545419ef..1789a541f 100644
--- a/samples/model_config/dssm_on_taobao.config
+++ b/samples/model_config/dssm_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: false
- num_steps: 2500
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/dssm_on_taobao_backbone.config b/samples/model_config/dssm_on_taobao_backbone.config
new file mode 100644
index 000000000..4a7f69679
--- /dev/null
+++ b/samples/model_config/dssm_on_taobao_backbone.config
@@ -0,0 +1,356 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_taobao_backbone_ckpt"
+
+train_config {
+ log_step_count_steps: 200
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 4000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 4000
+ sync_replicas: false
+ num_steps: 100000
+}
+
+eval_config {
+ metrics_set: {
+ recall_at_topk {
+ topk: 50
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 1024
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_name: 'DSSM'
+ model_class: 'MatchModel'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'user'
+ inputs {
+ feature_group_name: 'user'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: false
+ }
+ }
+ blocks {
+ name: 'item'
+ inputs {
+ feature_group_name: 'item'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: false
+ }
+ }
+
+ blocks {
+ name: 'user_tower'
+ inputs {
+ block_name: 'user'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 32]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+
+ blocks {
+ name: 'item_tower'
+ inputs {
+ block_name: 'item'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 32]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+ output_blocks: ['user_tower', 'item_tower']
+
+ }
+ model_params {
+ l2_regularization: 1e-4
+ user_tower_idx_in_output: 0
+ item_tower_idx_in_output: 1
+ scale_simi: false
+ simi_func: COSINE
+ temperature: 0.01
+ }
+ embedding_regularization: 5e-5
+ loss_type: SOFTMAX_CROSS_ENTROPY
+
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_reg_on_taobao.config b/samples/model_config/dssm_reg_on_taobao.config
index e3e55dcba..7ceef545a 100644
--- a/samples/model_config/dssm_reg_on_taobao.config
+++ b/samples/model_config/dssm_reg_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: false
- num_steps: 2500
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/dssm_senet_on_taobao.config b/samples/model_config/dssm_senet_on_taobao.config
new file mode 100644
index 000000000..3c059f6e2
--- /dev/null
+++ b/samples/model_config/dssm_senet_on_taobao.config
@@ -0,0 +1,321 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_senet_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 200
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ # initial_learning_rate: 0.001
+ initial_learning_rate: 0.0001
+ decay_steps: 4000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 4000
+ sync_replicas: false
+ num_steps: 100
+}
+
+eval_config {
+
+ metrics_set: {
+ recall_at_topk {
+ topk: 50
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 1024
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_class: "DSSM_SENet"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ dssm_senet {
+ user_tower {
+ id: "user_id"
+ senet {
+ num_squeeze_group : 2
+ reduction_ratio: 4
+ }
+ dnn {
+ hidden_units: [ 128, 32]
+ }
+ }
+ item_tower {
+ id: "adgroup_id"
+ senet {
+ num_squeeze_group : 2
+ reduction_ratio: 4
+ }
+ dnn {
+ hidden_units: [128, 32]
+ }
+ }
+ simi_func: COSINE
+ scale_simi: false
+ temperature: 0.01
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_senet_on_taobao_backbone.config b/samples/model_config/dssm_senet_on_taobao_backbone.config
new file mode 100644
index 000000000..35fb47844
--- /dev/null
+++ b/samples/model_config/dssm_senet_on_taobao_backbone.config
@@ -0,0 +1,387 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/dssm_senet_backbone_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 200
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ # initial_learning_rate: 0.001
+ initial_learning_rate: 0.0001
+ decay_steps: 4000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 4000
+ sync_replicas: false
+ num_steps: 100000
+}
+
+eval_config {
+
+ metrics_set: {
+ recall_at_topk {
+ topk: 50
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 1024
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_name: "DSSM_SENet"
+ model_class: 'MatchModel'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ backbone {
+ blocks {
+ name: 'user'
+ inputs {
+ feature_group_name: 'user'
+ }
+ input_layer {
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: 'item'
+ inputs {
+ feature_group_name: 'item'
+ }
+ input_layer {
+ only_output_feature_list: true
+ }
+ }
+
+ blocks {
+ name: 'user_senet'
+ inputs {
+ block_name: 'user'
+ }
+ keras_layer {
+ class_name: 'SENet'
+ senet {
+ num_squeeze_group: 2
+ reduction_ratio: 4
+ use_skip_connection: false
+ use_output_layer_norm: false
+ }
+ }
+ }
+
+ blocks {
+ name: 'item_senet'
+ inputs {
+ block_name: 'item'
+ }
+ keras_layer {
+ class_name: 'SENet'
+ senet {
+ num_squeeze_group: 2
+ reduction_ratio: 4
+ use_skip_connection: false
+ use_output_layer_norm: false
+ }
+ }
+ }
+
+ blocks {
+ name: 'user_dnn'
+ inputs {
+ block_name: 'user_senet'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 32]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+ blocks {
+ name: 'item_dnn'
+ inputs {
+ block_name: 'item_senet'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 32]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+ output_blocks: ['user_dnn', 'item_dnn']
+
+ }
+ model_params {
+ l2_regularization: 1e-4
+ user_tower_idx_in_output: 0
+ item_tower_idx_in_output: 1
+ simi_func: COSINE
+ scale_simi: false
+ temperature: 0.01
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/dssm_with_sample_weight.config b/samples/model_config/dssm_with_sample_weight.config
index f2a412a4d..5dbe96fac 100755
--- a/samples/model_config/dssm_with_sample_weight.config
+++ b/samples/model_config/dssm_with_sample_weight.config
@@ -3,7 +3,7 @@ eval_input_path: "data/test/test_sample_weight.txt"
model_dir: 'experiments/dssm_with_sample_weight/'
train_config {
- log_step_count_steps: 200
+ log_step_count_steps: 100
optimizer_config: {
adam_optimizer: {
learning_rate: {
@@ -18,7 +18,7 @@ train_config {
use_moving_average: false
}
- save_checkpoints_steps: 500
+ save_checkpoints_steps: 100
}
eval_config {
diff --git a/samples/model_config/esmm_distribute_eval_on_taobao.config b/samples/model_config/esmm_distribute_eval_on_taobao.config
new file mode 100644
index 000000000..ec9677c65
--- /dev/null
+++ b/samples/model_config/esmm_distribute_eval_on_taobao.config
@@ -0,0 +1,313 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "data/test/distribute_eval_test/esmm_distribute_eval_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+ label_fields: 'buy'
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'ESMM'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ esmm {
+ groups {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ groups {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ groups {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ cvr_tower {
+ tower_name: "cvr"
+ label_name: "buy"
+ dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {
+ num_thresholds: 10000
+ }
+ }
+ }
+ ctr_tower {
+ tower_name: "ctr"
+ label_name: "clk"
+ dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
diff --git a/samples/model_config/esmm_on_sequence_feature_taobao.config b/samples/model_config/esmm_on_sequence_feature_taobao.config
index 9f70c3613..5eb5b0a30 100644
--- a/samples/model_config/esmm_on_sequence_feature_taobao.config
+++ b/samples/model_config/esmm_on_sequence_feature_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: True
- num_steps: 5000
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/esmm_on_taobao.config b/samples/model_config/esmm_on_taobao.config
index d1c96c810..c69786db2 100644
--- a/samples/model_config/esmm_on_taobao.config
+++ b/samples/model_config/esmm_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: True
- num_steps: 5000
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/esmm_variational_dropout_on_taobao.config b/samples/model_config/esmm_variational_dropout_on_taobao.config
new file mode 100644
index 000000000..4a2979a1c
--- /dev/null
+++ b/samples/model_config/esmm_variational_dropout_on_taobao.config
@@ -0,0 +1,317 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/esmm_varitional_dropout_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 200
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+ label_fields: 'buy'
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'ESMM'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ esmm {
+ groups {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ groups {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ groups {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ cvr_tower {
+ tower_name: "cvr"
+ label_name: "buy"
+ dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {
+ num_thresholds: 10000
+ }
+ }
+ }
+ ctr_tower {
+ tower_name: "ctr"
+ label_name: "clk"
+ dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ l2_regularization: 1e-6
+ }
+ variational_dropout{
+ regularization_lambda:0.01
+ embedding_wise_variational_dropout:true
+ }
+ embedding_regularization: 5e-5
+}
diff --git a/samples/model_config/export_filter_input.config b/samples/model_config/export_filter_input.config
new file mode 100644
index 000000000..42a37ff8e
--- /dev/null
+++ b/samples/model_config/export_filter_input.config
@@ -0,0 +1,294 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/multi_tower_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 200
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+# features: {
+# input_names: 'tag_category_list'
+# feature_type: TagFeature
+# separator: '|'
+# hash_bucket_size: 100000
+# embedding_dim: 16
+# }
+# features: {
+# input_names: 'tag_brand_list'
+# feature_type: TagFeature
+# separator: '|'
+# hash_bucket_size: 100000
+# embedding_dim: 16
+# }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTower'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+# feature_names: 'tag_category_list'
+# feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
+
+export_config {
+ multi_placeholder: true
+ filter_inputs: false
+}
diff --git a/samples/model_config/fg_ev_v2.config b/samples/model_config/fg_ev_v2.config
index cde4f279d..19c7b61c8 100644
--- a/samples/model_config/fg_ev_v2.config
+++ b/samples/model_config/fg_ev_v2.config
@@ -3392,7 +3392,10 @@ model_config {
}
l2_regularization: 9.9999997e-05
}
- use_embedding_variable: true
+ ev_params {
+ filter_freq: 5
+ steps_to_live: 1000
+ }
}
export_config{
multi_placeholder: false
diff --git a/samples/model_config/fg_train.config b/samples/model_config/fg_train.config
new file mode 100644
index 000000000..698c41ac8
--- /dev/null
+++ b/samples/model_config/fg_train.config
@@ -0,0 +1,112 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/rtp_fg_demo_v1"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 400
+ sync_replicas: true
+ log_step_count_steps: 200
+}
+
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+
+fg_json_path: "samples/rtp_fg/fg.json"
+
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+
+ rtp_separator: ";"
+}
+
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/fg_train_ev.config b/samples/model_config/fg_train_ev.config
new file mode 100644
index 000000000..9f9c5b4de
--- /dev/null
+++ b/samples/model_config/fg_train_ev.config
@@ -0,0 +1,112 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/rtp_fg_demo_ev"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 400
+ sync_replicas: true
+ log_step_count_steps: 200
+}
+
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+
+fg_json_path: "samples/rtp_fg/fg_ev.json"
+
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+
+ rtp_separator: ";"
+}
+
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/fibinet_on_taobao.config b/samples/model_config/fibinet_on_taobao.config
new file mode 100644
index 000000000..05736d118
--- /dev/null
+++ b/samples/model_config/fibinet_on_taobao.config
@@ -0,0 +1,293 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/fibinet_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "all"
+ inputs {
+ feature_group_name: "all"
+ }
+ input_layer {
+ do_batch_norm: true
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: "fibinet"
+ inputs {
+ block_name: "all"
+ }
+ keras_layer {
+ class_name: 'FiBiNet'
+ fibinet {
+ senet {
+ reduction_ratio: 4
+ }
+ bilinear {
+ type: 'each'
+ num_output_units: 512
+ }
+ mlp {
+ hidden_units: [512, 256]
+ }
+ }
+ }
+ }
+ concat_blocks: ['fibinet']
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/samples/model_config/fm_variational_dropout_on_taobao.config b/samples/model_config/fm_variational_dropout_on_taobao.config
new file mode 100644
index 000000000..c41d0988c
--- /dev/null
+++ b/samples/model_config/fm_variational_dropout_on_taobao.config
@@ -0,0 +1,287 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/bst_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+
+}
+model_config: {
+ model_class: 'FM'
+ feature_groups: {
+ group_name: 'wide'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: WIDE
+ }
+ feature_groups: {
+ group_name: 'deep'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ fm {
+ }
+ variational_dropout{
+ regularization_lambda:0.01
+ embedding_wise_variational_dropout:true
+ }
+ embedding_regularization: 1e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/highway_on_movielens.config b/samples/model_config/highway_on_movielens.config
new file mode 100644
index 000000000..191d386aa
--- /dev/null
+++ b/samples/model_config/highway_on_movielens.config
@@ -0,0 +1,248 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/highway_movielens_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: TagFeature
+ separator: '|'
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ sequence_combiner: {
+ text_cnn: {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ }
+ }
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_name: 'HighWayNetwork'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ feature_names: 'title'
+ feature_names: 'genres'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'highway'
+ inputs {
+ feature_group_name: 'image'
+ }
+ keras_layer {
+ class_name: 'Highway'
+ }
+ }
+ blocks {
+ name: 'top_mlp'
+ inputs {
+ feature_group_name: 'general'
+ }
+ inputs {
+ block_name: 'highway'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/masknet_on_taobao.config b/samples/model_config/masknet_on_taobao.config
new file mode 100644
index 000000000..26b8c8262
--- /dev/null
+++ b/samples/model_config/masknet_on_taobao.config
@@ -0,0 +1,288 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/masknet_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "mask_net"
+ inputs {
+ feature_group_name: "all"
+ }
+ keras_layer {
+ class_name: 'MaskNet'
+ masknet {
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mask_blocks {
+ aggregation_size: 512
+ output_size: 256
+ }
+ mlp {
+ hidden_units: [512, 256]
+ }
+ }
+ }
+ }
+ concat_blocks: ['mask_net']
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/samples/model_config/metric_learning_on_taobao.config b/samples/model_config/metric_learning_on_taobao.config
index 6801017b7..d713f1384 100644
--- a/samples/model_config/metric_learning_on_taobao.config
+++ b/samples/model_config/metric_learning_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: false
- num_steps: 2500
+ num_steps: 100
}
eval_config {
diff --git a/samples/model_config/mind_on_taobao.config b/samples/model_config/mind_on_taobao.config
index ccc3fe398..6bdb0c6f6 100644
--- a/samples/model_config/mind_on_taobao.config
+++ b/samples/model_config/mind_on_taobao.config
@@ -264,6 +264,10 @@ model_config:{
hidden_units: [256, 128, 64, 32]
}
+ concat_dnn {
+ hidden_units: [64, 32]
+ }
+
capsule_config {
max_k: 5
max_seq_len: 64
diff --git a/samples/model_config/mind_on_taobao_hard_neg_sam.config b/samples/model_config/mind_on_taobao_hard_neg_sam.config
new file mode 100644
index 000000000..813ad8925
--- /dev/null
+++ b/samples/model_config/mind_on_taobao_hard_neg_sam.config
@@ -0,0 +1,307 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/mind_taobao_hard_neg_sam"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set {
+ recall_at_topk { topk: 1 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 5 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 10 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 20 }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 2
+ prefetch_size: 32
+ input_type: CSVInput
+
+ hard_negative_sampler {
+ user_input_path: 'data/test/tb_data/taobao_user_profile_gl'
+ item_input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ hard_neg_edge_input_path: 'data/test/tb_data/taobao_noclk_edge_gl'
+ num_sample: 1024
+ num_hard_sample: 5
+ num_eval_sample: 2048
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ user_id_field: 'user_id'
+ }
+
+}
+
+feature_configs : {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs : {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+}
+feature_configs : {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+}
+feature_configs : {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config:{
+ model_class: "MIND"
+ feature_groups: {
+ group_name: 'hist'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ mind {
+ user_dnn {
+ hidden_units: [256, 128, 64]
+ }
+ item_dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ concat_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ capsule_config {
+ max_k: 5
+ max_seq_len: 64
+ high_dim: 64
+ }
+ l2_regularization: 1e-6
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/fm_on_sequence_feature_taobao.config b/samples/model_config/mind_on_taobao_neg_sam.config
similarity index 78%
rename from samples/model_config/fm_on_sequence_feature_taobao.config
rename to samples/model_config/mind_on_taobao_neg_sam.config
index eb6096acb..c4df17e84 100644
--- a/samples/model_config/fm_on_sequence_feature_taobao.config
+++ b/samples/model_config/mind_on_taobao_neg_sam.config
@@ -1,9 +1,9 @@
train_input_path: "data/test/tb_data/taobao_train_data"
eval_input_path: "data/test/tb_data/taobao_test_data"
-model_dir: "experiments/fm_taobao_ckpt"
+model_dir: "experiments/mind_taobao_neg_sam"
train_config {
- log_step_count_steps: 100
+ log_step_count_steps: 5
optimizer_config: {
adam_optimizer: {
learning_rate: {
@@ -17,15 +17,22 @@ train_config {
}
use_moving_average: false
}
- save_checkpoints_steps: 100
- sync_replicas: True
- num_steps: 2500
+ save_summary_steps: 5
}
eval_config {
metrics_set: {
auc {}
}
+ metrics_set {
+ recall_at_topk { topk: 1 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 5 }
+ }
+ metrics_set {
+ recall_at_topk { topk: 10 }
+ }
}
data_config {
@@ -112,9 +119,22 @@ data_config {
label_fields: 'clk'
batch_size: 4096
- num_epochs: 10000
+ num_epochs: 2
prefetch_size: 32
input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 256
+ num_eval_sample: 4096
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+
}
feature_configs : {
@@ -211,7 +231,7 @@ feature_configs : {
input_names: 'tag_category_list'
feature_type: SequenceFeature
separator: '|'
- hash_bucket_size: 10000
+ hash_bucket_size: 100000
embedding_dim: 16
}
feature_configs : {
@@ -227,30 +247,15 @@ feature_configs : {
embedding_dim: 16
num_buckets: 50
}
-
-model_config: {
- model_class: 'FM'
+model_config:{
+ model_class: "MIND"
feature_groups: {
- group_name: 'wide'
- feature_names: 'user_id'
- feature_names: 'cms_segid'
- feature_names: 'cms_group_id'
- feature_names: 'age_level'
- feature_names: 'pvalue_level'
- feature_names: 'shopping_level'
- feature_names: 'occupation'
- feature_names: 'new_user_class_level'
- feature_names: 'adgroup_id'
- feature_names: 'cate_id'
- feature_names: 'campaign_id'
- feature_names: 'customer'
- feature_names: 'brand'
- feature_names: 'price'
- feature_names: 'pid'
- wide_deep: WIDE
+ group_name: 'hist'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
}
feature_groups: {
- group_name: 'deep'
+ group_name: 'user'
feature_names: 'user_id'
feature_names: 'cms_segid'
feature_names: 'cms_group_id'
@@ -259,29 +264,41 @@ model_config: {
feature_names: 'shopping_level'
feature_names: 'occupation'
feature_names: 'new_user_class_level'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
feature_names: 'adgroup_id'
feature_names: 'cate_id'
feature_names: 'campaign_id'
feature_names: 'customer'
feature_names: 'brand'
- feature_names: 'price'
- feature_names: 'pid'
- wide_deep: DEEP
- sequence_features: {
- group_name: "seq_fea"
- tf_summary: false
- allow_key_search:true
- seq_att_map: {
- key: "brand"
- key: "cate_id"
- hist_seq: "tag_brand_list"
- hist_seq: "tag_category_list"
- }
- }
+ wide_deep:DEEP
}
- fm {
+ mind {
+ user_dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ item_dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ concat_dnn {
+ hidden_units: [64, 32]
+ }
+
+ capsule_config {
+ max_k: 5
+ max_seq_len: 64
+ high_dim: 64
+ squash_pow: 0.5
+ }
+ l2_regularization: 1e-6
+ item_id: "adgroup_id"
+ ignore_in_batch_neg_sam: true
+ max_interests_simi: 0.75
}
- embedding_regularization: 1e-5
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
}
export_config {
diff --git a/samples/model_config/mind_on_taobao_with_time.config b/samples/model_config/mind_on_taobao_with_time.config
index 13d17b2e7..3f02ee140 100644
--- a/samples/model_config/mind_on_taobao_with_time.config
+++ b/samples/model_config/mind_on_taobao_with_time.config
@@ -275,6 +275,9 @@ model_config:{
item_dnn {
hidden_units: [256, 128, 64, 32]
}
+ concat_dnn {
+ hidden_units: [64, 32]
+ }
capsule_config {
max_k: 5
@@ -282,6 +285,8 @@ model_config:{
high_dim: 64
}
l2_regularization: 1e-6
+ time_id_fea: "time_id"
+ ignore_in_batch_neg_sam: true
}
embedding_regularization: 5e-5
}
diff --git a/samples/model_config/mlp_on_taobao_with_ziln_loss.config b/samples/model_config/mlp_on_taobao_with_ziln_loss.config
new file mode 100644
index 000000000..1f05afa91
--- /dev/null
+++ b/samples/model_config/mlp_on_taobao_with_ziln_loss.config
@@ -0,0 +1,279 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/mlp_ziln_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "deep"
+ inputs {
+ feature_group_name: "all"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ num_class: 3
+ losses {
+ loss_type: ZILN_LOSS
+ weight: 1.0
+ loss_name: 'LTV'
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/samples/model_config/mmoe_backbone_on_taobao.config b/samples/model_config/mmoe_backbone_on_taobao.config
new file mode 100644
index 000000000..39018342c
--- /dev/null
+++ b/samples/model_config/mmoe_backbone_on_taobao.config
@@ -0,0 +1,316 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/mmoe_backbone_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 200
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_config: {
+ features {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "tag_category_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "tag_brand_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config {
+ model_name: "MMoE"
+ model_class: "MultiTaskModel"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'all'
+ inputs {
+ feature_group_name: 'all'
+ }
+ input_layer {
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: "senet"
+ inputs {
+ block_name: "all"
+ }
+ keras_layer {
+ class_name: 'SENet'
+ senet {
+ reduction_ratio: 4
+ }
+ }
+ }
+ blocks {
+ name: "mmoe"
+ inputs {
+ block_name: "senet"
+ }
+ keras_layer {
+ class_name: 'MMoE'
+ mmoe {
+ num_task: 2
+ num_expert: 3
+ expert_mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ }
+ }
+ model_params {
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ dnn {
+ hidden_units: [128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ dnn {
+ hidden_units: [128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ l2_regularization: 1e-06
+ }
+ embedding_regularization: 5e-05
+}
diff --git a/samples/model_config/mmoe_mirrored_strategy_on_taobao.config b/samples/model_config/mmoe_mirrored_strategy_on_taobao.config
new file mode 100644
index 000000000..761f3e739
--- /dev/null
+++ b/samples/model_config/mmoe_mirrored_strategy_on_taobao.config
@@ -0,0 +1,318 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/mmoe_mirrored_strategy_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ train_distribute: MultiWorkerMirroredStrategy
+ num_gpus_per_worker: 2
+ num_steps: 200
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_config: {
+ features {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "tag_category_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "tag_brand_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config {
+ model_name: "MMoE"
+ model_class: "MultiTaskModel"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'all'
+ inputs {
+ feature_group_name: 'all'
+ }
+ input_layer {
+ only_output_feature_list: true
+ }
+ }
+ blocks {
+ name: "senet"
+ inputs {
+ block_name: "all"
+ }
+ keras_layer {
+ class_name: 'SENet'
+ senet {
+ reduction_ratio: 4
+ }
+ }
+ }
+ blocks {
+ name: "mmoe"
+ inputs {
+ block_name: "senet"
+ }
+ keras_layer {
+ class_name: 'MMoE'
+ mmoe {
+ num_task: 2
+ num_expert: 3
+ expert_mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ }
+ }
+ model_params {
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ dnn {
+ hidden_units: [128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ dnn {
+ hidden_units: [128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ l2_regularization: 1e-06
+ }
+ embedding_regularization: 5e-05
+}
diff --git a/samples/model_config/mmoe_on_taobao_with_multi_loss.config b/samples/model_config/mmoe_on_taobao_with_multi_loss.config
new file mode 100644
index 000000000..101785244
--- /dev/null
+++ b/samples/model_config/mmoe_on_taobao_with_multi_loss.config
@@ -0,0 +1,293 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/mmoe_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 5000
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_config: {
+ features {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "tag_category_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "tag_brand_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config {
+ model_class: "MMoE"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ mmoe {
+ expert_dnn {
+ hidden_units: [256, 192, 128, 64]
+ }
+ num_expert: 4
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ dnn {
+ hidden_units: [256, 192, 128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ losses {
+ loss_type: CLASSIFICATION
+ weight: 1.0
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
+ metrics_set: {
+ auc {}
+ }
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ dnn {
+ hidden_units: [256, 192, 128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ l2_regularization: 1e-06
+ }
+ embedding_regularization: 5e-05
+}
diff --git a/samples/model_config/multi_tower_backbone_on_taobao.config b/samples/model_config/multi_tower_backbone_on_taobao.config
new file mode 100644
index 000000000..d76795e33
--- /dev/null
+++ b/samples/model_config/multi_tower_backbone_on_taobao.config
@@ -0,0 +1,339 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/multi_tower_backbone_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 200
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_name: 'MultiTower'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ losses {
+ loss_type: F1_REWEIGHTED_LOSS
+ weight: 1.0
+ f1_reweighted_loss {
+ f1_beta_square: 2.25
+ }
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
+ backbone {
+ blocks {
+ name: "user_tower"
+ inputs {
+ feature_group_name: "user"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ blocks {
+ name: "item_tower"
+ inputs {
+ feature_group_name: "item"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ blocks {
+ name: "combo_tower"
+ inputs {
+ feature_group_name: "combo"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: [256, 128]
+ }
+ }
+ }
+ blocks {
+ name: "top_mlp"
+ inputs {
+ block_name: "user_tower"
+ }
+ inputs {
+ block_name: "item_tower"
+ }
+ inputs {
+ block_name: "combo_tower"
+ }
+ keras_layer {
+ class_name: "MLP"
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/multi_tower_on_taobao.config b/samples/model_config/multi_tower_on_taobao.config
index 1249b49f4..78dc12e2f 100644
--- a/samples/model_config/multi_tower_on_taobao.config
+++ b/samples/model_config/multi_tower_on_taobao.config
@@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 100
sync_replicas: True
- num_steps: 200
+ num_steps: 100
}
eval_config {
@@ -260,7 +260,17 @@ model_config: {
feature_names: 'tag_brand_list'
wide_deep: DEEP
}
-
+ losses {
+ loss_type: F1_REWEIGHTED_LOSS
+ weight: 1.0
+ f1_reweighted_loss {
+ f1_beta_square: 2.25
+ }
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
multi_tower {
towers {
input: "user"
diff --git a/samples/model_config/multi_tower_on_taobao_for_expr.config b/samples/model_config/multi_tower_on_taobao_for_expr.config
new file mode 100644
index 000000000..92649ba57
--- /dev/null
+++ b/samples/model_config/multi_tower_on_taobao_for_expr.config
@@ -0,0 +1,343 @@
+train_input_path: "data/test/tb_data/taobao_train_data_for_expr"
+eval_input_path: "data/test/tb_data/taobao_test_data_for_expr"
+model_dir: "experiments/multi_tower_taobao_ckpt_for_expr"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 200
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'item_age_level'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'item_age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ feature_name: "age_satisfy1"
+ input_names: "age_level"
+ input_names: "item_age_level"
+ feature_type: ExprFeature
+ expression: "age_level>=1"
+ }
+ features {
+ feature_name: "age_satisfy2"
+ input_names: "age_level"
+ input_names: "item_age_level"
+ feature_type: ExprFeature
+ expression: "age_level<(item_age_level+5)"
+ }
+ features {
+ feature_name: "age_satisfy3"
+ input_names: "age_level"
+ input_names: "item_age_level"
+ feature_type: ExprFeature
+ expression: "age_level==item_age_level"
+ }
+ features {
+ feature_name: "age_satisfy4"
+ input_names: "age_level"
+ input_names: "item_age_level"
+ feature_type: ExprFeature
+ expression: "(age_level>=item_age_level) & (age_level<=(item_age_level+5))"
+ }
+ features {
+ feature_name: "age_satisfy5"
+ input_names: "age_level"
+ input_names: "item_age_level"
+ feature_type: ExprFeature
+ expression: "(age_level>=item_age_level) | (age_level>5)"
+ }
+}
+model_config: {
+ model_class: 'MultiTower'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ feature_names: 'age_satisfy1'
+ feature_names: 'age_satisfy2'
+ feature_names: 'age_satisfy3'
+ feature_names: 'age_satisfy4'
+ feature_names: 'age_satisfy5'
+ wide_deep: DEEP
+ }
+
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/multi_tower_on_taobao_session_auc.config b/samples/model_config/multi_tower_on_taobao_session_auc.config
new file mode 100644
index 000000000..8070b71b0
--- /dev/null
+++ b/samples/model_config/multi_tower_on_taobao_session_auc.config
@@ -0,0 +1,294 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/multi_tower_taobao_gauc_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ session_auc {
+ session_id_field: 'user_id'
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTower'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/multi_tower_on_taobao_sok.config b/samples/model_config/multi_tower_on_taobao_sok.config
new file mode 100644
index 000000000..7dfe7c987
--- /dev/null
+++ b/samples/model_config/multi_tower_on_taobao_sok.config
@@ -0,0 +1,325 @@
+# train_input_path: "taobao_train_data_8192"
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/multi_tower_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 1000000
+ log_step_count_steps: 10
+ sync_replicas: False
+ train_distribute: EmbeddingParallelStrategy
+ num_steps: 200000
+ # is_profiling: true
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ # default_val: "430548_1007"
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ # default_val: "620392"
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ # default_val: "6519"
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ # default_val: "336811"
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ # default_val: "252462"
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ # default_val: "247789"
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ # default_val: "100002"
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ # default_val: "5"
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ # default_val: "2"
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ # default_val: "2"
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ # default_val: "2"
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ # default_val: "1"
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ # default_val: "3"
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ # default_val: "1"
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ # default_val: "2"
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ # default_val: '4281|4281|4526'
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ # default_val: '283837|283837|367594'
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ # default_val: '23'
+ }
+
+ label_fields: 'clk'
+ batch_size: 8192
+ num_epochs: 1000000
+ prefetch_size: 64
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTower'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ losses {
+ loss_type: F1_REWEIGHTED_LOSS
+ weight: 1.0
+ f1_reweighted_loss {
+ f1_beta_square: 2.25
+ }
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/multi_tower_on_taobao_test_udf.config b/samples/model_config/multi_tower_on_taobao_test_udf.config
new file mode 100644
index 000000000..9e37221ae
--- /dev/null
+++ b/samples/model_config/multi_tower_on_taobao_test_udf.config
@@ -0,0 +1,306 @@
+train_input_path: "data/test/tb_data/taobao_train_data_remap_label"
+eval_input_path: "data/test/tb_data/taobao_test_data_remap_label"
+model_dir: "experiments/multi_tower_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 200
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT64
+ user_define_fn: 'remap_lbl'
+ user_define_fn_path: 'samples/demo_script/process_lbl.py'
+ user_define_fn_res_type: INT64
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTower'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ losses {
+ loss_type: F1_REWEIGHTED_LOSS
+ weight: 1.0
+ f1_reweighted_loss {
+ f1_beta_square: 2.25
+ }
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/multi_tower_on_taobao_unblanace.config b/samples/model_config/multi_tower_on_taobao_unblanace.config
new file mode 100644
index 000000000..8105483c2
--- /dev/null
+++ b/samples/model_config/multi_tower_on_taobao_unblanace.config
@@ -0,0 +1,304 @@
+train_input_path: "data/test/tb_data/taobao_train_data,data/test/tb_data/taobao_test_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/multi_tower_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 200
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ file_shard: true
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTower'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'combo'
+ feature_names: 'pid'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ losses {
+ loss_type: F1_REWEIGHTED_LOSS
+ weight: 1.0
+ f1_reweighted_loss {
+ f1_beta_square: 2.25
+ }
+ }
+ losses {
+ loss_type: PAIR_WISE_LOSS
+ weight: 1.0
+ }
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: [128, 96, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-4
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/multi_tower_recall_neg_sampler_only_sequence_feature.config b/samples/model_config/multi_tower_recall_neg_sampler_only_sequence_feature.config
new file mode 100644
index 000000000..2028cb600
--- /dev/null
+++ b/samples/model_config/multi_tower_recall_neg_sampler_only_sequence_feature.config
@@ -0,0 +1,304 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/multi_tower_recall_neg_sampler_only_sequence_feature"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 6
+ sync_replicas: false
+ save_checkpoints_steps: 100
+ log_step_count_steps: 2
+}
+
+eval_config {
+ metrics_set: {
+ auc {
+ }
+ }
+ metrics_set: {
+ gauc {
+ uid_field: "user_id"
+ }
+ }
+}
+
+data_config {
+ batch_size: 16
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ num_epochs: 5
+ prefetch_size: 4
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 4
+ num_eval_sample: 4
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_configs : {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs : {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs : {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config:{
+ model_class: "MultiTowerRecall"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep:DEEP
+ negative_sampler:true
+ sequence_features: {
+ group_name: "seq_fea"
+ allow_key_search: true
+ need_key_feature:false
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ multi_tower_recall {
+ user_tower {
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ loss_type: CLASSIFICATION
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config b/samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config
new file mode 100644
index 000000000..a51260b59
--- /dev/null
+++ b/samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config
@@ -0,0 +1,304 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/multi_tower_recall_neg_sampler_sequence_feature"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 6
+ sync_replicas: false
+ save_checkpoints_steps: 100
+ log_step_count_steps: 2
+}
+
+eval_config {
+ metrics_set: {
+ auc {
+ }
+ }
+ metrics_set: {
+ gauc {
+ uid_field: "user_id"
+ }
+ }
+}
+
+data_config {
+ batch_size: 16
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ num_epochs: 5
+ prefetch_size: 4
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 4
+ num_eval_sample: 4
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_configs : {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs : {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs {
+ input_names: "tag_category_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs {
+ input_names: "tag_brand_list"
+ feature_type: SequenceFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ sub_feature_type: IdFeature
+ separator: "|"
+}
+feature_configs : {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+model_config:{
+ model_class: "MultiTowerRecall"
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep:DEEP
+ negative_sampler:true
+ sequence_features: {
+ group_name: "seq_fea"
+ allow_key_search: true
+ need_key_feature:true
+ seq_att_map: {
+ key: "brand"
+ key: "cate_id"
+ hist_seq: "tag_brand_list"
+ hist_seq: "tag_category_list"
+ }
+ }
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ multi_tower_recall {
+ user_tower {
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
+ }
+ }
+ item_tower {
+ dnn {
+ hidden_units: [256, 128, 64, 32]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 1e-6
+ }
+ loss_type: CLASSIFICATION
+ embedding_regularization: 5e-6
+}
diff --git a/samples/model_config/parallel_dssm_on_taobao_backbone.config b/samples/model_config/parallel_dssm_on_taobao_backbone.config
new file mode 100644
index 000000000..467594837
--- /dev/null
+++ b/samples/model_config/parallel_dssm_on_taobao_backbone.config
@@ -0,0 +1,589 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/parallel_dssm_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 200
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ # initial_learning_rate: 0.001
+ initial_learning_rate: 0.0001
+ decay_steps: 4000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 2000
+ sync_replicas: false
+ num_steps: 2000
+}
+
+eval_config {
+
+ metrics_set: {
+ recall_at_topk {
+ topk: 50
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 10
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 5
+ }
+ }
+ metrics_set: {
+ recall_at_topk {
+ topk: 1
+ }
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 2048
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+
+ negative_sampler {
+ input_path: 'data/test/tb_data/taobao_ad_feature_gl'
+ num_sample: 512
+ num_eval_sample: 512
+ attr_fields: 'adgroup_id'
+ attr_fields: 'cate_id'
+ attr_fields: 'campaign_id'
+ attr_fields: 'customer'
+ attr_fields: 'brand'
+ item_id_field: 'adgroup_id'
+ }
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config:{
+ model_name: "Parallel_DSSM"
+ model_class: 'MatchModel'
+ feature_groups: {
+ group_name: 'user_mlp_feature'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ feature_names: 'final_gender_code'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: 'user_dcn_feature'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ feature_names: 'final_gender_code'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: 'user_fm_feature'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ feature_names: 'final_gender_code'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: 'user_cin_feature'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ feature_names: 'final_gender_code'
+ wide_deep:DEEP
+ }
+
+ feature_groups: {
+ group_name: "item_mlp_feature"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item_dcn_feature"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item_fm_feature"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item_cin_feature"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ #feature_names: 'price'
+ #feature_names: 'pid'
+ wide_deep:DEEP
+ }
+
+ backbone {
+ blocks {
+ name: 'user_mlp'
+ inputs {
+ feature_group_name: 'user_mlp_feature'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 32]
+ }
+ }
+ }
+ blocks {
+ name: 'user_dcn'
+ inputs {
+ feature_group_name: 'user_dcn_feature'
+ input_fn: 'lambda x: [x, x]'
+ }
+ recurrent {
+ num_steps: 3
+ fixed_input_index: 0
+ keras_layer {
+ class_name: 'Cross'
+ }
+ }
+ }
+ blocks {
+ name: 'user_dcn_out'
+ inputs {
+ block_name: 'user_dcn'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [32]
+ }
+ }
+ }
+
+ blocks {
+ name: 'user_fm_feature'
+ inputs {
+ feature_group_name: 'user_fm_feature'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'user_fm'
+ inputs {
+ block_name: 'user_fm_feature'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ }
+ }
+
+ blocks {
+ name: 'user_cin_feature'
+ inputs {
+ feature_group_name: 'user_cin_feature'
+ }
+ input_layer {
+ only_output_3d_tensor: true
+ }
+ }
+
+ blocks {
+ name: 'user_cin'
+ inputs {
+ block_name: 'user_cin_feature'
+ }
+ keras_layer {
+ class_name: 'CIN'
+ cin {
+ hidden_feature_sizes: [16, 16, 16]
+ }
+ }
+ }
+
+ blocks {
+ name: 'item_mlp'
+ inputs {
+ feature_group_name: 'item_mlp_feature'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 32]
+ }
+ }
+ }
+ blocks {
+ name: 'item_dcn'
+ inputs {
+ feature_group_name: 'item_dcn_feature'
+ input_fn: 'lambda x: [x, x]'
+ }
+ recurrent {
+ num_steps: 3
+ fixed_input_index: 0
+ keras_layer {
+ class_name: 'Cross'
+ }
+ }
+ }
+ blocks {
+ name: 'item_dcn_out'
+ inputs {
+ block_name: 'item_dcn'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [32]
+ }
+ }
+ }
+
+ blocks {
+ name: 'item_fm_feature'
+ inputs {
+ feature_group_name: 'item_fm_feature'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'item_fm'
+ inputs {
+ block_name: 'item_fm_feature'
+ input_slice: '[1]'
+ }
+ keras_layer {
+ class_name: 'FM'
+ }
+ }
+ blocks {
+ name: 'item_cin_feature'
+ inputs {
+ feature_group_name: 'item_cin_feature'
+ }
+ input_layer {
+ only_output_3d_tensor: true
+ }
+ }
+ blocks {
+ name: 'item_cin'
+ inputs {
+ block_name: 'item_cin_feature'
+ }
+ keras_layer {
+ class_name: 'CIN'
+ cin {
+ hidden_feature_sizes: [16, 16, 16]
+ }
+ }
+ }
+
+ blocks {
+ name: 'user_tower_embedding'
+ inputs {
+ block_name: 'user_mlp'
+ }
+ inputs {
+ block_name: 'user_dcn_out'
+ }
+ inputs {
+ block_name: 'user_fm'
+ }
+ inputs {
+ block_name: 'user_cin'
+ }
+
+ merge_inputs_into_list: true
+ lambda {
+ expression: 'lambda x: tf.concat(x, axis=1)'
+ }
+ }
+ blocks {
+ name: 'item_tower_embedding'
+ inputs {
+ block_name: 'item_mlp'
+ }
+ inputs {
+ block_name: 'item_dcn_out'
+ }
+ inputs {
+ block_name: 'item_fm'
+ }
+ inputs {
+ block_name: 'item_cin'
+ }
+ merge_inputs_into_list: true
+ lambda {
+ expression: 'lambda x: tf.concat(x, axis=1)'
+ }
+ }
+
+ output_blocks: ['user_tower_embedding', 'item_tower_embedding']
+ }
+ model_params {
+ l2_regularization: 1e-4
+ user_tower_idx_in_output: 0
+ item_tower_idx_in_output: 1
+ scale_simi: false
+ simi_func: INNER_PRODUCT
+ }
+ loss_type: SOFTMAX_CROSS_ENTROPY
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/pdn_on_taobao.config b/samples/model_config/pdn_on_taobao.config
new file mode 100644
index 000000000..dbd72db32
--- /dev/null
+++ b/samples/model_config/pdn_on_taobao.config
@@ -0,0 +1,327 @@
+# Note: this is just a demo using faked data.
+
+train_input_path: "data/test/tb_data/taobao_pdn_fake_train_data"
+eval_input_path: "data/test/tb_data/taobao_pdn_fake_test_data"
+model_dir: "experiments/pdn_on_taobao"
+
+train_config {
+ log_step_count_steps: 50
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_summary_steps: 50
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "i2i_rnk"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "event_type_seq"
+ input_type: STRING
+ }
+
+ label_fields: 'clk'
+ batch_size: 512
+ num_epochs: 1
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_configs : {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+}
+feature_configs : {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+}
+feature_configs : {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+}
+feature_configs : {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+}
+feature_configs : {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+}
+feature_configs : {
+ input_names: 'event_type_seq'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+}
+feature_configs : {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+}
+feature_configs : {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+}
+feature_configs {
+ input_names: 'i2i_rnk'
+ feature_type: SequenceFeature
+ separator: '|'
+ num_buckets: 128
+ embedding_dim: 16
+}
+
+model_config:{
+ model_class: "PDN"
+ feature_groups: {
+ group_name: 'u2i_seq'
+ feature_names: 'event_type_seq'
+ }
+ feature_groups: {
+ group_name: 'i_seq'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ }
+ feature_groups: {
+ group_name: 'i2i_seq'
+ feature_names: 'i2i_rnk'
+ }
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "item"
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ wide_deep:DEEP
+ }
+ feature_groups {
+ group_name: 'bias'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ }
+
+ pdn {
+ user_dnn {
+ hidden_units: [128, 64, 32]
+ }
+ item_dnn {
+ hidden_units: [128, 64, 32]
+ }
+ u2i_dnn {
+ hidden_units: [64, 32]
+ }
+ trigger_dnn {
+ hidden_units: [64, 32, 1]
+ }
+ i2i_dnn {
+ hidden_units: [128, 64, 32]
+ }
+ sim_dnn {
+ hidden_units: [64, 32, 1]
+ }
+ bias_dnn {
+ hidden_units: [32, 32, 1]
+ }
+
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 5e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/ple_on_sequence_feature_taobao.config b/samples/model_config/ple_on_sequence_feature_taobao.config
index 82e34586e..9ac3a9400 100644
--- a/samples/model_config/ple_on_sequence_feature_taobao.config
+++ b/samples/model_config/ple_on_sequence_feature_taobao.config
@@ -16,7 +16,7 @@ train_config {
}
use_moving_average: false
}
- num_steps: 200
+ num_steps: 100
sync_replicas: true
save_checkpoints_steps: 100
log_step_count_steps: 100
diff --git a/samples/model_config/ple_on_taobao.config b/samples/model_config/ple_on_taobao.config
index b24a3b9c9..ea4fdd0dc 100644
--- a/samples/model_config/ple_on_taobao.config
+++ b/samples/model_config/ple_on_taobao.config
@@ -16,7 +16,7 @@ train_config {
}
use_moving_average: false
}
- num_steps: 200
+ num_steps: 100
sync_replicas: true
save_checkpoints_steps: 100
log_step_count_steps: 100
diff --git a/samples/model_config/ppnet_on_taobao.config b/samples/model_config/ppnet_on_taobao.config
new file mode 100644
index 000000000..6bf4ba212
--- /dev/null
+++ b/samples/model_config/ppnet_on_taobao.config
@@ -0,0 +1,289 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/ppnet_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_name: 'PPNet'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'memorize'
+ feature_names: 'user_id'
+ feature_names: 'adgroup_id'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "ppnet"
+ inputs {
+ feature_group_name: "general"
+ }
+ inputs {
+ feature_group_name: "memorize"
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: "PPNet"
+ ppnet {
+ mlp {
+ hidden_units: [512, 256]
+ }
+ mode: "lazy"
+ full_gate_input: true
+ }
+ }
+ }
+ top_mlp {
+ hidden_units: [128, 64]
+ }
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-5
+}
diff --git a/samples/model_config/rocket_launching_with_rtp_input.config b/samples/model_config/rocket_launching_with_rtp_input.config
new file mode 100644
index 000000000..09750cb26
--- /dev/null
+++ b/samples/model_config/rocket_launching_with_rtp_input.config
@@ -0,0 +1,266 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_demo"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 400
+ sync_replicas: true
+ log_step_count_steps: 200
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ separator: ""
+ }
+}
+model_config: {
+ model_class: 'RocketLaunching'
+ feature_groups: {
+ group_name: 'all'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+ rocket_launching {
+ share_dnn {
+ hidden_units: [128, 96, 64]
+ }
+ booster_dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ light_dnn{
+ hidden_units:[128, 64]
+ }
+ l2_regularization: 1e-6
+ feature_based_distillation:false
+ feature_distillation_function:COSINE
+ }
+ embedding_regularization:5e-6
+ num_class: 2
+
+}
+export_config {
+ multi_placeholder: false
+}
\ No newline at end of file
diff --git a/samples/model_config/share_embedding_not_used.config b/samples/model_config/share_embedding_not_used.config
new file mode 100644
index 000000000..35efd1a2d
--- /dev/null
+++ b/samples/model_config/share_embedding_not_used.config
@@ -0,0 +1,292 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/share_not_used"
+
+train_config {
+ log_step_count_steps: 100
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 100
+ sync_replicas: True
+ num_steps: 2500
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 4096
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ embedding_name: "item_id"
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ embedding_name: "cate_id"
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ embedding_name: "brand_id"
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 10000
+ embedding_dim: 16
+ embedding_name: "cate_id"
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: SequenceFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ embedding_name: "brand_id"
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config: {
+ model_class: 'MultiTower'
+ feature_groups: {
+ group_name: 'user'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'item'
+ # feature_names: 'adgroup_id'
+ # feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep: DEEP
+ }
+
+ multi_tower {
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: [256, 128, 96, 64]
+ }
+ }
+ din_towers {
+ input: "din"
+ dnn {
+ hidden_units: [128, 64, 32, 1]
+ }
+ }
+ final_dnn {
+ hidden_units: [128, 96, 64, 32, 16]
+ }
+ l2_regularization: 5e-7
+ }
+ embedding_regularization: 5e-5
+}
+
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/share_not_used.config b/samples/model_config/share_not_used.config
new file mode 100644
index 000000000..4eea99ec4
--- /dev/null
+++ b/samples/model_config/share_not_used.config
@@ -0,0 +1,571 @@
+binary_train_input {
+ category_path: 'data/test/criteo_data/category.bin'
+ dense_path: 'data/test/criteo_data/dense.bin'
+ label_path: 'data/test/criteo_data/label.bin'
+}
+binary_eval_input {
+ category_path: 'data/test/criteo_data/category.bin'
+ dense_path: 'data/test/criteo_data/dense.bin'
+ label_path: 'data/test/criteo_data/label.bin'
+}
+model_dir: "experiments/dlrm_criteo"
+
+train_config {
+ optimizer_config: {
+ momentum_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 1e-4
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 10000
+ log_step_count_steps: 10
+ sync_replicas: True
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+
+ input_fields: {
+ input_name: "f1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f2"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f3"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f4"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f5"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f6"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f7"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f8"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f9"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f10"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f11"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f12"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "f13"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c2"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c3"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c4"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c5"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c6"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c7"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c8"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c9"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c10"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c11"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c12"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c13"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c22"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c23"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c24"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c25"
+ input_type: INT64
+ default_val:""
+ }
+ input_fields: {
+ input_name: "c26"
+ input_type: INT64
+ default_val:""
+ }
+
+
+ label_fields: 'label'
+ batch_size: 4096
+ num_epochs: 1000
+ prefetch_size: 32
+ input_type: CriteoInput
+}
+
+feature_config: {
+ features: {
+ input_names: "f1"
+ feature_type: RawFeature
+ min_val:0.0
+ max_val: 5775.0
+ }
+ features: {
+ input_names: "f2"
+ feature_type: RawFeature
+ min_val: -3.0
+ max_val: 257675.0
+ }
+ features: {
+ input_names: "f3"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 65535.0
+ }
+ features: {
+ input_names: "f4"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 969.0
+ }
+ features: {
+ input_names: "f5"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 23159456.0
+ }
+ features: {
+ input_names: "f6"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 431037.0
+ }
+ features: {
+ input_names: "f7"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 56311.0
+ }
+ features: {
+ input_names: "f8"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 6047.0
+ }
+ features: {
+ input_names: "f9"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 29019.0
+ }
+ features: {
+ input_names: "f10"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 46.0
+ }
+ features: {
+ input_names: "f11"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 231.0
+ }
+ features: {
+ input_names: "f12"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 4008.0
+ }
+ features: {
+ input_names: "f13"
+ feature_type: RawFeature
+ min_val: 0.0
+ max_val: 7393.0
+ }
+ features: {
+ input_names: "c1"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c2"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c3"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c4"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c5"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c6"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c7"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c8"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c9"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c10"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3853295
+ }
+ features: {
+ input_names: "c11"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c12"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c13"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c19"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c20"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3997977
+ }
+ features: {
+ input_names: "c21"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 2564129
+ }
+ features: {
+ input_names: "c22"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 3966498
+ }
+ features: {
+ input_names: "c23"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding"
+ }
+ features: {
+ input_names: "c24"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding1"
+ }
+ features: {
+ input_names: "c25"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding1"
+ }
+ features: {
+ input_names: "c26"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ embedding_name: "embedding1"
+ }
+}
+model_config {
+ model_class: 'DLRM'
+
+ feature_groups {
+ group_name: 'dense'
+ feature_names: "f1"
+ feature_names: "f2"
+ feature_names: "f3"
+ feature_names: "f4"
+ feature_names: "f5"
+ feature_names: "f6"
+ feature_names: "f7"
+ feature_names: "f8"
+ feature_names: "f9"
+ feature_names: "f10"
+ feature_names: "f11"
+ feature_names: "f12"
+ feature_names: "f13"
+
+ wide_deep: DEEP
+ }
+
+ feature_groups {
+ group_name: 'sparse'
+ feature_names: "c1"
+ feature_names: "c2"
+ feature_names: "c3"
+ feature_names: "c4"
+ feature_names: "c5"
+ feature_names: "c6"
+ feature_names: "c7"
+ feature_names: "c8"
+ feature_names: "c9"
+ feature_names: "c10"
+ feature_names: "c11"
+ feature_names: "c12"
+ feature_names: "c13"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ feature_names: "c22"
+ feature_names: "c23"
+ # feature_names: "c24"
+ # feature_names: "c25"
+ # feature_names: "c26"
+
+ wide_deep: DEEP
+ }
+
+ dlrm {
+ bot_dnn {
+ hidden_units: [128, 64, 32]
+ }
+
+ top_dnn {
+ hidden_units: [256, 128, 128, 64]
+ }
+ }
+
+ embedding_regularization: 1e-5
+}
+
+export_config {
+}
diff --git a/samples/model_config/simple_multi_task_backbone_on_taobao.config b/samples/model_config/simple_multi_task_backbone_on_taobao.config
new file mode 100644
index 000000000..9737e8193
--- /dev/null
+++ b/samples/model_config/simple_multi_task_backbone_on_taobao.config
@@ -0,0 +1,291 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/simple_multi_task_backbone_taobao_ckpt"
+
+train_config {
+ optimizer_config {
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 1000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ num_steps: 200
+ sync_replicas: true
+ save_checkpoints_steps: 100
+ log_step_count_steps: 100
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 4096
+ label_fields: "clk"
+ label_fields: "buy"
+ prefetch_size: 32
+ input_type: CSVInput
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "buy"
+ input_type: INT32
+ }
+ input_fields {
+ input_name: "pid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cate_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "campaign_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "customer"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "brand"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "user_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_segid"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "final_gender_code"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "age_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "shopping_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "occupation"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_category_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "tag_brand_list"
+ input_type: STRING
+ }
+ input_fields {
+ input_name: "price"
+ input_type: INT32
+ }
+}
+feature_config: {
+ features {
+ input_names: "pid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "customer"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "brand"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "user_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features {
+ input_names: "final_gender_code"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "age_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "occupation"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features {
+ input_names: "tag_category_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "tag_brand_list"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: "|"
+ }
+ features {
+ input_names: "price"
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+model_config {
+ model_name: "SimpleMultiTask"
+ model_class: "MultiTaskModel"
+ feature_groups {
+ group_name: "all"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ feature_names: "tag_category_list"
+ feature_names: "tag_brand_list"
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: "identity"
+ inputs {
+ feature_group_name: "all"
+ }
+ }
+ }
+ model_params {
+ task_towers {
+ tower_name: "ctr"
+ label_name: "clk"
+ dnn {
+ hidden_units: [256, 192, 128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ task_towers {
+ tower_name: "cvr"
+ label_name: "buy"
+ dnn {
+ hidden_units: [256, 192, 128, 64]
+ }
+ num_class: 1
+ weight: 1.0
+ loss_type: CLASSIFICATION
+ metrics_set: {
+ auc {}
+ }
+ }
+ l2_regularization: 1e-07
+ }
+ embedding_regularization: 5e-06
+}
diff --git a/samples/model_config/simple_multi_task_on_sequence_feature_taobao.config b/samples/model_config/simple_multi_task_on_sequence_feature_taobao.config
index 2e7a2c12e..bec43d5a8 100644
--- a/samples/model_config/simple_multi_task_on_sequence_feature_taobao.config
+++ b/samples/model_config/simple_multi_task_on_sequence_feature_taobao.config
@@ -16,7 +16,7 @@ train_config {
}
use_moving_average: false
}
- num_steps: 5000
+ num_steps: 100
sync_replicas: true
save_checkpoints_steps: 100
log_step_count_steps: 100
diff --git a/samples/model_config/simple_multi_task_on_taobao.config b/samples/model_config/simple_multi_task_on_taobao.config
index 46ffbeb9a..d9a6734e0 100644
--- a/samples/model_config/simple_multi_task_on_taobao.config
+++ b/samples/model_config/simple_multi_task_on_taobao.config
@@ -16,7 +16,7 @@ train_config {
}
use_moving_average: false
}
- num_steps: 5000
+ num_steps: 100
sync_replicas: true
save_checkpoints_steps: 100
log_step_count_steps: 100
diff --git a/samples/model_config/taobao_fg.config b/samples/model_config/taobao_fg.config
index bfe7b78e2..e69bd7fab 100644
--- a/samples/model_config/taobao_fg.config
+++ b/samples/model_config/taobao_fg.config
@@ -16,7 +16,7 @@ train_config {
}
}
}
- num_steps: 400
+ num_steps: 30
sync_replicas: true
log_step_count_steps: 200
}
diff --git a/samples/model_config/taobao_fg_ev.config b/samples/model_config/taobao_fg_ev.config
index 3f8b1bf38..f490aa6d2 100644
--- a/samples/model_config/taobao_fg_ev.config
+++ b/samples/model_config/taobao_fg_ev.config
@@ -290,7 +290,9 @@ model_config {
}
l2_regularization: 0.0001
}
- use_embedding_variable: true
+ ev_params {
+ filter_freq: 5
+ }
}
export_config {
multi_placeholder: false
diff --git a/samples/model_config/taobao_fg_ev_v2.config b/samples/model_config/taobao_fg_ev_v2.config
index 58e1ca9fa..3a2f03b77 100644
--- a/samples/model_config/taobao_fg_ev_v2.config
+++ b/samples/model_config/taobao_fg_ev_v2.config
@@ -291,7 +291,10 @@ model_config {
}
l2_regularization: 0.0001
}
- use_embedding_variable: true
+ ev_params {
+ filter_freq: 5
+ steps_to_live: 1000
+ }
}
export_config {
multi_placeholder: false
diff --git a/samples/model_config/taobao_fg_incr_save.config b/samples/model_config/taobao_fg_incr_save.config
new file mode 100644
index 000000000..093bbf4ac
--- /dev/null
+++ b/samples/model_config/taobao_fg_incr_save.config
@@ -0,0 +1,313 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_incr_save"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ momentum_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 50
+ sync_replicas: false
+ log_step_count_steps: 200
+ save_checkpoints_steps: 50
+
+ incr_save_config {
+ dense_save_steps: 10
+ sparse_save_steps: 10
+ kafka {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_model_20220408'
+ consumer {
+ offset:0
+ }
+ }
+ debug_save_update: true
+ }
+
+ # enable_oss_stop_signal: true
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ separator: ""
+ }
+}
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+}
+export_config {
+ multi_placeholder: true
+}
diff --git a/samples/model_config/taobao_fg_incr_save_ev.config b/samples/model_config/taobao_fg_incr_save_ev.config
new file mode 100644
index 000000000..b8924d47d
--- /dev/null
+++ b/samples/model_config/taobao_fg_incr_save_ev.config
@@ -0,0 +1,315 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_incr_save_ev"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_async_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 20
+ sync_replicas: false
+ log_step_count_steps: 5
+ save_checkpoints_steps: 100
+
+ incr_save_config {
+ dense_save_steps: 10
+ sparse_save_steps: 10
+ kafka {
+ server: '127.0.0.1:9092'
+ topic: 'kafka_model_20220408'
+ consumer {
+ offset:0
+ }
+ }
+ debug_save_update: true
+ }
+
+ # enable_oss_stop_signal: true
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ separator: ""
+ }
+}
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+ ev_params {
+ }
+}
+export_config {
+ multi_placeholder: true
+}
diff --git a/samples/model_config/taobao_fg_incr_save_ev_local.config b/samples/model_config/taobao_fg_incr_save_ev_local.config
new file mode 100644
index 000000000..da18733cf
--- /dev/null
+++ b/samples/model_config/taobao_fg_incr_save_ev_local.config
@@ -0,0 +1,309 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_incr_save_ev"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_async_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 20
+ sync_replicas: false
+ log_step_count_steps: 5
+ save_checkpoints_steps: 100
+
+ incr_save_config {
+ dense_save_steps: 10
+ sparse_save_steps: 10
+ fs {
+ }
+ }
+
+ # enable_oss_stop_signal: true
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ separator: ""
+ }
+}
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+ ev_params {
+ }
+}
+export_config {
+ multi_placeholder: true
+}
diff --git a/samples/model_config/taobao_fg_incr_save_local.config b/samples/model_config/taobao_fg_incr_save_local.config
new file mode 100644
index 000000000..67869dae4
--- /dev/null
+++ b/samples/model_config/taobao_fg_incr_save_local.config
@@ -0,0 +1,308 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_incr_save"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ momentum_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 50
+ sync_replicas: false
+ log_step_count_steps: 200
+ save_checkpoints_steps: 100
+ keep_checkpoint_max: 50
+
+ incr_save_config {
+ dense_save_steps: 10
+ sparse_save_steps: 10
+ fs {
+ }
+ }
+
+ # enable_oss_stop_signal: true
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ separator: ""
+ }
+}
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+}
+export_config {
+ multi_placeholder: true
+}
diff --git a/samples/model_config/taobao_fg_incr_save_share_ev_local.config b/samples/model_config/taobao_fg_incr_save_share_ev_local.config
new file mode 100644
index 000000000..fddeeb2b5
--- /dev/null
+++ b/samples/model_config/taobao_fg_incr_save_share_ev_local.config
@@ -0,0 +1,318 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_incr_save_ev"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_async_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 20
+ sync_replicas: false
+ log_step_count_steps: 5
+ save_checkpoints_steps: 100
+ keep_checkpoint_max: 100
+
+ incr_save_config {
+ dense_save_steps: 10
+ sparse_save_steps: 10
+ fs {
+ }
+ }
+
+ # enable_oss_stop_signal: true
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ embedding_name: "item_embedding"
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ embedding_name: "item_embedding"
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ embedding_name: "item_embedding"
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ embedding_name: "item_embedding"
+ combiner: "sum"
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ combiner: "mean"
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ combiner: "mean"
+ separator: ""
+ }
+}
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+ ev_params {
+ filter_freq: 5
+ }
+}
+export_config {
+ multi_placeholder: true
+}
diff --git a/samples/model_config/taobao_fg_signal_stop.config b/samples/model_config/taobao_fg_signal_stop.config
new file mode 100644
index 000000000..9bd01afd0
--- /dev/null
+++ b/samples/model_config/taobao_fg_signal_stop.config
@@ -0,0 +1,299 @@
+train_input_path: "data/test/rtp/taobao_train_feature.txt"
+eval_input_path: "data/test/rtp/taobao_test_feature.txt"
+model_dir: "experiments/taobao_fg_demo"
+
+train_config {
+ optimizer_config {
+ use_moving_average: false
+ adam_optimizer {
+ learning_rate {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 100000
+ decay_factor: 0.5
+ min_learning_rate: 1e-07
+ }
+ }
+ }
+ }
+ num_steps: 400
+ sync_replicas: false
+ log_step_count_steps: 200
+ enable_oss_stop_signal: true
+ save_checkpoints_steps: 10
+}
+eval_config {
+ metrics_set {
+ auc {
+ }
+ }
+}
+data_config {
+ batch_size: 1024
+ label_fields: "clk"
+ input_type: RTPInput
+ separator: ""
+ selected_cols: "0,3"
+ input_fields {
+ input_name: "clk"
+ input_type: INT32
+ default_val: "0"
+ }
+ input_fields {
+ input_name: "user_id"
+ }
+ input_fields {
+ input_name: "cms_segid"
+ }
+ input_fields {
+ input_name: "cms_group_id"
+ }
+ input_fields {
+ input_name: "age_level"
+ }
+ input_fields {
+ input_name: "pvalue_level"
+ }
+ input_fields {
+ input_name: "shopping_level"
+ }
+ input_fields {
+ input_name: "occupation"
+ }
+ input_fields {
+ input_name: "new_user_class_level"
+ }
+ input_fields {
+ input_name: "adgroup_id"
+ }
+ input_fields {
+ input_name: "cate_id"
+ }
+ input_fields {
+ input_name: "campaign_id"
+ }
+ input_fields {
+ input_name: "customer"
+ }
+ input_fields {
+ input_name: "brand"
+ }
+ input_fields {
+ input_name: "price"
+ input_type: DOUBLE
+ default_val: "0.0"
+ }
+ input_fields {
+ input_name: "pid"
+ }
+ input_fields {
+ input_name: "user_tag_cate"
+ }
+ input_fields {
+ input_name: "combo_brand"
+ }
+ input_fields {
+ input_name: "combo_cate_id"
+ }
+ rtp_separator: ";"
+}
+feature_config: {
+ features {
+ input_names: "user_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ max_partitions: 4
+ separator: ""
+ }
+ features {
+ input_names: "cms_segid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "cms_group_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ separator: ""
+ }
+ features {
+ input_names: "age_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "pvalue_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "shopping_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "occupation"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "new_user_class_level"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ separator: ""
+ }
+ features {
+ input_names: "adgroup_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "campaign_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "customer"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "price"
+ feature_type: RawFeature
+ separator: ""
+ }
+ features {
+ input_names: "pid"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "user_tag_cate"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_brand"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ separator: ""
+ }
+ features {
+ input_names: "combo_cate_id"
+ feature_type: TagFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ separator: ""
+ }
+}
+model_config {
+ model_class: "MultiTower"
+ feature_groups {
+ group_name: "item"
+ feature_names: "adgroup_id"
+ feature_names: "cate_id"
+ feature_names: "campaign_id"
+ feature_names: "customer"
+ feature_names: "brand"
+ feature_names: "price"
+ feature_names: "pid"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "user"
+ feature_names: "user_id"
+ feature_names: "cms_segid"
+ feature_names: "cms_group_id"
+ feature_names: "age_level"
+ feature_names: "pvalue_level"
+ feature_names: "shopping_level"
+ feature_names: "occupation"
+ feature_names: "new_user_class_level"
+ feature_names: "user_tag_cate"
+ wide_deep: DEEP
+ }
+ feature_groups {
+ group_name: "combo"
+ feature_names: "combo_brand"
+ feature_names: "combo_cate_id"
+ wide_deep: DEEP
+ }
+ embedding_regularization: 1e-05
+ multi_tower {
+ towers {
+ input: "item"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "user"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ towers {
+ input: "combo"
+ dnn {
+ hidden_units: 192
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ }
+ }
+ final_dnn {
+ hidden_units: 256
+ hidden_units: 192
+ hidden_units: 128
+ hidden_units: 64
+ }
+ l2_regularization: 0.0001
+ }
+}
+export_config {
+ multi_placeholder: false
+}
diff --git a/samples/model_config/text_cnn_on_movielens.config b/samples/model_config/text_cnn_on_movielens.config
new file mode 100644
index 000000000..87dbee5ed
--- /dev/null
+++ b/samples/model_config/text_cnn_on_movielens.config
@@ -0,0 +1,162 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/text_cnn_movielens_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+}
+model_config: {
+ model_name: 'TextCNN'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'text_seq'
+ feature_names: 'title'
+ wide_deep: DEEP
+ }
+ backbone {
+ blocks {
+ name: 'text_seq'
+ inputs {
+ feature_group_name: 'text_seq'
+ }
+ input_layer {
+ output_seq_and_normal_feature: true
+ }
+ }
+ blocks {
+ name: 'textcnn'
+ inputs {
+ block_name: 'text_seq'
+ }
+ keras_layer {
+ class_name: 'TextCNN'
+ text_cnn {
+ filter_sizes: [2, 3, 4]
+ num_filters: [16, 8, 8]
+ pad_sequence_length: 14
+ mlp {
+ hidden_units: [256, 128, 64]
+ }
+ }
+ }
+ }
+ }
+ model_params {
+ l2_regularization: 1e-6
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/uniter_on_movielens.config b/samples/model_config/uniter_on_movielens.config
new file mode 100644
index 000000000..42f20e8e4
--- /dev/null
+++ b/samples/model_config/uniter_on_movielens.config
@@ -0,0 +1,241 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/uniter_movielens_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'Uniter'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'other'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ uniter {
+ config {
+ hidden_size: 512
+ num_attention_heads: 4
+ num_hidden_layers: 2
+ intermediate_size: 512
+ hidden_act: 'swish'
+ max_position_embeddings: 16
+ hidden_dropout_prob: 0.1
+ attention_probs_dropout_prob: 0
+ other_feature_dnn: {
+ hidden_units: 256
+ hidden_units: 128
+ }
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/uniter_on_movielens_only_image_feature.config b/samples/model_config/uniter_on_movielens_only_image_feature.config
new file mode 100644
index 000000000..45719ee62
--- /dev/null
+++ b/samples/model_config/uniter_on_movielens_only_image_feature.config
@@ -0,0 +1,139 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/cmbf_movielens_only_img_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'embedding'
+ feature_type: RawFeature
+ separator: '|'
+ raw_input_dim: 512
+ }
+}
+model_config: {
+ model_class: 'Uniter'
+ feature_groups: {
+ group_name: 'image'
+ feature_names: 'embedding'
+ wide_deep: DEEP
+ }
+ uniter {
+ config {
+ hidden_size: 512
+ num_attention_heads: 4
+ num_hidden_layers: 2
+ intermediate_size: 512
+ max_position_embeddings: 16
+ hidden_dropout_prob: 0.1
+ attention_probs_dropout_prob: 0
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/uniter_on_movielens_only_text_feature.config b/samples/model_config/uniter_on_movielens_only_text_feature.config
new file mode 100644
index 000000000..6c4bd1016
--- /dev/null
+++ b/samples/model_config/uniter_on_movielens_only_text_feature.config
@@ -0,0 +1,222 @@
+train_input_path: "data/test/movielens_1m/ml_train_data"
+eval_input_path: "data/test/movielens_1m/ml_test_data"
+model_dir: "experiments/cmbf_movielens_only_txt_ckpt"
+
+train_config {
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ constant_learning_rate {
+ learning_rate: 0.0001
+ }
+ }
+ beta1: 0.9
+ beta2: 0.999
+ }
+ use_moving_average: false
+ }
+ log_step_count_steps: 100
+ save_checkpoints_steps: 100
+ sync_replicas: true
+ num_steps: 10
+}
+
+eval_config {
+ metrics_set: {
+ gauc {
+ uid_field: 'user_id'
+ }
+ }
+ metrics_set: {
+ auc {}
+ }
+ metrics_set: {
+ max_f1 {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'rating'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'label'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'user_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'movie_id'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'gender'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'age'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'zip_id'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'genres'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'title'
+ input_type: STRING
+ default_val: 'unknown'
+ }
+ input_fields {
+ input_name: 'movie_year_bin'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'score_year_diff'
+ input_type: INT32
+ default_val: '0'
+ }
+ input_fields {
+ input_name: 'score_time'
+ input_type: DOUBLE
+ }
+ input_fields {
+ input_name: 'embedding'
+ input_type: STRING
+ default_val: ''
+ }
+
+ label_fields: 'label'
+ batch_size: 128
+ num_epochs: 10000
+ prefetch_size: 1
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 12000
+ }
+ features: {
+ input_names: 'movie_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 6000
+ }
+ features: {
+ input_names: 'gender'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 2
+ }
+ features: {
+ input_names: 'zip_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 3405
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 21
+ }
+ features: {
+ input_names: 'age'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 7
+ }
+ features: {
+ input_names: 'genres'
+ feature_type: SequenceFeature
+ separator: '|'
+ embedding_dim: 16
+ max_seq_len: 8
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'title'
+ feature_type: SequenceFeature
+ separator: ' '
+ max_seq_len: 16
+ embedding_dim: 16
+ hash_bucket_size: 20000
+ }
+ features: {
+ input_names: 'movie_year_bin'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 36
+ }
+ features: {
+ input_names: 'score_year_diff'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 83
+ }
+ features: {
+ input_names: 'score_time'
+ feature_type: RawFeature
+ embedding_dim: 16
+ }
+}
+model_config: {
+ model_class: 'Uniter'
+ feature_groups: {
+ group_name: 'general'
+ feature_names: 'user_id'
+ feature_names: 'movie_id'
+ feature_names: 'gender'
+ feature_names: 'age'
+ feature_names: 'occupation'
+ feature_names: 'zip_id'
+ feature_names: 'movie_year_bin'
+ feature_names: 'score_year_diff'
+ feature_names: 'score_time'
+ wide_deep: DEEP
+ }
+ feature_groups: {
+ group_name: 'text'
+ feature_names: 'title'
+ feature_names: 'genres'
+ wide_deep: DEEP
+ }
+ uniter {
+ config {
+ hidden_size: 512
+ num_attention_heads: 4
+ num_hidden_layers: 2
+ intermediate_size: 512
+ hidden_act: 'tanh'
+ max_position_embeddings: 16
+ hidden_dropout_prob: 0.1
+ attention_probs_dropout_prob: 0
+ }
+ final_dnn: {
+ hidden_units: 256
+ hidden_units: 64
+ }
+ }
+ embedding_regularization: 1e-6
+}
+export_config {
+ exporter_type: "best"
+ best_exporter_metric: "gauc"
+ exports_to_keep: 1
+}
diff --git a/samples/model_config/wide_and_deep_backbone_on_avazau.config b/samples/model_config/wide_and_deep_backbone_on_avazau.config
new file mode 100755
index 000000000..59de34076
--- /dev/null
+++ b/samples/model_config/wide_and_deep_backbone_on_avazau.config
@@ -0,0 +1,391 @@
+train_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+eval_input_path: "data/test/dwd_avazu_ctr_deepmodel_10w.csv"
+model_dir: "experiments/wide_and_deep_backbone_on_avazu"
+
+train_config {
+ log_step_count_steps: 200
+ # fine_tune_checkpoint: ""
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.0001
+ decay_steps: 10000
+ decay_factor: 0.5
+ min_learning_rate: 0.0000001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+
+ sync_replicas: true
+ save_checkpoints_steps: 500
+ num_steps: 100
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ separator: ","
+ input_fields: {
+ input_name: "label"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "hour"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c1"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "banner_pos"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "site_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_domain"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "app_category"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_id"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_ip"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_model"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "device_conn_type"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c14"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c15"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c16"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c17"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c18"
+ input_type: STRING
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c19"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c20"
+ input_type: INT64
+ default_val:"0"
+ }
+ input_fields: {
+ input_name: "c21"
+ input_type: INT64
+ default_val:"0"
+ }
+ label_fields: "label"
+
+ batch_size: 1024
+ num_epochs: 10000
+ prefetch_size: 32
+ input_type: CSVInput
+}
+
+feature_config: {
+ features: {
+ input_names: "hour"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c1"
+ feature_type: RawFeature
+ boundaries: [1000.0,1001.0,1002.0,1003.0,1004.0,1005.0,1006.0,1007.0,1008.0,1009.0,1010.0,1011.0,1012.0,1013.0,1014.0,1015.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "banner_pos"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "site_id"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "site_domain"
+ feature_type: IdFeature
+ embedding_dim: 20
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "site_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "app_id"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "app_domain"
+ feature_type: IdFeature
+ embedding_dim: 20
+ hash_bucket_size: 1000
+ }
+ features: {
+ input_names: "app_category"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: "device_id"
+ feature_type: IdFeature
+ embedding_dim: 64
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_ip"
+ feature_type: IdFeature
+ embedding_dim: 64
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: "device_model"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: "device_type"
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "device_conn_type"
+ feature_type: IdFeature
+ embedding_dim: 32
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: "c14"
+ feature_type: IdFeature
+ embedding_dim: 20
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c15"
+ feature_type: IdFeature
+ embedding_dim: 20
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c16"
+ feature_type: IdFeature
+ embedding_dim: 20
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c17"
+ feature_type: IdFeature
+ embedding_dim: 20
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c18"
+ feature_type: IdFeature
+ embedding_dim: 20
+ hash_bucket_size: 500
+ }
+ features: {
+ input_names: "c19"
+ feature_type: RawFeature
+ boundaries: [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c20"
+ feature_type: RawFeature
+ boundaries: [100.0,200.0,300.0,400.0,500.0,600.0,700.0,800.0, 900.0, 1000.0,1100.0,1200.0, 1300.0,1400.0]
+ embedding_dim: 16
+ }
+ features: {
+ input_names: "c21"
+ feature_type: RawFeature
+ boundaries: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
+ embedding_dim: 16
+ }
+}
+model_config {
+ model_class: "RankModel"
+ feature_groups: {
+ group_name: "deep"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:DEEP
+ }
+ feature_groups: {
+ group_name: "wide"
+ feature_names: "hour"
+ feature_names: "c1"
+ feature_names: "banner_pos"
+ feature_names: "site_id"
+ feature_names: "site_domain"
+ feature_names: "site_category"
+ feature_names: "app_id"
+ feature_names: "app_domain"
+ feature_names: "app_category"
+ feature_names: "device_id"
+ feature_names: "device_ip"
+ feature_names: "device_model"
+ feature_names: "device_type"
+ feature_names: "device_conn_type"
+ feature_names: "c14"
+ feature_names: "c15"
+ feature_names: "c16"
+ feature_names: "c17"
+ feature_names: "c18"
+ feature_names: "c19"
+ feature_names: "c20"
+ feature_names: "c21"
+ wide_deep:WIDE
+ }
+ backbone {
+ blocks {
+ name: 'wide'
+ inputs {
+ feature_group_name: 'wide'
+ }
+ input_layer {
+ only_output_feature_list: true
+ wide_output_dim: 1
+ }
+ }
+ blocks {
+ name: 'deep_logit'
+ inputs {
+ feature_group_name: 'deep'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [256, 256, 256, 1]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+ blocks {
+ name: 'final_logit'
+ inputs {
+ block_name: 'wide'
+ input_fn: 'lambda x: tf.add_n(x)'
+ }
+ inputs {
+ block_name: 'deep_logit'
+ }
+ merge_inputs_into_list: true
+ keras_layer {
+ class_name: 'Add'
+ }
+ }
+ concat_blocks: 'final_logit'
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-7
+}
diff --git a/samples/model_config/wide_and_deep_two_opti.config b/samples/model_config/wide_and_deep_two_opti.config
index 283a6d0c8..fa4ac3b01 100755
--- a/samples/model_config/wide_and_deep_two_opti.config
+++ b/samples/model_config/wide_and_deep_two_opti.config
@@ -10,30 +10,22 @@ train_config {
ftrl_optimizer: {
l1_reg: 10
learning_rate: {
- exponential_decay_learning_rate {
- initial_learning_rate: 0.0001
- decay_steps: 10000
- decay_factor: 0.5
- min_learning_rate: 0.0000001
+ constant_learning_rate {
+ learning_rate: 0.0005
}
}
}
- use_moving_average: false
}
optimizer_config: {
adam_optimizer: {
learning_rate: {
- exponential_decay_learning_rate {
- initial_learning_rate: 0.0001
- decay_steps: 10000
- decay_factor: 0.5
- min_learning_rate: 0.0000001
+ constant_learning_rate {
+ learning_rate: 0.0001
}
}
}
- use_moving_average: false
}
diff --git a/samples/model_config/xdeepfm_on_taobao_backbone.config b/samples/model_config/xdeepfm_on_taobao_backbone.config
new file mode 100644
index 000000000..06a97b3af
--- /dev/null
+++ b/samples/model_config/xdeepfm_on_taobao_backbone.config
@@ -0,0 +1,359 @@
+train_input_path: "data/test/tb_data/taobao_train_data"
+eval_input_path: "data/test/tb_data/taobao_test_data"
+model_dir: "experiments/xdeepfm_taobao_ckpt"
+
+train_config {
+ log_step_count_steps: 200
+ optimizer_config: {
+ adam_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.001
+ decay_steps: 5000
+ decay_factor: 0.5
+ min_learning_rate: 0.00001
+ }
+ }
+ }
+ use_moving_average: false
+ }
+ save_checkpoints_steps: 1000
+ sync_replicas: false
+}
+
+eval_config {
+ metrics_set: {
+ auc {}
+ }
+}
+
+data_config {
+ input_fields {
+ input_name:'clk'
+ input_type: INT32
+ }
+ input_fields {
+ input_name:'buy'
+ input_type: INT32
+ }
+ input_fields {
+ input_name: 'pid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'adgroup_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cate_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'campaign_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'customer'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'brand'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'user_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_segid'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'cms_group_id'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'final_gender_code'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'age_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'pvalue_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'shopping_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'occupation'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'new_user_class_level'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_category_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'tag_brand_list'
+ input_type: STRING
+ }
+ input_fields {
+ input_name: 'price'
+ input_type: INT32
+ }
+
+ label_fields: 'clk'
+ batch_size: 256
+ num_epochs: 100000
+ prefetch_size: 32
+ input_type: CSVInput
+ shuffle_buffer_size: 25600
+ shuffle: true
+}
+
+feature_config: {
+ features: {
+ input_names: 'pid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'adgroup_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cate_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10000
+ }
+ features: {
+ input_names: 'campaign_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'customer'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'brand'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'user_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100000
+ }
+ features: {
+ input_names: 'cms_segid'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'cms_group_id'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 100
+ }
+ features: {
+ input_names: 'final_gender_code'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'age_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'pvalue_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'shopping_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'occupation'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'new_user_class_level'
+ feature_type: IdFeature
+ embedding_dim: 16
+ hash_bucket_size: 10
+ }
+ features: {
+ input_names: 'tag_category_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'tag_brand_list'
+ feature_type: TagFeature
+ separator: '|'
+ hash_bucket_size: 100000
+ embedding_dim: 16
+ }
+ features: {
+ input_names: 'price'
+ feature_type: IdFeature
+ embedding_dim: 16
+ num_buckets: 50
+ }
+}
+
+
+model_config: {
+ model_name: 'xDeepFM'
+ model_class: 'RankModel'
+ feature_groups: {
+ group_name: 'features'
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep:DEEP
+ }
+
+ feature_groups: {
+ group_name: "wide"
+ feature_names: 'user_id'
+ feature_names: 'cms_segid'
+ feature_names: 'cms_group_id'
+ feature_names: 'age_level'
+ feature_names: 'pvalue_level'
+ feature_names: 'shopping_level'
+ feature_names: 'occupation'
+ feature_names: 'new_user_class_level'
+ feature_names: 'tag_category_list'
+ feature_names: 'tag_brand_list'
+ feature_names: 'adgroup_id'
+ feature_names: 'cate_id'
+ feature_names: 'campaign_id'
+ feature_names: 'customer'
+ feature_names: 'brand'
+ feature_names: 'price'
+ feature_names: 'pid'
+ wide_deep:WIDE
+ }
+ backbone {
+ blocks {
+ name: 'wide'
+ inputs {
+ feature_group_name: 'wide'
+ }
+ input_layer {
+ only_output_feature_list: true
+ wide_output_dim: 1
+ }
+ }
+
+ blocks {
+ name: 'features'
+ inputs {
+ feature_group_name: 'features'
+ }
+ input_layer {
+ output_2d_tensor_and_feature_list: true
+ }
+ }
+ blocks {
+ name: 'cin'
+ inputs {
+ block_name: 'features'
+ input_slice: '[1]'
+ }
+ extra_input_fn: 'lambda x: tf.stack(x, axis=1)'
+ keras_layer {
+ class_name: 'CIN'
+ cin {
+ hidden_feature_sizes: [64, 64, 64]
+ }
+ }
+ }
+
+ blocks {
+ name: 'dnn'
+ inputs {
+ block_name: 'features'
+ input_slice: '[0]'
+ }
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [128, 64]
+ }
+ }
+ }
+
+ blocks {
+ name: 'final_logit'
+ inputs {
+ block_name: 'wide'
+ input_fn: 'lambda x: tf.add_n(x)'
+ }
+ inputs {
+ block_name: 'cin'
+ }
+ inputs {
+ block_name: 'dnn'
+ }
+
+ keras_layer {
+ class_name: 'MLP'
+ mlp {
+ hidden_units: [32, 1]
+ use_final_bn: false
+ final_activation: 'linear'
+ }
+ }
+ }
+ concat_blocks: 'final_logit'
+ }
+ model_params {
+ l2_regularization: 1e-4
+ }
+ embedding_regularization: 1e-4
+}
diff --git a/samples/odps_script/boundary/create_external_boundary_table.sql b/samples/odps_script/boundary/create_external_boundary_table.sql
deleted file mode 100644
index 7c80741ea..000000000
--- a/samples/odps_script/boundary/create_external_boundary_table.sql
+++ /dev/null
@@ -1,73 +0,0 @@
-drop TABLE IF EXISTS external_boundary_test_{TIME_STAMP} ;
-create EXTERNAL table external_boundary_test_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price double
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/train/'
-;
-
-
-drop TABLE IF EXISTS external_boundary_train_{TIME_STAMP} ;
-create EXTERNAL table external_boundary_train_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price double
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/test/'
-;
-
-
-drop TABLE IF EXISTS external_boundary_info_table_{TIME_STAMP} ;
-create EXTERNAL table external_boundary_info_table_{TIME_STAMP}(
- feature STRING
- ,json STRING
-)
-STORED BY 'com.aliyun.odps.TsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/boundary/'
-;
diff --git a/samples/odps_script/boundary/create_inner_boundary_table.sql b/samples/odps_script/boundary/create_inner_boundary_table.sql
index c2b972ce8..58bd91214 100644
--- a/samples/odps_script/boundary/create_inner_boundary_table.sql
+++ b/samples/odps_script/boundary/create_inner_boundary_table.sql
@@ -23,8 +23,7 @@ create table boundary_test_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE boundary_test_{TIME_STAMP}
-select * from external_boundary_test_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/train_{TIME_STAMP} boundary_test_{TIME_STAMP};
drop TABLE IF EXISTS boundary_train_{TIME_STAMP} ;
@@ -52,8 +51,7 @@ create table boundary_train_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE boundary_train_{TIME_STAMP}
-select * from external_boundary_train_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/test_{TIME_STAMP} boundary_train_{TIME_STAMP};
drop TABLE IF EXISTS boundary_info_table_{TIME_STAMP} ;
@@ -63,5 +61,4 @@ create table boundary_info_table_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE boundary_info_table_{TIME_STAMP}
-select * from external_boundary_info_table_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/boundary_{TIME_STAMP} boundary_info_table_{TIME_STAMP} -fd='\t';
diff --git a/samples/odps_script/boundary/finetune_multi_tower_conti.sql b/samples/odps_script/boundary/finetune_multi_tower_conti.sql
new file mode 100644
index 000000000..3cc427611
--- /dev/null
+++ b/samples/odps_script/boundary/finetune_multi_tower_conti.sql
@@ -0,0 +1,15 @@
+pai -name easy_rec_ext
+-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/taobao_multi_tower_boundary_test.config
+-Dcmd=train
+-Dboundary_table=odps://{ODPS_PROJ_NAME}/tables/boundary_info_table_{TIME_STAMP}
+-Dmodel_dir="oss://{OSS_BUCKET_NAME}/easy_rec_odps_test/{EXP_NAME}/edit_boundary_test/finetune/"
+-Dfine_tune_checkpoint='oss://{OSS_BUCKET_NAME}/easy_rec_odps_test/{EXP_NAME}/edit_boundary_test/checkpoints/'
+-Dedit_config_json='{"train_config.num_steps": 500}'
+-Dtrain_tables=odps://{ODPS_PROJ_NAME}/tables/boundary_train_{TIME_STAMP}
+-Deval_tables=odps://{ODPS_PROJ_NAME}/tables/boundary_test_{TIME_STAMP}
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
+-Darn={ROLEARN}
+-Dbuckets=oss://{OSS_BUCKET_NAME}/
+-DossHost={OSS_ENDPOINT}
+-Deval_method=separate
+;
diff --git a/samples/odps_script/boundary/finetune_multi_tower_model.sql b/samples/odps_script/boundary/finetune_multi_tower_model.sql
new file mode 100644
index 000000000..3697e9e74
--- /dev/null
+++ b/samples/odps_script/boundary/finetune_multi_tower_model.sql
@@ -0,0 +1,15 @@
+pai -name easy_rec_ext
+-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/taobao_multi_tower_boundary_test.config
+-Dcmd=train
+-Dboundary_table=odps://{ODPS_PROJ_NAME}/tables/boundary_info_table_{TIME_STAMP}
+-Dmodel_dir="oss://{OSS_BUCKET_NAME}/easy_rec_odps_test/{EXP_NAME}/edit_boundary_test/finetune/"
+-Dfine_tune_checkpoint='oss://{OSS_BUCKET_NAME}/easy_rec_odps_test/{EXP_NAME}/edit_boundary_test/checkpoints/'
+-Dedit_config_json='{"train_config.num_steps": 200}'
+-Dtrain_tables=odps://{ODPS_PROJ_NAME}/tables/boundary_train_{TIME_STAMP}
+-Deval_tables=odps://{ODPS_PROJ_NAME}/tables/boundary_test_{TIME_STAMP}
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
+-Darn={ROLEARN}
+-Dbuckets=oss://{OSS_BUCKET_NAME}/
+-DossHost={OSS_ENDPOINT}
+-Deval_method=separate
+;
diff --git a/samples/odps_script/boundary/train_compat.sql b/samples/odps_script/boundary/train_compat.sql
index 9904e0d9d..aaf90cc64 100644
--- a/samples/odps_script/boundary/train_compat.sql
+++ b/samples/odps_script/boundary/train_compat.sql
@@ -3,7 +3,7 @@ pai -name easy_rec_ext
-Dcmd=train
-Dboundary_table=odps://{ODPS_PROJ_NAME}/tables/boundary_info_table_{TIME_STAMP}
-Dtables=odps://{ODPS_PROJ_NAME}/tables/boundary_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/boundary_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/boundary/train_multi_tower_model.sql b/samples/odps_script/boundary/train_multi_tower_model.sql
index 070ef249f..a9e3b8a33 100644
--- a/samples/odps_script/boundary/train_multi_tower_model.sql
+++ b/samples/odps_script/boundary/train_multi_tower_model.sql
@@ -4,7 +4,7 @@ pai -name easy_rec_ext
-Dboundary_table=odps://{ODPS_PROJ_NAME}/tables/boundary_info_table_{TIME_STAMP}
-Dtrain_tables=odps://{ODPS_PROJ_NAME}/tables/boundary_train_{TIME_STAMP}
-Deval_tables=odps://{ODPS_PROJ_NAME}/tables/boundary_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/configs/multi_tower_bst.config b/samples/odps_script/configs/multi_tower_bst.config
index 612afd65d..3df8ef94e 100644
--- a/samples/odps_script/configs/multi_tower_bst.config
+++ b/samples/odps_script/configs/multi_tower_bst.config
@@ -1,7 +1,7 @@
train_input_path: ""
eval_input_path: ""
-model_dir: "oss://{OSS_BUCKET_NAME}/{EXP_NAME}/multil_tower_bst"
+model_dir: "oss://{OSS_BUCKET_NAME}/{EXP_NAME}/multi_tower_bst"
train_config {
log_step_count_steps: 20
diff --git a/samples/odps_script/configs/multi_tower_din.config b/samples/odps_script/configs/multi_tower_din.config
index 24d1c5020..081076cb7 100644
--- a/samples/odps_script/configs/multi_tower_din.config
+++ b/samples/odps_script/configs/multi_tower_din.config
@@ -1,6 +1,6 @@
train_input_path: ""
eval_input_path: ""
-model_dir: "oss://{OSS_BUCKET_NAME}/{EXP_NAME}/multil_tower_din/"
+model_dir: "oss://{OSS_BUCKET_NAME}/{EXP_NAME}/multi_tower_din/"
train_config {
log_step_count_steps: 20
diff --git a/samples/odps_script/deep_fm/create_external_deepfm_table.sql b/samples/odps_script/deep_fm/create_external_deepfm_table.sql
deleted file mode 100644
index 031af15dd..000000000
--- a/samples/odps_script/deep_fm/create_external_deepfm_table.sql
+++ /dev/null
@@ -1,65 +0,0 @@
-drop TABLE IF EXISTS external_deepfm_train_{TIME_STAMP} ;
-create EXTERNAL table external_deepfm_train_{TIME_STAMP}(
- label BIGINT
- ,`hour` string
- ,c1 STRING
- ,banner_pos STRING
- ,site_id STRING
- ,site_domain STRING
- ,site_category STRING
- ,app_id STRING
- ,app_domain STRING
- ,app_category STRING
- ,device_id STRING
- ,device_ip STRING
- ,device_model STRING
- ,device_type STRING
- ,device_conn_type STRING
- ,c14 STRING
- ,c15 STRING
- ,c16 STRING
- ,c17 STRING
- ,c18 STRING
- ,c19 STRING
- ,c20 STRING
- ,c21 STRING
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_ENDPOINT_INTERNAL}/{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/train/'
-;
-
-drop TABLE IF EXISTS external_deepfm_test_{TIME_STAMP};
-create EXTERNAL table external_deepfm_test_{TIME_STAMP}(
- label BIGINT
- ,`hour` string
- ,c1 STRING
- ,banner_pos STRING
- ,site_id STRING
- ,site_domain STRING
- ,site_category STRING
- ,app_id STRING
- ,app_domain STRING
- ,app_category STRING
- ,device_id STRING
- ,device_ip STRING
- ,device_model STRING
- ,device_type STRING
- ,device_conn_type STRING
- ,c14 STRING
- ,c15 STRING
- ,c16 STRING
- ,c17 STRING
- ,c18 STRING
- ,c19 STRING
- ,c20 STRING
- ,c21 STRING
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_ENDPOINT_INTERNAL}/{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/test/'
-;
diff --git a/samples/odps_script/deep_fm/create_inner_deepfm_table.sql b/samples/odps_script/deep_fm/create_inner_deepfm_table.sql
index bf70f0ac1..1b7955013 100644
--- a/samples/odps_script/deep_fm/create_inner_deepfm_table.sql
+++ b/samples/odps_script/deep_fm/create_inner_deepfm_table.sql
@@ -26,11 +26,7 @@ create table deepfm_train_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE deepfm_train_{TIME_STAMP}
-select * from external_deepfm_train_{TIME_STAMP} ;
-
-desc deepfm_train_{TIME_STAMP};
-desc external_deepfm_train_{TIME_STAMP};
+tunnel upload {TEST_DATA_DIR}/train_{TIME_STAMP} deepfm_train_{TIME_STAMP};
drop TABLE IF EXISTS deepfm_test_{TIME_STAMP};
create table deepfm_test_{TIME_STAMP}(
@@ -60,7 +56,4 @@ create table deepfm_test_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE deepfm_test_{TIME_STAMP}
-select * from external_deepfm_test_{TIME_STAMP};
-desc deepfm_test_{TIME_STAMP};
-desc external_deepfm_test_{TIME_STAMP};
+tunnel upload {TEST_DATA_DIR}/test_{TIME_STAMP} deepfm_test_{TIME_STAMP};
diff --git a/samples/odps_script/deep_fm/eval_deepfm.sql b/samples/odps_script/deep_fm/eval_deepfm.sql
index 502c3ae72..fd1395b9d 100644
--- a/samples/odps_script/deep_fm/eval_deepfm.sql
+++ b/samples/odps_script/deep_fm/eval_deepfm.sql
@@ -2,7 +2,7 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dwd_avazu_ctr_deepmodel_ext.config
-Dcmd=evaluate
-Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
--Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/deep_fm/export_rtp_ckpt.sql b/samples/odps_script/deep_fm/export_rtp_ckpt.sql
new file mode 100644
index 000000000..9b1669e0c
--- /dev/null
+++ b/samples/odps_script/deep_fm/export_rtp_ckpt.sql
@@ -0,0 +1,8 @@
+pai -name easy_rec_ext
+-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dwd_avazu_ctr_deepmodel_ext.config
+-Dcmd=export_checkpoint
+-Darn={ROLEARN}
+-Dbuckets=oss://{OSS_BUCKET_NAME}/
+-DossHost={OSS_ENDPOINT}
+-Dexport_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/dwd_avazu_ctr/export_model/savemodel/
+-Dbatch_size=256;
diff --git a/samples/odps_script/deep_fm/predict_deepfm.sql b/samples/odps_script/deep_fm/predict_deepfm.sql
index 62f6f4951..ced078b92 100644
--- a/samples/odps_script/deep_fm/predict_deepfm.sql
+++ b/samples/odps_script/deep_fm/predict_deepfm.sql
@@ -1,7 +1,7 @@
drop table if exists ctr_test_output_v1;
pai -name easy_rec_ext
-Dcmd=predict
--Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000, "gpu":100}}'
+-Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-Dsaved_model_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/dwd_avazu_ctr/checkpoints1/savemodel/
diff --git a/samples/odps_script/deep_fm/train_deepfm_model.sql b/samples/odps_script/deep_fm/train_deepfm_model.sql
index 882dc6cff..df0f113a0 100644
--- a/samples/odps_script/deep_fm/train_deepfm_model.sql
+++ b/samples/odps_script/deep_fm/train_deepfm_model.sql
@@ -2,7 +2,7 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dwd_avazu_ctr_deepmodel_ext.config
-Dcmd=train
-Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/dssm/create_external_dssm_table.sql b/samples/odps_script/dssm/create_external_dssm_table.sql
deleted file mode 100644
index b1e488a78..000000000
--- a/samples/odps_script/dssm/create_external_dssm_table.sql
+++ /dev/null
@@ -1,60 +0,0 @@
-drop TABLE IF EXISTS external_dssm_test_{TIME_STAMP} ;
-create EXTERNAL table external_dssm_test_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/train/'
-;
-
-
-drop TABLE IF EXISTS external_dssm_train_{TIME_STAMP} ;
-create EXTERNAL table external_dssm_train_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/test/'
-;
diff --git a/samples/odps_script/dssm/create_inner_dssm_table.sql b/samples/odps_script/dssm/create_inner_dssm_table.sql
index 964deeafc..6cb258b82 100644
--- a/samples/odps_script/dssm/create_inner_dssm_table.sql
+++ b/samples/odps_script/dssm/create_inner_dssm_table.sql
@@ -23,8 +23,7 @@ create table dssm_test_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE dssm_test_{TIME_STAMP}
-select * from external_dssm_test_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/test_{TIME_STAMP} dssm_test_{TIME_STAMP};
drop TABLE IF EXISTS dssm_train_{TIME_STAMP} ;
@@ -52,5 +51,4 @@ create table dssm_train_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE dssm_train_{TIME_STAMP}
-select * from external_dssm_train_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/train_{TIME_STAMP} dssm_train_{TIME_STAMP};
diff --git a/samples/odps_script/dssm/eval_dssm.sql b/samples/odps_script/dssm/eval_dssm.sql
index 34a21436b..0678329ce 100644
--- a/samples/odps_script/dssm/eval_dssm.sql
+++ b/samples/odps_script/dssm/eval_dssm.sql
@@ -2,7 +2,7 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dssm_demo.config
-Dcmd=evaluate
-Dtables=odps://{ODPS_PROJ_NAME}/tables/dssm_test_{TIME_STAMP}
--Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/dssm/predict_dssm.sql b/samples/odps_script/dssm/predict_dssm.sql
index 960103fb9..707877981 100644
--- a/samples/odps_script/dssm/predict_dssm.sql
+++ b/samples/odps_script/dssm/predict_dssm.sql
@@ -1,7 +1,7 @@
drop table if exists dssm_test_output_v1_{TIME_STAMP};
pai -name easy_rec_ext
-Dcmd=predict
--Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000, "gpu":100}}'
+-Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-Dsaved_model_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/dssm/savemodel/
diff --git a/samples/odps_script/dssm/train_dssm_model.sql b/samples/odps_script/dssm/train_dssm_model.sql
index 5e74e17b0..40a265bcd 100644
--- a/samples/odps_script/dssm/train_dssm_model.sql
+++ b/samples/odps_script/dssm/train_dssm_model.sql
@@ -2,7 +2,7 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dssm_demo.config
-Dcmd=train
-Dtables=odps://{ODPS_PROJ_NAME}/tables/dssm_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/dssm_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/embedding_variable/create_table.sql b/samples/odps_script/embedding_variable/create_table.sql
index d03ac8f6f..2cb157dfd 100644
--- a/samples/odps_script/embedding_variable/create_table.sql
+++ b/samples/odps_script/embedding_variable/create_table.sql
@@ -1,36 +1,19 @@
-drop TABLE IF EXISTS external_ev_test_{TIME_STAMP} ;
-create EXTERNAL table external_ev_test_{TIME_STAMP}(
+drop TABLE IF EXISTS inner_ev_test_{TIME_STAMP};
+create table inner_ev_test_{TIME_STAMP}(
clk bigint
,user_id string
,item_id string
,features string
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}',
- 'odps.text.option.delimiter'=';'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/fg_data/test/'
-;
+);
+tunnel upload {TEST_DATA_DIR}/fg_data/test_${TIME_STAMP} inner_ev_test_{TIME_STAMP} -fd=';';
-drop TABLE IF EXISTS external_ev_train_{TIME_STAMP} ;
-create EXTERNAL table external_ev_train_{TIME_STAMP}(
+drop TABLE IF EXISTS inner_ev_train_{TIME_STAMP};
+create table inner_ev_train_{TIME_STAMP}(
clk bigint
,user_id string
,item_id string
,features string
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}',
- 'odps.text.option.delimiter'=';'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/fg_data/train/'
-;
+);
-drop table if exists inner_ev_test_{TIME_STAMP};
-create table inner_ev_test_{TIME_STAMP} as select * from external_ev_test_{TIME_STAMP};
-
-drop table if exists inner_ev_train_{TIME_STAMP};
-create table inner_ev_train_{TIME_STAMP} as select * from external_ev_train_{TIME_STAMP};
+tunnel upload {TEST_DATA_DIR}/fg_data/train_${TIME_STAMP} inner_ev_train_{TIME_STAMP} -fd=';';
diff --git a/samples/odps_script/embedding_variable/train.sql b/samples/odps_script/embedding_variable/train.sql
index 230b39f1f..db62a8c6e 100644
--- a/samples/odps_script/embedding_variable/train.sql
+++ b/samples/odps_script/embedding_variable/train.sql
@@ -3,7 +3,7 @@ pai -name easy_rec_ext
-Dcmd=train
-Dtrain_tables=odps://{ODPS_PROJ_NAME}/tables/inner_ev_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/inner_ev_train_{TIME_STAMP}
-Deval_tables=odps://{ODPS_PROJ_NAME}/tables/inner_ev_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/embedding_variable/train_work_que.sql b/samples/odps_script/embedding_variable/train_work_que.sql
index ae33a33b0..c49b8c6aa 100644
--- a/samples/odps_script/embedding_variable/train_work_que.sql
+++ b/samples/odps_script/embedding_variable/train_work_que.sql
@@ -3,7 +3,7 @@ pai -name easy_rec_ext
-Dcmd=train
-Dtrain_tables=odps://{ODPS_PROJ_NAME}/tables/inner_ev_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/inner_ev_train_{TIME_STAMP}
-Deval_tables=odps://{ODPS_PROJ_NAME}/tables/inner_ev_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/mmoe/create_external_mmoe_table.sql b/samples/odps_script/mmoe/create_external_mmoe_table.sql
deleted file mode 100644
index f7ccc8fad..000000000
--- a/samples/odps_script/mmoe/create_external_mmoe_table.sql
+++ /dev/null
@@ -1,59 +0,0 @@
-drop TABLE IF EXISTS external_mmoe_test_{TIME_STAMP} ;
-create EXTERNAL table external_mmoe_test_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/test/'
-;
-
-drop TABLE IF EXISTS external_mmoe_train_{TIME_STAMP} ;
-create EXTERNAL table external_mmoe_train_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/train/'
-;
diff --git a/samples/odps_script/mmoe/create_inner_mmoe_table.sql b/samples/odps_script/mmoe/create_inner_mmoe_table.sql
index 02a67f2c7..299778023 100644
--- a/samples/odps_script/mmoe/create_inner_mmoe_table.sql
+++ b/samples/odps_script/mmoe/create_inner_mmoe_table.sql
@@ -23,10 +23,7 @@ create table mmoe_train_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE mmoe_train_{TIME_STAMP}
-select * from external_mmoe_train_{TIME_STAMP} ;
-
-
+tunnel upload {TEST_DATA_DIR}/tb_data/train_{TIME_STAMP} mmoe_train_{TIME_STAMP};
drop TABLE IF EXISTS mmoe_test_{TIME_STAMP};
create table mmoe_test_{TIME_STAMP}(
@@ -53,5 +50,4 @@ create table mmoe_test_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE mmoe_test_{TIME_STAMP}
-select * from external_mmoe_test_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/test_{TIME_STAMP} mmoe_test_{TIME_STAMP};
diff --git a/samples/odps_script/mmoe/eval_mmoe.sql b/samples/odps_script/mmoe/eval_mmoe.sql
index 5afb24ea0..30cd47d28 100644
--- a/samples/odps_script/mmoe/eval_mmoe.sql
+++ b/samples/odps_script/mmoe/eval_mmoe.sql
@@ -2,7 +2,7 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/mmoe_demo.config
-Dcmd=evaluate
-Dtables=odps://{ODPS_PROJ_NAME}/tables/mmoe_test_{TIME_STAMP}
--Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/mmoe/predict_mmoe.sql b/samples/odps_script/mmoe/predict_mmoe.sql
index b971fbf8e..6f19681d2 100644
--- a/samples/odps_script/mmoe/predict_mmoe.sql
+++ b/samples/odps_script/mmoe/predict_mmoe.sql
@@ -1,7 +1,7 @@
drop table if exists mmoe_test_output_v1_{TIME_STAMP};
pai -name easy_rec_ext
-Dcmd=predict
--Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000, "gpu":100}}'
+-Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-Dsaved_model_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/mmoe/savemodel/
diff --git a/samples/odps_script/mmoe/train_mmoe_model.sql b/samples/odps_script/mmoe/train_mmoe_model.sql
index a0415e081..abd065300 100644
--- a/samples/odps_script/mmoe/train_mmoe_model.sql
+++ b/samples/odps_script/mmoe/train_mmoe_model.sql
@@ -2,7 +2,7 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/mmoe_demo.config
-Dcmd=train
-Dtables=odps://{ODPS_PROJ_NAME}/tables/mmoe_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/mmoe_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/multi_tower/create_external_multi_tower_table.sql b/samples/odps_script/multi_tower/create_external_multi_tower_table.sql
deleted file mode 100644
index b5ed9a92e..000000000
--- a/samples/odps_script/multi_tower/create_external_multi_tower_table.sql
+++ /dev/null
@@ -1,59 +0,0 @@
-drop TABLE if EXISTS external_multil_tower_train_{TIME_STAMP} ;
-create EXTERNAL table external_multil_tower_train_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/train/'
-;
-
-drop TABLE if EXISTS external_multil_tower_test_{TIME_STAMP} ;
-create EXTERNAL table external_multil_tower_test_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/test/'
-;
diff --git a/samples/odps_script/multi_tower/create_inner_multil_tower_table.sql b/samples/odps_script/multi_tower/create_inner_multi_tower_table.sql
similarity index 69%
rename from samples/odps_script/multi_tower/create_inner_multil_tower_table.sql
rename to samples/odps_script/multi_tower/create_inner_multi_tower_table.sql
index 73b029874..b31828765 100644
--- a/samples/odps_script/multi_tower/create_inner_multil_tower_table.sql
+++ b/samples/odps_script/multi_tower/create_inner_multi_tower_table.sql
@@ -1,5 +1,5 @@
-drop TABLE IF EXISTS multil_tower_test_{TIME_STAMP} ;
-create table multil_tower_test_{TIME_STAMP}(
+drop TABLE IF EXISTS multi_tower_test_{TIME_STAMP} ;
+create table multi_tower_test_{TIME_STAMP}(
clk bigint
,buy bigint
,pid string
@@ -23,12 +23,10 @@ create table multil_tower_test_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE multil_tower_test_{TIME_STAMP}
-select * from external_multil_tower_test_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/test_{TIME_STAMP} multi_tower_test_{TIME_STAMP};
-
-drop TABLE IF EXISTS multil_tower_train_{TIME_STAMP} ;
-create table multil_tower_train_{TIME_STAMP}(
+drop TABLE IF EXISTS multi_tower_train_{TIME_STAMP} ;
+create table multi_tower_train_{TIME_STAMP}(
clk bigint
,buy bigint
,pid string
@@ -52,5 +50,4 @@ create table multil_tower_train_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE multil_tower_train_{TIME_STAMP}
-select * from external_multil_tower_train_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/train_{TIME_STAMP} multi_tower_train_{TIME_STAMP};
diff --git a/samples/odps_script/multi_tower/drop_multil_tower_table.sql b/samples/odps_script/multi_tower/drop_multil_tower_table.sql
index 08d2e8f39..a7ecdb562 100644
--- a/samples/odps_script/multi_tower/drop_multil_tower_table.sql
+++ b/samples/odps_script/multi_tower/drop_multil_tower_table.sql
@@ -1,5 +1,5 @@
-drop TABLE IF EXISTS external_multil_tower_test_{TIME_STAMP} ;
-drop TABLE IF EXISTS external_multil_tower_train_{TIME_STAMP} ;
-drop TABLE IF EXISTS multil_tower_test_{TIME_STAMP} ;
-drop TABLE IF EXISTS multil_tower_train_{TIME_STAMP} ;
-drop TABLE IF EXISTS multil_tower_test_output_v1_{TIME_STAMP} ;
+drop TABLE IF EXISTS external_multi_tower_test_{TIME_STAMP} ;
+drop TABLE IF EXISTS external_multi_tower_train_{TIME_STAMP} ;
+drop TABLE IF EXISTS multi_tower_test_{TIME_STAMP} ;
+drop TABLE IF EXISTS multi_tower_train_{TIME_STAMP} ;
+drop TABLE IF EXISTS multi_tower_test_output_v1_{TIME_STAMP} ;
diff --git a/samples/odps_script/multi_tower/eval_multil_tower.sql b/samples/odps_script/multi_tower/eval_multil_tower.sql
index 692f89791..c22fa2dae 100644
--- a/samples/odps_script/multi_tower/eval_multil_tower.sql
+++ b/samples/odps_script/multi_tower/eval_multil_tower.sql
@@ -1,8 +1,8 @@
pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/multi_tower_bst.config
-Dcmd=evaluate
--Dtables=odps://{ODPS_PROJ_NAME}/tables/multil_tower_test_{TIME_STAMP}
--Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dtables=odps://{ODPS_PROJ_NAME}/tables/multi_tower_test_{TIME_STAMP}
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/multi_tower/export_again_multi_tower.sql b/samples/odps_script/multi_tower/export_again_multi_tower.sql
new file mode 100644
index 000000000..f7149202b
--- /dev/null
+++ b/samples/odps_script/multi_tower/export_again_multi_tower.sql
@@ -0,0 +1,10 @@
+pai -name easy_rec_ext
+-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/multi_tower_bst.config
+-Dcmd=export
+-Dexport_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/multil_tower/savemodel/
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
+-Darn={ROLEARN}
+-Dbuckets=oss://{OSS_BUCKET_NAME}/
+-Dextra_params="--clear_export"
+-DossHost={OSS_ENDPOINT}
+;
diff --git a/samples/odps_script/multi_tower/predict_multil_tower.sql b/samples/odps_script/multi_tower/predict_multil_tower.sql
index e2c4b2746..7290a5e51 100644
--- a/samples/odps_script/multi_tower/predict_multil_tower.sql
+++ b/samples/odps_script/multi_tower/predict_multil_tower.sql
@@ -1,12 +1,12 @@
-drop table if exists multil_tower_test_output_v1_{TIME_STAMP};
+drop table if exists multi_tower_test_output_v1_{TIME_STAMP};
pai -name easy_rec_ext
-Dcmd=predict
--Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000, "gpu":100}}'
+-Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-Dsaved_model_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/multil_tower/savemodel/
--Dinput_table=odps://{ODPS_PROJ_NAME}/tables/multil_tower_test_{TIME_STAMP}
--Doutput_table=odps://{ODPS_PROJ_NAME}/tables/multil_tower_test_output_v1_{TIME_STAMP}
+-Dinput_table=odps://{ODPS_PROJ_NAME}/tables/multi_tower_test_{TIME_STAMP}
+-Doutput_table=odps://{ODPS_PROJ_NAME}/tables/multi_tower_test_output_v1_{TIME_STAMP}
-Dexcluded_cols=label
-Dreserved_cols=ALL_COLUMNS
-Dbatch_size=1024
diff --git a/samples/odps_script/multi_tower/train_multil_tower_bst_model.sql b/samples/odps_script/multi_tower/train_multil_tower_bst_model.sql
index b62a83bab..6825a03f9 100644
--- a/samples/odps_script/multi_tower/train_multil_tower_bst_model.sql
+++ b/samples/odps_script/multi_tower/train_multil_tower_bst_model.sql
@@ -1,8 +1,8 @@
pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/multi_tower_bst.config
-Dcmd=train
--Dtables=odps://{ODPS_PROJ_NAME}/tables/multil_tower_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/multil_tower_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dtables=odps://{ODPS_PROJ_NAME}/tables/multi_tower_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/multi_tower_test_{TIME_STAMP}
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/multi_tower/train_multil_tower_din_model.sql b/samples/odps_script/multi_tower/train_multil_tower_din_model.sql
index 722259131..1e4d0db3a 100644
--- a/samples/odps_script/multi_tower/train_multil_tower_din_model.sql
+++ b/samples/odps_script/multi_tower/train_multil_tower_din_model.sql
@@ -1,8 +1,8 @@
pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/multi_tower_din.config
-Dcmd=train
--Dtables=odps://{ODPS_PROJ_NAME}/tables/multil_tower_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/multil_tower_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dtables=odps://{ODPS_PROJ_NAME}/tables/multi_tower_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/multi_tower_test_{TIME_STAMP}
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/multi_value/create_external_multi_value_table.sql b/samples/odps_script/multi_value/create_external_multi_value_table.sql
deleted file mode 100644
index bf46174df..000000000
--- a/samples/odps_script/multi_value/create_external_multi_value_table.sql
+++ /dev/null
@@ -1,60 +0,0 @@
-drop TABLE IF EXISTS external_multi_value_test_{TIME_STAMP} ;
-create EXTERNAL table external_multi_value_test_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/train/'
-;
-
-
-drop TABLE IF EXISTS external_multi_value_train_{TIME_STAMP} ;
-create EXTERNAL table external_multi_value_train_{TIME_STAMP}(
- clk bigint
- ,buy bigint
- ,pid string
- ,adgroup_id string
- ,cate_id string
- ,campaign_id string
- ,customer string
- ,brand string
- ,user_id string
- ,cms_segid string
- ,cms_group_id string
- ,final_gender_code string
- ,age_level string
- ,pvalue_level string
- ,shopping_level string
- ,occupation string
- ,new_user_class_level string
- ,tag_category_list string
- ,tag_brand_list string
- ,price bigint
-)
-STORED BY 'com.aliyun.odps.CsvStorageHandler'
-WITH SERDEPROPERTIES (
- 'odps.properties.rolearn'='{ROLEARN}'
-)
-LOCATION 'oss://{OSS_BUCKET_NAME}/{EXP_NAME}/test_data/tb_data/test/'
-;
diff --git a/samples/odps_script/multi_value/create_inner_multi_value_table.sql b/samples/odps_script/multi_value/create_inner_multi_value_table.sql
index defd9bdc5..9165ec11f 100644
--- a/samples/odps_script/multi_value/create_inner_multi_value_table.sql
+++ b/samples/odps_script/multi_value/create_inner_multi_value_table.sql
@@ -23,8 +23,7 @@ create table multi_value_test_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE multi_value_test_{TIME_STAMP}
-select * from external_multi_value_test_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/test_{TIME_STAMP} multi_value_test_{TIME_STAMP};
drop TABLE IF EXISTS multi_value_train_{TIME_STAMP} ;
@@ -52,5 +51,4 @@ create table multi_value_train_{TIME_STAMP}(
)
;
-INSERT OVERWRITE TABLE multi_value_train_{TIME_STAMP}
-select * from external_multi_value_train_{TIME_STAMP} ;
+tunnel upload {TEST_DATA_DIR}/tb_data/train_{TIME_STAMP} multi_value_train_{TIME_STAMP};
diff --git a/samples/odps_script/other_test/test_eval_checkpoint_path.sql b/samples/odps_script/other_test/test_eval_checkpoint_path.sql
index b9857938c..22533bd7b 100644
--- a/samples/odps_script/other_test/test_eval_checkpoint_path.sql
+++ b/samples/odps_script/other_test/test_eval_checkpoint_path.sql
@@ -3,7 +3,7 @@ pai -name easy_rec_ext
-Dcmd=evaluate
-Dcheckpoint_path=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/dwd_avazu_ctr2/checkpoints5/model.ckpt-100
-Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
--Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}'
+-Dcluster='{"worker" : {"count":1, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/other_test/test_predict_selected_cols.sql b/samples/odps_script/other_test/test_predict_selected_cols.sql
index 33cf6f819..0c5407ebc 100644
--- a/samples/odps_script/other_test/test_predict_selected_cols.sql
+++ b/samples/odps_script/other_test/test_predict_selected_cols.sql
@@ -1,7 +1,7 @@
drop table if exists deepfm_output_v1_{TIME_STAMP};
pai -name easy_rec_ext
-Dcmd=predict
--Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000, "gpu":100}}'
+-Dcluster='{"worker" : {"count":2, "cpu":1000, "memory":20000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-Dsaved_model_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/dwd_avazu_ctr2/checkpoints5/savemodel/
diff --git a/samples/odps_script/other_test/test_train_before_export.sql b/samples/odps_script/other_test/test_train_before_export.sql
index e4dc0b37c..798e0d80f 100644
--- a/samples/odps_script/other_test/test_train_before_export.sql
+++ b/samples/odps_script/other_test/test_train_before_export.sql
@@ -1,10 +1,11 @@
pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dwd_avazu_ctr_deepmodel_ext_v5.config
-Dcmd=train
--Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dtrain_tables=odps://{ODPS_PROJ_NAME}/tables/deepfm_train_{TIME_STAMP}
+-Deval_tables=odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":5, "cpu":1000, "memory":40000}}'
+-Deval_method='separate'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
--Dversion=20201029
;
diff --git a/samples/odps_script/other_test/test_train_best_export.sql b/samples/odps_script/other_test/test_train_best_export.sql
index cc7b2057a..11bba2b76 100644
--- a/samples/odps_script/other_test/test_train_best_export.sql
+++ b/samples/odps_script/other_test/test_train_best_export.sql
@@ -2,11 +2,11 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dwd_avazu_ctr_deepmodel_ext_best_export.config
-Dcmd=train
-Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
-Dhpo_param_path=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/hpo_param.json
--Dwith_evaluator=1
+-Deval_method='separate'
-Dhpo_metric_save_path=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/dwd_avazu_ctr2/checkpoints4/hpo
;
diff --git a/samples/odps_script/other_test/test_train_distribute_strategy_ess.sql b/samples/odps_script/other_test/test_train_distribute_strategy_ess.sql
index 286eec1de..8f25d5e33 100644
--- a/samples/odps_script/other_test/test_train_distribute_strategy_ess.sql
+++ b/samples/odps_script/other_test/test_train_distribute_strategy_ess.sql
@@ -3,7 +3,7 @@ pai -name easy_rec_ext
-Dcmd=train
-Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
-Ddistribute_strategy=ess
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/other_test/test_train_hpo_with_evaluator.sql b/samples/odps_script/other_test/test_train_hpo_with_evaluator.sql
index 0d390acf4..808c77951 100644
--- a/samples/odps_script/other_test/test_train_hpo_with_evaluator.sql
+++ b/samples/odps_script/other_test/test_train_hpo_with_evaluator.sql
@@ -2,7 +2,7 @@ pai -name easy_rec_ext
-Dconfig=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/configs/dwd_avazu_ctr_deepmodel_ext_v4.config
-Dcmd=train
-Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/other_test/test_train_version.sql b/samples/odps_script/other_test/test_train_version.sql
index 91526dc35..c5293caca 100644
--- a/samples/odps_script/other_test/test_train_version.sql
+++ b/samples/odps_script/other_test/test_train_version.sql
@@ -3,7 +3,7 @@ pai -name easy_rec_ext
-Dcmd=train
-Dtables=odps://{ODPS_PROJ_NAME}/tables/deepfm_train_{TIME_STAMP},odps://{ODPS_PROJ_NAME}/tables/deepfm_test_{TIME_STAMP}
-Dmodel_dir=oss://{OSS_BUCKET_NAME}/{EXP_NAME}/dwd_avazu_ctr2/checkpoints_version/
--Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000,"gpu":100, "memory":40000}}'
+-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":2, "cpu":1000, "memory":40000}}'
-Darn={ROLEARN}
-Dbuckets=oss://{OSS_BUCKET_NAME}/
-DossHost={OSS_ENDPOINT}
diff --git a/samples/odps_script/vector_retrieve/drop_table.sql b/samples/odps_script/vector_retrieve/drop_table.sql
index 3550efc6b..7d7f03062 100644
--- a/samples/odps_script/vector_retrieve/drop_table.sql
+++ b/samples/odps_script/vector_retrieve/drop_table.sql
@@ -1,3 +1,3 @@
drop TABLE IF EXISTS query_vector_{TIME_STAMP};
drop TABLE IF EXISTS doc_vector_{TIME_STAMP};
-drop TABLE IF EXISTS result_vector_{TIME_STAMP};
\ No newline at end of file
+drop TABLE IF EXISTS result_vector_{TIME_STAMP};
diff --git a/samples/odps_script/vector_retrieve/run_vector_retrieve.sql b/samples/odps_script/vector_retrieve/run_vector_retrieve.sql
index 2314a3eea..2f4559c54 100644
--- a/samples/odps_script/vector_retrieve/run_vector_retrieve.sql
+++ b/samples/odps_script/vector_retrieve/run_vector_retrieve.sql
@@ -13,4 +13,4 @@ pai -name easy_rec_ext
-Dknn_feature_dims=4
-Dknn_index_type='ivfflat'
-Dknn_feature_delimiter=','
-;
\ No newline at end of file
+;
diff --git a/samples/rtp_fg/fg_ev.json b/samples/rtp_fg/fg_ev.json
new file mode 100644
index 000000000..af190d627
--- /dev/null
+++ b/samples/rtp_fg/fg_ev.json
@@ -0,0 +1,26 @@
+{
+ "features": [
+ {"expression": "user:user_id", "feature_name": "user_id", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100000, "embedding_dim": 16, "group":"user", "use_embedding_variable":true},
+ {"expression": "user:cms_segid", "feature_name": "cms_segid", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100, "embedding_dim": 16, "group":"user"},
+ {"expression": "user:cms_group_id", "feature_name": "cms_group_id", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100, "embedding_dim": 16, "group":"user"},
+ {"expression": "user:age_level", "feature_name": "age_level", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 10, "embedding_dim": 16, "group":"user"},
+ {"expression": "user:pvalue_level", "feature_name": "pvalue_level", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 10, "embedding_dim": 16, "group":"user"},
+ {"expression": "user:shopping_level", "feature_name": "shopping_level", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 10, "embedding_dim": 16, "group":"user"},
+ {"expression": "user:occupation", "feature_name": "occupation", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 10, "embedding_dim": 16, "group":"user"},
+ {"expression": "user:new_user_class_level", "feature_name": "new_user_class_level", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 10, "embedding_dim": 16, "group":"user"},
+ {"expression": "item:adgroup_id", "feature_name": "adgroup_id", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100000, "embedding_dim": 16, "group":"item"},
+ {"expression": "item:cate_id", "feature_name": "cate_id", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100000, "embedding_dim": 16, "group":"item"},
+ {"expression": "item:campaign_id", "feature_name": "campaign_id", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100000, "embedding_dim": 16, "group":"item"},
+ {"expression": "item:customer", "feature_name": "customer", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100000, "embedding_dim": 16, "group":"item"},
+ {"expression": "item:brand", "feature_name": "brand", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100000, "embedding_dim": 16, "group":"item"},
+ {"expression": "item:price", "feature_name": "price", "feature_type":"raw_feature", "value_type":"Integer", "combiner":"mean", "group":"item"},
+ {"expression": "item:pid", "feature_name": "pid", "feature_type":"id_feature", "value_type":"String", "combiner":"mean", "hash_bucket_size": 100000, "embedding_dim": 16, "group":"item"},
+ {"expression": "user:tag_category_list", "feature_name": "user_tag_cate", "feature_type":"id_feature", "hash_bucket_size":100000, "group":"user"},
+ {"map": "user:tag_brand_list", "key":"item:brand", "feature_name": "combo_brand", "feature_type":"lookup_feature", "needDiscrete":true, "hash_bucket_size":100000, "group":"combo"},
+ {"map": "user:tag_category_list", "key":"item:cate_id", "feature_name": "combo_cate_id", "feature_type":"lookup_feature", "needDiscrete":true, "hash_bucket_size":10000, "group":"combo"}
+ ],
+ "reserves": [
+ "user_id", "campaign_id", "clk"
+ ],
+ "multi_val_sep": "|"
+}
diff --git a/samples/rtp_fg/fg_test_extensions_final.config b/samples/rtp_fg/fg_test_extensions_final.config
index 58266998c..a3a4e3040 100644
--- a/samples/rtp_fg/fg_test_extensions_final.config
+++ b/samples/rtp_fg/fg_test_extensions_final.config
@@ -173,7 +173,8 @@ model_config {
export_config {
multi_placeholder: false
}
-fg_json_path: "samples/rtp_fg/fg_test_extensions.json"
+# ! means fg config is already loaded
+fg_json_path: "!samples/rtp_fg/fg_test_extensions.json"
feature_config {
features {
input_names: "user_id"
@@ -307,23 +308,23 @@ feature_config {
embedding_dim: 16
hash_bucket_size: 100000
separator: ""
- combiner: "mean"
+ combiner: "sum"
}
features {
input_names: "combo_brand"
- feature_type: IdFeature
+ feature_type: TagFeature
embedding_dim: 16
hash_bucket_size: 100000
separator: ""
- combiner: "mean"
+ combiner: "sum"
}
features {
input_names: "combo_cate_id"
- feature_type: IdFeature
+ feature_type: TagFeature
embedding_dim: 16
hash_bucket_size: 10000
separator: ""
- combiner: "mean"
+ combiner: "sum"
}
features {
input_names: "opt_content_long_seq_svid"
@@ -332,6 +333,7 @@ feature_config {
hash_bucket_size: 100000
separator: ""
combiner: "mean"
+ sub_feature_type: IdFeature
}
features {
input_names: "opt_content_long_seq_source_type"
@@ -340,6 +342,7 @@ feature_config {
hash_bucket_size: 100000
separator: ""
combiner: "mean"
+ sub_feature_type: IdFeature
}
}
diff --git a/scripts/build.sh b/scripts/build.sh
index abd0f30b5..c138bcca3 100755
--- a/scripts/build.sh
+++ b/scripts/build.sh
@@ -1,3 +1,5 @@
#!/bin/sh
-cd ../ && sh -x scripts/gen_proto.sh && python3.7 setup.py sdist bdist_wheel && cp package/dist/easy*.whl . && cd -
+sh -x scripts/gen_proto.sh
+python setup.py sdist bdist_wheel
+ls -lh dist/easy*.whl
diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh
new file mode 100644
index 000000000..16a80775a
--- /dev/null
+++ b/scripts/build_docker.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+bash scripts/gen_proto.sh
+if [ $? -ne 0 ]
+then
+ echo "gen proto failed"
+ exit 1
+fi
+
+version=`grep "__version__" easy_rec/version.py | awk '{ if($1 == "__version__") print $NF}'`
+# strip "'"
+version=${version//\'/}
+echo "EasyRec Version: $version"
+
+if [ -z "$version" ]
+then
+ echo "Failed to get EasyRec version"
+ exit 1
+fi
+
+sudo docker build --network=host . -f docker/Dockerfile -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-${version}
diff --git a/scripts/build_docker_tf112.sh b/scripts/build_docker_tf112.sh
new file mode 100644
index 000000000..345b1d5ea
--- /dev/null
+++ b/scripts/build_docker_tf112.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+bash scripts/gen_proto.sh
+if [ $? -ne 0 ]
+then
+ echo "gen proto failed"
+ exit 1
+fi
+
+version=`grep "__version__" easy_rec/version.py | awk '{ if($1 == "__version__") print $NF}'`
+# strip "'"
+version=${version//\'/}
+echo "EasyRec Version: $version"
+
+if [ -z "$version" ]
+then
+ echo "Failed to get EasyRec version"
+ exit 1
+fi
+
+sudo docker build --network=host . -f docker/Dockerfile_tf112 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py27-tf1.12-${version}
diff --git a/scripts/build_docker_tf115.sh b/scripts/build_docker_tf115.sh
new file mode 100644
index 000000000..e6ef8667b
--- /dev/null
+++ b/scripts/build_docker_tf115.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+bash scripts/gen_proto.sh
+if [ $? -ne 0 ]
+then
+ echo "gen proto failed"
+ exit 1
+fi
+
+version=`grep "__version__" easy_rec/version.py | awk '{ if($1 == "__version__") print $NF}'`
+# strip "'"
+version=${version//\'/}
+echo "EasyRec Version: $version"
+
+if [ -z "$version" ]
+then
+ echo "Failed to get EasyRec version"
+ exit 1
+fi
+
+sudo docker build --network=host . -f docker/Dockerfile_tf115 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-${version}
diff --git a/scripts/build_docker_tf210.sh b/scripts/build_docker_tf210.sh
new file mode 100644
index 000000000..33bc1a11d
--- /dev/null
+++ b/scripts/build_docker_tf210.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+bash scripts/gen_proto.sh
+if [ $? -ne 0 ]
+then
+ echo "gen proto failed"
+ exit 1
+fi
+
+version=`grep "__version__" easy_rec/version.py | awk '{ if($1 == "__version__") print $NF}'`
+# strip "'"
+version=${version//\'/}
+echo "EasyRec Version: $version"
+
+if [ -z "$version" ]
+then
+ echo "Failed to get EasyRec version"
+ exit 1
+fi
+
+sudo docker build --progress=plain --network=host . -f docker/Dockerfile_tf210 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.10-${version}
diff --git a/scripts/build_docker_tf212.sh b/scripts/build_docker_tf212.sh
new file mode 100644
index 000000000..e56bd1871
--- /dev/null
+++ b/scripts/build_docker_tf212.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+bash scripts/gen_proto.sh
+if [ $? -ne 0 ]
+then
+ echo "gen proto failed"
+ exit 1
+fi
+
+version=`grep "__version__" easy_rec/version.py | awk '{ if($1 == "__version__") print $NF}'`
+# strip "'"
+version=${version//\'/}
+echo "EasyRec Version: $version"
+
+if [ -z "$version" ]
+then
+ echo "Failed to get EasyRec version"
+ exit 1
+fi
+
+sudo docker build --network=host . -f docker/Dockerfile_tf212 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-${version}
diff --git a/scripts/build_docs.sh b/scripts/build_docs.sh
index 929df6044..7d880270c 100644
--- a/scripts/build_docs.sh
+++ b/scripts/build_docs.sh
@@ -15,3 +15,7 @@ cd docs
rm -rf build
make html
rm -rf build/html/_modules
+
+python post_fix.py build/html/search.html
+
+echo "view docs: python -m http.server --directory=docs/build/html/ 8081"
diff --git a/scripts/build_read_the_docs.sh b/scripts/build_read_the_docs.sh
index 650f4e519..4c6314b6d 100644
--- a/scripts/build_read_the_docs.sh
+++ b/scripts/build_read_the_docs.sh
@@ -7,7 +7,5 @@ bash scripts/gen_proto.sh
PATH=./protoc/bin/ protoc/bin/protoc --doc_out=html,proto.html:docs/source easy_rec/python/protos/*.proto
sed -i 's##
#g;s## #g' docs/source/proto.html
-pip3 install -r requirements/docs.txt
-pip install -r requirements/docs.txt
+pip3 install protobuf==3.20
pip3 install tensorflow==2.3
-pip install tensorflow==2.3
diff --git a/scripts/ci_test.sh b/scripts/ci_test.sh
index 1e6696353..75cda963e 100755
--- a/scripts/ci_test.sh
+++ b/scripts/ci_test.sh
@@ -1,5 +1,7 @@
#!/usr/bin/env bash
+echo "will test pull_request(number=$PULL_REQUEST_NUM)"
+
# pip install
pip install oss2
pip install -r requirements.txt
@@ -7,6 +9,24 @@ pip install -r requirements.txt
# update/generate proto
bash scripts/gen_proto.sh
+if [ -n "$PULL_REQUEST_NUM" ]
+then
+ # check updates
+ PYTHONPATH=. python scripts/ci_test_change_files.py --pull_request_num $PULL_REQUEST_NUM --exclude_dir docs
+ flag=$?
+ if [ $flag -eq 2 ]
+ then
+ echo "ci_test_passed=0" >> $GITHUB_OUTPUT
+ exit
+ fi
+ if [ $flag -ne 0 ]
+ then
+ # there are no code changes related to this test
+ echo "ci_test_passed=1" >> $GITHUB_OUTPUT
+ exit
+ fi
+fi
+
export CUDA_VISIBLE_DEVICES=""
export TEST_DEVICES=""
@@ -16,24 +36,12 @@ else
export TEST_DIR="/tmp/easy_rec_test_${USER}_`date +%s`"
fi
-export UnitTestSucceedFlag=EasyRecUnitSucceed
-
-PYTHONPATH=. python -m easy_rec.python.test.run --list_test_to_file UNIT_TEST_CASE_LIST
-
-for test_name in `cat UNIT_TEST_CASE_LIST`
-do
- rm -rf $UnitTestSucceedFlag
- # run test
- PYTHONPATH=. python -m easy_rec.python.test.run --pattern ${test_name}.*
- # for github
- if [ ! -e "$UnitTestSucceedFlag" ]
- then
- echo "::set-output name=ci_test_passed::0"
- exit
- fi
-done
+PYTHONPATH=. python -m easy_rec.python.test.run # --pattern export_test.*
# for github
-echo "::set-output name=ci_test_passed::1"
-rm -rf $UnitTestSucceedFlag
-rm -rf UNIT_TEST_CASE_LIST
+if [ $? -eq 0 ]
+then
+ echo "ci_test_passed=1" >> $GITHUB_OUTPUT
+else
+ echo "ci_test_passed=0" >> $GITHUB_OUTPUT
+fi
diff --git a/scripts/ci_test_change_files.py b/scripts/ci_test_change_files.py
new file mode 100644
index 000000000..0502f90c1
--- /dev/null
+++ b/scripts/ci_test_change_files.py
@@ -0,0 +1,41 @@
+# -*- encoding:utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import json
+import logging
+import sys
+
+try:
+ from easy_rec.python.utils.io_util import http_read
+except Exception as ex:
+ logging.error(ex)
+ sys.exit(2)
+
+logging.basicConfig(
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--pull_request_num', type=int, default=None, help='pull request number')
+ parser.add_argument(
+ '--exclude_dirs', nargs='*', type=str, help='the directory to be ignored')
+
+ args = parser.parse_args()
+
+ url = 'https://api.github.com/repos/alibaba/EasyRec/pulls/%d/files' % args.pull_request_num
+ pull_request_data = http_read(url)
+
+ changes = json.loads(pull_request_data)
+ change_dir = []
+ for obj in changes:
+ filename = obj['filename']
+ toks = filename.split('/')
+ if len(toks) > 0:
+ if toks[0] not in args.exclude_dirs:
+ change_dir.append(toks[0])
+
+ change_dir = list(set(change_dir))
+ logging.info('changed directories: %s' % ','.join(change_dir))
+ if len(change_dir) == 0:
+ sys.exit(1)
diff --git a/scripts/gen_proto.sh b/scripts/gen_proto.sh
index 263df44da..a6752b08b 100644
--- a/scripts/gen_proto.sh
+++ b/scripts/gen_proto.sh
@@ -1,20 +1,39 @@
#!/bin/bash
-ROOT_URL="http://easy-rec.oss-cn-hangzhou.aliyuncs.com/data/tools/"
+ROOT_URL="http://easyrec.oss-cn-beijing.aliyuncs.com/tools/"
+if [ -z "$TMPDIR" ]
+then
+ TMPDIR="/tmp"
+fi
+
+cache_file=$TMPDIR/protoc-3.4.0.tar.gz
if [[ ! -d protoc ]]; then
- if [[ "$(uname)" == "Darwin" ]]; then
- curl ${ROOT_URL}protoc-3.4.0-osx-x86_64.tar.gz -o protoc-3.4.0.tar.gz
- elif [[ "$(expr substr $(uname -s) 1 5)" == "Linux" ]]; then
- wget ${ROOT_URL}protoc-3.4.0-linux-x86_64.tar.gz -O protoc-3.4.0.tar.gz
+ if [ ! -e "$cache_file" ]
+ then
+ if [[ "$(uname)" == "Darwin" ]]; then
+ curl ${ROOT_URL}protoc-3.4.0-osx-x86_64.tar.gz -o $cache_file
+ flag=$?
+ elif [[ "$(expr substr $(uname -s) 1 5)" == "Linux" ]]; then
+ wget ${ROOT_URL}protoc-3.4.0-linux-x86_64.tar.gz -O $cache_file
+ flag=$?
+ else
+ echo "unknown system $(uname -a)"
+ exit 1
+ fi
+ if [ $flag -ne 0 ]
+ then
+ echo "Download protoc-3.4.0.tar.gz failed"
+ exit 1
+ fi
fi
+
mkdir protoc
- tar -xf protoc-3.4.0.tar.gz -C protoc
+ tar -xf $cache_file -C protoc
fi
+
protoc/bin/protoc easy_rec/python/protos/*.proto --python_out=.
+
if [ $? -ne 0 ]
then
exit 1
fi
-
-#PATH=protoc/bin protoc/bin/protoc --doc_out=html,index.html:. easy_rec/python/protos/*.proto
-#sed -i 's##
#g;s## #g' index.html
diff --git a/scripts/git/post-checkout b/scripts/git/post-checkout
new file mode 100755
index 000000000..73812baa4
--- /dev/null
+++ b/scripts/git/post-checkout
@@ -0,0 +1 @@
+python git-lfs/git_lfs.py pull -f
diff --git a/pre-commit b/scripts/git/pre-commit
similarity index 77%
rename from pre-commit
rename to scripts/git/pre-commit
index 964bd792e..65cc7b3c8 100755
--- a/pre-commit
+++ b/scripts/git/pre-commit
@@ -29,3 +29,17 @@ if [ $result -eq 0 ];then
else
exit 1
fi
+
+python git-lfs/git_lfs.py push
+
+pwd
+
+gt_py36=`python --version 2>&1 | awk '{ print ($2 >= 3.6) }'`
+
+if [ $gt_py36 -eq 1 ]
+then
+ pip install pre-commit
+ pre-commit run -a
+else
+ echo "[WARNING] pre-commit is not supported, please use python >= 3.6"
+fi
diff --git a/scripts/init.sh b/scripts/init.sh
index f6e266faf..503a535ca 100644
--- a/scripts/init.sh
+++ b/scripts/init.sh
@@ -2,8 +2,14 @@
# init pre-commit check hook
rm -rf .git/hooks/pre-commit
-cp pre-commit .git/hooks/
+cp scripts/git/pre-commit .git/hooks/
chmod a+rx .git/hooks/pre-commit
+rm -rf .git/hooks/post-checkout
+cp scripts/git/post-checkout .git/hooks/
+chmod a+rx .git/hooks/post-checkout
+
+python git-lfs/git_lfs.py pull
+
# compile proto files
source scripts/gen_proto.sh
diff --git a/scripts/kafka_test.sh b/scripts/kafka_test.sh
new file mode 100644
index 000000000..e18193629
--- /dev/null
+++ b/scripts/kafka_test.sh
@@ -0,0 +1 @@
+kafka_install_dir=../kafka_2.13-3.1.0/ oss_path=oss://yangxi-bj/export_embedding_taobao_fg_step_0 oss_ak=xxx oss_sk=xxx oss_endpoint=oss-cn-beijing.aliyuncs.com TEST_DEVICES='' PYTHONPATH=.:pai_jobs/ python -m easy_rec.python.test.kafka_test KafkaTest.test_kafka_processor
diff --git a/scripts/kill_all.sh b/scripts/kill_all.sh
new file mode 100644
index 000000000..df2968852
--- /dev/null
+++ b/scripts/kill_all.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+ps aux | grep "easy_rec.python.train_eval" | grep -v grep | awk '{ print $2}' | while read line_str; do echo "kill $line_str"; kill -9 $line_str; done
diff --git a/scripts/pre-commit b/scripts/pre-commit
new file mode 100755
index 000000000..fbd34dfde
--- /dev/null
+++ b/scripts/pre-commit
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+# File generated by pre-commit: https://pre-commit.com
+# ID: 138fd403232d2ddd5efb44317e38bf03
+import os
+import sys
+
+# we try our best, but the shebang of this script is difficult to determine:
+# - macos doesn't ship with python3
+# - windows executables are almost always `python.exe`
+# therefore we continue to support python2 for this small script
+if sys.version_info < (3, 3):
+ from distutils.spawn import find_executable as which
+else:
+ from shutil import which
+
+# work around https://github.com/Homebrew/homebrew-core/issues/30445
+os.environ.pop('__PYVENV_LAUNCHER__', None)
+
+# start templated
+INSTALL_PYTHON = '/apsarapangu/disk2/yancheng.lgq/miniconda3/envs/tf_1_15/bin/python'
+ARGS = [
+ 'hook-impl', '--config=.pre-commit-config.yaml', '--hook-type=pre-commit'
+]
+# end templated
+ARGS.extend(('--hook-dir', os.path.realpath(os.path.dirname(__file__))))
+ARGS.append('--')
+ARGS.extend(sys.argv[1:])
+
+DNE = '`pre-commit` not found. Did you forget to activate your virtualenv?'
+if os.access(INSTALL_PYTHON, os.X_OK):
+ CMD = [INSTALL_PYTHON, '-mpre_commit']
+elif which('pre-commit'):
+ CMD = ['pre-commit']
+else:
+ raise SystemExit(DNE)
+
+CMD.extend(ARGS)
+if sys.platform == 'win32': # https://bugs.python.org/issue19124
+ import subprocess
+
+ if sys.version_info < (3, 7): # https://bugs.python.org/issue25942
+ raise SystemExit(subprocess.Popen(CMD).wait())
+ else:
+ raise SystemExit(subprocess.call(CMD))
+else:
+ os.execvp(CMD[0], CMD)
diff --git a/scripts/train_ngpu.sh b/scripts/train_ngpu.sh
new file mode 100644
index 000000000..85875d410
--- /dev/null
+++ b/scripts/train_ngpu.sh
@@ -0,0 +1,175 @@
+#!/bin/bash
+
+LOG_DIR="logs/"
+
+START_GPU=0
+START_PORT=2001
+WORKER_NUM=8
+PS_NUM=2
+HOST='localhost'
+
+usage() {
+ echo "Usage: `basename $0` -c criteo.config -s gpu_id -p start_port \
+-m model_dir -f fine_tune_ckpt -W worker_num -P ps_num -E extra_args \
+-H hostname -L logdir -N exp_name"
+}
+
+args="--continue_train"
+
+while getopts "c:s:p:m:f:P:W:E:H:L:N:" arg; do
+ case $arg in
+ c)
+ PIPELINE_CONFIG=$OPTARG
+ ;;
+ s)
+ START_GPU=$OPTARG
+ ;;
+ p)
+ START_PORT=$OPTARG
+ ;;
+ m)
+ args="$args --model_dir $OPTARG"
+ ;;
+ f)
+ args="$args --fine_tune_checkpoint $OPTARG"
+ ;;
+ W)
+ WORKER_NUM=$OPTARG
+ ;;
+ P)
+ PS_NUM=$OPTARG
+ ;;
+ E)
+ args="$args $OPTARG"
+ ;;
+ H)
+ HOST=$OPTARG
+ ;;
+ L)
+ LOG_DIR=$OPTARG
+ ;;
+ N)
+ EXP=$OPTARG
+ ;;
+ *)
+ usage
+ exit 1
+ ;;
+ esac
+done
+
+shift $(($OPTIND - 1))
+
+if [ -n "$@" ]
+then
+ args="$args $@"
+fi
+
+if [ -z "$EXP" ]
+then
+ EXP="easy_rec"
+fi
+
+EXP=${EXP}_${START_GPU}_${START_PORT}
+
+if [ -z "$PIPELINE_CONFIG" ]
+then
+ usage
+ exit 1
+fi
+
+if [ ! -e $PIPELINE_CONFIG ]
+then
+ usage
+ exit 1
+fi
+
+if [ ! -e $LOG_DIR ]
+then
+ mkdir $LOG_DIR
+fi
+
+echo "pipeline config: ${PIPELINE_CONFIG}"
+echo "start gpu: ${START_GPU}"
+echo "start port: ${START_PORT}"
+echo "worker_num: ${WORKER_NUM}"
+echo "ps_num: ${PS_NUM}"
+echo "host: ${HOST}"
+echo "exp: ${EXP}"
+echo "logdir: ${LOG_DIR}"
+echo "more args: ${args}"
+
+ps_hosts="\"$HOST:$START_PORT\""
+for ps_id in `seq 1 $((PS_NUM-1))`
+do
+ ps_hosts=${ps_hosts}",\"$HOST:$((START_PORT+ps_id))\""
+done
+
+master_hosts=\"$HOST:$((START_PORT+PS_NUM))\"
+
+worker_hosts="\"$HOST:$((START_PORT+PS_NUM+1))\""
+for worker_id in `seq 1 $((WORKER_NUM-2))`
+do
+ worker_hosts=${worker_hosts}",\"$HOST:$((START_PORT+PS_NUM+worker_id+1))\""
+done
+
+echo "ps_hosts: $ps_hosts"
+echo "master_hosts: $master_hosts"
+echo "worker_hosts: $worker_hosts"
+
+cluster_spec="{
+ \"ps\": [$ps_hosts],
+ \"master\": [$master_hosts],
+ \"worker\": [$worker_hosts]
+ }"
+
+for ps_id in `seq 0 $((PS_NUM-1))`
+do
+ echo "start ps: ${ps_id}"
+ log_file=$LOG_DIR/log_${EXP}_ps_${ps_id}.txt
+ # Parameter Server Process
+ export TF_CONFIG="{
+ \"cluster\":$cluster_spec,
+ \"task\":
+ {
+ \"type\": \"ps\",
+ \"index\": ${ps_id}
+ }
+ }"
+ CUDA_VISIBLE_DEVICES='' nohup python -m easy_rec.python.train_eval \
+ --pipeline_config_path $PIPELINE_CONFIG $args \
+ > $log_file &
+done
+
+
+log_file_1=$LOG_DIR/log_${EXP}_master.txt
+export TF_CONFIG="{
+ \"cluster\":$cluster_spec,
+ \"task\":
+ {
+ \"type\": \"master\",
+ \"index\": 0
+ }
+ }"
+CUDA_VISIBLE_DEVICES=$START_GPU nohup python -m easy_rec.python.train_eval \
+ --pipeline_config_path $PIPELINE_CONFIG $args\
+ > $log_file_1 &
+echo $log_file_1
+
+for worker_id in `seq 0 $((WORKER_NUM-2))`
+do
+ log_file_2=$LOG_DIR/log_${EXP}_worker_$((worker_id)).txt
+ export TF_CONFIG="{
+ \"cluster\":$cluster_spec,
+ \"task\":
+ {
+ \"type\": \"worker\",
+ \"index\": $worker_id
+ }
+ }"
+ ((START_GPU++))
+ CUDA_VISIBLE_DEVICES=$START_GPU nohup python -m easy_rec.python.train_eval \
+ --pipeline_config_path $PIPELINE_CONFIG $args\
+ > $log_file_2 &
+ echo $log_file_2
+done
diff --git a/scripts/train_ps.sh b/scripts/train_ps.sh
index 4eef82628..58053a226 100644
--- a/scripts/train_ps.sh
+++ b/scripts/train_ps.sh
@@ -28,9 +28,9 @@ then
shift 1
fi
-ps_hosts="localhost:2227"
-chief_hosts="localhost:2223"
-worker_hosts="localhost:2224"
+ps_hosts="localhost:2327"
+chief_hosts="localhost:2323"
+worker_hosts="localhost:2324"
gpus=""
echo "ps_hosts=${ps_hosts}"
diff --git a/setup.cfg b/setup.cfg
index fa156e1ee..f0223c47a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
-known_third_party = absl,common_io,future,google,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
+known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,scipy,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,tensorflow_probability,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
diff --git a/setup.py b/setup.py
index fb65fe0f8..26b94f707 100644
--- a/setup.py
+++ b/setup.py
@@ -52,11 +52,11 @@ def parse_require_file(fpath):
setup(
name='easy-rec',
version=get_version(),
- description='An framework for deep learning on recommendation',
+ description='An easy-to-use framework for Recommendation',
doc=readme(),
author='EasyRec Team',
author_email='easy_rec@alibaba-inc.com',
- url='http://gitlab.alibaba-inc.com/pai_biz_arch/EasyRec',
+ url='https://github.com/alibaba/EasyRec',
packages=find_packages(),
include_package_data=True,
classifiers=[