22import  pytest 
33from  deepctr .models .din  import  DIN 
44from  deepctr .layers .activation  import  Dice 
5- from  deepctr .utils  import  custom_objects ,SingleFeat 
5+ from  deepctr .utils  import  custom_objects ,  SingleFeat 
66from  tensorflow .python .keras .models  import  load_model , save_model 
7- from  ..utils  import  check_model 
87
98
109def  get_xy_fd ():
11-     feature_dim_dict  =  {"sparse" :[SingleFeat ('user' ,4 ),SingleFeat ('gender' ,2 ),SingleFeat ('item' ,4 ),SingleFeat ('item_gender' ,2 )], "dense" : []}
10+     feature_dim_dict  =  {"sparse" : [SingleFeat ('user' , 4 ), SingleFeat (
11+         'gender' , 2 ), SingleFeat ('item' , 4 ), SingleFeat ('item_gender' , 2 )], "dense" : []}
1212    behavior_feature_list  =  ["item" ]
1313    uid  =  np .array ([1 , 2 , 3 ])
1414    ugender  =  np .array ([0 , 1 , 0 ])
@@ -37,7 +37,7 @@ def test_DIN_model_io():
3737
3838    model  =  DIN (feature_dim_dict , behavior_feature_list , hist_len_max = 4 , embedding_size = 8 , att_activation = Dice ,
3939
40-                   hidden_size = [4 , 4 , 4 ], keep_prob = 0.6 ,)
40+                 hidden_size = [4 , 4 , 4 ], keep_prob = 0.6 ,)
4141
4242    model .compile ('adam' , 'binary_crossentropy' ,
4343                  metrics = ['binary_crossentropy' ])
@@ -53,7 +53,7 @@ def test_DIN_att():
5353    x , y , feature_dim_dict , behavior_feature_list  =  get_xy_fd ()
5454
5555    model  =  DIN (feature_dim_dict , behavior_feature_list , hist_len_max = 4 , embedding_size = 8 ,
56-                   hidden_size = [4 , 4 , 4 ], keep_prob = 0.6 ,)
56+                 hidden_size = [4 , 4 , 4 ], keep_prob = 0.6 ,)
5757
5858    model .compile ('adam' , 'binary_crossentropy' ,
5959                  metrics = ['binary_crossentropy' ])
@@ -74,17 +74,5 @@ def test_DIN_att():
7474    print (model_name  +  " test pass!" )
7575
7676
77- # def test_DIN_sum(): 
78- # 
79- #     model_name = "DIN_sum" 
80- #     x, y, feature_dim_dict, behavior_feature_list = get_xy_fd() 
81- # 
82- #     model = DIN(feature_dim_dict, behavior_feature_list, hist_len_max=4, embedding_size=8, 
83- #                 use_din=False, hidden_size=[4, 4, 4], keep_prob=0.6, activation="sigmoid") 
84- # 
85- #     check_model(model, model_name, x, y) 
86- 
87- 
8877if  __name__  ==  "__main__" :
8978    test_DIN_att ()
90-     #test_DIN_sum() 
0 commit comments