Skip to content

Commit 2e3142b

Browse files
authored
Pack unary ops. (tensorflow#1505)
PERF
1 parent 4e7b780 commit 2e3142b

File tree

3 files changed

+88
-11
lines changed

3 files changed

+88
-11
lines changed

src/kernels/backend_webgl.ts

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ import {TransposeProgram} from './webgl/transpose_gpu';
106106
import * as unary_op from './webgl/unaryop_gpu';
107107
import {UnaryOpProgram} from './webgl/unaryop_gpu';
108108
import * as unary_packed_op from './webgl/unaryop_packed_gpu';
109+
import {UnaryOpPackedProgram} from './webgl/unaryop_packed_gpu';
109110
import {UnpackProgram} from './webgl/unpack_gpu';
110111
import * as webgl_util from './webgl/webgl_util';
111112
import {whereImpl} from './where_impl';
@@ -1440,7 +1441,12 @@ export class MathBackendWebGL implements KernelBackend {
14401441
}
14411442

14421443
exp<T extends Tensor>(x: T): T {
1443-
const program = new UnaryOpProgram(x.shape, unary_op.EXP);
1444+
let program: UnaryOpProgram | UnaryOpPackedProgram;
1445+
if(ENV.get('WEBGL_PACK')) {
1446+
program = new UnaryOpPackedProgram(x.shape, unary_op.EXP);
1447+
} else {
1448+
program = new UnaryOpProgram(x.shape, unary_op.EXP);
1449+
}
14441450
return this.compileAndRun(program, [x]) as T;
14451451
}
14461452

@@ -1450,7 +1456,12 @@ export class MathBackendWebGL implements KernelBackend {
14501456
}
14511457

14521458
log<T extends Tensor>(x: T): T {
1453-
const program = new UnaryOpProgram(x.shape, unary_op.LOG);
1459+
let program: UnaryOpProgram | UnaryOpPackedProgram;
1460+
if(ENV.get('WEBGL_PACK')) {
1461+
program = new UnaryOpPackedProgram(x.shape, unary_packed_op.LOG);
1462+
} else {
1463+
program = new UnaryOpProgram(x.shape, unary_op.LOG);
1464+
}
14541465
const customSetup = program.getCustomSetupFunc();
14551466
return this.compileAndRun(program, [x], null, customSetup) as T;
14561467
}
@@ -1481,7 +1492,12 @@ export class MathBackendWebGL implements KernelBackend {
14811492
}
14821493

14831494
relu<T extends Tensor>(x: T): T {
1484-
const program = new UnaryOpProgram(x.shape, unary_op.RELU);
1495+
let program: UnaryOpProgram | UnaryOpPackedProgram;
1496+
if(ENV.get('WEBGL_PACK')) {
1497+
program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU);
1498+
} else {
1499+
program = new UnaryOpProgram(x.shape, unary_op.RELU);
1500+
}
14851501
return this.compileAndRun(program, [x]) as T;
14861502
}
14871503

src/kernels/webgl/shader_compiler.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -504,15 +504,15 @@ function getOutputPackedNDCoords(
504504
int b${b} = index / ${texelsInBatchN};
505505
index -= b${b} * ${texelsInBatchN};
506506
` + batches;
507-
coords = `b${b}, ` + coords;
507+
coords = `b${b}, ` + coords;
508508
}
509509

510510
return `
511511
ivec${shape.length} getOutputCoords() {
512512
ivec2 resTexRC = ivec2(resultUV.yx *
513513
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
514514
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
515-
515+
516516
${batches}
517517
518518
int b = index / ${texelsInBatch};
@@ -984,7 +984,7 @@ function getSampler3D(inputInfo: InputInfo): string {
984984
`;
985985
}
986986

987-
function getPackedSamplerND(inputInfo: InputInfo): string {
987+
function getPackedSamplerND(inputInfo: InputInfo): string {
988988
const shape = inputInfo.shapeInfo.logicalShape;
989989
const rank = shape.length;
990990
const texName = inputInfo.name;
@@ -1001,7 +1001,7 @@ function getPackedSamplerND(inputInfo: InputInfo): string {
10011001
let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;
10021002
for (let b = 2; b < rank - 1; b++) {
10031003
params = `int b${b}, ` + params;
1004-
texelsInBatch *= shape[rank - b - 1];
1004+
texelsInBatch *= shape[rank - b - 1];
10051005
index = `b${b} * ${texelsInBatch} + ` + index;
10061006
}
10071007
const glsl = getGlslDifferences();
@@ -1011,7 +1011,7 @@ function getPackedSamplerND(inputInfo: InputInfo): string {
10111011
int texR = index / ${texNumC};
10121012
int texC = index - texR * ${texNumC};
10131013
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});
1014-
return ${glsl.texture2D}(${texName}, uv);
1014+
return ${glsl.texture2D}(${texName}, uv);
10151015
}
10161016
`;
10171017
}

src/kernels/webgl/unaryop_packed_gpu.ts

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,70 @@
1515
* =============================================================================
1616
*/
1717

18-
const CHECK_NAN_SNIPPET = `if (hasNaN(x)) return x;`;
18+
import {GPGPUContext} from './gpgpu_context';
19+
import {GPGPUProgram} from './gpgpu_math';
1920

2021
export const LINEAR = `return x;`;
2122

22-
export const RELU =
23-
CHECK_NAN_SNIPPET + `return x * vec4(greaterThanEqual(x, vec4(0.0)));`;
23+
export const LOG = `
24+
vec4 result = log(x);
25+
vec4 isNaN = vec4(lessThan(x, vec4(0.0)));
26+
result.r = isNaN.r == 1.0 ? NAN : result.r;
27+
result.g = isNaN.g == 1.0 ? NAN : result.g;
28+
result.b = isNaN.b == 1.0 ? NAN : result.b;
29+
result.a = isNaN.a == 1.0 ? NAN : result.a;
30+
31+
return result;
32+
`;
33+
34+
export const RELU = `
35+
vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
36+
37+
result.r = isNaN(x.r) ? x.r : result.r;
38+
result.g = isNaN(x.g) ? x.g : result.g;
39+
result.b = isNaN(x.b) ? x.b : result.b;
40+
result.a = isNaN(x.a) ? x.a : result.a;
41+
42+
return result;
43+
`;
44+
45+
export class UnaryOpPackedProgram implements GPGPUProgram {
46+
variableNames = ['A'];
47+
userCode: string;
48+
outputShape: number[];
49+
usesPackedTextures = true;
50+
51+
// Caching uniform location for speed.
52+
startLoc: WebGLUniformLocation;
53+
54+
constructor(aShape: number[], opSnippet: string) {
55+
this.outputShape = aShape;
56+
this.userCode = `
57+
uniform float NAN;
58+
vec4 unaryOperation(vec4 x) {
59+
${opSnippet}
60+
}
61+
62+
void main() {
63+
vec4 x = getAAtOutCoords();
64+
vec4 y = unaryOperation(x);
65+
66+
setOutput(y);
67+
}
68+
`;
69+
}
70+
71+
getCustomSetupFunc() {
72+
return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
73+
if (this.startLoc == null) {
74+
this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'NAN');
75+
if (this.startLoc == null) {
76+
// This means the compiler has optimized and realized it doesn't need
77+
// the uniform.
78+
return;
79+
}
80+
}
81+
gpgpu.gl.uniform1f(this.startLoc, NaN);
82+
};
83+
}
84+
}

0 commit comments

Comments
 (0)