1313from tensorflow .python .keras .layers import Embedding , Input , Flatten
1414from tensorflow .python .keras .regularizers import l2
1515
16- from .layers .sequence import SequencePoolingLayer
16+ from .layers .sequence import SequencePoolingLayer , SequenceMultiplyLayer
1717from .layers .utils import Hash ,concat_fun ,Linear
1818
19-
2019class SparseFeat (namedtuple ('SparseFeat' , ['name' , 'dimension' , 'use_hash' , 'dtype' ,'embedding_name' ,'embedding' ])):
2120 __slots__ = ()
2221
@@ -25,6 +24,16 @@ def __new__(cls, name, dimension, use_hash=False, dtype="int32", embedding_name=
2524 embedding_name = name
2625 return super (SparseFeat , cls ).__new__ (cls , name , dimension , use_hash , dtype , embedding_name ,embedding )
2726
27+ def __hash__ (self ):
28+ return self .name .__hash__ ()
29+
30+ def __eq__ (self , other ):
31+ if self .name == other .name :
32+ return True
33+ return False
34+
35+ def __repr__ (self ):
36+ return 'SparseFeat:' + self .name
2837
2938class DenseFeat (namedtuple ('DenseFeat' , ['name' , 'dimension' , 'dtype' ])):
3039 __slots__ = ()
@@ -33,6 +42,16 @@ def __new__(cls, name, dimension=1, dtype="float32"):
3342
3443 return super (DenseFeat , cls ).__new__ (cls , name , dimension , dtype )
3544
45+ def __hash__ (self ):
46+ return self .name .__hash__ ()
47+
48+ def __eq__ (self , other ):
49+ if self .name == other .name :
50+ return True
51+ return False
52+
53+ def __repr__ (self ):
54+ return 'DenseFeat:' + self .name
3655
3756class VarLenSparseFeat (namedtuple ('VarLenFeat' , ['name' , 'dimension' , 'maxlen' , 'combiner' , 'use_hash' , 'dtype' ,'embedding_name' ,'embedding' ])):
3857 __slots__ = ()
@@ -42,6 +61,17 @@ def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype
4261 embedding_name = name
4362 return super (VarLenSparseFeat , cls ).__new__ (cls , name , dimension , maxlen , combiner , use_hash , dtype , embedding_name ,embedding )
4463
64+ def __hash__ (self ):
65+ return self .name .__hash__ ()
66+
67+ def __eq__ (self , other ):
68+ if self .name == other .name :
69+ return True
70+ return False
71+
72+ def __repr__ (self ):
73+ return 'VarLenSparseFeat:' + self .name
74+
4575def get_feature_names (feature_columns ):
4676 features = build_input_features (feature_columns )
4777 return list (features .keys ())
@@ -209,6 +239,30 @@ def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_colu
209239 pooling_vec_list .append (vec )
210240 return pooling_vec_list
211241
242+ def get_varlen_multiply_list (embedding_dict , features , varlen_sparse_feature_columns_name_dict ):
243+ multiply_vec_list = []
244+ print (embedding_dict )
245+ for key_feature in varlen_sparse_feature_columns_name_dict :
246+ for value_feature in varlen_sparse_feature_columns_name_dict [key_feature ]:
247+ key_feature_length_name = key_feature .name + '_seq_length'
248+ if isinstance (value_feature , VarLenSparseFeat ):
249+ value_input = embedding_dict [value_feature .name ]
250+ elif isinstance (value_feature , DenseFeat ):
251+ value_input = features [value_feature .name ]
252+ else :
253+ raise TypeError ("Invalid feature column type,got" ,type (value_feature ))
254+ if key_feature_length_name in features :
255+ varlen_vec = SequenceMultiplyLayer (supports_masking = False )(
256+ [embedding_dict [key_feature .name ], features [key_feature_length_name ], value_input ])
257+ vec = SequencePoolingLayer ('sum' , supports_masking = False )(
258+ [varlen_vec , features [key_feature_length_name ]])
259+ else :
260+ varlen_vec = SequenceMultiplyLayer (supports_masking = True )(
261+ [embedding_dict [key_feature .name ], value_input ])
262+ vec = SequencePoolingLayer ('sum' , supports_masking = True )( varlen_vec )
263+ multiply_vec_list .append (vec )
264+ return multiply_vec_list
265+
212266def get_dense_input (features ,feature_columns ):
213267 dense_feature_columns = list (filter (lambda x :isinstance (x ,DenseFeat ),feature_columns )) if feature_columns else []
214268 dense_input_list = []
0 commit comments