@@ -1042,3 +1042,123 @@ def get_config(self, ):
10421042 config = {'bilinear_type' : self .bilinear_type , 'seed' : self .seed }
10431043 base_config = super (BilinearInteraction , self ).get_config ()
10441044 return dict (list (base_config .items ()) + list (config .items ()))
1045+
1046+
1047+ class FieldWiseBiInteraction (Layer ):
1048+ """Field-Wise Bi-Interaction Layer used in FLEN,compress the
1049+ pairwise element-wise product of features into one single vector.
1050+
1051+ Input shape
1052+ - A list of 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
1053+
1054+ Output shape
1055+ - 2D tensor with shape: ``(batch_size,embedding_size)``.
1056+
1057+ Arguments
1058+ - **use_bias** : Boolean, if use bias.
1059+ - **l2_reg** : Float, l2 regularization coefficient.
1060+ - **seed** : A Python integer to use as random seed.
1061+
1062+ References
1063+ [1] hen W, Zhan L, Ci Y, Lin C https://arxiv.org/pdf/1911.04690
1064+ """
1065+ def __init__ (self , l2_reg = 1e-5 , seed = 1024 , ** kwargs ):
1066+
1067+ self .l2_reg = l2_reg
1068+ self .seed = seed
1069+
1070+ super (FieldWiseBiInteraction , self ).__init__ (** kwargs )
1071+
1072+ def build (self , input_shape ):
1073+
1074+ if not isinstance (input_shape , list ) or len (input_shape ) < 2 :
1075+ raise ValueError (
1076+ 'A `Field-Wise Bi-Interaction` layer should be called '
1077+ 'on a list of at least 2 inputs' )
1078+
1079+ self .num_fields = len (input_shape )
1080+ embedding_size = input_shape [0 ][- 1 ]
1081+
1082+ self .kernel_inter = self .add_weight (
1083+ name = 'kernel_inter' ,
1084+ shape = (int (self .num_fields * (self .num_fields - 1 ) / 2 ), 1 ),
1085+ initializer = glorot_normal (seed = self .seed ),
1086+ regularizer = l2 (self .l2_reg ),
1087+ trainable = True )
1088+ self .bias_inter = self .add_weight (name = 'bias_inter' ,
1089+ shape = (embedding_size ),
1090+ initializer = Zeros (),
1091+ trainable = True )
1092+ self .kernel_intra = self .add_weight (
1093+ name = 'kernel_intra' ,
1094+ shape = (self .num_fields , 1 ),
1095+ initializer = glorot_normal (seed = self .seed ),
1096+ regularizer = l2 (self .l2_reg ),
1097+ trainable = True )
1098+ self .bias_intra = self .add_weight (name = 'bias_intra' ,
1099+ shape = (embedding_size ),
1100+ initializer = Zeros (),
1101+ trainable = True )
1102+
1103+ super (FieldWiseBiInteraction ,
1104+ self ).build (input_shape ) # Be sure to call this somewhere!
1105+
1106+ def call (self , inputs , ** kwargs ):
1107+
1108+ if K .ndim (inputs [0 ]) != 3 :
1109+ raise ValueError (
1110+ "Unexpected inputs dimensions %d, expect to be 3 dimensions" %
1111+ (K .ndim (inputs )))
1112+
1113+ field_wise_embeds_list = inputs
1114+
1115+ # MF module
1116+ field_wise_vectors = tf .concat ([
1117+ reduce_sum (field_i_vectors , axis = 1 , keep_dims = True )
1118+ for field_i_vectors in field_wise_embeds_list
1119+ ], 1 )
1120+
1121+ left = []
1122+ right = []
1123+ for i in range (self .num_fields ):
1124+ for j in range (i + 1 , self .num_fields ):
1125+ left .append (i )
1126+ right .append (j )
1127+
1128+ embeddings_left = tf .gather (params = field_wise_vectors ,
1129+ indices = left ,
1130+ axis = 1 )
1131+ embeddings_right = tf .gather (params = field_wise_vectors ,
1132+ indices = right ,
1133+ axis = 1 )
1134+
1135+ embeddings_prod = embeddings_left * embeddings_right
1136+ field_weighted_embedding = embeddings_prod * self .kernel_inter
1137+ h_mf = reduce_sum (field_weighted_embedding , axis = 1 )
1138+ h_mf = tf .nn .bias_add (h_mf , self .bias_inter )
1139+
1140+ # FM module
1141+ square_of_sum_list = [
1142+ tf .square (reduce_sum (field_i_vectors , axis = 1 , keep_dims = True ))
1143+ for field_i_vectors in field_wise_embeds_list
1144+ ]
1145+ sum_of_square_list = [
1146+ reduce_sum (field_i_vectors * field_i_vectors ,
1147+ axis = 1 ,
1148+ keep_dims = True )
1149+ for field_i_vectors in field_wise_embeds_list
1150+ ]
1151+
1152+ field_fm = tf .concat ([
1153+ square_of_sum - sum_of_square for square_of_sum , sum_of_square in
1154+ zip (square_of_sum_list , sum_of_square_list )
1155+ ], 1 )
1156+
1157+ h_fm = reduce_sum (field_fm * self .kernel_intra , axis = 1 )
1158+
1159+ h_fm = tf .nn .bias_add (h_fm , self .bias_intra )
1160+
1161+ return h_mf + h_fm
1162+
1163+ def compute_output_shape (self , input_shape ):
1164+ return (None , input_shape [0 ][- 1 ])
0 commit comments