11# -*- coding:utf-8 -*- 
22""" 
33
4- Author: 
5- 4+ Authors: 
5+ 6+     Harshit Pande 
67
78""" 
89
1112import  tensorflow  as  tf 
1213from  tensorflow .python .keras  import  backend  as  K 
1314from  tensorflow .python .keras .initializers  import  (Zeros , glorot_normal ,
14-                                                   glorot_uniform )
15+                                                   glorot_uniform ,  TruncatedNormal )
1516from  tensorflow .python .keras .layers  import  Layer 
1617from  tensorflow .python .keras .regularizers  import  l2 
18+ from  tensorflow .python .keras .backend  import  batch_dot 
1719from  tensorflow .python .layers  import  utils 
1820
1921from  .activation  import  activation_layer 
@@ -1052,7 +1054,7 @@ class FieldWiseBiInteraction(Layer):
10521054
10531055      Output shape 
10541056        - 2D tensor with shape: ``(batch_size,embedding_size)``. 
1055-       
1057+ 
10561058      Arguments 
10571059        - **use_bias** : Boolean, if use bias. 
10581060        - **seed** : A Python integer to use as random seed. 
@@ -1062,7 +1064,7 @@ class FieldWiseBiInteraction(Layer):
10621064
10631065    """ 
10641066
1065-     def  __init__ (self ,use_bias = True , seed = 1024 , ** kwargs ):
1067+     def  __init__ (self ,  use_bias = True , seed = 1024 , ** kwargs ):
10661068        self .use_bias  =  use_bias 
10671069        self .seed  =  seed 
10681070
@@ -1167,3 +1169,80 @@ def get_config(self, ):
11671169        config  =  {'use_bias' : self .use_bias , 'seed' : self .seed }
11681170        base_config  =  super (FieldWiseBiInteraction , self ).get_config ()
11691171        return  dict (list (base_config .items ()) +  list (config .items ()))
1172+ 
1173+ 
1174+ class  FwFM (Layer ):
1175+     """Field-weighted Factorization Machines 
1176+ 
1177+       Input shape 
1178+         - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. 
1179+ 
1180+       Output shape 
1181+         - 2D tensor with shape: ``(batch_size, 1)``. 
1182+ 
1183+       Arguments 
1184+         - **num_fields** : integer for number of fields 
1185+         - **regularizer** : L2 regularizer weight for the field strength parameters of FwFM 
1186+ 
1187+       References 
1188+         - [Field-weighted Factorization Machines for Click-Through Rate Prediction in Display Advertising] 
1189+         https://arxiv.org/pdf/1806.03514.pdf 
1190+     """ 
1191+ 
1192+     def  __init__ (self , num_fields = 4 , regularizer = 0.000001 , ** kwargs ):
1193+         self .num_fields  =  num_fields 
1194+         self .regularizer  =  regularizer 
1195+         super (FwFM , self ).__init__ (** kwargs )
1196+ 
1197+     def  build (self , input_shape ):
1198+         if  len (input_shape ) !=  3 :
1199+             raise  ValueError ("Unexpected inputs dimensions % d,\  
1200+                               expect to be 3 dimensions"  %  (len (input_shape )))
1201+ 
1202+         if  input_shape [1 ] !=  self .num_fields :
1203+             raise  ValueError ("Mismatch in number of fields {} and \  
1204+                   concatenated embeddings dims {}" .format (self .num_fields , input_shape [1 ]))
1205+ 
1206+         self .field_strengths  =  self .add_weight (name = 'field_pair_strengths' ,
1207+                                                shape = (self .num_fields , self .num_fields ),
1208+                                                initializer = TruncatedNormal (),
1209+                                                regularizer = l2 (self .regularizer ),
1210+                                                trainable = True )
1211+ 
1212+         super (FwFM , self ).build (input_shape )  # Be sure to call this somewhere! 
1213+ 
1214+     def  call (self , inputs , ** kwargs ):
1215+         if  K .ndim (inputs ) !=  3 :
1216+             raise  ValueError (
1217+                 "Unexpected inputs dimensions %d, expect to be 3 dimensions" 
1218+                 %  (K .ndim (inputs )))
1219+ 
1220+         if  inputs .shape [1 ] !=  self .num_fields :
1221+             raise  ValueError ("Mismatch in number of fields {} and \  
1222+                   concatenated embeddings dims {}" .format (self .num_fields , inputs .shape [1 ]))
1223+ 
1224+         pairwise_inner_prods  =  []
1225+         for  fi , fj  in  itertools .combinations (range (self .num_fields ), 2 ):
1226+             # get field strength for pair fi and fj 
1227+             r_ij  =  self .field_strengths [fi , fj ]
1228+ 
1229+             # get embeddings for the features of both the fields 
1230+             feat_embed_i  =  tf .squeeze (inputs [0 :, fi :fi  +  1 , 0 :], axis = 1 )
1231+             feat_embed_j  =  tf .squeeze (inputs [0 :, fj :fj  +  1 , 0 :], axis = 1 )
1232+ 
1233+             f  =  tf .scalar_mul (r_ij , batch_dot (feat_embed_i , feat_embed_j , axes = 1 ))
1234+             pairwise_inner_prods .append (f )
1235+ 
1236+         sum_  =  tf .add_n (pairwise_inner_prods )
1237+         return  sum_ 
1238+ 
1239+     def  compute_output_shape (self , input_shape ):
1240+         return  (None , 1 )
1241+ 
1242+     def  get_config (self ):
1243+         config  =  super (FwFM , self ).get_config ().copy ()
1244+         config .update ({
1245+             'num_fields' : self .num_fields ,
1246+             'regularizer' : self .regularizer 
1247+         })
1248+         return  config 
0 commit comments