Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 9f5ec47

Browse files
syt123450dsmilkov
authored andcommitted
Add NCHW dataFormat support for conv2d (#1791)
This PR makes `conv2d` support `NCHW` dataFormat, `NCHW` works well in inference and gradient. The `NCHW` dataFormat is supported in CPU and WebGL backend. Feature requested in [tensorflow/tfjs#1648](tensorflow/tfjs#1648). **Features:** * Make conv2d support NCHW. * Make conv2dDerInput support NCHW. * Make conv2dDerFilter support NCHW. * Add a helper method `convertConv2DDataFormat` to convert `'NHWC'|'NCHW'` format into `'channelsLast'|'channelsFirst'` format. **Tests:** * Add unit tests for NCHW conv2D inference. * Add unit tests for NCHW conv2D gradient. * Add unit tests for `convertConv2DDataFormat`. **Changes in Kernel:** * Make CPU kernel function `conv2d`, `conv2dDerInput`, `conv2dDerFilter` support `channelsFirst` and `channelsLast` dataFormat. * Make GPU kernel function `conv2d`, `conv2dDerInput`, `conv2dDerFilter` support `channelsFirst` and `channelsLast` dataFormat. * Make GPI kernel function `conv2dByMatMul`, `conv2dWithIm2Row` support `channelsFirst` and `channelsLast` dataFormat. * Make GPGPUProgram `Conv2DProgram`, `Conv2DDerInputProgram`, `Conv2DDerFilterProgram`, `Im2ColPackedProgram` support `channelsFirst` and `channelsLast` dataFormat.
1 parent 6183caf commit 9f5ec47

9 files changed

+429
-96
lines changed

src/backends/cpu/backend_cpu.ts

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,40 +1558,52 @@ export class MathBackendCPU implements KernelBackend {
15581558
const dilationWidth = convInfo.dilationWidth;
15591559
const padLeft = convInfo.padInfo.left;
15601560
const padTop = convInfo.padInfo.top;
1561+
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
1562+
15611563
const y = ops.buffer(convInfo.outShape, x.dtype as 'float32');
15621564

1565+
const xBatchStride = x.strides[0];
1566+
const xRowStride = isChannelsLast ? x.strides[1] : x.strides[2];
1567+
const xColStride = isChannelsLast ? x.strides[2] : 1;
1568+
const xChannelStride = isChannelsLast ? 1 : x.strides[1];
1569+
const yBatchStride = y.strides[0];
1570+
const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
1571+
const yColStride = isChannelsLast ? y.strides[2] : 1;
1572+
const yChannelStride = isChannelsLast ? 1 : y.strides[1];
1573+
15631574
const xVals = this.readSync(x.dataId) as TypedArray;
15641575
const wVals = this.readSync(filter.dataId) as TypedArray;
15651576
const yVals = y.values;
15661577

15671578
for (let b = 0; b < convInfo.batchSize; ++b) {
1568-
const xOffset1 = b * x.strides[0];
1569-
const yOffset1 = b * y.strides[0];
1579+
const xOffset1 = b * xBatchStride;
1580+
const yOffset1 = b * yBatchStride;
15701581
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
1571-
const yOffset2 = yOffset1 + yR * y.strides[1];
1582+
const yOffset2 = yOffset1 + yR * yRowStride;
15721583
const xRCorner = yR * convInfo.strideHeight - padTop;
15731584
for (let wR = 0; wR < filterHeight; wR++) {
15741585
const xR = xRCorner + wR * dilationHeight;
15751586
if (xR < 0 || xR >= convInfo.inHeight) {
15761587
continue;
15771588
}
15781589
const wOffset1 = wR * filter.strides[0];
1579-
const xOffset2 = xOffset1 + xR * x.strides[1];
1590+
const xOffset2 = xOffset1 + xR * xRowStride;
15801591
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
1581-
const yOffset3 = yOffset2 + yC * convInfo.outChannels;
1592+
const yOffset3 = yOffset2 + yC * yColStride;
15821593
const xCCorner = yC * convInfo.strideWidth - padLeft;
15831594
for (let wC = 0; wC < filterWidth; wC++) {
15841595
const xC = xCCorner + wC * dilationWidth;
15851596
if (xC < 0 || xC >= convInfo.inWidth) {
15861597
continue;
15871598
}
15881599
const wOffset2 = wOffset1 + wC * filter.strides[1];
1589-
const xOffset3 = xOffset2 + xC * convInfo.inChannels;
1600+
const xOffset3 = xOffset2 + xC * xColStride;
15901601
let wOffset3 = wOffset2;
15911602
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
1592-
const xVal = xVals[xOffset3 + d1];
1603+
const xVal = xVals[xOffset3 + d1 * xChannelStride];
15931604
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
1594-
yVals[yOffset3 + d2] += xVal * wVals[wOffset3 + d2];
1605+
yVals[yOffset3 + d2 * yChannelStride] +=
1606+
xVal * wVals[wOffset3 + d2];
15951607
}
15961608
wOffset3 += convInfo.outChannels;
15971609
}
@@ -1677,9 +1689,7 @@ export class MathBackendCPU implements KernelBackend {
16771689

16781690
const dx = ops.buffer<Rank.R4>(convInfo.inShape, 'float32');
16791691
const dxValues = dx.values;
1680-
const [dxS0, dxS1, dxS2] = dx.strides;
16811692
const dyValues = this.readSync(dy.dataId) as TypedArray;
1682-
const [dyS0, dyS1, dyS2] = dy.strides;
16831693
const fltValues = this.readSync(filter.dataId) as TypedArray;
16841694
const [fltS0, fltS1, fltS2] = filter.strides;
16851695
const {
@@ -1693,11 +1703,22 @@ export class MathBackendCPU implements KernelBackend {
16931703
outHeight,
16941704
outWidth,
16951705
strideHeight,
1696-
strideWidth
1706+
strideWidth,
1707+
dataFormat
16971708
} = convInfo;
16981709
const topPad = filterHeight - 1 - convInfo.padInfo.top;
16991710
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
17001711

1712+
const isChannelsLast = dataFormat === 'channelsLast';
1713+
const xBatchStride = dx.strides[0];
1714+
const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
1715+
const xColStride = isChannelsLast ? dx.strides[2] : 1;
1716+
const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
1717+
const yBatchStride = dy.strides[0];
1718+
const yRowStride = isChannelsLast ? dy.strides[1] : dy.strides[2];
1719+
const yColStride = isChannelsLast ? dy.strides[2] : 1;
1720+
const yChannelStride = isChannelsLast ? 1 : dy.strides[1];
1721+
17011722
for (let b = 0; b < batchSize; ++b) {
17021723
for (let d1 = 0; d1 < inChannels; ++d1) {
17031724
for (let xR = 0; xR < inHeight; ++xR) {
@@ -1718,18 +1739,21 @@ export class MathBackendCPU implements KernelBackend {
17181739

17191740
for (let yC = xCMin; yC < yCMax; ++yC) {
17201741
const wC = yC * strideWidth - xCCorner;
1721-
const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
1742+
const dyOffset =
1743+
yBatchStride * b + yRowStride * yR + yColStride * yC;
17221744
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
17231745
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
17241746

17251747
for (let d2 = 0; d2 < outChannels; ++d2) {
1726-
const pixel = dyValues[dyOffset + d2];
1748+
const pixel = dyValues[dyOffset + yChannelStride * d2];
17271749
const weight = fltValues[fltOffset + d2];
17281750
dotProd += pixel * weight;
17291751
}
17301752
}
17311753
}
1732-
dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
1754+
const dxOffset = xBatchStride * b + xRowStride * xR +
1755+
xColStride * xC + xChannelStride * d1;
1756+
dxValues[dxOffset] = dotProd;
17331757
}
17341758
}
17351759
}
@@ -1829,6 +1853,7 @@ export class MathBackendCPU implements KernelBackend {
18291853
const strideWidth = convInfo.strideWidth;
18301854
const filterHeight = convInfo.filterHeight;
18311855
const filterWidth = convInfo.filterWidth;
1856+
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
18321857
const dW = ops.buffer<Rank.R4>(convInfo.filterShape, 'float32');
18331858

18341859
const leftPad = convInfo.padInfo.left;
@@ -1854,7 +1879,13 @@ export class MathBackendCPU implements KernelBackend {
18541879
const xR = wR + yR * strideHeight - topPad;
18551880
for (let yC = yCMin; yC < yCMax; ++yC) {
18561881
const xC = wC + yC * strideWidth - leftPad;
1857-
dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
1882+
if (isChannelsLast) {
1883+
dotProd +=
1884+
xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
1885+
} else {
1886+
dotProd +=
1887+
xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC);
1888+
}
18581889
}
18591890
}
18601891
}

src/backends/webgl/backend_webgl.ts

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,6 +1843,7 @@ export class MathBackendWebGL implements KernelBackend {
18431843
const sharedMatMulDim = convInfo.inChannels;
18441844
const outerShapeX = xShape[0] * xShape[1] * xShape[2];
18451845
const outerShapeFilter = convInfo.outChannels;
1846+
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
18461847
const transposeA = false;
18471848
const transposeB = false;
18481849

@@ -1856,10 +1857,10 @@ export class MathBackendWebGL implements KernelBackend {
18561857
if (batchMatMulWillBeUnpacked || !ENV.getBool('WEBGL_LAZILY_UNPACK') ||
18571858
!ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ||
18581859
!reshapeWillBeExpensive) {
1860+
const targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] :
1861+
xShape[0] * xShape[2] * xShape[3];
18591862
const xReshaped =
1860-
this.reshape(
1861-
x, [1, xShape[0] * xShape[1] * xShape[2], convInfo.inChannels]) as
1862-
Tensor3D;
1863+
this.reshape(x, [1, targetShape, convInfo.inChannels]) as Tensor3D;
18631864
const filterReshaped =
18641865
this.reshape(
18651866
filter, [1, convInfo.inChannels, convInfo.outChannels]) as
@@ -1879,17 +1880,19 @@ export class MathBackendWebGL implements KernelBackend {
18791880
}
18801881

18811882
// Following optimization is specific to packed |x| with odd row count
1882-
// ('row count' refers to x.shape[2]): we avoid expensive packed 2x2
1883-
// reshape by padding row count to next, even number. When x.shape[2] is
1884-
// odd, the result of packed batchMatMul is the same (has the same texture
1885-
// layout and and values in the texture) as it is for even x.shape[2] + 1.
1886-
// We make the odd-rows tensor to look like even-rows tensor before the
1887-
// operation and, after the batchMatMul, fix the even-rows result to have
1888-
// odd number of rows.
1889-
const xReshaped =
1890-
Tensor.make(
1891-
[1, xShape[0] * xShape[1] * (xShape[2] + 1), convInfo.inChannels],
1892-
{dataId: x.dataId}, x.dtype, this) as Tensor3D;
1883+
// (For example, in channelLast mode, 'row count' refers to x.shape[2]):
1884+
// we avoid expensive packed 2x2 reshape by padding row count to next,
1885+
// even number. When x.shape[2] is odd, the result of packed batchMatMul is
1886+
// the same (has the same texture layout and and values in the texture) as
1887+
// it is for even x.shape[2] + 1. We make the odd-rows tensor to look like
1888+
// even-rows tensor before the operation and, after the batchMatMul,
1889+
// fix the even-rows result to have odd number of rows.
1890+
const targetShape = isChannelsLast ?
1891+
xShape[0] * xShape[1] * (xShape[2] + 1) :
1892+
xShape[0] * xShape[2] * (xShape[3] + 1);
1893+
const xReshaped = Tensor.make(
1894+
[1, targetShape, convInfo.inChannels],
1895+
{dataId: x.dataId}, x.dtype, this) as Tensor3D;
18931896

18941897
// xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.
18951898
// Decrementing row count, after batchMatMul->...->compileProgram leads to
@@ -1948,8 +1951,11 @@ export class MathBackendWebGL implements KernelBackend {
19481951
inChannels,
19491952
outWidth,
19501953
outHeight,
1954+
dataFormat
19511955
} = convInfo;
19521956

1957+
const isChannelsLast = dataFormat === 'channelsLast';
1958+
19531959
const sharedDim = filterWidth * filterHeight * inChannels;
19541960
const numCols = outHeight * outWidth;
19551961
const x2ColShape = [sharedDim, numCols];
@@ -1982,7 +1988,11 @@ export class MathBackendWebGL implements KernelBackend {
19821988
}
19831989
const product = this.compileAndRun<Tensor4D>(matmulProgram, inputs);
19841990

1985-
return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
1991+
if (isChannelsLast) {
1992+
return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
1993+
} else {
1994+
return product.reshape([1, convInfo.outChannels, outHeight, outWidth]);
1995+
}
19861996
}
19871997

19881998
fusedConv2d(

src/backends/webgl/conv_backprop_gpu.ts

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export class Conv2DDerFilterProgram implements GPGPUProgram {
3030
const strideWidth = convInfo.strideWidth;
3131
const padTop = convInfo.padInfo.top;
3232
const padLeft = convInfo.padInfo.left;
33+
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
3334

3435
this.userCode = `
3536
void main() {
@@ -58,9 +59,16 @@ export class Conv2DDerFilterProgram implements GPGPUProgram {
5859
continue;
5960
}
6061
61-
float dyValue = getDy(b, yR, yC, d2);
62-
float xValue = getX(b, xR, xC, d1);
63-
dotProd += (xValue * dyValue);
62+
if (${isChannelsLast}) {
63+
float dyValue = getDy(b, yR, yC, d2);
64+
float xValue = getX(b, xR, xC, d1);
65+
dotProd += (xValue * dyValue);
66+
} else {
67+
float dyValue = getDy(b, d2, yR, yC);
68+
float xValue = getX(b, d1, xR, xC);
69+
dotProd += (xValue * dyValue);
70+
}
71+
6472
}
6573
}
6674
}
@@ -82,19 +90,24 @@ export class Conv2DDerInputProgram implements GPGPUProgram {
8290
const filterWidth = convInfo.filterWidth;
8391
const strideHeight = convInfo.strideHeight;
8492
const strideWidth = convInfo.strideWidth;
93+
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
8594

8695
const padTop = filterHeight - 1 - convInfo.padInfo.top;
8796
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
8897

98+
const rowDim = isChannelsLast ? 1 : 2;
99+
const colDim = isChannelsLast ? 2 : 3;
100+
const channelDim = isChannelsLast ? 3 : 1;
101+
89102
this.userCode = `
90103
const ivec2 pads = ivec2(${padTop}, ${padLeft});
91104
92105
void main() {
93106
ivec4 coords = getOutputCoords();
94107
int batch = coords[0];
95-
int d1 = coords[3];
108+
int d1 = coords[${channelDim}];
96109
97-
ivec2 dyCorner = coords.yz - pads;
110+
ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;
98111
int dyRCorner = dyCorner.x;
99112
int dyCCorner = dyCorner.y;
100113
@@ -123,9 +136,17 @@ export class Conv2DDerInputProgram implements GPGPUProgram {
123136
int wCPerm = ${filterWidth} - 1 - wC;
124137
125138
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
126-
float xValue = getDy(batch, idyR, idyC, d2);
127-
float wValue = getW(wRPerm, wCPerm, d1, d2);
128-
dotProd += xValue * wValue;
139+
140+
if (${isChannelsLast}) {
141+
float xValue = getDy(batch, idyR, idyC, d2);
142+
float wValue = getW(wRPerm, wCPerm, d1, d2);
143+
dotProd += xValue * wValue;
144+
} else {
145+
float xValue = getDy(batch, d2, idyR, idyC);
146+
float wValue = getW(wRPerm, wCPerm, d1, d2);
147+
dotProd += xValue * wValue;
148+
}
149+
129150
}
130151
}
131152
}

0 commit comments

Comments
 (0)