Skip to content

Commit c6a71db

Browse files
zboldygadsmilkov
authored andcommitted
Add Conv3d (tensorflow#1238)
This is an implementation of conv3d, as discussed in: tensorflow/tfjs#470 FEATURE
1 parent 413e864 commit c6a71db

11 files changed

+1827
-15
lines changed

src/kernels/backend.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
* =============================================================================
1616
*/
1717

18-
import {Conv2DInfo} from '../ops/conv_util';
19-
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
18+
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
19+
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
2020
import {DataType, DataValues, Rank, ShapeMap} from '../types';
2121

2222
// Required information for all backends.
@@ -403,7 +403,16 @@ export class KernelBackend implements TensorStorage, BackendTimer {
403403
Tensor4D {
404404
throw new Error('Not yet implemented');
405405
}
406-
406+
conv3d(x: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
407+
throw new Error('Not yet implemented');
408+
}
409+
conv3dDerInput(dy: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo):
410+
Tensor5D {
411+
throw new Error('Not yet implemented');
412+
}
413+
conv3dDerFilter(x: Tensor5D, dY: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
414+
throw new Error('Not yet implemented');
415+
}
407416
maxPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
408417
throw new Error('Not yet implemented');
409418
}

src/kernels/backend_cpu.ts

Lines changed: 234 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ import * as array_ops_util from '../ops/array_ops_util';
2222
import * as axis_util from '../ops/axis_util';
2323
import * as broadcast_util from '../ops/broadcast_util';
2424
import * as concat_util from '../ops/concat_util';
25-
import {Conv2DInfo} from '../ops/conv_util';
25+
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
2626
import * as erf_util from '../ops/erf_util';
2727
import * as gather_nd_util from '../ops/gather_nd_util';
2828
import * as ops from '../ops/ops';
2929
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../ops/ops';
3030
import * as scatter_nd_util from '../ops/scatter_nd_util';
3131
import * as selu_util from '../ops/selu_util';
3232
import {getStridedSlicedInfo} from '../ops/slice_util';
33-
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer} from '../tensor';
33+
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../tensor';
3434
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../types';
3535
import * as util from '../util';
3636
import {now} from '../util';
@@ -1420,6 +1420,74 @@ export class MathBackendCPU implements KernelBackend {
14201420
return y.toTensor() as Tensor4D;
14211421
}
14221422

1423+
conv3d(x: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
1424+
const filterDepth = convInfo.filterDepth;
1425+
const filterHeight = convInfo.filterHeight;
1426+
const filterWidth = convInfo.filterWidth;
1427+
const dilationDepth = convInfo.dilationDepth;
1428+
const dilationHeight = convInfo.dilationHeight;
1429+
const dilationWidth = convInfo.dilationWidth;
1430+
const padFront = convInfo.padInfo.front;
1431+
const padLeft = convInfo.padInfo.left;
1432+
const padTop = convInfo.padInfo.top;
1433+
const y = ops.buffer<Rank.R5>(convInfo.outShape, x.dtype as 'float32');
1434+
1435+
const xVals = x.dataSync();
1436+
const wVals = filter.dataSync();
1437+
const yVals = y.values;
1438+
1439+
for (let b = 0; b < convInfo.batchSize; ++b) {
1440+
const xOffset1 = b * x.strides[0];
1441+
const yOffset1 = b * y.strides[0];
1442+
for (let yF = 0; yF < convInfo.outDepth; ++yF) {
1443+
const yOffset2 = yOffset1 + yF * y.strides[1];
1444+
const xFCorner = yF * convInfo.strideDepth - padFront;
1445+
for (let wF = 0; wF < filterDepth; wF++) {
1446+
const xF = xFCorner + wF * dilationDepth;
1447+
if (xF < 0 || xF >= convInfo.inDepth) {
1448+
continue;
1449+
}
1450+
const wOffset1 = wF * filter.strides[0];
1451+
const xOffset2 = xOffset1 + xF * x.strides[1];
1452+
1453+
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
1454+
const yOffset3 = yOffset2 + yR * y.strides[2];
1455+
const xRCorner = yR * convInfo.strideHeight - padTop;
1456+
for (let wR = 0; wR < filterHeight; wR++) {
1457+
const xR = xRCorner + wR * dilationHeight;
1458+
if (xR < 0 || xR >= convInfo.inHeight) {
1459+
continue;
1460+
}
1461+
const wOffset2 = wOffset1 + wR * filter.strides[1];
1462+
const xOffset3 = xOffset2 + xR * x.strides[2];
1463+
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
1464+
const yOffset4 = yOffset3 + yC * convInfo.outChannels;
1465+
const xCCorner = yC * convInfo.strideWidth - padLeft;
1466+
for (let wC = 0; wC < filterWidth; wC++) {
1467+
const xC = xCCorner + wC * dilationWidth;
1468+
if (xC < 0 || xC >= convInfo.inWidth) {
1469+
continue;
1470+
}
1471+
const wOffset3 = wOffset2 + wC * filter.strides[2];
1472+
const xOffset4 = xOffset3 + xC * convInfo.inChannels;
1473+
let wOffset4 = wOffset3;
1474+
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
1475+
const xVal = xVals[xOffset4 + d1];
1476+
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
1477+
yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
1478+
}
1479+
wOffset4 += convInfo.outChannels;
1480+
}
1481+
}
1482+
}
1483+
}
1484+
}
1485+
}
1486+
}
1487+
}
1488+
return y.toTensor();
1489+
}
1490+
14231491
conv2dDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
14241492
Tensor4D {
14251493
this.assertNotComplex([dy, filter], 'conv2dDerInput');
@@ -1486,6 +1554,91 @@ export class MathBackendCPU implements KernelBackend {
14861554
return dx.toTensor();
14871555
}
14881556

1557+
conv3dDerInput(dy: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo):
1558+
Tensor5D {
1559+
const dx = ops.buffer<Rank.R5>(convInfo.inShape, 'float32');
1560+
const dxValues = dx.values;
1561+
const [dxS0, dxS1, dxS2, dxS3] = dx.strides;
1562+
const dyValues = dy.dataSync();
1563+
const [dyS0, dyS1, dyS2, dyS3] = dy.strides;
1564+
const fltValues = filter.dataSync();
1565+
const [fltS0, fltS1, fltS2, fltS3] = filter.strides;
1566+
const {
1567+
batchSize,
1568+
filterDepth,
1569+
filterHeight,
1570+
filterWidth,
1571+
inChannels,
1572+
inDepth,
1573+
inHeight,
1574+
inWidth,
1575+
outChannels,
1576+
outDepth,
1577+
outHeight,
1578+
outWidth,
1579+
strideDepth,
1580+
strideHeight,
1581+
strideWidth
1582+
} = convInfo;
1583+
const frontPad = filterDepth - 1 - convInfo.padInfo.front;
1584+
const topPad = filterHeight - 1 - convInfo.padInfo.top;
1585+
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
1586+
1587+
for (let b = 0; b < batchSize; ++b) {
1588+
for (let d1 = 0; d1 < inChannels; ++d1) {
1589+
// Frames of depth
1590+
for (let xF = 0; xF < inDepth; ++xF) {
1591+
const xFCorner = xF - frontPad;
1592+
const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
1593+
const yFMax =
1594+
Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
1595+
1596+
// Rows as per standard 2d matrix notation
1597+
for (let xR = 0; xR < inHeight; ++xR) {
1598+
const xRCorner = xR - topPad;
1599+
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
1600+
const yRMax =
1601+
Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
1602+
// Columns as per standard 2d matrix notation
1603+
for (let xC = 0; xC < inWidth; ++xC) {
1604+
const xCCorner = xC - leftPad;
1605+
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
1606+
const yCMax =
1607+
Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
1608+
1609+
let dotProd = 0;
1610+
for (let yF = xFMin; yF < yFMax; ++yF) {
1611+
const wF = yF * strideDepth - xFCorner;
1612+
1613+
for (let yR = xRMin; yR < yRMax; ++yR) {
1614+
const wR = yR * strideHeight - xRCorner;
1615+
1616+
for (let yC = xCMin; yC < yCMax; ++yC) {
1617+
const wC = yC * strideWidth - xCCorner;
1618+
const dyOffset =
1619+
dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
1620+
const fltOffset = fltS0 * (filterDepth - 1 - wF) +
1621+
fltS1 * (filterHeight - 1 - wR) +
1622+
fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
1623+
1624+
for (let d2 = 0; d2 < outChannels; ++d2) {
1625+
const pixel = dyValues[dyOffset + d2];
1626+
const weight = fltValues[fltOffset + d2];
1627+
dotProd += pixel * weight;
1628+
}
1629+
}
1630+
}
1631+
}
1632+
dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =
1633+
dotProd;
1634+
}
1635+
}
1636+
}
1637+
}
1638+
}
1639+
return dx.toTensor();
1640+
}
1641+
14891642
conv2dDerFilter(x: Tensor4D, dy: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
14901643
this.assertNotComplex([x, dy], 'conv2dDerFilter');
14911644

@@ -1529,6 +1682,85 @@ export class MathBackendCPU implements KernelBackend {
15291682
return dW.toTensor();
15301683
}
15311684

1685+
conv3dDerFilter(x: Tensor5D, dy: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
1686+
const strideDepth = convInfo.strideDepth;
1687+
const strideHeight = convInfo.strideHeight;
1688+
const strideWidth = convInfo.strideWidth;
1689+
const filterDepth = convInfo.filterDepth;
1690+
const filterHeight = convInfo.filterHeight;
1691+
const filterWidth = convInfo.filterWidth;
1692+
1693+
const dw = ops.buffer<Rank.R5>(convInfo.filterShape, 'float32');
1694+
const dwValues = dw.values;
1695+
const [dwS0, dwS1, dwS2, dwS3] = dw.strides;
1696+
const dyValues = dy.dataSync();
1697+
const [dyS0, dyS1, dyS2, dyS3] = dy.strides;
1698+
const xValues = x.dataSync();
1699+
const [xS0, xS1, xS2, xS3] = x.strides;
1700+
1701+
const frontPad = convInfo.padInfo.front;
1702+
const leftPad = convInfo.padInfo.left;
1703+
const topPad = convInfo.padInfo.top;
1704+
1705+
for (let wF = 0; wF < filterDepth; ++wF) {
1706+
const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
1707+
const yFMax = Math.min(
1708+
convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
1709+
const wOffset1 = wF * dwS0;
1710+
1711+
for (let wR = 0; wR < filterHeight; ++wR) {
1712+
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
1713+
const yRMax = Math.min(
1714+
convInfo.outHeight,
1715+
(convInfo.inHeight + topPad - wR) / strideHeight);
1716+
const wOffset2 = wR * dwS1 + wOffset1;
1717+
1718+
for (let wC = 0; wC < filterWidth; ++wC) {
1719+
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
1720+
const yCMax = Math.min(
1721+
convInfo.outWidth,
1722+
(convInfo.inWidth + leftPad - wC) / strideWidth);
1723+
const wOffset3 = wC * dwS2 + wOffset2;
1724+
1725+
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
1726+
const wOffset4 = d1 * dwS3 + wOffset3;
1727+
1728+
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
1729+
let dotProd = 0;
1730+
for (let b = 0; b < convInfo.batchSize; ++b) {
1731+
const xOffset1 = b * xS0;
1732+
const yOffset1 = b * dyS0;
1733+
1734+
for (let yF = yFMin; yF < yFMax; ++yF) {
1735+
const xF = wF + yF * strideDepth - frontPad;
1736+
const xOffset2 = xF * xS1 + xOffset1;
1737+
const yOffset2 = yF * dyS1 + yOffset1;
1738+
1739+
for (let yR = yRMin; yR < yRMax; ++yR) {
1740+
const xR = wR + yR * strideHeight - topPad;
1741+
const xOffset3 = xR * xS2 + xOffset2;
1742+
const yOffset3 = yR * dyS2 + yOffset2;
1743+
1744+
for (let yC = yCMin; yC < yCMax; ++yC) {
1745+
const xC = wC + yC * strideWidth - leftPad;
1746+
const xOffset4 = xC * xS3 + xOffset3;
1747+
const yOffset4 = yC * dyS3 + yOffset3;
1748+
1749+
dotProd +=
1750+
xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
1751+
}
1752+
}
1753+
}
1754+
}
1755+
dwValues[wOffset4 + d2] = dotProd;
1756+
}
1757+
}
1758+
}
1759+
}
1760+
}
1761+
return dw.toTensor();
1762+
}
1763+
15321764
depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
15331765
Tensor4D {
15341766
this.assertNotComplex([x, filter], 'depthwiseConv2D');

src/kernels/backend_webgl.ts

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ import {warn} from '../log';
2323
import * as array_ops_util from '../ops/array_ops_util';
2424
import * as axis_util from '../ops/axis_util';
2525
import {computeOutShape} from '../ops/concat_util';
26-
import {Conv2DInfo} from '../ops/conv_util';
26+
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
2727
import * as gather_nd_util from '../ops/gather_nd_util';
2828
import * as reduce_util from '../ops/reduce_util';
2929
import * as scatter_nd_util from '../ops/scatter_nd_util';
3030
import * as segment_util from '../ops/segment_util';
3131
import {getStridedSlicedInfo} from '../ops/slice_util';
3232
import {softmax} from '../ops/softmax';
3333
import {range, scalar, tensor} from '../ops/tensor_ops';
34-
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
34+
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
3535
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../types';
3636
import * as util from '../util';
3737
import {getTypedArrayFromDType, sizeFromShape} from '../util';
@@ -53,9 +53,9 @@ import {ClipProgram} from './webgl/clip_gpu';
5353
import {ClipPackedProgram} from './webgl/clip_packed_gpu';
5454
import {ComplexAbsProgram} from './webgl/complex_abs_gpu';
5555
import {ConcatProgram} from './webgl/concat_gpu';
56-
import {Conv2DDerFilterProgram, Conv2DDerInputProgram} from './webgl/conv_backprop_gpu';
56+
import {Conv2DDerFilterProgram, Conv2DDerInputProgram, Conv3DDerFilterProgram, Conv3DDerInputProgram} from './webgl/conv_backprop_gpu';
5757
import {DepthwiseConv2DDerFilterProgram, DepthwiseConv2DDerInputProgram} from './webgl/conv_backprop_gpu_depthwise';
58-
import {Conv2DProgram} from './webgl/conv_gpu';
58+
import {Conv2DProgram, Conv3DProgram} from './webgl/conv_gpu';
5959
import {DepthwiseConv2DProgram} from './webgl/conv_gpu_depthwise';
6060
import {DepthwiseConvPacked2DProgram} from './webgl/conv_packed_gpu_depthwise';
6161
import {CropAndResizeProgram} from './webgl/crop_and_resize_gpu';
@@ -1542,6 +1542,22 @@ export class MathBackendWebGL implements KernelBackend {
15421542
return this.compileAndRun(program, [x, dy]);
15431543
}
15441544

1545+
conv3d(x: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
1546+
const program = new Conv3DProgram(convInfo);
1547+
return this.compileAndRun(program, [x, filter]);
1548+
}
1549+
1550+
conv3dDerInput(dy: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo):
1551+
Tensor5D {
1552+
const program = new Conv3DDerInputProgram(convInfo);
1553+
return this.compileAndRun(program, [dy, filter]);
1554+
}
1555+
1556+
conv3dDerFilter(x: Tensor5D, dy: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
1557+
const program = new Conv3DDerFilterProgram(convInfo);
1558+
return this.compileAndRun(program, [x, dy]);
1559+
}
1560+
15451561
maxPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
15461562
const program = new Pool2DProgram(convInfo, 'max', false);
15471563
const output =

0 commit comments

Comments
 (0)