Skip to content

Commit 29ecc85

Browse files
committed
save
1 parent 72f5026 commit 29ecc85

File tree

4 files changed

+127
-185
lines changed

4 files changed

+127
-185
lines changed

src/kernels/asm.ts

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import {Tensor} from '../tensor';
2+
import {getTypedArrayFromDType} from '../util';
3+
24
import {worker} from './worker';
35

46
function makeWorkers(n: number) {
@@ -26,36 +28,74 @@ function computeOffsets(n: number, numSplits: number) {
2628
return offsets;
2729
}
2830

29-
// Safari doesn't have one.
31+
// Safari doesn't tell you, so we assume 4 cores.
3032
const nWorkers = navigator.hardwareConcurrency || 4;
3133
const workers = makeWorkers(nWorkers);
34+
workers.forEach(worker => {
35+
worker.onmessage = msg => {
36+
const [msgId, data] = msg.data;
37+
workMap.get(msgId).resolve(data);
38+
};
39+
});
40+
let nextWorker = 0;
41+
let nextMsgId = 0;
42+
const workMap = new Map<MsgId, Work>();
3243

33-
export async function matmul(a: Tensor, b: Tensor): Promise<Float32Array> {
34-
const aSize = a.shape[0];
35-
const bSize = b.shape[1];
36-
const k = a.shape[1];
44+
export type MsgId = number;
3745

38-
const offsets = computeOffsets(aSize, nWorkers);
39-
const [aVals, bVals] = await Promise.all([a.data(), b.data()]);
46+
interface Work {
47+
resolve: (data: {}) => void;
48+
}
4049

41-
// const res = zeros([aSize, bSize]);
42-
const resVals = new Float32Array(aSize * bSize);
43-
let count = 0;
44-
return new Promise<Float32Array>(resolve => {
45-
workers.forEach((worker, i) => {
46-
worker.onmessage = e => {
47-
const offset = offsets[i] * bSize;
48-
resVals.set(e.data, offset);
49-
count++;
50-
if (count === nWorkers) {
51-
resolve(resVals);
52-
}
53-
};
54-
const offset = offsets[i] * k;
55-
const nextOffset =
56-
i + 1 < offsets.length ? offsets[i + 1] * k : undefined;
57-
const aSubVals = aVals.subarray(offset, nextOffset);
58-
worker.postMessage([aSubVals, bVals, k]);
59-
});
50+
export function sendWork(data: {}): Promise<{}> {
51+
const worker = workers[nextWorker];
52+
nextWorker = (nextWorker + 1) % nWorkers;
53+
return new Promise(resolve => {
54+
const msgId = nextMsgId++;
55+
workMap.set(msgId, {resolve});
56+
worker.postMessage([msgId, data]);
6057
});
6158
}
59+
60+
export async function matmul(
61+
a: Tensor, b: Tensor, transposeA: boolean,
62+
transposeB: boolean): Promise<Float32Array> {
63+
const innerDim = transposeA ? a.shape[1] : a.shape[2];
64+
const leftDim = transposeA ? a.shape[2] : a.shape[1];
65+
const rightDim = transposeB ? b.shape[1] : b.shape[2];
66+
const batchDim = a.shape[0];
67+
const aSize = leftDim * innerDim;
68+
const bSize = innerDim * rightDim;
69+
const cSize = leftDim * rightDim;
70+
71+
const nSplits = Math.min(leftDim, nWorkers);
72+
const offsets = computeOffsets(leftDim, nSplits);
73+
const [aVals, bVals] = await Promise.all([a.data(), b.data()]);
74+
75+
const resVals = getTypedArrayFromDType(
76+
a.dtype as 'float32', batchDim * leftDim * rightDim);
77+
78+
const jobs: Array<Promise<{}>> = [];
79+
for (let b = 0; b < batchDim; b++) {
80+
for (let i = 0; i < nSplits; i++) {
81+
const aOffset = b * aSize + offsets[i] * innerDim;
82+
const nextOffset = i + 1 < offsets.length ?
83+
b * aSize + offsets[i + 1] * innerDim :
84+
aOffset + aSize;
85+
const aSubVals = aVals.subarray(aOffset, nextOffset);
86+
const bOffset = b * bSize;
87+
const bSubVals = bVals.subarray(bOffset, bOffset + bSize);
88+
jobs.push(sendWork([aSubVals, bSubVals, innerDim]));
89+
}
90+
}
91+
const results = await Promise.all(jobs);
92+
for (let b = 0; b < batchDim; b++) {
93+
for (let i = 0; i < nSplits; i++) {
94+
const resIdx = b * nSplits + i;
95+
const data = results[resIdx];
96+
const resOffset = b * cSize + offsets[i] * rightDim;
97+
resVals.set(data as Float32Array, resOffset);
98+
}
99+
}
100+
return resVals;
101+
}

src/kernels/backend_cpu.ts

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../ops/
3434
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../tensor';
3535
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../types';
3636
import * as util from '../util';
37-
// import * as asm from './asm';
37+
import * as asm from './asm';
3838
import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../util';
39-
4039
import {BackendTimingInfo, DataMover, DataStorage, KernelBackend} from './backend';
4140
import * as backend_util from './backend_util';
4241
import * as complex_util from './complex_util';
@@ -460,63 +459,60 @@ export class MathBackendCPU implements KernelBackend {
460459
transposeB: boolean): Tensor3D {
461460
this.assertNotComplex([a, b], 'matMul');
462461

463-
const sharedDim = transposeA ? a.shape[1] : a.shape[2];
462+
// const sharedDim = transposeA ? a.shape[1] : a.shape[2];
464463
const leftDim = transposeA ? a.shape[2] : a.shape[1];
465464
const rightDim = transposeB ? b.shape[1] : b.shape[2];
466465
const batchDim = a.shape[0];
467-
// const nWorkers = navigator.hardwareConcurrency || 4;
468466
const outShape = [batchDim, leftDim, rightDim];
469-
// if (batchDim === 1 && a.shape[0] >= nWorkers) {
470-
// console.warn('asking for asm');
471-
// const values = asm.matmul(a.squeeze([0]), b.squeeze([0]));
472-
// return Tensor.make(outShape, {values}, a.dtype);
473-
// }
474-
475-
const compute = async () => {
476-
const [aValues, bValues] = await Promise.all([a.data(), b.data()]);
477-
const [aOuterStep, aInnerStep] =
478-
transposeA ? [1, a.strides[1]] : [a.strides[1], 1];
479-
const [bInnerStep, bOuterStep] =
480-
transposeB ? [1, b.strides[1]] : [b.strides[1], 1];
481-
482-
const resVals = util.getTypedArrayFromDType(
483-
a.dtype as 'float32', sizeFromShape(outShape));
484-
const blockSize = this.blockSize;
485-
486-
for (let batch = 0; batch < batchDim; batch++) {
487-
const aBatch = batch * a.strides[0];
488-
const bBatch = batch * b.strides[0];
489-
for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
490-
const iBlock = i0 + blockSize < leftDim ? i0 + blockSize : leftDim;
491-
for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
492-
const jBlock =
493-
j0 + blockSize < rightDim ? j0 + blockSize : rightDim;
494-
for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
495-
// for when blockSize doesn't evenly divide the input
496-
const kBlock =
497-
k0 + blockSize < sharedDim ? k0 + blockSize : sharedDim;
498-
499-
for (let i = i0; i < iBlock; i++) {
500-
const iDim = i * rightDim;
501-
const iStep = aBatch + i * aOuterStep;
502-
for (let j = j0; j < jBlock; j++) {
503-
const jStep = bBatch + j * bOuterStep;
504-
let sum = 0.0;
505-
506-
for (let k = k0; k < kBlock; k++) {
507-
sum += aValues[k * aInnerStep + iStep] *
508-
bValues[k * bInnerStep + jStep];
509-
}
510-
resVals[iDim + j] += sum;
511-
}
512-
}
513-
}
514-
}
515-
}
516-
}
517-
return resVals;
518-
};
519-
return Tensor.make(outShape, {values: compute()}, a.dtype) as Tensor3D;
467+
const values = asm.matmul(a, b, transposeA, transposeB);
468+
return Tensor.make(outShape, {values}, a.dtype);
469+
470+
// const compute = async () => {
471+
// const [aValues, bValues] = await Promise.all([a.data(), b.data()]);
472+
// const [aOuterStep, aInnerStep] =
473+
// transposeA ? [1, a.strides[1]] : [a.strides[1], 1];
474+
// const [bInnerStep, bOuterStep] =
475+
// transposeB ? [1, b.strides[1]] : [b.strides[1], 1];
476+
477+
// const resVals = util.getTypedArrayFromDType(
478+
// a.dtype as 'float32', sizeFromShape(outShape));
479+
// const blockSize = this.blockSize;
480+
481+
// for (let batch = 0; batch < batchDim; batch++) {
482+
// const aBatch = batch * a.strides[0];
483+
// const bBatch = batch * b.strides[0];
484+
// const resBatch = batch * leftDim * rightDim;
485+
// for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
486+
// const iBlock = i0 + blockSize < leftDim ? i0 + blockSize : leftDim;
487+
// for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
488+
// const jBlock =
489+
// j0 + blockSize < rightDim ? j0 + blockSize : rightDim;
490+
// for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
491+
// // for when blockSize doesn't evenly divide the input
492+
// const kBlock =
493+
// k0 + blockSize < sharedDim ? k0 + blockSize : sharedDim;
494+
495+
// for (let i = i0; i < iBlock; i++) {
496+
// const iDim = resBatch + i * rightDim;
497+
// const iStep = aBatch + i * aOuterStep;
498+
// for (let j = j0; j < jBlock; j++) {
499+
// const jStep = bBatch + j * bOuterStep;
500+
// let sum = 0.0;
501+
502+
// for (let k = k0; k < kBlock; k++) {
503+
// sum += aValues[k * aInnerStep + iStep] *
504+
// bValues[k * bInnerStep + jStep];
505+
// }
506+
// resVals[iDim + j] += sum;
507+
// }
508+
// }
509+
// }
510+
// }
511+
// }
512+
// }
513+
// return resVals;
514+
// };
515+
// return Tensor.make(outShape, {values: compute()}, a.dtype) as Tensor3D;
520516
}
521517

522518
fusedBatchMatMul(

src/kernels/worker.ts

Lines changed: 5 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,6 @@ export function worker() {
1717
bSize = bSize | 0;
1818
mid = mid | 0;
1919

20-
// Variable declaration.
21-
var offset = 0;
22-
var i = 0;
23-
var j = 0;
24-
var k = 0;
25-
var bOffset = 0;
26-
var cOffset = 0;
27-
var a = fround(0);
28-
var b = fround(0);
29-
var dot = fround(0);
30-
31-
bOffset = imul(aSize, mid);
32-
cOffset = (bOffset + imul(bSize, mid)) | 0;
33-
34-
for (i = 0; (i | 0) < (aSize | 0); i = (i + 1) | 0) {
35-
for (j = 0; (j | 0) < (bSize | 0); j = (j + 1) | 0) {
36-
dot = fround(0);
37-
for (k = 0; (k | 0) < (mid | 0); k = (k + 1) | 0) {
38-
offset = (imul(i, mid) + k) << 2;
39-
a = fround(heap32[offset >> 2]); // a[i * mid + k]
40-
41-
offset = (bOffset + imul(k, bSize) + j) << 2;
42-
b = fround(heap32[offset >> 2]); // b[k * bSize + j]
43-
44-
dot = fround(dot + fround(a * b));
45-
}
46-
offset = (cOffset + imul(i, bSize) + j) << 2;
47-
heap32[offset >> 2] = fround(dot);
48-
}
49-
}
50-
}
51-
52-
function matmulBlocked(aSize: number, bSize: number, mid: number) {
53-
// Function arguments.
54-
aSize = aSize | 0;
55-
bSize = bSize | 0;
56-
mid = mid | 0;
57-
5820
// Variable declaration.
5921
var offset = 0;
6022
var blockSize = 48;
@@ -109,81 +71,24 @@ export function worker() {
10971
}
11072
}
11173
}
112-
return {matmul: matmul, matmulBlocked: matmulBlocked};
74+
return {matmul: matmul};
11375
}
11476

11577
var heap = new ArrayBuffer(1024 * 1024 * 16); // 128k heap
11678
var heapF32 = new Float32Array(heap);
11779
var asm = ASMModule(self as any, null, heap);
11880

119-
// @ts-ignore
120-
function matmulSimple(aVals: Float32Array, bVals: Float32Array, mid: number) {
121-
const aSize = aVals.length / mid;
122-
const bSize = bVals.length / mid;
123-
const res = new Float32Array(aSize * bSize);
124-
for (let i = 0; i < aSize; ++i) {
125-
const iMid = i * mid;
126-
const iBSize = i * bSize;
127-
for (let j = 0; j < bSize; ++j) {
128-
let dot = 0;
129-
for (let k = 0; k < mid; ++k) {
130-
dot += aVals[iMid + k] * bVals[k * bSize + j];
131-
}
132-
res[iBSize + j] = dot;
133-
}
134-
}
135-
return res;
136-
}
137-
138-
// @ts-ignore
139-
function matmulBlocked(
140-
aVals: Float32Array, bVals: Float32Array, mid: number) {
141-
const aSize = aVals.length / mid;
142-
const bSize = bVals.length / mid;
143-
144-
const res = new Float32Array(aSize * bSize);
145-
const blockSize = 48;
146-
147-
for (let i0 = 0; i0 < aSize; i0 += blockSize) {
148-
for (let j0 = 0; j0 < bSize; j0 += blockSize) {
149-
for (let k0 = 0; k0 < mid; k0 += blockSize) {
150-
// for when blockSize doesn't evenly divide the input
151-
const iBlock = Math.min(i0 + blockSize, aSize);
152-
const jBlock = Math.min(j0 + blockSize, bSize);
153-
const kBlock = Math.min(k0 + blockSize, mid);
154-
155-
for (let i = i0; i < iBlock; i++) {
156-
for (let j = j0; j < jBlock; j++) {
157-
let sum = 0.0;
158-
159-
for (let k = k0; k < kBlock; k++) {
160-
sum += aVals[i * mid + k] * bVals[k * bSize + j];
161-
}
162-
res[i * bSize + j] += sum;
163-
}
164-
}
165-
}
166-
}
167-
}
168-
return res;
169-
}
170-
171-
self.onmessage = function(e) {
172-
const [aVals, bVals, mid] = e.data;
81+
self.onmessage = function(msg) {
82+
const [msgId, [aVals, bVals, mid]] = msg.data;
17383

17484
const aSize = aVals.length / mid;
17585
const bSize = bVals.length / mid;
17686
heapF32.set(aVals, 0);
17787
heapF32.set(bVals, aVals.length);
17888
const offset = aVals.length + bVals.length;
17989
heapF32.fill(0, offset, offset + aSize * bSize);
180-
asm.matmulBlocked(aSize, bSize, mid);
90+
asm.matmul(aSize, bSize, mid);
18191
const res = heapF32.slice(offset, offset + aSize * bSize);
182-
183-
// const res = matmulSimple(aVals, bVals, mid);
184-
// const res = matmulBlocked(aVals, bVals, mid);
185-
186-
// @ts-ignore
187-
self.postMessage(res);
92+
self.postMessage([msgId, res], null);
18893
}
18994
}

0 commit comments

Comments
 (0)