@@ -33,8 +33,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
33
33
it ( 'A x B with relu' , ( ) => {
34
34
const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 2 , 3 ] ) ;
35
35
const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 3 , 2 ] ) ;
36
+ const transposeA = false ;
37
+ const transposeB = false ;
36
38
37
- const c = tf . fused . matMul ( a , b , false , false , null , 'relu' ) ;
39
+ const c = tf . fused . matMul ( a , b , transposeA , transposeB , null , 'relu' ) ;
38
40
39
41
expect ( c . shape ) . toEqual ( [ 2 , 2 ] ) ;
40
42
expectArraysClose ( c , [ 0 , 8 , 0 , 20 ] ) ;
@@ -43,8 +45,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
43
45
it ( 'A x B with relu transpose' , ( ) => {
44
46
const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 2 , 3 ] ) ;
45
47
const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 2 , 3 ] ) ;
48
+ const transposeA = false ;
49
+ const transposeB = true ;
46
50
47
- const c = tf . fused . matMul ( a , b , false , true , null , 'relu' ) ;
51
+ const c = tf . fused . matMul ( a , b , transposeA , transposeB , null , 'relu' ) ;
48
52
49
53
expect ( c . shape ) . toEqual ( [ 2 , 2 ] ) ;
50
54
expectArraysClose ( c , [ 0 , 9 , 0 , 24 ] ) ;
@@ -54,8 +58,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
54
58
const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 2 , 3 ] ) ;
55
59
const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 3 , 2 ] ) ;
56
60
const c = tf . tensor2d ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 ] ) ;
61
+ const transposeA = false ;
62
+ const transposeB = false ;
57
63
58
- const d = tf . fused . matMul ( a , b , false , false , c , 'relu' ) ;
64
+ const d = tf . fused . matMul ( a , b , transposeA , transposeB , c , 'relu' ) ;
59
65
60
66
expect ( d . shape ) . toEqual ( [ 2 , 2 ] ) ;
61
67
expectArraysClose ( d , [ 1 , 9 , 0 , 21 ] ) ;
@@ -66,8 +72,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
66
72
const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 3 , 2 ] ) ;
67
73
const c = tf . tensor1d ( [ 1 , 1 ] ) ;
68
74
const act : tf . fused . Activation = 'relu' ;
75
+ const transposeA = false ;
76
+ const transposeB = false ;
69
77
70
- const d = tf . fused . matMul ( a , b , false , false , c , act ) ;
78
+ const d = tf . fused . matMul ( a , b , transposeA , transposeB , c , act ) ;
71
79
72
80
expect ( d . shape ) . toEqual ( [ 2 , 2 ] ) ;
73
81
expectArraysClose ( d , [ 1 , 9 , 0 , 21 ] ) ;
@@ -78,8 +86,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
78
86
const b = tf . tensor3d ( [ 0 , 1 , - 3 , 2 , 2 , 1 , 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 2 , 3 , 2 ] ) ;
79
87
const c = tf . tensor2d ( [ 1 , 2 ] , [ 1 , 2 ] ) ;
80
88
const act : tf . fused . Activation = 'relu' ;
89
+ const transposeA = false ;
90
+ const transposeB = false ;
81
91
82
- const d = tf . fused . matMul ( a , b , false , false , c , act ) ;
92
+ const d = tf . fused . matMul ( a , b , transposeA , transposeB , c , act ) ;
83
93
84
94
expect ( d . shape ) . toEqual ( [ 2 , 2 , 2 ] ) ;
85
95
expectArraysClose ( d , [ 2 , 6 , 0 , 18 , 0 , 30 , 0 , 42 ] ) ;
@@ -89,8 +99,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
89
99
const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 2 , 3 ] ) ;
90
100
const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 3 , 2 ] ) ;
91
101
const c = tf . tensor2d ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 ] ) ;
102
+ const transposeA = false ;
103
+ const transposeB = false ;
92
104
93
- const d = tf . fused . matMul ( a , b , false , false , c , 'linear' ) ;
105
+ const d = tf . fused . matMul ( a , b , transposeA , transposeB , c , 'linear' ) ;
94
106
95
107
expect ( d . shape ) . toEqual ( [ 2 , 2 ] ) ;
96
108
expectArraysClose ( d , [ 1 , 9 , - 2 , 21 ] ) ;
@@ -100,14 +112,16 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
100
112
const a = tf . tensor2d ( [ 1 , 2 , 3 , 10 , 20 , - 30 ] , [ 2 , 3 ] ) ;
101
113
const b = tf . tensor2d ( [ 2 , 3 , 4 , - 1 , 2 , 3 ] , [ 3 , 2 ] ) ;
102
114
const dy = tf . tensor2d ( [ 1 , 10 , 20 , 30 ] , [ 2 , 2 ] ) ;
115
+ const transposeA = false ;
116
+ const transposeB = false ;
103
117
104
118
const grads = tf . grads ( ( a , b ) => {
105
- const prod = tf . matMul ( a , b , false , false ) ;
119
+ const prod = tf . matMul ( a , b , transposeA , transposeB ) ;
106
120
return tf . relu ( prod ) ;
107
121
} ) ;
108
122
109
123
const fusedGrads = tf . grads ( ( a , b ) => {
110
- return tf . fused . matMul ( a , b , false , false , null , 'relu' ) ;
124
+ return tf . fused . matMul ( a , b , transposeA , transposeB , null , 'relu' ) ;
111
125
} ) ;
112
126
113
127
const [ da , db ] = grads ( [ a , b ] , dy ) ;
@@ -120,17 +134,19 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
120
134
const a = tf . tensor2d ( [ 1 , 2 , 3 , 10 , 20 , - 30 ] , [ 2 , 3 ] ) ;
121
135
const b = tf . tensor2d ( [ 2 , 3 , 4 , - 1 , 2 , 3 ] , [ 3 , 2 ] ) ;
122
136
const c = tf . tensor2d ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 ] ) ;
137
+ const transposeA = false ;
138
+ const transposeB = false ;
123
139
124
140
const dy = tf . tensor2d ( [ 1 , 10 , 20 , 30 ] , [ 2 , 2 ] ) ;
125
141
126
142
const grads = tf . grads ( ( a , b , c ) => {
127
- const prod = tf . matMul ( a , b , false , false ) ;
143
+ const prod = tf . matMul ( a , b , transposeA , transposeB ) ;
128
144
const sum = tf . add ( prod , c ) ;
129
145
return tf . relu ( sum ) ;
130
146
} ) ;
131
147
132
148
const fusedGrads = tf . grads ( ( a , b , c ) => {
133
- return tf . fused . matMul ( a , b , false , false , c , 'relu' ) ;
149
+ return tf . fused . matMul ( a , b , transposeA , transposeB , c , 'relu' ) ;
134
150
} ) ;
135
151
136
152
const [ da , db , dc ] = grads ( [ a , b , c ] , dy ) ;
@@ -145,17 +161,46 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
145
161
const a = tf . tensor2d ( [ 1 , 2 , 3 , 10 , 20 , - 30 ] , [ 3 , 2 ] ) ;
146
162
const b = tf . tensor2d ( [ 2 , 3 , 4 , - 1 , 2 , 3 ] , [ 3 , 2 ] ) ;
147
163
const c = tf . tensor2d ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 ] ) ;
164
+ const transposeA = true ;
165
+ const transposeB = false ;
148
166
149
167
const dy = tf . tensor2d ( [ 1 , 10 , 20 , 30 ] , [ 2 , 2 ] ) ;
150
168
151
169
const grads = tf . grads ( ( a , b , c ) => {
152
- const prod = tf . matMul ( a , b , true , false ) ;
170
+ const prod = tf . matMul ( a , b , transposeA , transposeB ) ;
153
171
const sum = tf . add ( prod , c ) ;
154
172
return tf . relu ( sum ) ;
155
173
} ) ;
156
174
157
175
const fusedGrads = tf . grads ( ( a , b , c ) => {
158
- return tf . fused . matMul ( a , b , true , false , c , 'relu' ) ;
176
+ return tf . fused . matMul ( a , b , transposeA , transposeB , c , 'relu' ) ;
177
+ } ) ;
178
+
179
+ const [ da , db , dc ] = grads ( [ a , b , c ] , dy ) ;
180
+ const [ fusedDa , fusedDb , fusedDc ] = fusedGrads ( [ a , b , c ] , dy ) ;
181
+
182
+ expectArraysClose ( da , fusedDa ) ;
183
+ expectArraysClose ( db , fusedDb ) ;
184
+ expectArraysClose ( dc , fusedDc ) ;
185
+ } ) ;
186
+
187
+ it ( 'A x B with relu and broadcasted bias gradient' , ( ) => {
188
+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 10 , 20 , - 30 ] , [ 2 , 3 ] ) ;
189
+ const b = tf . tensor2d ( [ 2 , 3 , 4 , - 1 , 2 , 3 ] , [ 3 , 2 ] ) ;
190
+ const c = tf . tensor2d ( [ [ 1 ] ] ) ;
191
+ const transposeA = false ;
192
+ const transposeB = false ;
193
+
194
+ const dy = tf . tensor2d ( [ 1 , 10 , 20 , 30 ] , [ 2 , 2 ] ) ;
195
+
196
+ const grads = tf . grads ( ( a , b , c ) => {
197
+ const prod = tf . matMul ( a , b , transposeA , transposeB ) ;
198
+ const sum = tf . add ( prod , c ) ;
199
+ return tf . relu ( sum ) ;
200
+ } ) ;
201
+
202
+ const fusedGrads = tf . grads ( ( a , b , c ) => {
203
+ return tf . fused . matMul ( a , b , transposeA , transposeB , c , 'relu' ) ;
159
204
} ) ;
160
205
161
206
const [ da , db , dc ] = grads ( [ a , b , c ] , dy ) ;
0 commit comments