Skip to content

Commit 9424e64

Browse files
authored
Adds packed support for all non-complex binary ops. (tensorflow#1534)
PERF
1 parent 3d3b417 commit 9424e64

File tree

3 files changed

+177
-36
lines changed

3 files changed

+177
-36
lines changed

src/kernels/backend_webgl.ts

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ import * as binaryop_complex_gpu from './webgl/binaryop_complex_gpu';
5151
import {BinaryOpComplexProgram} from './webgl/binaryop_complex_gpu';
5252
import * as binaryop_gpu from './webgl/binaryop_gpu';
5353
import {BinaryOpProgram} from './webgl/binaryop_gpu';
54-
import {BinaryOpPackedProgram, PACKED_DIV, PACKED_INT_DIV, PACKED_POW} from './webgl/binaryop_packed_gpu';
54+
import * as binaryop_packed_gpu from './webgl/binaryop_packed_gpu';
55+
import {BinaryOpPackedProgram} from './webgl/binaryop_packed_gpu';
5556
import {ClipProgram} from './webgl/clip_gpu';
5657
import {ClipPackedProgram} from './webgl/clip_packed_gpu';
5758
import {ComplexAbsProgram} from './webgl/complex_abs_gpu';
@@ -768,8 +769,8 @@ export class MathBackendWebGL implements KernelBackend {
768769

769770
const dtype = upcastType(a.dtype, b.dtype);
770771

771-
const program = new MatMulPackedProgram(a.shape,
772-
[batch, outerShapeA, outerShapeB], transposeA, transposeB);
772+
const program = new MatMulPackedProgram(
773+
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB);
773774
const output =
774775
this.makePackedTensor(program.outputShape, dtype) as Tensor3D;
775776
return this.compileAndRun<Tensor3D>(program, [a, b], output);
@@ -784,8 +785,9 @@ export class MathBackendWebGL implements KernelBackend {
784785

785786
const dtype = upcastType(a.dtype, b.dtype);
786787

787-
const program = new MatMulPackedProgram(a.shape,
788-
[batch, outerShapeA, outerShapeB], transposeA, transposeB, !!bias,
788+
const program = new MatMulPackedProgram(
789+
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,
790+
!!bias,
789791
activation ? mapActivationToShaderProgram(activation, true) : null);
790792
const output =
791793
this.makePackedTensor(program.outputShape, dtype) as Tensor3D;
@@ -1099,12 +1101,18 @@ export class MathBackendWebGL implements KernelBackend {
10991101
}
11001102

11011103
equal(a: Tensor, b: Tensor): Tensor {
1104+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1105+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.EQUAL, 'bool');
1106+
}
11021107
const program = new BinaryOpProgram(binaryop_gpu.EQUAL, a.shape, b.shape);
11031108
const output = this.makeOutputArray(program.outputShape, 'bool');
11041109
return this.compileAndRun(program, [a, b], output);
11051110
}
11061111

11071112
notEqual(a: Tensor, b: Tensor): Tensor {
1113+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1114+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.NOT_EQUAL, 'bool');
1115+
}
11081116
const program =
11091117
new BinaryOpProgram(binaryop_gpu.NOT_EQUAL, a.shape, b.shape);
11101118
const output = this.makeOutputArray(program.outputShape, 'bool');
@@ -1116,12 +1124,19 @@ export class MathBackendWebGL implements KernelBackend {
11161124
return this.cpuBackend.less(a, b);
11171125
}
11181126

1127+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1128+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS, 'bool');
1129+
}
1130+
11191131
const program = new BinaryOpProgram(binaryop_gpu.LESS, a.shape, b.shape);
11201132
const output = this.makeOutputArray(program.outputShape, 'bool');
11211133
return this.compileAndRun(program, [a, b], output);
11221134
}
11231135

11241136
lessEqual(a: Tensor, b: Tensor): Tensor {
1137+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1138+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS_EQUAL, 'bool');
1139+
}
11251140
const program =
11261141
new BinaryOpProgram(binaryop_gpu.LESS_EQUAL, a.shape, b.shape);
11271142
const output = this.makeOutputArray(program.outputShape, 'bool');
@@ -1133,12 +1148,20 @@ export class MathBackendWebGL implements KernelBackend {
11331148
return this.cpuBackend.greater(a, b);
11341149
}
11351150

1151+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1152+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.GREATER, 'bool');
1153+
}
1154+
11361155
const program = new BinaryOpProgram(binaryop_gpu.GREATER, a.shape, b.shape);
11371156
const output = this.makeOutputArray(program.outputShape, 'bool');
11381157
return this.compileAndRun(program, [a, b], output);
11391158
}
11401159

11411160
greaterEqual(a: Tensor, b: Tensor): Tensor {
1161+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1162+
return this.packedBinaryOp(
1163+
a, b, binaryop_packed_gpu.GREATER_EQUAL, 'bool');
1164+
}
11421165
const program =
11431166
new BinaryOpProgram(binaryop_gpu.GREATER_EQUAL, a.shape, b.shape);
11441167
const output = this.makeOutputArray(program.outputShape, 'bool');
@@ -1151,13 +1174,19 @@ export class MathBackendWebGL implements KernelBackend {
11511174
}
11521175

11531176
logicalAnd(a: Tensor, b: Tensor): Tensor {
1177+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1178+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_AND, 'bool');
1179+
}
11541180
const program =
11551181
new BinaryOpProgram(binaryop_gpu.LOGICAL_AND, a.shape, b.shape);
11561182
const output = this.makeOutputArray(program.outputShape, 'bool');
11571183
return this.compileAndRun(program, [a, b], output);
11581184
}
11591185

11601186
logicalOr(a: Tensor, b: Tensor): Tensor {
1187+
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1188+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_OR, 'bool');
1189+
}
11611190
const program =
11621191
new BinaryOpProgram(binaryop_gpu.LOGICAL_OR, a.shape, b.shape);
11631192
const output = this.makeOutputArray(program.outputShape, 'bool');
@@ -1198,12 +1227,17 @@ export class MathBackendWebGL implements KernelBackend {
11981227
return this.cpuBackend.minimum(a, b);
11991228
}
12001229

1201-
const program = new BinaryOpProgram(binaryop_gpu.MIN, a.shape, b.shape);
1202-
return this.compileAndRun(program, [a, b]);
1230+
const program = ENV.get('WEBGL_PACK_BINARY_OPERATIONS') ?
1231+
new BinaryOpPackedProgram(binaryop_packed_gpu.MIN, a.shape, b.shape) :
1232+
new BinaryOpProgram(binaryop_gpu.MIN, a.shape, b.shape);
1233+
const customSetup = program.getCustomSetupFunc();
1234+
return this.compileAndRun(program, [a, b], null, customSetup);
12031235
}
12041236

12051237
mod(a: Tensor, b: Tensor): Tensor {
1206-
const program = new BinaryOpProgram(binaryop_gpu.MOD, a.shape, b.shape);
1238+
const program = ENV.get('WEBGL_PACK_BINARY_OPERATIONS') ?
1239+
new BinaryOpPackedProgram(binaryop_packed_gpu.MOD, a.shape, b.shape) :
1240+
new BinaryOpProgram(binaryop_gpu.MOD, a.shape, b.shape);
12071241
const customSetup = program.getCustomSetupFunc();
12081242
return this.compileAndRun(program, [a, b], null, customSetup);
12091243
}
@@ -1222,8 +1256,11 @@ export class MathBackendWebGL implements KernelBackend {
12221256
return this.cpuBackend.maximum(a, b);
12231257
}
12241258

1225-
const program = new BinaryOpProgram(binaryop_gpu.MAX, a.shape, b.shape);
1226-
return this.compileAndRun(program, [a, b]);
1259+
const program = ENV.get('WEBGL_PACK_BINARY_OPERATIONS') ?
1260+
new BinaryOpPackedProgram(binaryop_packed_gpu.MAX, a.shape, b.shape) :
1261+
new BinaryOpProgram(binaryop_gpu.MAX, a.shape, b.shape);
1262+
const customSetup = program.getCustomSetupFunc();
1263+
return this.compileAndRun(program, [a, b], null, customSetup);
12271264
}
12281265

12291266
all(x: Tensor, axes: number[]): Tensor {
@@ -1245,7 +1282,9 @@ export class MathBackendWebGL implements KernelBackend {
12451282
}
12461283

12471284
squaredDifference(a: Tensor, b: Tensor): Tensor {
1248-
const program =
1285+
const program = ENV.get('WEBGL_PACK_BINARY_OPERATIONS') ?
1286+
new BinaryOpPackedProgram(
1287+
binaryop_gpu.SQUARED_DIFFERENCE, a.shape, b.shape) :
12491288
new BinaryOpProgram(binaryop_gpu.SQUARED_DIFFERENCE, a.shape, b.shape);
12501289
return this.compileAndRun(program, [a, b]);
12511290
}
@@ -1254,7 +1293,7 @@ export class MathBackendWebGL implements KernelBackend {
12541293
const op = binaryop_gpu.DIV;
12551294
const outputDtype = 'float32';
12561295
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1257-
return this.packedBinaryOp(a, b, PACKED_DIV, outputDtype);
1296+
return this.packedBinaryOp(a, b, binaryop_packed_gpu.DIV, outputDtype);
12581297
}
12591298
const program = new BinaryOpProgram(op, a.shape, b.shape);
12601299
const output = this.makeOutputArray(program.outputShape, outputDtype);
@@ -1265,7 +1304,8 @@ export class MathBackendWebGL implements KernelBackend {
12651304
const op = binaryop_gpu.INT_DIV;
12661305
const outputDtype = 'int32';
12671306
if (ENV.get('WEBGL_PACK_BINARY_OPERATIONS')) {
1268-
return this.packedBinaryOp(a, b, PACKED_INT_DIV, outputDtype);
1307+
return this.packedBinaryOp(
1308+
a, b, binaryop_packed_gpu.INT_DIV, outputDtype);
12691309
}
12701310
const program = new BinaryOpProgram(op, a.shape, b.shape);
12711311
const output = this.makeOutputArray(program.outputShape, outputDtype);
@@ -1285,7 +1325,8 @@ export class MathBackendWebGL implements KernelBackend {
12851325
return this.compileAndRun<Tensor>(program, [a, b], output);
12861326
}
12871327

1288-
private packedBinaryOp(a: Tensor, b: Tensor, op: string, dtype: DataType) {
1328+
private packedBinaryOp(
1329+
a: TensorHandle, b: TensorHandle, op: string, dtype: DataType) {
12891330
const program = new BinaryOpPackedProgram(op, a.shape, b.shape);
12901331
const output = this.makePackedTensor(program.outputShape, dtype) as Tensor;
12911332
return this.compileAndRun<Tensor>(program, [a, b], output);
@@ -1305,14 +1346,14 @@ export class MathBackendWebGL implements KernelBackend {
13051346
].map(complexParts => {
13061347
const [aPart, bPart] = complexParts;
13071348

1349+
const aHandle = this.makeComplexComponentTensorHandle(a, aPart);
1350+
const bHandle = this.makeComplexComponentTensorHandle(b, bPart);
1351+
13081352
const program = new BinaryOpProgram(op, a.shape, b.shape);
13091353
const output = this.makeOutputArray(
13101354
program.outputShape,
13111355
upcastType(aPart.dtype, bPart.dtype)) as Tensor;
13121356

1313-
const aHandle = this.makeComplexComponentTensorHandle(a, aPart);
1314-
const bHandle = this.makeComplexComponentTensorHandle(b, bPart);
1315-
13161357
return this.compileAndRun<Tensor>(program, [aHandle, bHandle], output);
13171358
});
13181359

@@ -1362,7 +1403,7 @@ export class MathBackendWebGL implements KernelBackend {
13621403
pow<T extends Tensor>(a: T, b: Tensor): T {
13631404
const usePackedOp = ENV.get('WEBGL_PACK_BINARY_OPERATIONS');
13641405
const program = usePackedOp ?
1365-
new BinaryOpPackedProgram(PACKED_POW, a.shape, b.shape) :
1406+
new BinaryOpPackedProgram(binaryop_packed_gpu.POW, a.shape, b.shape) :
13661407
new BinaryOpProgram(binaryop_gpu.POW, a.shape, b.shape);
13671408
const dtype = upcastType(a.dtype, b.dtype);
13681409
const output = usePackedOp ?
@@ -1454,7 +1495,9 @@ export class MathBackendWebGL implements KernelBackend {
14541495
}
14551496

14561497
prelu<T extends Tensor>(x: T, alpha: T): T {
1457-
const program =
1498+
const program = ENV.get('WEBGL_PACK_BINARY_OPERATIONS') ?
1499+
new BinaryOpPackedProgram(
1500+
binaryop_packed_gpu.PRELU, x.shape, alpha.shape) :
14581501
new BinaryOpProgram(binaryop_gpu.PRELU, x.shape, alpha.shape);
14591502
return this.compileAndRun(program, [x, alpha]) as T;
14601503
}
@@ -1465,7 +1508,9 @@ export class MathBackendWebGL implements KernelBackend {
14651508
}
14661509

14671510
eluDer<T extends Tensor>(dy: T, y: T): T {
1468-
const program =
1511+
const program = ENV.get('WEBGL_PACK_BINARY_OPERATIONS') ?
1512+
new BinaryOpPackedProgram(
1513+
binaryop_packed_gpu.ELU_DER, dy.shape, y.shape) :
14691514
new BinaryOpProgram(binaryop_gpu.ELU_DER, dy.shape, y.shape);
14701515
return this.compileAndRun(program, [dy, y]) as T;
14711516
}
@@ -1550,8 +1595,11 @@ export class MathBackendWebGL implements KernelBackend {
15501595
}
15511596

15521597
atan2<T extends Tensor>(a: T, b: T): T {
1553-
const program = new BinaryOpProgram(binaryop_gpu.ATAN2, a.shape, b.shape);
1554-
return this.compileAndRun(program, [a, b]) as T;
1598+
const program = ENV.get('WEBGL_PACK_BINARY_OPERATIONS') ?
1599+
new BinaryOpPackedProgram(binaryop_packed_gpu.ATAN2, a.shape, b.shape) :
1600+
new BinaryOpProgram(binaryop_gpu.ATAN2, a.shape, b.shape);
1601+
const customSetup = program.getCustomSetupFunc();
1602+
return this.compileAndRun(program, [a, b], null, customSetup) as T;
15551603
}
15561604

15571605
sinh<T extends Tensor>(x: T): T {
@@ -1680,12 +1728,13 @@ export class MathBackendWebGL implements KernelBackend {
16801728

16811729
const im2ColProgram =
16821730
new Im2ColProgram(x2ColShape, xSqueezed.shape, convInfo);
1683-
const im2Col = this.compileAndRun<Tensor2D>(im2ColProgram, [xSqueezed]).
1684-
reshape([1, x2ColShape[0], x2ColShape[1]]) as Tensor3D;
1731+
const im2Col =
1732+
this.compileAndRun<Tensor2D>(im2ColProgram, [xSqueezed]).reshape([
1733+
1, x2ColShape[0], x2ColShape[1]
1734+
]) as Tensor3D;
16851735

16861736
const matmulProgram = new MatMulPackedProgram(
1687-
im2Col.shape, [1, numCols, convInfo.outChannels], true,
1688-
false);
1737+
im2Col.shape, [1, numCols, convInfo.outChannels], true, false);
16891738
const product =
16901739
this.compileAndRun<Tensor4D>(matmulProgram, [im2Col, w2Row]);
16911740

@@ -1825,9 +1874,9 @@ export class MathBackendWebGL implements KernelBackend {
18251874

18261875
reshape<R extends Rank>(x: Tensor, shape: ShapeMap[R]): Tensor<R> {
18271876
const texData = this.texData.get(x.dataId);
1828-
if (texData.isPacked && !webgl_util.isReshapeFree(x.shape, shape)
1829-
&& !(texData.texture !== null &&
1830-
webgl_util.isReshapeFree(texData.shape, shape))) {
1877+
if (texData.isPacked && !webgl_util.isReshapeFree(x.shape, shape) &&
1878+
!(texData.texture !== null &&
1879+
webgl_util.isReshapeFree(texData.shape, shape))) {
18311880
return this.packedReshape(x, shape);
18321881
}
18331882
return backend_util.reshapeTensor(x, shape);

0 commit comments

Comments
 (0)