Skip to content

Commit 0aa9e04

Browse files
authored
Add partial support for broadcasting of bias in fusedMatMul. (tensorflow#1502)
PERF
1 parent 017a3bf commit 0aa9e04

File tree

6 files changed

+58
-35
lines changed

6 files changed

+58
-35
lines changed

src/kernels/backend.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
19-
import {FusableActivation} from '../ops/fused_util';
19+
import {Activation} from '../ops/fused_util';
2020
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
2121
import {DataType, DataValues, Rank, ShapeMap} from '../types';
2222

@@ -124,7 +124,7 @@ export class KernelBackend implements TensorStorage, BackendTimer {
124124

125125
fusedBatchMatMul(
126126
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
127-
bias?: Tensor3D, activation?: FusableActivation): Tensor3D {
127+
bias?: Tensor, activation?: Activation): Tensor3D {
128128
throw new Error('Not yet implemented');
129129
}
130130

src/kernels/backend_cpu.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import * as broadcast_util from '../ops/broadcast_util';
2525
import * as concat_util from '../ops/concat_util';
2626
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
2727
import * as erf_util from '../ops/erf_util';
28-
import {FusableActivation} from '../ops/fused_util';
28+
import {Activation} from '../ops/fused_util';
2929
import * as gather_nd_util from '../ops/gather_nd_util';
3030
import * as ops from '../ops/ops';
3131
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../ops/ops';
@@ -46,7 +46,7 @@ import {topkImpl} from './topk_impl';
4646
import {whereImpl} from './where_impl';
4747

4848
function mapActivation(
49-
backend: MathBackendCPU, activation: FusableActivation, x: Tensor): Tensor {
49+
backend: MathBackendCPU, activation: Activation, x: Tensor): Tensor {
5050
if (activation === 'linear') {
5151
return backend.linear(x);
5252
} else if (activation === 'relu') {
@@ -484,7 +484,7 @@ export class MathBackendCPU implements KernelBackend {
484484

485485
fusedBatchMatMul(
486486
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
487-
bias?: Tensor3D, activation?: FusableActivation): Tensor3D {
487+
bias?: Tensor, activation?: Activation): Tensor3D {
488488
let result = this.batchMatMul(a, b, transposeA, transposeB);
489489
if (bias) {
490490
result = this.add(result, bias) as Tensor3D;

src/kernels/backend_webgl.ts

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import * as axis_util from '../ops/axis_util';
2525
import * as broadcast_util from '../ops/broadcast_util';
2626
import {computeOutShape} from '../ops/concat_util';
2727
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
28-
import {FusableActivation} from '../ops/fused_util';
28+
import {Activation} from '../ops/fused_util';
2929
import * as gather_nd_util from '../ops/gather_nd_util';
3030
import * as reduce_util from '../ops/reduce_util';
3131
import * as scatter_nd_util from '../ops/scatter_nd_util';
@@ -132,7 +132,7 @@ export interface WebGLTimingInfo extends TimingInfo {
132132
}
133133

134134
function mapActivationToShaderProgram(
135-
activation: FusableActivation, packed = false): string {
135+
activation: Activation, packed = false): string {
136136
if (activation === 'linear') {
137137
if (packed) {
138138
return unary_packed_op.LINEAR;
@@ -789,7 +789,7 @@ export class MathBackendWebGL implements KernelBackend {
789789

790790
fusedBatchMatMul(
791791
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
792-
bias?: Tensor3D, activation?: FusableActivation): Tensor3D {
792+
bias?: Tensor, activation?: Activation): Tensor3D {
793793
const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
794794
const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
795795
const [batch, , ] = a.shape;
@@ -807,17 +807,17 @@ export class MathBackendWebGL implements KernelBackend {
807807
activation ? mapActivationToShaderProgram(activation, true) : null);
808808
const output =
809809
this.makePackedTensor(program.outputShape, dtype) as Tensor2D;
810-
const inputs = [aSqueezed, bSqueezed];
810+
const inputs: TensorHandle[] = [aSqueezed, bSqueezed];
811811
if (bias) {
812-
inputs.push(bias.as2D(bias.shape[1], bias.shape[2]));
812+
inputs.push(bias);
813813
}
814814
const result = this.compileAndRun<Tensor2D>(program, inputs, output);
815815
return result.reshape([1, result.shape[0], result.shape[1]]);
816816
} else {
817817
const program = new MatMulProgram(
818818
a.shape, b.shape, transposeA, transposeB, !!bias,
819819
activation ? mapActivationToShaderProgram(activation) : null);
820-
const inputs = [a, b];
820+
const inputs: TensorHandle[] = [a, b];
821821
if (bias) {
822822
inputs.push(bias);
823823
}
@@ -1441,8 +1441,8 @@ export class MathBackendWebGL implements KernelBackend {
14411441
}
14421442

14431443
exp<T extends Tensor>(x: T): T {
1444-
let program: UnaryOpProgram | UnaryOpPackedProgram;
1445-
if(ENV.get('WEBGL_PACK')) {
1444+
let program: UnaryOpProgram|UnaryOpPackedProgram;
1445+
if (ENV.get('WEBGL_PACK')) {
14461446
program = new UnaryOpPackedProgram(x.shape, unary_op.EXP);
14471447
} else {
14481448
program = new UnaryOpProgram(x.shape, unary_op.EXP);
@@ -1456,8 +1456,8 @@ export class MathBackendWebGL implements KernelBackend {
14561456
}
14571457

14581458
log<T extends Tensor>(x: T): T {
1459-
let program: UnaryOpProgram | UnaryOpPackedProgram;
1460-
if(ENV.get('WEBGL_PACK')) {
1459+
let program: UnaryOpProgram|UnaryOpPackedProgram;
1460+
if (ENV.get('WEBGL_PACK')) {
14611461
program = new UnaryOpPackedProgram(x.shape, unary_packed_op.LOG);
14621462
} else {
14631463
program = new UnaryOpProgram(x.shape, unary_op.LOG);
@@ -1492,8 +1492,8 @@ export class MathBackendWebGL implements KernelBackend {
14921492
}
14931493

14941494
relu<T extends Tensor>(x: T): T {
1495-
let program: UnaryOpProgram | UnaryOpPackedProgram;
1496-
if(ENV.get('WEBGL_PACK')) {
1495+
let program: UnaryOpProgram|UnaryOpPackedProgram;
1496+
if (ENV.get('WEBGL_PACK')) {
14971497
program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU);
14981498
} else {
14991499
program = new UnaryOpProgram(x.shape, unary_op.RELU);

src/ops/fused_ops.ts

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ import {makeTypesMatch} from '../tensor_util';
2222
import {convertToTensor} from '../tensor_util_env';
2323
import {TensorLike} from '../types';
2424
import * as util from '../util';
25-
import {FusableActivation} from './fused_util';
25+
26+
import * as broadcast_util from './broadcast_util';
27+
import {Activation} from './fused_util';
2628

2729
/**
2830
* Computes the dot product of two matrices with optional activation and bias.
@@ -45,7 +47,7 @@ import {FusableActivation} from './fused_util';
4547
/** @doc {heading: 'Operations', subheading: 'Matrices', namespace: 'fused'} */
4648
function matMul_<T extends Tensor>(
4749
a: T|TensorLike, b: T|TensorLike, transposeA = false, transposeB = false,
48-
bias?: T|TensorLike, activation: FusableActivation = 'linear'): T {
50+
bias?: Tensor|TensorLike, activation: Activation = 'linear'): T {
4951
let $a = convertToTensor(a, 'a', 'fused matMul');
5052
let $b = convertToTensor(b, 'b', 'fused matMul');
5153
[$a, $b] = makeTypesMatch($a, $b);
@@ -89,21 +91,15 @@ function matMul_<T extends Tensor>(
8991
$a.as3D(batchDimA, outerShapeA, innerShapeA);
9092
const b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) :
9193
$b.as3D(batchDimB, innerShapeB, outerShapeB);
92-
let bias3D: Tensor3D;
94+
95+
let $bias: Tensor;
9396
if (bias != null) {
94-
let $bias = convertToTensor(bias, 'bias', 'fused matMul');
97+
$bias = convertToTensor(bias, 'bias', 'fused matMul');
9598
[$bias] = makeTypesMatch($bias, $a);
9699

97-
const rowsBias = $bias.shape[$bias.rank - 2];
98-
const colsBias = $bias.shape[$bias.rank - 1];
99-
100100
util.assert(
101-
outerShapeA === rowsBias && outerShapeB === colsBias,
102-
`Error in fused matMul: inner dimensions of bias shape ${
103-
$bias.shape} must match outer shapes (${outerShapeA}) and (${
104-
outerShapeB}) of Tensors with shapes ${$a.shape} and ${$b.shape}`);
105-
106-
bias3D = $bias.as3D(batchDimA, rowsBias, colsBias);
101+
broadcast_util.getBroadcastDims(outShape, $bias.shape).length === 0,
102+
`Error in fused matMul: broadcasting is not supported for bias add.`);
107103
}
108104

109105
const grad = (dy: Tensor3D, saved: Tensor[]) => {
@@ -120,7 +116,20 @@ function matMul_<T extends Tensor>(
120116
`implemented yet.`);
121117
}
122118

123-
const biasGradient = bias != null ? {$bias: () => dyActivation} : {};
119+
let biasGradient = {};
120+
if (bias != null) {
121+
biasGradient = {
122+
$bias: () => {
123+
let res = dyActivation;
124+
const reduceAxes =
125+
broadcast_util.getReductionAxes($bias.shape, outShape);
126+
if (reduceAxes.length > 0) {
127+
res = res.sum(reduceAxes);
128+
}
129+
return res.reshape($bias.shape);
130+
}
131+
};
132+
}
124133

125134
if (!transposeA && !transposeB) {
126135
return Object.assign(
@@ -155,14 +164,16 @@ function matMul_<T extends Tensor>(
155164

156165
const inputs: {$a: Tensor, $b: Tensor, $bias?: Tensor} = {$a: a3D, $b: b3D};
157166
if (bias != null) {
158-
inputs.$bias = bias3D;
167+
inputs.$bias = $bias;
159168
}
160169

161170
const res = ENV.engine.runKernel(
162171
(backend, save) => save(backend.fusedBatchMatMul(
163-
a3D, b3D, transposeA, transposeB, bias3D, activation)),
172+
a3D, b3D, transposeA, transposeB, $bias, activation)),
164173
inputs, grad);
165174
return res.reshape(outShape) as T;
166175
}
167176

168-
export const matMul = op({matMul_});
177+
export const matMul = op({matMul_});
178+
179+
export {Activation};

src/ops/fused_test.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
6161
expectArraysClose(d, [1, 9, 0, 21]);
6262
});
6363

64+
it('A x B with relu and broadcasted bias', () => {
65+
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
66+
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
67+
const c = tf.tensor1d([1, 1]);
68+
const act: tf.fused.Activation = 'relu';
69+
70+
const d = tf.fused.matMul(a, b, false, false, c, act);
71+
72+
expect(d.shape).toEqual([2, 2]);
73+
expectArraysClose(d, [1, 9, 0, 21]);
74+
});
75+
6476
it('A x B with bias only', () => {
6577
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
6678
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);

src/ops/fused_util.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
* =============================================================================
1616
*/
1717

18-
export type FusableActivation = 'linear'|'relu';
18+
export type Activation = 'linear'|'relu';

0 commit comments

Comments
 (0)