Skip to content

Commit f782fc8

Browse files
authored
If matMul is going to be unpacked, then don't try to avoid expensive reshape in pointwise conv. (tensorflow#1610)
BUG This PR fixes tensorflow/tfjs#1335 Related: tensorflow/tfjs#1336
1 parent b072cae commit f782fc8

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

src/kernels/backend_webgl.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,9 +1709,19 @@ export class MathBackendWebGL implements KernelBackend {
17091709
// result from 2D to 4D.
17101710
const xShape = x.shape;
17111711
const xTexData = this.texData.get(x.dataId);
1712-
if (!ENV.get('WEBGL_LAZILY_UNPACK') ||
1713-
!ENV.get('WEBGL_PACK_BINARY_OPERATIONS') || xShape[2] % 2 === 0 ||
1714-
!xTexData.isPacked) {
1712+
const sharedMatMulDim = convInfo.inChannels;
1713+
const outerShapeX = xShape[0] * xShape[1] * xShape[2];
1714+
const outerShapeFilter = convInfo.outChannels;
1715+
1716+
// TODO: Once reduction ops are packed, batchMatMul will always be packed
1717+
// and we can remove this condition.
1718+
const batchMatMulWillBeUnpacked =
1719+
(outerShapeX === 1 || outerShapeFilter === 1) &&
1720+
sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
1721+
const reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked;
1722+
1723+
if (batchMatMulWillBeUnpacked || !ENV.get('WEBGL_LAZILY_UNPACK') ||
1724+
!ENV.get('WEBGL_PACK_BINARY_OPERATIONS') || !reshapeWillBeExpensive) {
17151725
const xReshaped =
17161726
this.reshape(
17171727
x, [1, xShape[0] * xShape[1] * xShape[2], convInfo.inChannels]) as

src/ops/conv2d_test.ts

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ function generateCaseInputs(totalSizeTensor: number, totalSizeFilter: number) {
3434
return {input: inp, filter: filt};
3535
}
3636

37-
describeWithFlags('im2col', PACKED_ENVS, () => {
38-
it('should not leak memory', () => {
37+
describeWithFlags('conv to matmul', PACKED_ENVS, () => {
38+
it('im2col should not leak memory', () => {
3939
const inputDepth = 1;
4040
const inputShape: [number, number, number] = [2, 2, inputDepth];
4141
const outputDepth = 1;
@@ -55,6 +55,26 @@ describeWithFlags('im2col', PACKED_ENVS, () => {
5555

5656
expect(endNumBytes - startNumBytes).toEqual(4);
5757
});
58+
59+
it('pointwise conv should work when matmul is unpacked', () => {
60+
const inputDepth =
61+
1001; // this number must be greater than MATMUL_SHARED_DIM_THRESHOLD
62+
// for matmul to be unpacked
63+
const inputShape: [number, number, number] = [3, 3, inputDepth];
64+
const outputDepth = 1;
65+
const fSize = 1;
66+
const pad = 'same';
67+
const stride: [number, number] = [1, 1];
68+
69+
let x = tf.randomNormal(inputShape) as tf.Tensor3D;
70+
x = x.add(1); // this packs x so we can test the case where we mistakenly
71+
// want to avoid expensive reshape in pointwise conv2d even
72+
// though matmul is unpacked
73+
const w =
74+
tf.randomNormal([fSize, fSize, inputDepth, outputDepth]) as tf.Tensor4D;
75+
76+
expect(() => tf.conv2d(x, w, stride, pad)).not.toThrow();
77+
});
5878
});
5979

6080
describeWithFlags('conv2d', ALL_ENVS, () => {

0 commit comments

Comments
 (0)