Skip to content

Commit 72f5026

Browse files
committed
save
1 parent 4d33109 commit 72f5026

10 files changed

+533
-492
lines changed

src/engine_test.ts

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ describeWithFlags('fromPixels + regular math op', WEBGL_ENVS, () => {
4545
});
4646

4747
describeWithFlags('gradients', ALL_ENVS, () => {
48-
it('matmul + relu', () => {
48+
it('matmul + relu', async () => {
4949
const a = tf.tensor2d([-1, 2, -3, 10, -20, 30], [2, 3]);
5050
const b = tf.tensor2d([2, -3, 4, -1, 2, -3], [3, 2]);
5151

@@ -67,13 +67,17 @@ describeWithFlags('gradients', ALL_ENVS, () => {
6767
expect(da.shape).toEqual(a.shape);
6868
let transposeA = false;
6969
let transposeB = true;
70-
expectArraysClose(da, tf.matMul(dedm, b, transposeA, transposeB));
70+
expectArraysClose(
71+
await da.data(),
72+
await tf.matMul(dedm, b, transposeA, transposeB).data());
7173

7274
// de/db = dot(aT, de/dy)
7375
expect(db.shape).toEqual(b.shape);
7476
transposeA = true;
7577
transposeB = false;
76-
expectArraysClose(db, tf.matMul(a, dedm, transposeA, transposeB));
78+
expectArraysClose(
79+
await db.data(),
80+
await tf.matMul(a, dedm, transposeA, transposeB).data());
7781
});
7882

7983
it('grad(f)', () => {
@@ -186,7 +190,7 @@ describeWithFlags('gradients', ALL_ENVS, () => {
186190
});
187191

188192
describeWithFlags('valueAndGradients', ALL_ENVS, () => {
189-
it('matmul + relu', () => {
193+
it('matmul + relu', async () => {
190194
const a = tf.tensor2d([-1, 2, -3, 10, -20, 30], [2, 3]);
191195
const b = tf.tensor2d([2, -3, 4, -1, 2, -3], [3, 2]);
192196

@@ -200,7 +204,7 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
200204
return tf.sum(y);
201205
})([a, b]);
202206

203-
expectArraysClose(value, 10);
207+
expectArraysClose(await value.data(), 10);
204208

205209
// de/dy = 1
206210
// dy/dm = step(m)
@@ -211,15 +215,19 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
211215
// de/da = dot(de/dy, bT)
212216
let transposeA = false;
213217
let transposeB = true;
214-
expectArraysClose(da, tf.matMul(dedm, b, transposeA, transposeB));
218+
expectArraysClose(
219+
await da.data(),
220+
await tf.matMul(dedm, b, transposeA, transposeB).data());
215221

216222
// de/db = dot(aT, de/dy)
217223
transposeA = true;
218224
transposeB = false;
219-
expectArraysClose(db, tf.matMul(a, dedm, transposeA, transposeB));
225+
expectArraysClose(
226+
await db.data(),
227+
await tf.matMul(a, dedm, transposeA, transposeB).data());
220228
});
221229

222-
it('matmul + relu + inner tidy', () => {
230+
it('matmul + relu + inner tidy', async () => {
223231
const a = tf.tensor2d([-1, 2, -3, 10, -20, 30], [2, 3]);
224232
const b = tf.tensor2d([2, -3, 4, -1, 2, -3], [3, 2]);
225233

@@ -235,7 +243,7 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
235243
});
236244
})([a, b]);
237245

238-
expectArraysClose(value, 10);
246+
expectArraysClose(await value.data(), 10);
239247

240248
// de/dy = 1
241249
// dy/dm = step(m)
@@ -246,12 +254,16 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
246254
// de/da = dot(de/dy, bT)
247255
let transposeA = false;
248256
let transposeB = true;
249-
expectArraysClose(da, tf.matMul(dedm, b, transposeA, transposeB));
257+
expectArraysClose(
258+
await da.data(),
259+
await tf.matMul(dedm, b, transposeA, transposeB).data());
250260

251261
// de/db = dot(aT, de/dy)
252262
transposeA = true;
253263
transposeB = false;
254-
expectArraysClose(db, tf.matMul(a, dedm, transposeA, transposeB));
264+
expectArraysClose(
265+
await db.data(),
266+
await tf.matMul(a, dedm, transposeA, transposeB).data());
255267
});
256268
});
257269

src/kernels/asm.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import {Tensor} from '../tensor';
2+
import {worker} from './worker';
23

34
function makeWorkers(n: number) {
45
const workers = [];
6+
const blobURL = URL.createObjectURL(new Blob(
7+
['(', worker.toString(), ')()'], {type: 'application/javascript'}));
8+
59
for (let i = 0; i < n; i++) {
6-
workers.push(new Worker('worker/worker.js'));
10+
workers.push(new Worker(blobURL));
711
}
812
return workers;
913
}

0 commit comments

Comments
 (0)