Skip to content

Commit 5cdd7f5

Browse files
authored
For matMul gradient broadcasting logic must account for temporary reshaping of inputs. (tensorflow#1598)
BUG
1 parent 3be0717 commit 5cdd7f5

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

src/ops/fused_ops.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ function matMul_<T extends Tensor>(
120120
biasGradient = {
121121
$bias: () => {
122122
let res = dyActivation;
123+
// Using dyActivation as reference shape because outputShape does not
124+
// account for the fact that we temporarily reshape inputs to 3D as
125+
// part of batched matMul.
123126
const reduceAxes =
124-
broadcast_util.getReductionAxes($bias.shape, outShape);
127+
broadcast_util.getReductionAxes($bias.shape, dyActivation.shape);
125128
if (reduceAxes.length > 0) {
126129
res = res.sum(reduceAxes);
127130
}

src/ops/fused_test.ts

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
3333
it('A x B with relu', () => {
3434
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
3535
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
36+
const transposeA = false;
37+
const transposeB = false;
3638

37-
const c = tf.fused.matMul(a, b, false, false, null, 'relu');
39+
const c = tf.fused.matMul(a, b, transposeA, transposeB, null, 'relu');
3840

3941
expect(c.shape).toEqual([2, 2]);
4042
expectArraysClose(c, [0, 8, 0, 20]);
@@ -43,8 +45,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
4345
it('A x B with relu transpose', () => {
4446
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
4547
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [2, 3]);
48+
const transposeA = false;
49+
const transposeB = true;
4650

47-
const c = tf.fused.matMul(a, b, false, true, null, 'relu');
51+
const c = tf.fused.matMul(a, b, transposeA, transposeB, null, 'relu');
4852

4953
expect(c.shape).toEqual([2, 2]);
5054
expectArraysClose(c, [0, 9, 0, 24]);
@@ -54,8 +58,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
5458
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
5559
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
5660
const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
61+
const transposeA = false;
62+
const transposeB = false;
5763

58-
const d = tf.fused.matMul(a, b, false, false, c, 'relu');
64+
const d = tf.fused.matMul(a, b, transposeA, transposeB, c, 'relu');
5965

6066
expect(d.shape).toEqual([2, 2]);
6167
expectArraysClose(d, [1, 9, 0, 21]);
@@ -66,8 +72,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
6672
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
6773
const c = tf.tensor1d([1, 1]);
6874
const act: tf.fused.Activation = 'relu';
75+
const transposeA = false;
76+
const transposeB = false;
6977

70-
const d = tf.fused.matMul(a, b, false, false, c, act);
78+
const d = tf.fused.matMul(a, b, transposeA, transposeB, c, act);
7179

7280
expect(d.shape).toEqual([2, 2]);
7381
expectArraysClose(d, [1, 9, 0, 21]);
@@ -78,8 +86,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
7886
const b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]);
7987
const c = tf.tensor2d([1, 2], [1, 2]);
8088
const act: tf.fused.Activation = 'relu';
89+
const transposeA = false;
90+
const transposeB = false;
8191

82-
const d = tf.fused.matMul(a, b, false, false, c, act);
92+
const d = tf.fused.matMul(a, b, transposeA, transposeB, c, act);
8393

8494
expect(d.shape).toEqual([2, 2, 2]);
8595
expectArraysClose(d, [2, 6, 0, 18, 0, 30, 0, 42]);
@@ -89,8 +99,10 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
8999
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
90100
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
91101
const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
102+
const transposeA = false;
103+
const transposeB = false;
92104

93-
const d = tf.fused.matMul(a, b, false, false, c, 'linear');
105+
const d = tf.fused.matMul(a, b, transposeA, transposeB, c, 'linear');
94106

95107
expect(d.shape).toEqual([2, 2]);
96108
expectArraysClose(d, [1, 9, -2, 21]);
@@ -100,14 +112,16 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
100112
const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
101113
const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
102114
const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
115+
const transposeA = false;
116+
const transposeB = false;
103117

104118
const grads = tf.grads((a, b) => {
105-
const prod = tf.matMul(a, b, false, false);
119+
const prod = tf.matMul(a, b, transposeA, transposeB);
106120
return tf.relu(prod);
107121
});
108122

109123
const fusedGrads = tf.grads((a, b) => {
110-
return tf.fused.matMul(a, b, false, false, null, 'relu');
124+
return tf.fused.matMul(a, b, transposeA, transposeB, null, 'relu');
111125
});
112126

113127
const [da, db] = grads([a, b], dy);
@@ -120,17 +134,19 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
120134
const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
121135
const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
122136
const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
137+
const transposeA = false;
138+
const transposeB = false;
123139

124140
const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
125141

126142
const grads = tf.grads((a, b, c) => {
127-
const prod = tf.matMul(a, b, false, false);
143+
const prod = tf.matMul(a, b, transposeA, transposeB);
128144
const sum = tf.add(prod, c);
129145
return tf.relu(sum);
130146
});
131147

132148
const fusedGrads = tf.grads((a, b, c) => {
133-
return tf.fused.matMul(a, b, false, false, c, 'relu');
149+
return tf.fused.matMul(a, b, transposeA, transposeB, c, 'relu');
134150
});
135151

136152
const [da, db, dc] = grads([a, b, c], dy);
@@ -145,17 +161,46 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
145161
const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [3, 2]);
146162
const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
147163
const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
164+
const transposeA = true;
165+
const transposeB = false;
148166

149167
const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
150168

151169
const grads = tf.grads((a, b, c) => {
152-
const prod = tf.matMul(a, b, true, false);
170+
const prod = tf.matMul(a, b, transposeA, transposeB);
153171
const sum = tf.add(prod, c);
154172
return tf.relu(sum);
155173
});
156174

157175
const fusedGrads = tf.grads((a, b, c) => {
158-
return tf.fused.matMul(a, b, true, false, c, 'relu');
176+
return tf.fused.matMul(a, b, transposeA, transposeB, c, 'relu');
177+
});
178+
179+
const [da, db, dc] = grads([a, b, c], dy);
180+
const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy);
181+
182+
expectArraysClose(da, fusedDa);
183+
expectArraysClose(db, fusedDb);
184+
expectArraysClose(dc, fusedDc);
185+
});
186+
187+
it('A x B with relu and broadcasted bias gradient', () => {
188+
const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
189+
const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
190+
const c = tf.tensor2d([[1]]);
191+
const transposeA = false;
192+
const transposeB = false;
193+
194+
const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
195+
196+
const grads = tf.grads((a, b, c) => {
197+
const prod = tf.matMul(a, b, transposeA, transposeB);
198+
const sum = tf.add(prod, c);
199+
return tf.relu(sum);
200+
});
201+
202+
const fusedGrads = tf.grads((a, b, c) => {
203+
return tf.fused.matMul(a, b, transposeA, transposeB, c, 'relu');
159204
});
160205

161206
const [da, db, dc] = grads([a, b, c], dy);

0 commit comments

Comments
 (0)