Skip to content

Commit faca022

Browse files
committed
save
1 parent f3b8706 commit faca022

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed

src/kernels/webworker.ts

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import {Conv2DInfo} from '../ops/conv_util';
2+
import {Tensor, Tensor4D} from '../tensor';
3+
import {TypedArray} from '../types';
4+
import {getTypedArrayFromDType} from '../util';
5+
6+
import {worker} from './worker';
7+
8+
function makeWorkers(n: number) {
9+
const workers = [];
10+
const blobURL = URL.createObjectURL(new Blob(
11+
['(', worker.toString(), ')()'], {type: 'application/javascript'}));
12+
13+
for (let i = 0; i < n; i++) {
14+
workers.push(new Worker(blobURL));
15+
}
16+
return workers;
17+
}
18+
19+
function computeOffsets(n: number, numSplits: number) {
20+
const offsets = [];
21+
const modulo = n % numSplits;
22+
let offset = 0;
23+
for (let i = 0; i < numSplits; i++) {
24+
offsets.push(offset);
25+
offset += Math.floor(n / numSplits);
26+
if (i < modulo) {
27+
offset++;
28+
}
29+
}
30+
return offsets;
31+
}
32+
33+
// Safari doesn't tell you, so we assume 4 cores.
34+
const nWorkers = navigator.hardwareConcurrency || 4;
35+
const workers = makeWorkers(nWorkers);
36+
workers.forEach(worker => {
37+
worker.onmessage = msg => {
38+
const [msgId, data] = msg.data;
39+
workMap.get(msgId).resolve(data);
40+
};
41+
});
42+
let nextWorker = 0;
43+
let nextMsgId = 0;
44+
const workMap = new Map<MsgId, Work>();
45+
46+
export type MsgId = number;
47+
48+
interface Work {
49+
resolve: (data: {}) => void;
50+
}
51+
52+
export function sendWork(type: string, data: {}): Promise<{}> {
53+
const worker = workers[nextWorker];
54+
nextWorker = (nextWorker + 1) % nWorkers;
55+
return new Promise(resolve => {
56+
const msgId = nextMsgId++;
57+
workMap.set(msgId, {resolve});
58+
worker.postMessage([msgId, type, data]);
59+
});
60+
}
61+
62+
export async function conv2d(
63+
convInfo: Conv2DInfo, x: Tensor4D, filter: Tensor4D): Promise<TypedArray> {
64+
const {
65+
batchSize,
66+
padInfo: {top},
67+
strideHeight,
68+
filterHeight,
69+
inHeight,
70+
dilationHeight,
71+
outHeight,
72+
outWidth,
73+
outChannels
74+
} = convInfo;
75+
76+
const [xVal, wVal] = await Promise.all([x.data(), filter.data()]);
77+
const start = performance.now();
78+
const jobSize = convInfo.outHeight;
79+
const nSplits = Math.min(jobSize, nWorkers);
80+
const offsets = computeOffsets(jobSize, nSplits);
81+
const jobs: Array<Promise<{}>> = [];
82+
for (let b = 0; b < batchSize; b++) {
83+
for (let i = 0; i < nSplits; i++) {
84+
const yRStart = offsets[i];
85+
const yREnd = i + 1 < offsets.length ? offsets[i + 1] : jobSize;
86+
const xRstart = Math.max(0, yRStart * strideHeight - top);
87+
const xREnd = Math.min(
88+
inHeight,
89+
(yREnd - 1) * strideHeight - top +
90+
(filterHeight - 1) * dilationHeight + 1);
91+
const xOffset = b * x.strides[0] + xRstart * x.strides[1];
92+
const xOffsetEnd = b * x.strides[0] + xREnd * x.strides[1];
93+
const xSub = xVal.subarray(xOffset, xOffsetEnd);
94+
jobs.push(sendWork(
95+
'conv2d', [convInfo, yRStart, yREnd, xRstart, xREnd, xSub, wVal]));
96+
}
97+
}
98+
const yStrideHWC = outHeight * outWidth * outChannels;
99+
const yStrideWC = outWidth * outChannels;
100+
const y = getTypedArrayFromDType(
101+
x.dtype as 'float32', batchSize * outHeight * outWidth * outChannels);
102+
const results = await Promise.all(jobs);
103+
for (let b = 0; b < batchSize; b++) {
104+
for (let i = 0; i < nSplits; i++) {
105+
const resIdx = b * nSplits + i;
106+
const data = results[resIdx];
107+
const resOffset = b * yStrideHWC + offsets[i] * yStrideWC;
108+
y.set(data as Float32Array, resOffset);
109+
}
110+
}
111+
console.log('conv2d worker took', (performance.now() - start).toFixed(0));
112+
return y;
113+
}
114+
115+
export async function matmul(
116+
a: Tensor, b: Tensor, transposeA: boolean,
117+
transposeB: boolean): Promise<Float32Array> {
118+
const innerDim = transposeA ? a.shape[1] : a.shape[2];
119+
const leftDim = transposeA ? a.shape[2] : a.shape[1];
120+
const rightDim = transposeB ? b.shape[1] : b.shape[2];
121+
const batchDim = a.shape[0];
122+
const aSize = leftDim * innerDim;
123+
const bSize = innerDim * rightDim;
124+
const cSize = leftDim * rightDim;
125+
126+
const nSplits = Math.min(leftDim, nWorkers);
127+
const offsets = computeOffsets(leftDim, nSplits);
128+
const [aVals, bVals] = await Promise.all([a.data(), b.data()]);
129+
130+
const start = performance.now();
131+
const resVals = getTypedArrayFromDType(
132+
a.dtype as 'float32', batchDim * leftDim * rightDim);
133+
134+
const jobs: Array<Promise<{}>> = [];
135+
for (let b = 0; b < batchDim; b++) {
136+
for (let i = 0; i < nSplits; i++) {
137+
const aOffset = b * aSize + offsets[i] * innerDim;
138+
const nextOffset = i + 1 < offsets.length ?
139+
b * aSize + offsets[i + 1] * innerDim :
140+
aOffset + aSize;
141+
const aSubVals = aVals.subarray(aOffset, nextOffset);
142+
const bOffset = b * bSize;
143+
const bSubVals = bVals.subarray(bOffset, bOffset + bSize);
144+
// const leftDim = aSubVals.length / innerDim;
145+
// const rightDim = bSubVals.length / innerDim;
146+
console.log(
147+
`partial matmul ` +
148+
`${leftDim}x${innerDim} * ${innerDim}x${rightDim}`);
149+
jobs.push(sendWork('matmul', [aSubVals, bSubVals, innerDim]));
150+
}
151+
}
152+
const results = await Promise.all(jobs);
153+
for (let b = 0; b < batchDim; b++) {
154+
for (let i = 0; i < nSplits; i++) {
155+
const resIdx = b * nSplits + i;
156+
const data = results[resIdx];
157+
const resOffset = b * cSize + offsets[i] * rightDim;
158+
resVals.set(data as Float32Array, resOffset);
159+
}
160+
}
161+
console.log('matmul worker took', (performance.now() - start).toFixed(0));
162+
return resVals;
163+
}

0 commit comments

Comments
 (0)