Skip to content

Commit dce101e

Browse files
authored
Implement packed clipping op. (tensorflow#1412)
PERF - Added `ClipPackedProgram` that implements clipping for packed textures - Added ENV flag `WEBGL_PACK_CLIP`, defaults to `false` - Added ENV flag `WEBGL_PACK`, defaults to `false`, that turns all packing-related flags on or off - Turn off NaN checking if `PROD` flag is on - Support scalar sampling in packed ops - Partially support packed `get${Var}AtOutCoords` and broadcasting - broadcasting only occurs when the innermost dimensions are the same and the vector to be broadcast is rank 0 or 1.
1 parent ed2032a commit dce101e

9 files changed

+294
-75
lines changed

src/environment.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,18 @@ export class Environment {
308308
return isChrome();
309309
} else if (feature === 'WEBGL_CPU_FORWARD') {
310310
return true;
311-
} else if (feature === 'WEBGL_PACK_BATCHNORMALIZATION') {
311+
} else if (feature === 'WEBGL_PACK') {
312312
return false;
313+
} else if (feature === 'WEBGL_PACK_BATCHNORMALIZATION') {
314+
return this.get('WEBGL_PACK');
315+
} else if (feature === 'WEBGL_PACK_CLIP') {
316+
return this.get('WEBGL_PACK');
313317
} else if (feature === 'WEBGL_PACK_DEPTHWISECONV') {
314-
return false;
318+
return this.get('WEBGL_PACK');
315319
} else if (feature === 'WEBGL_LAZILY_UNPACK') {
316-
return false;
320+
return this.get('WEBGL_PACK');
317321
} else if (feature === 'WEBGL_CONV_IM2COL') {
318-
return false;
322+
return this.get('WEBGL_PACK');
319323
} else if (feature === 'WEBGL_PAGING_ENABLED') {
320324
return this.get('IS_BROWSER') && !this.get('PROD');
321325
} else if (feature === 'WEBGL_MAX_TEXTURE_SIZE') {

src/environment_util.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ export interface Features {
2828
'WEBGL_LAZILY_UNPACK'?: boolean;
2929
// Whether the WebGL backend will sometimes forward ops to the CPU.
3030
'WEBGL_CPU_FORWARD'?: boolean;
31+
// Whether to turn all packing related flags on.
32+
'WEBGL_PACK'?: boolean;
3133
// Whether we will pack the batchnormalization op.
3234
'WEBGL_PACK_BATCHNORMALIZATION'?: boolean;
35+
// Whether we will pack the clipping op.
36+
'WEBGL_PACK_CLIP'?: boolean;
3337
// Whether we pack the depthwise convolution op.
3438
'WEBGL_PACK_DEPTHWISECONV'?: boolean;
3539
// Whether we will use the im2col algorithm to speed up convolutions.
@@ -91,7 +95,9 @@ export const URL_PROPERTIES: URLProperty[] = [
9195
{name: 'IS_BROWSER', type: Type.BOOLEAN},
9296
{name: 'WEBGL_LAZILY_UNPACK', type: Type.BOOLEAN},
9397
{name: 'WEBGL_CPU_FORWARD', type: Type.BOOLEAN},
98+
{name: 'WEBGL_PACK', type: Type.BOOLEAN},
9499
{name: 'WEBGL_PACK_BATCHNORMALIZATION', type: Type.BOOLEAN},
100+
{name: 'WEBGL_PACK_CLIP', type: Type.BOOLEAN},
95101
{name: 'WEBGL_PACK_DEPTHWISECONV', type: Type.BOOLEAN},
96102
{name: 'WEBGL_CONV_IM2COL', type: Type.BOOLEAN},
97103
{name: 'WEBGL_MAX_TEXTURE_SIZE', type: Type.NUMBER},

src/kernels/backend_webgl.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import {BinaryOpComplexProgram} from './webgl/binaryop_complex_gpu';
5151
import * as binaryop_gpu from './webgl/binaryop_gpu';
5252
import {BinaryOpProgram} from './webgl/binaryop_gpu';
5353
import {ClipProgram} from './webgl/clip_gpu';
54+
import {ClipPackedProgram} from './webgl/clip_packed_gpu';
5455
import {ComplexAbsProgram} from './webgl/complex_abs_gpu';
5556
import {ConcatProgram} from './webgl/concat_gpu';
5657
import {Conv2DDerFilterProgram, Conv2DDerInputProgram} from './webgl/conv_backprop_gpu';
@@ -366,9 +367,11 @@ export class MathBackendWebGL implements KernelBackend {
366367
const {shape, dtype, texture, texShape} = this.texData.get(dataId);
367368
if (ENV.get('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
368369
if (this.texData.get(dataId).isPacked) {
369-
const batch = util.sizeFromShape(shape.slice(0, shape.length - 2));
370-
const rows = shape.length > 1 ? shape[shape.length - 2] : 1;
371-
const cols = shape[shape.length - 1];
370+
const batch = this.getBatchDim(shape);
371+
let rows = 1, cols = 1;
372+
if (shape.length) {
373+
[rows, cols] = this.getRowsCols(shape);
374+
}
372375
return this.gpgpu.downloadMatrixFromPackedTexture(
373376
texture, batch, rows, cols, texShape[0], texShape[1]);
374377
} else {
@@ -1337,7 +1340,12 @@ export class MathBackendWebGL implements KernelBackend {
13371340
}
13381341

13391342
clip<T extends Tensor>(x: T, min: number, max: number): T {
1340-
const program = new ClipProgram(x.shape, min, max);
1343+
let program;
1344+
if (ENV.get('WEBGL_PACK_CLIP')) {
1345+
program = new ClipPackedProgram(x.shape, min, max);
1346+
} else {
1347+
program = new ClipProgram(x.shape, min, max);
1348+
}
13411349
return this.compileAndRun(program, [x]) as T;
13421350
}
13431351

src/kernels/webgl/batchnorm_packed_gpu.ts

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,34 +33,29 @@ export class BatchNormPackedProgram implements GPGPUProgram {
3333
broadcast_util.assertAndGetBroadcastShape(xShape, meanShape);
3434
broadcast_util.assertAndGetBroadcastShape(xShape, varianceShape);
3535

36-
const meanSnippet = broadcastSample('mean', meanShape.length);
37-
const varianceSnippet = broadcastSample('variance', varianceShape.length);
38-
39-
let offsetSnippet = 'vec4 offset = vec4(0.0)';
36+
let offsetSnippet = 'vec4(0.0)';
4037
if (offsetShape != null) {
4138
broadcast_util.assertAndGetBroadcastShape(xShape, offsetShape);
4239
this.variableNames.push('offset');
43-
offsetSnippet = broadcastSample('offset', offsetShape.length);
40+
offsetSnippet = 'getOffsetAtOutCoords()';
4441
}
4542

46-
let scaleSnippet = 'vec4 scale = vec4(1.0)';
43+
let scaleSnippet = 'vec4(1.0)';
4744
if (scaleShape != null) {
4845
broadcast_util.assertAndGetBroadcastShape(xShape, scaleShape);
4946
this.variableNames.push('scale');
50-
scaleSnippet = broadcastSample('scale', scaleShape.length);
47+
scaleSnippet = 'getScaleAtOutCoords()';
5148
}
5249

5350
this.outputShape = xShape;
5451
this.userCode = `
5552
void main() {
56-
ivec4 rc = getOutputCoords();
57-
58-
${offsetSnippet};
59-
${scaleSnippet};
53+
vec4 offset = ${offsetSnippet};
54+
vec4 scale = ${scaleSnippet};
6055
61-
vec4 x = getX(rc.x, rc.y, rc.z, rc.w);
62-
${meanSnippet};
63-
${varianceSnippet};
56+
vec4 x = getXAtOutCoords();
57+
vec4 mean = getMeanAtOutCoords();
58+
vec4 variance = getVarianceAtOutCoords();
6459
6560
vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
6661
@@ -69,14 +64,3 @@ export class BatchNormPackedProgram implements GPGPUProgram {
6964
`;
7065
}
7166
}
72-
73-
function broadcastSample(texName: string, rank: number): string {
74-
const texSampler = `get${texName.charAt(0).toUpperCase()}${texName.slice(1)}`;
75-
if (rank === 1) {
76-
return `
77-
vec4 ${texName}Sample = ${texSampler}(rc.w);
78-
vec4 ${texName} = vec4(${texName}Sample.xy, ${texName}Sample.xy);
79-
`;
80-
}
81-
return `vec4 ${texName} = ${texSampler}(rc.x, rc.y, rc.z, rc.w)`;
82-
}

src/kernels/webgl/clip_packed_gpu.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {GPGPUProgram} from './gpgpu_math';
19+
20+
export class ClipPackedProgram implements GPGPUProgram {
21+
variableNames = ['A'];
22+
usesPackedTextures = true;
23+
userCode: string;
24+
outputShape: number[];
25+
26+
constructor(aShape: number[], min: number, max: number) {
27+
this.outputShape = aShape;
28+
this.userCode = `
29+
void main() {
30+
vec4 value = getAAtOutCoords();
31+
32+
if (hasNaN(value)) {
33+
setOutput(value);
34+
return;
35+
}
36+
37+
setOutput(clamp(value, vec4(${min}), vec4(${max})));
38+
}
39+
`;
40+
}
41+
}

src/kernels/webgl/pack_gpu.ts

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,36 @@ export class PackProgram implements GPGPUProgram {
3333
this.outputShape = outputShape;
3434
const rank = outputShape.length;
3535

36-
const channels = getChannels('rc', rank);
37-
const dtype = getCoordsDataType(rank);
38-
const outOfBoundsCondition =
39-
getOutOfBoundsCondition(rank, outputShape, channels);
40-
const setup = getSetup(
41-
rank, outputShape[outputShape.length - 1],
42-
outputShape[outputShape.length - 2], channels);
43-
const output = getOutput(outputShape, channels);
44-
45-
this.userCode = `
46-
void main() {
47-
${dtype} rc = getOutputCoords();
48-
49-
if(${outOfBoundsCondition}) {
50-
gl_FragColor = vec4(0);
51-
} else {
52-
${setup}
53-
54-
setOutput(vec4(${output}));
36+
if (rank === 0) {
37+
this.userCode = `
38+
void main() {
39+
setOutput(vec4(getA(), 0., 0., 0.));
5540
}
56-
}
57-
`;
41+
`;
42+
} else {
43+
const channels = getChannels('rc', rank);
44+
const dtype = getCoordsDataType(rank);
45+
const outOfBoundsCondition =
46+
getOutOfBoundsCondition(rank, outputShape, channels);
47+
const setup = getSetup(
48+
rank, outputShape[outputShape.length - 1],
49+
outputShape[outputShape.length - 2], channels);
50+
const output = getOutput(outputShape, channels);
51+
52+
this.userCode = `
53+
void main() {
54+
${dtype} rc = getOutputCoords();
55+
56+
if(${outOfBoundsCondition}) {
57+
setOutput(vec4(0));
58+
} else {
59+
${setup}
60+
61+
setOutput(vec4(${output}));
62+
}
63+
}
64+
`;
65+
}
5866
}
5967
}
6068

0 commit comments

Comments
 (0)