11from tensorflow .python .keras .layers import Layer ,Activation ,BatchNormalization
22from tensorflow .python .keras .regularizers import l2
3- from tensorflow .python .keras .initializers import RandomNormal , Zeros ,glorot_normal ,glorot_uniform
3+ from tensorflow .python .keras .initializers import Zeros ,glorot_normal ,glorot_uniform
44from tensorflow .python .keras import backend as K
55
66import tensorflow as tf
@@ -61,7 +61,7 @@ class AFMLayer(Layer):
6161
6262 - **l2_reg_w** : float between 0 and 1. L2 regularizer strength applied to attention network.
6363
64- - **keep_prob** : float between 0 and 1. Fraction of the attention net output units to keep.
64+ - **keep_prob** : float between 0 and 1. Fraction of the attention net output units to keep.
6565
6666 - **seed** : A Python integer to use as random seed.
6767
@@ -175,9 +175,10 @@ def call(self, inputs,**kwargs):
175175
176176 if isinstance (self .activation ,str ):
177177 output = Activation (self .activation )(x )
178+ elif issubclass (self .activation ,Layer ):
179+ output = self .activation ()(x )
178180 else :
179- output = self .activation (x )
180-
181+ raise ValueError ("Invalid activation of MLP,found %s.You should use a str or a Activation Layer Class." % (self .activation ))
181182 output = tf .reshape (output ,(- 1 ,1 ))
182183
183184 return output
@@ -254,7 +255,7 @@ def get_config(self,):
254255
255256class MLP (Layer ):
256257 """The Multi Layer Percetron
257-
258+
258259 Input shape
259260 - nD tensor with shape: ``(batch_size, ..., input_dim)``. The most common situation would be a 2D input with shape ``(batch_size, input_dim)``.
260261
@@ -268,14 +269,14 @@ class MLP(Layer):
268269
269270 - **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix.
270271
271- - **keep_prob**: float between 0 and 1. Fraction of the units to keep.
272+ - **keep_prob**: float between 0 and 1. Fraction of the units to keep.
272273
273274 - **use_bn**: bool. Whether use BatchNormalization before activation or not.
274275
275276 - **seed**: A Python integer to use as random seed.
276277 """
277278
278- def __init__ (self , hidden_size , activation ,l2_reg , keep_prob , use_bn ,seed ,** kwargs ):
279+ def __init__ (self , hidden_size , activation = 'relu' ,l2_reg = 0 , keep_prob = 1 , use_bn = False ,seed = 1024 ,** kwargs ):
279280 self .hidden_size = hidden_size
280281 self .activation = activation
281282 self .keep_prob = keep_prob
@@ -338,7 +339,7 @@ class BiInteractionPooling(Layer):
338339 """Bi-Interaction Layer used in Neural FM,compress the pairwise element-wise product of features into one single vector.
339340
340341 Input shape
341- - A list of 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
342+ - A 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
342343
343344 Output shape
344345 - 3D tensor with shape: ``(batch_size,1,embedding_size)``.
@@ -381,7 +382,7 @@ class OutterProductLayer(Layer):
381382
382383 Output shape
383384 - 2D tensor with shape:``(batch_size,N*(N-1)/2 )``.
384-
385+
385386 Arguments
386387 - **kernel_type**: str. The kernel weight matrix type to use,can be mat,vec or num
387388
@@ -557,7 +558,7 @@ def call(self, inputs,**kwargs):
557558 row = []
558559 col = []
559560 num_inputs = len (embed_list )
560- num_pairs = int (num_inputs * (num_inputs - 1 ) / 2 )
561+ # num_pairs = int(num_inputs * (num_inputs - 1) / 2)
561562
562563
563564 for i in range (num_inputs - 1 ):
@@ -604,7 +605,7 @@ class LocalActivationUnit(Layer):
604605
605606 - **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix of attention net.
606607
607- - **keep_prob**: float between 0 and 1. Fraction of the units to keep of attention net.
608+ - **keep_prob**: float between 0 and 1. Fraction of the units to keep of attention net.
608609
609610 - **use_bn**: bool. Whether use BatchNormalization before activation or not in attention net.
610611
@@ -614,7 +615,7 @@ class LocalActivationUnit(Layer):
614615 - [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)
615616 """
616617
617- def __init__ (self ,hidden_size , activation ,l2_reg , keep_prob , use_bn ,seed ,** kwargs ):
618+ def __init__ (self ,hidden_size = ( 64 , 32 ), activation = 'sigmoid' ,l2_reg = 0 , keep_prob = 1 , use_bn = False ,seed = 1024 ,** kwargs ):
618619 self .hidden_size = hidden_size
619620 self .activation = activation
620621 self .l2_reg = l2_reg
@@ -663,7 +664,7 @@ def compute_output_shape(self, input_shape):
663664 return input_shape [1 ][:2 ] + (1 ,)
664665
665666 def get_config (self ,):
666- config = {'activation' : self .activation ,'hidden_size' :self .hidden_size , 'l2_reg' :self .l2_reg , 'keep_prob' :self .keep_prob ,'seed' : self .seed }
667+ config = {'activation' : self .activation ,'hidden_size' :self .hidden_size , 'l2_reg' :self .l2_reg , 'keep_prob' :self .keep_prob ,'use_bn' : self . use_bn , ' seed' : self .seed }
667668 base_config = super (LocalActivationUnit , self ).get_config ()
668669 return dict (list (base_config .items ()) + list (config .items ()))
669670
0 commit comments