Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 5c0d017

Browse files
authored
Fuse conv2d with bias and activation. (#1859)
FEATURE PERF
1 parent 95e44a5 commit 5c0d017

File tree

7 files changed

+495
-15
lines changed

7 files changed

+495
-15
lines changed

src/backends/backend.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
411411
throw new Error('Not yet implemented');
412412
}
413413

414+
fusedConv2d(
415+
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
416+
activation?: Activation): Tensor4D {
417+
throw new Error('Not yet implemented');
418+
}
419+
414420
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
415421
throw new Error('Not yet implemented');
416422
}

src/backends/cpu/backend_cpu.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,20 @@ export class MathBackendCPU implements KernelBackend {
15131513
return Tensor.make(x.shape, {values: resultValues}) as T;
15141514
}
15151515

1516+
fusedConv2d(
1517+
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1518+
activation?: Activation): Tensor4D {
1519+
let result = this.conv2d(x, filter, convInfo);
1520+
1521+
if (bias) {
1522+
result = this.add(result, bias) as Tensor4D;
1523+
}
1524+
if (activation) {
1525+
result = mapActivation(this, activation, result) as Tensor4D;
1526+
}
1527+
return result;
1528+
}
1529+
15161530
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
15171531
this.assertNotComplex([x, filter], 'conv2d');
15181532

src/backends/webgl/backend_webgl.ts

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -873,10 +873,12 @@ export class MathBackendWebGL implements KernelBackend {
873873

874874
const dtype = upcastType(a.dtype, b.dtype);
875875

876+
const hasBias = bias != null;
877+
const fusedActivation =
878+
activation ? mapActivationToShaderProgram(activation, true) : null;
876879
const program = new MatMulPackedProgram(
877880
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,
878-
!!bias,
879-
activation ? mapActivationToShaderProgram(activation, true) : null);
881+
hasBias, fusedActivation);
880882
const output =
881883
this.makePackedTensor(program.outputShape, dtype) as Tensor3D;
882884
const inputs: TensorHandle[] = [a, b];
@@ -1815,15 +1817,18 @@ export class MathBackendWebGL implements KernelBackend {
18151817
return this.compileAndRun(program, [x]) as T;
18161818
}
18171819

1818-
conv2dByMatMul(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
1819-
Tensor4D {
1820+
private conv2dByMatMul(
1821+
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1822+
activation?: Activation): Tensor4D {
18201823
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
18211824
// result from 2D to 4D.
18221825
const xShape = x.shape;
18231826
const xTexData = this.texData.get(x.dataId);
18241827
const sharedMatMulDim = convInfo.inChannels;
18251828
const outerShapeX = xShape[0] * xShape[1] * xShape[2];
18261829
const outerShapeFilter = convInfo.outChannels;
1830+
const transposeA = false;
1831+
const transposeB = false;
18271832

18281833
// TODO: Once reduction ops are packed, batchMatMul will always be packed
18291834
// and we can remove this condition.
@@ -1843,8 +1848,11 @@ export class MathBackendWebGL implements KernelBackend {
18431848
this.reshape(
18441849
filter, [1, convInfo.inChannels, convInfo.outChannels]) as
18451850
Tensor3D;
1851+
18461852
return this.reshape<Rank.R4>(
1847-
this.batchMatMul(xReshaped, filterReshaped, false, false),
1853+
this.fusedBatchMatMul(
1854+
xReshaped, filterReshaped, transposeA, transposeB, bias,
1855+
activation),
18481856
convInfo.outShape);
18491857
}
18501858

@@ -1880,8 +1888,8 @@ export class MathBackendWebGL implements KernelBackend {
18801888
this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]) as
18811889
Tensor3D;
18821890

1883-
const pointwiseConv =
1884-
this.batchMatMul(xReshaped, filterReshaped, false, false);
1891+
const pointwiseConv = this.fusedBatchMatMul(
1892+
xReshaped, filterReshaped, transposeA, transposeB, bias, activation);
18851893
const pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);
18861894
util.assert(
18871895
pointwiseConvTexData.isPacked,
@@ -1896,8 +1904,9 @@ export class MathBackendWebGL implements KernelBackend {
18961904
pointwiseConv.dtype, this) as Tensor4D;
18971905
}
18981906

1899-
conv2dWithIm2Row(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
1900-
Tensor4D {
1907+
private conv2dWithIm2Row(
1908+
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1909+
activation?: Activation): Tensor4D {
19011910
// Rearranges conv2d input so each block to be convolved over forms the
19021911
// column of a new matrix with shape [filterWidth * filterHeight *
19031912
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
@@ -1915,6 +1924,8 @@ export class MathBackendWebGL implements KernelBackend {
19151924
const sharedDim = filterWidth * filterHeight * inChannels;
19161925
const numCols = outHeight * outWidth;
19171926
const x2ColShape = [sharedDim, numCols];
1927+
const transposeA = true;
1928+
const transposeB = false;
19181929

19191930
const xSqueezed = x.squeeze([0]);
19201931
const w2Row = filter.reshape([1, sharedDim, -1]) as Tensor3D;
@@ -1926,14 +1937,46 @@ export class MathBackendWebGL implements KernelBackend {
19261937
1, x2ColShape[0], x2ColShape[1]
19271938
]) as Tensor3D;
19281939

1940+
const hasBias = bias != null;
1941+
const fusedActivation =
1942+
activation ? mapActivationToShaderProgram(activation, true) : null;
19291943
const matmulProgram = new MatMulPackedProgram(
1930-
im2Col.shape, [1, numCols, convInfo.outChannels], true, false);
1931-
const product =
1932-
this.compileAndRun<Tensor4D>(matmulProgram, [im2Col, w2Row]);
1944+
im2Col.shape, [1, numCols, convInfo.outChannels], transposeA,
1945+
transposeB, hasBias, fusedActivation);
1946+
const inputs: TensorHandle[] = [im2Col, w2Row];
1947+
if (bias) {
1948+
inputs.push(bias);
1949+
}
1950+
const product = this.compileAndRun<Tensor4D>(matmulProgram, inputs);
19331951

19341952
return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
19351953
}
19361954

1955+
fusedConv2d(
1956+
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1957+
activation?: Activation): Tensor4D {
1958+
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
1959+
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
1960+
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
1961+
(convInfo.padInfo.type === 'SAME' ||
1962+
convInfo.padInfo.type === 'VALID')) {
1963+
return this.conv2dByMatMul(x, filter, convInfo, bias, activation);
1964+
}
1965+
if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
1966+
return this.conv2dWithIm2Row(x, filter, convInfo, bias, activation);
1967+
}
1968+
1969+
const hasBias = bias != null;
1970+
const fusedActivation =
1971+
activation ? mapActivationToShaderProgram(activation, false) : null;
1972+
const program = new Conv2DProgram(convInfo, hasBias, fusedActivation);
1973+
const inputs: TensorHandle[] = [x, filter];
1974+
if (bias) {
1975+
inputs.push(bias);
1976+
}
1977+
return this.compileAndRun(program, inputs);
1978+
}
1979+
19371980
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
19381981
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
19391982
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&

src/backends/webgl/conv_gpu.ts

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ export class Conv2DProgram implements GPGPUProgram {
2323
outputShape: number[];
2424
userCode: string;
2525

26-
constructor(convInfo: Conv2DInfo) {
26+
constructor(
27+
convInfo: Conv2DInfo, addBias = false, activation: string = null) {
2728
this.outputShape = convInfo.outShape;
2829
const padTop = convInfo.padInfo.top;
2930
const padLeft = convInfo.padInfo.left;
@@ -37,7 +38,25 @@ export class Conv2DProgram implements GPGPUProgram {
3738
const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
3839
const inputDepthVec4Remainder = convInfo.inChannels % 4;
3940

41+
let activationSnippet = '', applyActivationSnippet = '';
42+
if (activation) {
43+
activationSnippet = `
44+
float activation(float x) {
45+
${activation}
46+
}
47+
`;
48+
49+
applyActivationSnippet = `result = activation(result);`;
50+
}
51+
52+
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
53+
if (addBias) {
54+
this.variableNames.push('bias');
55+
}
56+
4057
this.userCode = `
58+
${activationSnippet}
59+
4160
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
4261
const ivec2 pads = ivec2(${padTop}, ${padLeft});
4362
@@ -113,7 +132,11 @@ export class Conv2DProgram implements GPGPUProgram {
113132
}
114133
}
115134
}
116-
setOutput(dotProd);
135+
136+
float result = dotProd;
137+
${addBiasSnippet}
138+
${applyActivationSnippet}
139+
setOutput(result);
117140
}
118141
`;
119142
}

src/ops/conv.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,7 @@ export const conv1d = op({conv1d_});
927927
export const conv2d = op({conv2d_});
928928
export const conv3d = op({conv3d_});
929929
export const conv2dDerFilter = op({conv2dDerFilter_});
930+
export const conv2dDerInput = op({conv2dDerInput_});
930931
export const depthwiseConv2d = op({depthwiseConv2d_});
931932
export const separableConv2d = op({separableConv2d_});
932933
export const conv2dTranspose = op({conv2dTranspose_});

0 commit comments

Comments
 (0)