@@ -68,16 +68,16 @@ def _model_fn(features, labels, mode, config):
6868 cross_out = CrossNet (cross_num , l2_reg = l2_reg_cross )(dnn_input )
6969 stack_out = tf .keras .layers .Concatenate ()([cross_out , deep_out ])
7070 final_logit = tf .keras .layers .Dense (
71- 1 , use_bias = False , activation = None )(stack_out )
71+ 1 , use_bias = False , kernel_initializer = tf . keras . initializers . glorot_normal ( seed ) )(stack_out )
7272 elif len (dnn_hidden_units ) > 0 : # Only Deep
7373 deep_out = DNN (dnn_hidden_units , dnn_activation , l2_reg_dnn , dnn_dropout ,
7474 dnn_use_bn , seed )(dnn_input , training = train_flag )
7575 final_logit = tf .keras .layers .Dense (
76- 1 , use_bias = False , activation = None )(deep_out )
76+ 1 , use_bias = False , kernel_initializer = tf . keras . initializers . glorot_normal ( seed ) )(deep_out )
7777 elif cross_num > 0 : # Only Cross
7878 cross_out = CrossNet (cross_num , l2_reg = l2_reg_cross )(dnn_input )
7979 final_logit = tf .keras .layers .Dense (
80- 1 , use_bias = False , activation = None )(cross_out )
80+ 1 , use_bias = False , kernel_initializer = tf . keras . initializers . glorot_normal ( seed ) )(cross_out )
8181 else : # Error
8282 raise NotImplementedError
8383
0 commit comments