|
| 1 | +import {Conv2DInfo} from '../ops/conv_util'; |
| 2 | +import {Tensor, Tensor4D} from '../tensor'; |
| 3 | +import {TypedArray} from '../types'; |
| 4 | +import {getTypedArrayFromDType} from '../util'; |
| 5 | + |
| 6 | +import {worker} from './worker'; |
| 7 | + |
| 8 | +function makeWorkers(n: number) { |
| 9 | + const workers = []; |
| 10 | + const blobURL = URL.createObjectURL(new Blob( |
| 11 | + ['(', worker.toString(), ')()'], {type: 'application/javascript'})); |
| 12 | + |
| 13 | + for (let i = 0; i < n; i++) { |
| 14 | + workers.push(new Worker(blobURL)); |
| 15 | + } |
| 16 | + return workers; |
| 17 | +} |
| 18 | + |
| 19 | +function computeOffsets(n: number, numSplits: number) { |
| 20 | + const offsets = []; |
| 21 | + const modulo = n % numSplits; |
| 22 | + let offset = 0; |
| 23 | + for (let i = 0; i < numSplits; i++) { |
| 24 | + offsets.push(offset); |
| 25 | + offset += Math.floor(n / numSplits); |
| 26 | + if (i < modulo) { |
| 27 | + offset++; |
| 28 | + } |
| 29 | + } |
| 30 | + return offsets; |
| 31 | +} |
| 32 | + |
| 33 | +// Safari doesn't tell you, so we assume 4 cores. |
| 34 | +const nWorkers = navigator.hardwareConcurrency || 4; |
| 35 | +const workers = makeWorkers(nWorkers); |
| 36 | +workers.forEach(worker => { |
| 37 | + worker.onmessage = msg => { |
| 38 | + const [msgId, data] = msg.data; |
| 39 | + workMap.get(msgId).resolve(data); |
| 40 | + }; |
| 41 | +}); |
| 42 | +let nextWorker = 0; |
| 43 | +let nextMsgId = 0; |
| 44 | +const workMap = new Map<MsgId, Work>(); |
| 45 | + |
| 46 | +export type MsgId = number; |
| 47 | + |
| 48 | +interface Work { |
| 49 | + resolve: (data: {}) => void; |
| 50 | +} |
| 51 | + |
| 52 | +export function sendWork(type: string, data: {}): Promise<{}> { |
| 53 | + const worker = workers[nextWorker]; |
| 54 | + nextWorker = (nextWorker + 1) % nWorkers; |
| 55 | + return new Promise(resolve => { |
| 56 | + const msgId = nextMsgId++; |
| 57 | + workMap.set(msgId, {resolve}); |
| 58 | + worker.postMessage([msgId, type, data]); |
| 59 | + }); |
| 60 | +} |
| 61 | + |
| 62 | +export async function conv2d( |
| 63 | + convInfo: Conv2DInfo, x: Tensor4D, filter: Tensor4D): Promise<TypedArray> { |
| 64 | + const { |
| 65 | + batchSize, |
| 66 | + padInfo: {top}, |
| 67 | + strideHeight, |
| 68 | + filterHeight, |
| 69 | + inHeight, |
| 70 | + dilationHeight, |
| 71 | + outHeight, |
| 72 | + outWidth, |
| 73 | + outChannels |
| 74 | + } = convInfo; |
| 75 | + |
| 76 | + const [xVal, wVal] = await Promise.all([x.data(), filter.data()]); |
| 77 | + const start = performance.now(); |
| 78 | + const jobSize = convInfo.outHeight; |
| 79 | + const nSplits = Math.min(jobSize, nWorkers); |
| 80 | + const offsets = computeOffsets(jobSize, nSplits); |
| 81 | + const jobs: Array<Promise<{}>> = []; |
| 82 | + for (let b = 0; b < batchSize; b++) { |
| 83 | + for (let i = 0; i < nSplits; i++) { |
| 84 | + const yRStart = offsets[i]; |
| 85 | + const yREnd = i + 1 < offsets.length ? offsets[i + 1] : jobSize; |
| 86 | + const xRstart = Math.max(0, yRStart * strideHeight - top); |
| 87 | + const xREnd = Math.min( |
| 88 | + inHeight, |
| 89 | + (yREnd - 1) * strideHeight - top + |
| 90 | + (filterHeight - 1) * dilationHeight + 1); |
| 91 | + const xOffset = b * x.strides[0] + xRstart * x.strides[1]; |
| 92 | + const xOffsetEnd = b * x.strides[0] + xREnd * x.strides[1]; |
| 93 | + const xSub = xVal.subarray(xOffset, xOffsetEnd); |
| 94 | + jobs.push(sendWork( |
| 95 | + 'conv2d', [convInfo, yRStart, yREnd, xRstart, xREnd, xSub, wVal])); |
| 96 | + } |
| 97 | + } |
| 98 | + const yStrideHWC = outHeight * outWidth * outChannels; |
| 99 | + const yStrideWC = outWidth * outChannels; |
| 100 | + const y = getTypedArrayFromDType( |
| 101 | + x.dtype as 'float32', batchSize * outHeight * outWidth * outChannels); |
| 102 | + const results = await Promise.all(jobs); |
| 103 | + for (let b = 0; b < batchSize; b++) { |
| 104 | + for (let i = 0; i < nSplits; i++) { |
| 105 | + const resIdx = b * nSplits + i; |
| 106 | + const data = results[resIdx]; |
| 107 | + const resOffset = b * yStrideHWC + offsets[i] * yStrideWC; |
| 108 | + y.set(data as Float32Array, resOffset); |
| 109 | + } |
| 110 | + } |
| 111 | + console.log('conv2d worker took', (performance.now() - start).toFixed(0)); |
| 112 | + return y; |
| 113 | +} |
| 114 | + |
| 115 | +export async function matmul( |
| 116 | + a: Tensor, b: Tensor, transposeA: boolean, |
| 117 | + transposeB: boolean): Promise<Float32Array> { |
| 118 | + const innerDim = transposeA ? a.shape[1] : a.shape[2]; |
| 119 | + const leftDim = transposeA ? a.shape[2] : a.shape[1]; |
| 120 | + const rightDim = transposeB ? b.shape[1] : b.shape[2]; |
| 121 | + const batchDim = a.shape[0]; |
| 122 | + const aSize = leftDim * innerDim; |
| 123 | + const bSize = innerDim * rightDim; |
| 124 | + const cSize = leftDim * rightDim; |
| 125 | + |
| 126 | + const nSplits = Math.min(leftDim, nWorkers); |
| 127 | + const offsets = computeOffsets(leftDim, nSplits); |
| 128 | + const [aVals, bVals] = await Promise.all([a.data(), b.data()]); |
| 129 | + |
| 130 | + const start = performance.now(); |
| 131 | + const resVals = getTypedArrayFromDType( |
| 132 | + a.dtype as 'float32', batchDim * leftDim * rightDim); |
| 133 | + |
| 134 | + const jobs: Array<Promise<{}>> = []; |
| 135 | + for (let b = 0; b < batchDim; b++) { |
| 136 | + for (let i = 0; i < nSplits; i++) { |
| 137 | + const aOffset = b * aSize + offsets[i] * innerDim; |
| 138 | + const nextOffset = i + 1 < offsets.length ? |
| 139 | + b * aSize + offsets[i + 1] * innerDim : |
| 140 | + aOffset + aSize; |
| 141 | + const aSubVals = aVals.subarray(aOffset, nextOffset); |
| 142 | + const bOffset = b * bSize; |
| 143 | + const bSubVals = bVals.subarray(bOffset, bOffset + bSize); |
| 144 | + // const leftDim = aSubVals.length / innerDim; |
| 145 | + // const rightDim = bSubVals.length / innerDim; |
| 146 | + console.log( |
| 147 | + `partial matmul ` + |
| 148 | + `${leftDim}x${innerDim} * ${innerDim}x${rightDim}`); |
| 149 | + jobs.push(sendWork('matmul', [aSubVals, bSubVals, innerDim])); |
| 150 | + } |
| 151 | + } |
| 152 | + const results = await Promise.all(jobs); |
| 153 | + for (let b = 0; b < batchDim; b++) { |
| 154 | + for (let i = 0; i < nSplits; i++) { |
| 155 | + const resIdx = b * nSplits + i; |
| 156 | + const data = results[resIdx]; |
| 157 | + const resOffset = b * cSize + offsets[i] * rightDim; |
| 158 | + resVals.set(data as Float32Array, resOffset); |
| 159 | + } |
| 160 | + } |
| 161 | + console.log('matmul worker took', (performance.now() - start).toFixed(0)); |
| 162 | + return resVals; |
| 163 | +} |
0 commit comments