@@ -45,7 +45,7 @@ describeWithFlags('fromPixels + regular math op', WEBGL_ENVS, () => {
45
45
} ) ;
46
46
47
47
describeWithFlags ( 'gradients' , ALL_ENVS , ( ) => {
48
- it ( 'matmul + relu' , ( ) => {
48
+ it ( 'matmul + relu' , async ( ) => {
49
49
const a = tf . tensor2d ( [ - 1 , 2 , - 3 , 10 , - 20 , 30 ] , [ 2 , 3 ] ) ;
50
50
const b = tf . tensor2d ( [ 2 , - 3 , 4 , - 1 , 2 , - 3 ] , [ 3 , 2 ] ) ;
51
51
@@ -67,13 +67,17 @@ describeWithFlags('gradients', ALL_ENVS, () => {
67
67
expect ( da . shape ) . toEqual ( a . shape ) ;
68
68
let transposeA = false ;
69
69
let transposeB = true ;
70
- expectArraysClose ( da , tf . matMul ( dedm , b , transposeA , transposeB ) ) ;
70
+ expectArraysClose (
71
+ await da . data ( ) ,
72
+ await tf . matMul ( dedm , b , transposeA , transposeB ) . data ( ) ) ;
71
73
72
74
// de/db = dot(aT, de/dy)
73
75
expect ( db . shape ) . toEqual ( b . shape ) ;
74
76
transposeA = true ;
75
77
transposeB = false ;
76
- expectArraysClose ( db , tf . matMul ( a , dedm , transposeA , transposeB ) ) ;
78
+ expectArraysClose (
79
+ await db . data ( ) ,
80
+ await tf . matMul ( a , dedm , transposeA , transposeB ) . data ( ) ) ;
77
81
} ) ;
78
82
79
83
it ( 'grad(f)' , ( ) => {
@@ -186,7 +190,7 @@ describeWithFlags('gradients', ALL_ENVS, () => {
186
190
} ) ;
187
191
188
192
describeWithFlags ( 'valueAndGradients' , ALL_ENVS , ( ) => {
189
- it ( 'matmul + relu' , ( ) => {
193
+ it ( 'matmul + relu' , async ( ) => {
190
194
const a = tf . tensor2d ( [ - 1 , 2 , - 3 , 10 , - 20 , 30 ] , [ 2 , 3 ] ) ;
191
195
const b = tf . tensor2d ( [ 2 , - 3 , 4 , - 1 , 2 , - 3 ] , [ 3 , 2 ] ) ;
192
196
@@ -200,7 +204,7 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
200
204
return tf . sum ( y ) ;
201
205
} ) ( [ a , b ] ) ;
202
206
203
- expectArraysClose ( value , 10 ) ;
207
+ expectArraysClose ( await value . data ( ) , 10 ) ;
204
208
205
209
// de/dy = 1
206
210
// dy/dm = step(m)
@@ -211,15 +215,19 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
211
215
// de/da = dot(de/dy, bT)
212
216
let transposeA = false ;
213
217
let transposeB = true ;
214
- expectArraysClose ( da , tf . matMul ( dedm , b , transposeA , transposeB ) ) ;
218
+ expectArraysClose (
219
+ await da . data ( ) ,
220
+ await tf . matMul ( dedm , b , transposeA , transposeB ) . data ( ) ) ;
215
221
216
222
// de/db = dot(aT, de/dy)
217
223
transposeA = true ;
218
224
transposeB = false ;
219
- expectArraysClose ( db , tf . matMul ( a , dedm , transposeA , transposeB ) ) ;
225
+ expectArraysClose (
226
+ await db . data ( ) ,
227
+ await tf . matMul ( a , dedm , transposeA , transposeB ) . data ( ) ) ;
220
228
} ) ;
221
229
222
- it ( 'matmul + relu + inner tidy' , ( ) => {
230
+ it ( 'matmul + relu + inner tidy' , async ( ) => {
223
231
const a = tf . tensor2d ( [ - 1 , 2 , - 3 , 10 , - 20 , 30 ] , [ 2 , 3 ] ) ;
224
232
const b = tf . tensor2d ( [ 2 , - 3 , 4 , - 1 , 2 , - 3 ] , [ 3 , 2 ] ) ;
225
233
@@ -235,7 +243,7 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
235
243
} ) ;
236
244
} ) ( [ a , b ] ) ;
237
245
238
- expectArraysClose ( value , 10 ) ;
246
+ expectArraysClose ( await value . data ( ) , 10 ) ;
239
247
240
248
// de/dy = 1
241
249
// dy/dm = step(m)
@@ -246,12 +254,16 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
246
254
// de/da = dot(de/dy, bT)
247
255
let transposeA = false ;
248
256
let transposeB = true ;
249
- expectArraysClose ( da , tf . matMul ( dedm , b , transposeA , transposeB ) ) ;
257
+ expectArraysClose (
258
+ await da . data ( ) ,
259
+ await tf . matMul ( dedm , b , transposeA , transposeB ) . data ( ) ) ;
250
260
251
261
// de/db = dot(aT, de/dy)
252
262
transposeA = true ;
253
263
transposeB = false ;
254
- expectArraysClose ( db , tf . matMul ( a , dedm , transposeA , transposeB ) ) ;
264
+ expectArraysClose (
265
+ await db . data ( ) ,
266
+ await tf . matMul ( a , dedm , transposeA , transposeB ) . data ( ) ) ;
255
267
} ) ;
256
268
} ) ;
257
269
0 commit comments