@@ -22,15 +22,15 @@ import * as array_ops_util from '../ops/array_ops_util';
22
22
import * as axis_util from '../ops/axis_util' ;
23
23
import * as broadcast_util from '../ops/broadcast_util' ;
24
24
import * as concat_util from '../ops/concat_util' ;
25
- import { Conv2DInfo } from '../ops/conv_util' ;
25
+ import { Conv2DInfo , Conv3DInfo } from '../ops/conv_util' ;
26
26
import * as erf_util from '../ops/erf_util' ;
27
27
import * as gather_nd_util from '../ops/gather_nd_util' ;
28
28
import * as ops from '../ops/ops' ;
29
29
import { buffer , scalar , tensor , tensor3d , tensor4d } from '../ops/ops' ;
30
30
import * as scatter_nd_util from '../ops/scatter_nd_util' ;
31
31
import * as selu_util from '../ops/selu_util' ;
32
32
import { getStridedSlicedInfo } from '../ops/slice_util' ;
33
- import { DataId , Scalar , setTensorTracker , Tensor , Tensor1D , Tensor2D , Tensor3D , Tensor4D , TensorBuffer } from '../tensor' ;
33
+ import { DataId , Scalar , setTensorTracker , Tensor , Tensor1D , Tensor2D , Tensor3D , Tensor4D , Tensor5D , TensorBuffer } from '../tensor' ;
34
34
import { DataType , DataTypeMap , DataValues , NumericDataType , Rank , ShapeMap , TypedArray , upcastType } from '../types' ;
35
35
import * as util from '../util' ;
36
36
import { now } from '../util' ;
@@ -1420,6 +1420,74 @@ export class MathBackendCPU implements KernelBackend {
1420
1420
return y . toTensor ( ) as Tensor4D ;
1421
1421
}
1422
1422
1423
+ conv3d ( x : Tensor5D , filter : Tensor5D , convInfo : Conv3DInfo ) : Tensor5D {
1424
+ const filterDepth = convInfo . filterDepth ;
1425
+ const filterHeight = convInfo . filterHeight ;
1426
+ const filterWidth = convInfo . filterWidth ;
1427
+ const dilationDepth = convInfo . dilationDepth ;
1428
+ const dilationHeight = convInfo . dilationHeight ;
1429
+ const dilationWidth = convInfo . dilationWidth ;
1430
+ const padFront = convInfo . padInfo . front ;
1431
+ const padLeft = convInfo . padInfo . left ;
1432
+ const padTop = convInfo . padInfo . top ;
1433
+ const y = ops . buffer < Rank . R5 > ( convInfo . outShape , x . dtype as 'float32' ) ;
1434
+
1435
+ const xVals = x . dataSync ( ) ;
1436
+ const wVals = filter . dataSync ( ) ;
1437
+ const yVals = y . values ;
1438
+
1439
+ for ( let b = 0 ; b < convInfo . batchSize ; ++ b ) {
1440
+ const xOffset1 = b * x . strides [ 0 ] ;
1441
+ const yOffset1 = b * y . strides [ 0 ] ;
1442
+ for ( let yF = 0 ; yF < convInfo . outDepth ; ++ yF ) {
1443
+ const yOffset2 = yOffset1 + yF * y . strides [ 1 ] ;
1444
+ const xFCorner = yF * convInfo . strideDepth - padFront ;
1445
+ for ( let wF = 0 ; wF < filterDepth ; wF ++ ) {
1446
+ const xF = xFCorner + wF * dilationDepth ;
1447
+ if ( xF < 0 || xF >= convInfo . inDepth ) {
1448
+ continue ;
1449
+ }
1450
+ const wOffset1 = wF * filter . strides [ 0 ] ;
1451
+ const xOffset2 = xOffset1 + xF * x . strides [ 1 ] ;
1452
+
1453
+ for ( let yR = 0 ; yR < convInfo . outHeight ; ++ yR ) {
1454
+ const yOffset3 = yOffset2 + yR * y . strides [ 2 ] ;
1455
+ const xRCorner = yR * convInfo . strideHeight - padTop ;
1456
+ for ( let wR = 0 ; wR < filterHeight ; wR ++ ) {
1457
+ const xR = xRCorner + wR * dilationHeight ;
1458
+ if ( xR < 0 || xR >= convInfo . inHeight ) {
1459
+ continue ;
1460
+ }
1461
+ const wOffset2 = wOffset1 + wR * filter . strides [ 1 ] ;
1462
+ const xOffset3 = xOffset2 + xR * x . strides [ 2 ] ;
1463
+ for ( let yC = 0 ; yC < convInfo . outWidth ; ++ yC ) {
1464
+ const yOffset4 = yOffset3 + yC * convInfo . outChannels ;
1465
+ const xCCorner = yC * convInfo . strideWidth - padLeft ;
1466
+ for ( let wC = 0 ; wC < filterWidth ; wC ++ ) {
1467
+ const xC = xCCorner + wC * dilationWidth ;
1468
+ if ( xC < 0 || xC >= convInfo . inWidth ) {
1469
+ continue ;
1470
+ }
1471
+ const wOffset3 = wOffset2 + wC * filter . strides [ 2 ] ;
1472
+ const xOffset4 = xOffset3 + xC * convInfo . inChannels ;
1473
+ let wOffset4 = wOffset3 ;
1474
+ for ( let d1 = 0 ; d1 < convInfo . inChannels ; ++ d1 ) {
1475
+ const xVal = xVals [ xOffset4 + d1 ] ;
1476
+ for ( let d2 = 0 ; d2 < convInfo . outChannels ; ++ d2 ) {
1477
+ yVals [ yOffset4 + d2 ] += xVal * wVals [ wOffset4 + d2 ] ;
1478
+ }
1479
+ wOffset4 += convInfo . outChannels ;
1480
+ }
1481
+ }
1482
+ }
1483
+ }
1484
+ }
1485
+ }
1486
+ }
1487
+ }
1488
+ return y . toTensor ( ) ;
1489
+ }
1490
+
1423
1491
conv2dDerInput ( dy : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) :
1424
1492
Tensor4D {
1425
1493
this . assertNotComplex ( [ dy , filter ] , 'conv2dDerInput' ) ;
@@ -1486,6 +1554,91 @@ export class MathBackendCPU implements KernelBackend {
1486
1554
return dx . toTensor ( ) ;
1487
1555
}
1488
1556
1557
+ conv3dDerInput ( dy : Tensor5D , filter : Tensor5D , convInfo : Conv3DInfo ) :
1558
+ Tensor5D {
1559
+ const dx = ops . buffer < Rank . R5 > ( convInfo . inShape , 'float32' ) ;
1560
+ const dxValues = dx . values ;
1561
+ const [ dxS0 , dxS1 , dxS2 , dxS3 ] = dx . strides ;
1562
+ const dyValues = dy . dataSync ( ) ;
1563
+ const [ dyS0 , dyS1 , dyS2 , dyS3 ] = dy . strides ;
1564
+ const fltValues = filter . dataSync ( ) ;
1565
+ const [ fltS0 , fltS1 , fltS2 , fltS3 ] = filter . strides ;
1566
+ const {
1567
+ batchSize,
1568
+ filterDepth,
1569
+ filterHeight,
1570
+ filterWidth,
1571
+ inChannels,
1572
+ inDepth,
1573
+ inHeight,
1574
+ inWidth,
1575
+ outChannels,
1576
+ outDepth,
1577
+ outHeight,
1578
+ outWidth,
1579
+ strideDepth,
1580
+ strideHeight,
1581
+ strideWidth
1582
+ } = convInfo ;
1583
+ const frontPad = filterDepth - 1 - convInfo . padInfo . front ;
1584
+ const topPad = filterHeight - 1 - convInfo . padInfo . top ;
1585
+ const leftPad = filterWidth - 1 - convInfo . padInfo . left ;
1586
+
1587
+ for ( let b = 0 ; b < batchSize ; ++ b ) {
1588
+ for ( let d1 = 0 ; d1 < inChannels ; ++ d1 ) {
1589
+ // Frames of depth
1590
+ for ( let xF = 0 ; xF < inDepth ; ++ xF ) {
1591
+ const xFCorner = xF - frontPad ;
1592
+ const xFMin = Math . max ( 0 , Math . ceil ( xFCorner / strideDepth ) ) ;
1593
+ const yFMax =
1594
+ Math . min ( outDepth , ( filterDepth + xFCorner ) / strideDepth ) ;
1595
+
1596
+ // Rows as per standard 2d matrix notation
1597
+ for ( let xR = 0 ; xR < inHeight ; ++ xR ) {
1598
+ const xRCorner = xR - topPad ;
1599
+ const xRMin = Math . max ( 0 , Math . ceil ( xRCorner / strideHeight ) ) ;
1600
+ const yRMax =
1601
+ Math . min ( outHeight , ( filterHeight + xRCorner ) / strideHeight ) ;
1602
+ // Columns as per standard 2d matrix notation
1603
+ for ( let xC = 0 ; xC < inWidth ; ++ xC ) {
1604
+ const xCCorner = xC - leftPad ;
1605
+ const xCMin = Math . max ( 0 , Math . ceil ( xCCorner / strideWidth ) ) ;
1606
+ const yCMax =
1607
+ Math . min ( outWidth , ( filterWidth + xCCorner ) / strideWidth ) ;
1608
+
1609
+ let dotProd = 0 ;
1610
+ for ( let yF = xFMin ; yF < yFMax ; ++ yF ) {
1611
+ const wF = yF * strideDepth - xFCorner ;
1612
+
1613
+ for ( let yR = xRMin ; yR < yRMax ; ++ yR ) {
1614
+ const wR = yR * strideHeight - xRCorner ;
1615
+
1616
+ for ( let yC = xCMin ; yC < yCMax ; ++ yC ) {
1617
+ const wC = yC * strideWidth - xCCorner ;
1618
+ const dyOffset =
1619
+ dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC ;
1620
+ const fltOffset = fltS0 * ( filterDepth - 1 - wF ) +
1621
+ fltS1 * ( filterHeight - 1 - wR ) +
1622
+ fltS2 * ( filterWidth - 1 - wC ) + fltS3 * d1 ;
1623
+
1624
+ for ( let d2 = 0 ; d2 < outChannels ; ++ d2 ) {
1625
+ const pixel = dyValues [ dyOffset + d2 ] ;
1626
+ const weight = fltValues [ fltOffset + d2 ] ;
1627
+ dotProd += pixel * weight ;
1628
+ }
1629
+ }
1630
+ }
1631
+ }
1632
+ dxValues [ dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1 ] =
1633
+ dotProd ;
1634
+ }
1635
+ }
1636
+ }
1637
+ }
1638
+ }
1639
+ return dx . toTensor ( ) ;
1640
+ }
1641
+
1489
1642
conv2dDerFilter ( x : Tensor4D , dy : Tensor4D , convInfo : Conv2DInfo ) : Tensor4D {
1490
1643
this . assertNotComplex ( [ x , dy ] , 'conv2dDerFilter' ) ;
1491
1644
@@ -1529,6 +1682,85 @@ export class MathBackendCPU implements KernelBackend {
1529
1682
return dW . toTensor ( ) ;
1530
1683
}
1531
1684
1685
+ conv3dDerFilter ( x : Tensor5D , dy : Tensor5D , convInfo : Conv3DInfo ) : Tensor5D {
1686
+ const strideDepth = convInfo . strideDepth ;
1687
+ const strideHeight = convInfo . strideHeight ;
1688
+ const strideWidth = convInfo . strideWidth ;
1689
+ const filterDepth = convInfo . filterDepth ;
1690
+ const filterHeight = convInfo . filterHeight ;
1691
+ const filterWidth = convInfo . filterWidth ;
1692
+
1693
+ const dw = ops . buffer < Rank . R5 > ( convInfo . filterShape , 'float32' ) ;
1694
+ const dwValues = dw . values ;
1695
+ const [ dwS0 , dwS1 , dwS2 , dwS3 ] = dw . strides ;
1696
+ const dyValues = dy . dataSync ( ) ;
1697
+ const [ dyS0 , dyS1 , dyS2 , dyS3 ] = dy . strides ;
1698
+ const xValues = x . dataSync ( ) ;
1699
+ const [ xS0 , xS1 , xS2 , xS3 ] = x . strides ;
1700
+
1701
+ const frontPad = convInfo . padInfo . front ;
1702
+ const leftPad = convInfo . padInfo . left ;
1703
+ const topPad = convInfo . padInfo . top ;
1704
+
1705
+ for ( let wF = 0 ; wF < filterDepth ; ++ wF ) {
1706
+ const yFMin = Math . max ( 0 , Math . ceil ( ( frontPad - wF ) / strideDepth ) ) ;
1707
+ const yFMax = Math . min (
1708
+ convInfo . outDepth , ( convInfo . inDepth + frontPad - wF ) / strideDepth ) ;
1709
+ const wOffset1 = wF * dwS0 ;
1710
+
1711
+ for ( let wR = 0 ; wR < filterHeight ; ++ wR ) {
1712
+ const yRMin = Math . max ( 0 , Math . ceil ( ( topPad - wR ) / strideHeight ) ) ;
1713
+ const yRMax = Math . min (
1714
+ convInfo . outHeight ,
1715
+ ( convInfo . inHeight + topPad - wR ) / strideHeight ) ;
1716
+ const wOffset2 = wR * dwS1 + wOffset1 ;
1717
+
1718
+ for ( let wC = 0 ; wC < filterWidth ; ++ wC ) {
1719
+ const yCMin = Math . max ( 0 , Math . ceil ( ( leftPad - wC ) / strideWidth ) ) ;
1720
+ const yCMax = Math . min (
1721
+ convInfo . outWidth ,
1722
+ ( convInfo . inWidth + leftPad - wC ) / strideWidth ) ;
1723
+ const wOffset3 = wC * dwS2 + wOffset2 ;
1724
+
1725
+ for ( let d1 = 0 ; d1 < convInfo . inChannels ; ++ d1 ) {
1726
+ const wOffset4 = d1 * dwS3 + wOffset3 ;
1727
+
1728
+ for ( let d2 = 0 ; d2 < convInfo . outChannels ; ++ d2 ) {
1729
+ let dotProd = 0 ;
1730
+ for ( let b = 0 ; b < convInfo . batchSize ; ++ b ) {
1731
+ const xOffset1 = b * xS0 ;
1732
+ const yOffset1 = b * dyS0 ;
1733
+
1734
+ for ( let yF = yFMin ; yF < yFMax ; ++ yF ) {
1735
+ const xF = wF + yF * strideDepth - frontPad ;
1736
+ const xOffset2 = xF * xS1 + xOffset1 ;
1737
+ const yOffset2 = yF * dyS1 + yOffset1 ;
1738
+
1739
+ for ( let yR = yRMin ; yR < yRMax ; ++ yR ) {
1740
+ const xR = wR + yR * strideHeight - topPad ;
1741
+ const xOffset3 = xR * xS2 + xOffset2 ;
1742
+ const yOffset3 = yR * dyS2 + yOffset2 ;
1743
+
1744
+ for ( let yC = yCMin ; yC < yCMax ; ++ yC ) {
1745
+ const xC = wC + yC * strideWidth - leftPad ;
1746
+ const xOffset4 = xC * xS3 + xOffset3 ;
1747
+ const yOffset4 = yC * dyS3 + yOffset3 ;
1748
+
1749
+ dotProd +=
1750
+ xValues [ xOffset4 + d1 ] * dyValues [ yOffset4 + d2 ] ;
1751
+ }
1752
+ }
1753
+ }
1754
+ }
1755
+ dwValues [ wOffset4 + d2 ] = dotProd ;
1756
+ }
1757
+ }
1758
+ }
1759
+ }
1760
+ }
1761
+ return dw . toTensor ( ) ;
1762
+ }
1763
+
1532
1764
depthwiseConv2D ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) :
1533
1765
Tensor4D {
1534
1766
this . assertNotComplex ( [ x , filter ] , 'depthwiseConv2D' ) ;
0 commit comments