Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 5cc5267

Browse files
authored
Fuse prelu activation. (#1867)
FEATURE PERF
1 parent 5aa35a3 commit 5cc5267

File tree

8 files changed

+454
-108
lines changed

8 files changed

+454
-108
lines changed

src/backends/backend.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
19-
import {Activation} from '../ops/fused_util';
19+
import {Activation, FusedBatchMatMulConfig} from '../ops/fused_util';
2020
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
2121
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';
2222

@@ -132,8 +132,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
132132
}
133133

134134
fusedBatchMatMul(
135-
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
136-
bias?: Tensor, activation?: Activation): Tensor3D {
135+
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
136+
FusedBatchMatMulConfig): Tensor3D {
137137
throw new Error('Not yet implemented');
138138
}
139139

@@ -413,7 +413,7 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
413413

414414
fusedConv2d(
415415
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
416-
activation?: Activation): Tensor4D {
416+
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
417417
throw new Error('Not yet implemented');
418418
}
419419

src/backends/cpu/backend_cpu.ts

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import * as broadcast_util from '../../ops/broadcast_util';
2626
import * as concat_util from '../../ops/concat_util';
2727
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
2828
import * as erf_util from '../../ops/erf_util';
29-
import {Activation} from '../../ops/fused_util';
29+
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
3030
import * as gather_nd_util from '../../ops/gather_nd_util';
3131
import * as ops from '../../ops/ops';
3232
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
@@ -47,11 +47,14 @@ import {topkImpl} from '../topk_impl';
4747
import {whereImpl} from '../where_impl';
4848

4949
function mapActivation(
50-
backend: MathBackendCPU, activation: Activation, x: Tensor): Tensor {
50+
backend: MathBackendCPU, x: Tensor, activation: Activation,
51+
preluActivationWeights?: Tensor): Tensor {
5152
if (activation === 'linear') {
5253
return backend.linear(x);
5354
} else if (activation === 'relu') {
5455
return backend.relu(x);
56+
} else if (activation === 'prelu') {
57+
return backend.prelu(x, preluActivationWeights);
5558
}
5659
throw new Error(
5760
`Activation ${activation} has not been implemented for the CPU backend.`);
@@ -522,14 +525,16 @@ export class MathBackendCPU implements KernelBackend {
522525
}
523526

524527
fusedBatchMatMul(
525-
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
526-
bias?: Tensor, activation?: Activation): Tensor3D {
528+
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
529+
FusedBatchMatMulConfig): Tensor3D {
527530
let result = this.batchMatMul(a, b, transposeA, transposeB);
528531
if (bias) {
529532
result = this.add(result, bias) as Tensor3D;
530533
}
531534
if (activation) {
532-
result = mapActivation(this, activation, result) as Tensor3D;
535+
result =
536+
mapActivation(this, result, activation, preluActivationWeights) as
537+
Tensor3D;
533538
}
534539
return result;
535540
}
@@ -1515,14 +1520,16 @@ export class MathBackendCPU implements KernelBackend {
15151520

15161521
fusedConv2d(
15171522
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1518-
activation?: Activation): Tensor4D {
1523+
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
15191524
let result = this.conv2d(x, filter, convInfo);
15201525

15211526
if (bias) {
15221527
result = this.add(result, bias) as Tensor4D;
15231528
}
15241529
if (activation) {
1525-
result = mapActivation(this, activation, result) as Tensor4D;
1530+
result =
1531+
mapActivation(this, result, activation, preluActivationWeights) as
1532+
Tensor4D;
15261533
}
15271534
return result;
15281535
}

src/backends/webgl/backend_webgl.ts

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import * as array_ops_util from '../../ops/array_ops_util';
2828
import * as axis_util from '../../ops/axis_util';
2929
import {computeOutShape} from '../../ops/concat_util';
3030
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
31-
import {Activation} from '../../ops/fused_util';
31+
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
3232
import * as gather_nd_util from '../../ops/gather_nd_util';
3333
import * as reduce_util from '../../ops/reduce_util';
3434
import * as scatter_nd_util from '../../ops/scatter_nd_util';
@@ -174,6 +174,11 @@ function mapActivationToShaderProgram(
174174
return unary_packed_op.RELU;
175175
}
176176
return unary_op.RELU;
177+
} else if (activation === 'prelu') {
178+
if (packed) {
179+
return binaryop_packed_gpu.PRELU;
180+
}
181+
return binaryop_gpu.PRELU;
177182
}
178183
throw new Error(`Activation ${
179184
activation} has not been implemented for the WebGL backend.`);
@@ -865,26 +870,30 @@ export class MathBackendWebGL implements KernelBackend {
865870
}
866871

867872
fusedBatchMatMul(
868-
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
869-
bias?: Tensor, activation?: Activation): Tensor3D {
873+
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
874+
FusedBatchMatMulConfig): Tensor3D {
870875
const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
871876
const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
872877
const [batch, , ] = a.shape;
873878

874879
const dtype = upcastType(a.dtype, b.dtype);
875880

876881
const hasBias = bias != null;
882+
const hasPreluActivationWeights = preluActivationWeights != null;
877883
const fusedActivation =
878884
activation ? mapActivationToShaderProgram(activation, true) : null;
879885
const program = new MatMulPackedProgram(
880886
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,
881-
hasBias, fusedActivation);
887+
hasBias, fusedActivation, hasPreluActivationWeights);
882888
const output =
883889
this.makePackedTensor(program.outputShape, dtype) as Tensor3D;
884890
const inputs: TensorHandle[] = [a, b];
885891
if (bias) {
886892
inputs.push(bias);
887893
}
894+
if (preluActivationWeights) {
895+
inputs.push(preluActivationWeights);
896+
}
888897
return this.compileAndRun<Tensor3D>(program, inputs, output);
889898
}
890899

@@ -1819,7 +1828,7 @@ export class MathBackendWebGL implements KernelBackend {
18191828

18201829
private conv2dByMatMul(
18211830
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1822-
activation?: Activation): Tensor4D {
1831+
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
18231832
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
18241833
// result from 2D to 4D.
18251834
const xShape = x.shape;
@@ -1850,9 +1859,15 @@ export class MathBackendWebGL implements KernelBackend {
18501859
Tensor3D;
18511860

18521861
return this.reshape<Rank.R4>(
1853-
this.fusedBatchMatMul(
1854-
xReshaped, filterReshaped, transposeA, transposeB, bias,
1855-
activation),
1862+
this.fusedBatchMatMul({
1863+
a: xReshaped,
1864+
b: filterReshaped,
1865+
transposeA,
1866+
transposeB,
1867+
bias,
1868+
activation,
1869+
preluActivationWeights
1870+
}),
18561871
convInfo.outShape);
18571872
}
18581873

@@ -1888,8 +1903,15 @@ export class MathBackendWebGL implements KernelBackend {
18881903
this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]) as
18891904
Tensor3D;
18901905

1891-
const pointwiseConv = this.fusedBatchMatMul(
1892-
xReshaped, filterReshaped, transposeA, transposeB, bias, activation);
1906+
const pointwiseConv = this.fusedBatchMatMul({
1907+
a: xReshaped,
1908+
b: filterReshaped,
1909+
transposeA,
1910+
transposeB,
1911+
bias,
1912+
activation,
1913+
preluActivationWeights
1914+
});
18931915
const pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);
18941916
util.assert(
18951917
pointwiseConvTexData.isPacked,
@@ -1906,7 +1928,7 @@ export class MathBackendWebGL implements KernelBackend {
19061928

19071929
private conv2dWithIm2Row(
19081930
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1909-
activation?: Activation): Tensor4D {
1931+
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
19101932
// Rearranges conv2d input so each block to be convolved over forms the
19111933
// column of a new matrix with shape [filterWidth * filterHeight *
19121934
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
@@ -1938,42 +1960,53 @@ export class MathBackendWebGL implements KernelBackend {
19381960
]) as Tensor3D;
19391961

19401962
const hasBias = bias != null;
1963+
const hasPreluActivationWeights = preluActivationWeights != null;
19411964
const fusedActivation =
19421965
activation ? mapActivationToShaderProgram(activation, true) : null;
19431966
const matmulProgram = new MatMulPackedProgram(
19441967
im2Col.shape, [1, numCols, convInfo.outChannels], transposeA,
1945-
transposeB, hasBias, fusedActivation);
1968+
transposeB, hasBias, fusedActivation, hasPreluActivationWeights);
19461969
const inputs: TensorHandle[] = [im2Col, w2Row];
19471970
if (bias) {
19481971
inputs.push(bias);
19491972
}
1973+
if (hasPreluActivationWeights) {
1974+
inputs.push(preluActivationWeights);
1975+
}
19501976
const product = this.compileAndRun<Tensor4D>(matmulProgram, inputs);
19511977

19521978
return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
19531979
}
19541980

19551981
fusedConv2d(
19561982
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1957-
activation?: Activation): Tensor4D {
1983+
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
19581984
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
19591985
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
19601986
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
19611987
(convInfo.padInfo.type === 'SAME' ||
19621988
convInfo.padInfo.type === 'VALID')) {
1963-
return this.conv2dByMatMul(x, filter, convInfo, bias, activation);
1989+
return this.conv2dByMatMul(
1990+
x, filter, convInfo, bias, activation, preluActivationWeights);
19641991
}
19651992
if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
1966-
return this.conv2dWithIm2Row(x, filter, convInfo, bias, activation);
1993+
return this.conv2dWithIm2Row(
1994+
x, filter, convInfo, bias, activation, preluActivationWeights);
19671995
}
19681996

19691997
const hasBias = bias != null;
1998+
const hasPreluActivationWeights = preluActivationWeights != null;
19701999
const fusedActivation =
19712000
activation ? mapActivationToShaderProgram(activation, false) : null;
1972-
const program = new Conv2DProgram(convInfo, hasBias, fusedActivation);
2001+
const program = new Conv2DProgram(
2002+
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
19732003
const inputs: TensorHandle[] = [x, filter];
19742004
if (bias) {
19752005
inputs.push(bias);
19762006
}
2007+
if (preluActivationWeights) {
2008+
inputs.push(preluActivationWeights);
2009+
}
19772010
return this.compileAndRun(program, inputs);
19782011
}
19792012

src/backends/webgl/conv_gpu.ts

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ export class Conv2DProgram implements GPGPUProgram {
2424
userCode: string;
2525

2626
constructor(
27-
convInfo: Conv2DInfo, addBias = false, activation: string = null) {
27+
convInfo: Conv2DInfo, addBias = false, activation: string = null,
28+
hasPreluActivationWeights = false) {
2829
this.outputShape = convInfo.outShape;
2930
const padTop = convInfo.padInfo.top;
3031
const padLeft = convInfo.padInfo.left;
@@ -40,11 +41,18 @@ export class Conv2DProgram implements GPGPUProgram {
4041

4142
let activationSnippet = '', applyActivationSnippet = '';
4243
if (activation) {
43-
activationSnippet = `
44-
float activation(float x) {
44+
if (hasPreluActivationWeights) {
45+
activationSnippet = `float activation(float a) {
46+
float b = getPreluActivationWeightsAtOutCoords();
4547
${activation}
46-
}
47-
`;
48+
}`;
49+
} else {
50+
activationSnippet = `
51+
float activation(float x) {
52+
${activation}
53+
}
54+
`;
55+
}
4856

4957
applyActivationSnippet = `result = activation(result);`;
5058
}
@@ -54,6 +62,10 @@ export class Conv2DProgram implements GPGPUProgram {
5462
this.variableNames.push('bias');
5563
}
5664

65+
if (hasPreluActivationWeights) {
66+
this.variableNames.push('preluActivationWeights');
67+
}
68+
5769
this.userCode = `
5870
${activationSnippet}
5971

src/backends/webgl/mulmat_packed_gpu.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export class MatMulPackedProgram implements GPGPUProgram {
2626
constructor(
2727
aShape: [number, number, number], outputShape: [number, number, number],
2828
transposeA = false, transposeB = false, addBias = false,
29-
activation: string = null) {
29+
activation: string = null, hasPreluActivation = false) {
3030
this.outputShape = outputShape;
3131

3232
const sharedDim = transposeA ? aShape[1] : aShape[2];
@@ -39,9 +39,16 @@ export class MatMulPackedProgram implements GPGPUProgram {
3939

4040
let activationSnippet = '', applyActivationSnippet = '';
4141
if (activation) {
42-
activationSnippet = `vec4 activation(vec4 x) {
43-
${activation}
44-
}`;
42+
if (hasPreluActivation) {
43+
activationSnippet = `vec4 activation(vec4 a) {
44+
vec4 b = getPreluActivationWeightsAtOutCoords();
45+
${activation}
46+
}`;
47+
} else {
48+
activationSnippet = `vec4 activation(vec4 x) {
49+
${activation}
50+
}`;
51+
}
4552

4653
applyActivationSnippet = `result = activation(result);`;
4754
}
@@ -51,6 +58,10 @@ export class MatMulPackedProgram implements GPGPUProgram {
5158
this.variableNames.push('bias');
5259
}
5360

61+
if (hasPreluActivation) {
62+
this.variableNames.push('preluActivationWeights');
63+
}
64+
5465
this.userCode = `
5566
${activationSnippet}
5667
@@ -82,4 +93,4 @@ export class MatMulPackedProgram implements GPGPUProgram {
8293
}
8394
`;
8495
}
85-
}
96+
}

0 commit comments

Comments
 (0)