@@ -1022,12 +1022,13 @@ def call(self, inputs, **kwargs):
10221022 raise ValueError (
10231023 "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K .ndim (inputs )))
10241024
1025+ n = len (inputs )
10251026 if self .bilinear_type == "all" :
1026- p = [tf .multiply ( tf . tensordot (v_i , self .W , axes = (- 1 , 0 )), v_j )
1027- for v_i , v_j in itertools .combinations (inputs , 2 )]
1027+ vidots = [tf .tensordot (inputs [ i ] , self .W , axes = (- 1 , 0 )) for i in range ( n )]
1028+ p = [ tf . multiply ( vidots [ i ], inputs [ j ]) for i , j in itertools .combinations (range ( n ) , 2 )]
10281029 elif self .bilinear_type == "each" :
1029- p = [tf .multiply ( tf . tensordot (inputs [i ], self .W_list [i ], axes = (- 1 , 0 )), inputs [ j ])
1030- for i , j in itertools .combinations (range (len ( inputs ) ), 2 )]
1030+ vidots = [tf .tensordot (inputs [i ], self .W_list [i ], axes = (- 1 , 0 )) for i in range ( n - 1 )]
1031+ p = [ tf . multiply ( vidots [ i ], inputs [ j ]) for i , j in itertools .combinations (range (n ), 2 )]
10311032 elif self .bilinear_type == "interaction" :
10321033 p = [tf .multiply (tf .tensordot (v [0 ], w , axes = (- 1 , 0 )), v [1 ])
10331034 for v , w in zip (itertools .combinations (inputs , 2 ), self .W_list )]
0 commit comments