1818from ...layers .utils import concat_func , add_func , combined_dnn_input , reduce_sum
1919
2020
21- def DeepFEFMEstimator (linear_feature_columns , dnn_feature_columns , embedding_size = 48 ,
22- dnn_hidden_units = (1024 , 1024 , 1024 ), l2_reg_linear = 0.000001 , l2_reg_embedding_feat = 0.00001 ,
23- l2_reg_embedding_field = 0.0000001 , l2_reg_dnn = 0 , seed = 1024 , dnn_dropout = 0.2 ,
21+ def DeepFEFMEstimator (linear_feature_columns , dnn_feature_columns ,
22+ dnn_hidden_units = (128 , 128 ), l2_reg_linear = 0.00001 , l2_reg_embedding_feat = 0.00001 ,
23+ l2_reg_embedding_field = 0.00001 , l2_reg_dnn = 0 , seed = 1024 , dnn_dropout = 0.0 ,
2424 dnn_activation = 'relu' , dnn_use_bn = False , task = 'binary' , model_dir = None ,
2525 config = None , linear_optimizer = 'Ftrl' , dnn_optimizer = 'Adagrad' , training_chief_hooks = None ):
2626 """Instantiates the DeepFEFM Network architecture or the shallow FEFM architecture (Ablation support not provided
2727 as estimator is meant for production, Ablation support provided in DeepFEFM implementation in models
2828
2929 :param linear_feature_columns: An iterable containing all the features used by linear part of the model.
3030 :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
31- :param embedding_size: positive integer,sparse feature embedding_size
3231 :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
3332 :param l2_reg_linear: float. L2 regularizer strength applied to linear part
3433 :param l2_reg_embedding_feat: float. L2 regularizer strength applied to embedding vector of features
@@ -62,10 +61,11 @@ def _model_fn(features, labels, mode, config):
6261 sparse_embedding_list , dense_value_list = input_from_feature_columns (features , dnn_feature_columns ,
6362 l2_reg_embedding = l2_reg_embedding_feat )
6463
65- fefm_interaction_embedding = FEFMLayer (num_fields = len ( sparse_embedding_list ), embedding_size = embedding_size ,
66- regularizer = l2_reg_embedding_field )(concat_func (sparse_embedding_list , axis = 1 ))
64+ fefm_interaction_embedding = FEFMLayer (
65+ regularizer = l2_reg_embedding_field )(concat_func (sparse_embedding_list , axis = 1 ))
6766
68- fefm_logit = tf .keras .layers .Lambda (lambda x : reduce_sum (x , axis = 1 , keep_dims = True ))(fefm_interaction_embedding )
67+ fefm_logit = tf .keras .layers .Lambda (lambda x : reduce_sum (x , axis = 1 , keep_dims = True ))(
68+ fefm_interaction_embedding )
6969
7070 final_logit_components .append (fefm_logit )
7171
@@ -87,6 +87,3 @@ def _model_fn(features, labels, mode, config):
8787 training_chief_hooks = training_chief_hooks )
8888
8989 return tf .estimator .Estimator (_model_fn , model_dir = model_dir , config = config )
90-
91-
92-
0 commit comments