@@ -1558,40 +1558,52 @@ export class MathBackendCPU implements KernelBackend {
1558
1558
const dilationWidth = convInfo . dilationWidth ;
1559
1559
const padLeft = convInfo . padInfo . left ;
1560
1560
const padTop = convInfo . padInfo . top ;
1561
+ const isChannelsLast = convInfo . dataFormat === 'channelsLast' ;
1562
+
1561
1563
const y = ops . buffer ( convInfo . outShape , x . dtype as 'float32' ) ;
1562
1564
1565
+ const xBatchStride = x . strides [ 0 ] ;
1566
+ const xRowStride = isChannelsLast ? x . strides [ 1 ] : x . strides [ 2 ] ;
1567
+ const xColStride = isChannelsLast ? x . strides [ 2 ] : 1 ;
1568
+ const xChannelStride = isChannelsLast ? 1 : x . strides [ 1 ] ;
1569
+ const yBatchStride = y . strides [ 0 ] ;
1570
+ const yRowStride = isChannelsLast ? y . strides [ 1 ] : y . strides [ 2 ] ;
1571
+ const yColStride = isChannelsLast ? y . strides [ 2 ] : 1 ;
1572
+ const yChannelStride = isChannelsLast ? 1 : y . strides [ 1 ] ;
1573
+
1563
1574
const xVals = this . readSync ( x . dataId ) as TypedArray ;
1564
1575
const wVals = this . readSync ( filter . dataId ) as TypedArray ;
1565
1576
const yVals = y . values ;
1566
1577
1567
1578
for ( let b = 0 ; b < convInfo . batchSize ; ++ b ) {
1568
- const xOffset1 = b * x . strides [ 0 ] ;
1569
- const yOffset1 = b * y . strides [ 0 ] ;
1579
+ const xOffset1 = b * xBatchStride ;
1580
+ const yOffset1 = b * yBatchStride ;
1570
1581
for ( let yR = 0 ; yR < convInfo . outHeight ; ++ yR ) {
1571
- const yOffset2 = yOffset1 + yR * y . strides [ 1 ] ;
1582
+ const yOffset2 = yOffset1 + yR * yRowStride ;
1572
1583
const xRCorner = yR * convInfo . strideHeight - padTop ;
1573
1584
for ( let wR = 0 ; wR < filterHeight ; wR ++ ) {
1574
1585
const xR = xRCorner + wR * dilationHeight ;
1575
1586
if ( xR < 0 || xR >= convInfo . inHeight ) {
1576
1587
continue ;
1577
1588
}
1578
1589
const wOffset1 = wR * filter . strides [ 0 ] ;
1579
- const xOffset2 = xOffset1 + xR * x . strides [ 1 ] ;
1590
+ const xOffset2 = xOffset1 + xR * xRowStride ;
1580
1591
for ( let yC = 0 ; yC < convInfo . outWidth ; ++ yC ) {
1581
- const yOffset3 = yOffset2 + yC * convInfo . outChannels ;
1592
+ const yOffset3 = yOffset2 + yC * yColStride ;
1582
1593
const xCCorner = yC * convInfo . strideWidth - padLeft ;
1583
1594
for ( let wC = 0 ; wC < filterWidth ; wC ++ ) {
1584
1595
const xC = xCCorner + wC * dilationWidth ;
1585
1596
if ( xC < 0 || xC >= convInfo . inWidth ) {
1586
1597
continue ;
1587
1598
}
1588
1599
const wOffset2 = wOffset1 + wC * filter . strides [ 1 ] ;
1589
- const xOffset3 = xOffset2 + xC * convInfo . inChannels ;
1600
+ const xOffset3 = xOffset2 + xC * xColStride ;
1590
1601
let wOffset3 = wOffset2 ;
1591
1602
for ( let d1 = 0 ; d1 < convInfo . inChannels ; ++ d1 ) {
1592
- const xVal = xVals [ xOffset3 + d1 ] ;
1603
+ const xVal = xVals [ xOffset3 + d1 * xChannelStride ] ;
1593
1604
for ( let d2 = 0 ; d2 < convInfo . outChannels ; ++ d2 ) {
1594
- yVals [ yOffset3 + d2 ] += xVal * wVals [ wOffset3 + d2 ] ;
1605
+ yVals [ yOffset3 + d2 * yChannelStride ] +=
1606
+ xVal * wVals [ wOffset3 + d2 ] ;
1595
1607
}
1596
1608
wOffset3 += convInfo . outChannels ;
1597
1609
}
@@ -1677,9 +1689,7 @@ export class MathBackendCPU implements KernelBackend {
1677
1689
1678
1690
const dx = ops . buffer < Rank . R4 > ( convInfo . inShape , 'float32' ) ;
1679
1691
const dxValues = dx . values ;
1680
- const [ dxS0 , dxS1 , dxS2 ] = dx . strides ;
1681
1692
const dyValues = this . readSync ( dy . dataId ) as TypedArray ;
1682
- const [ dyS0 , dyS1 , dyS2 ] = dy . strides ;
1683
1693
const fltValues = this . readSync ( filter . dataId ) as TypedArray ;
1684
1694
const [ fltS0 , fltS1 , fltS2 ] = filter . strides ;
1685
1695
const {
@@ -1693,11 +1703,22 @@ export class MathBackendCPU implements KernelBackend {
1693
1703
outHeight,
1694
1704
outWidth,
1695
1705
strideHeight,
1696
- strideWidth
1706
+ strideWidth,
1707
+ dataFormat
1697
1708
} = convInfo ;
1698
1709
const topPad = filterHeight - 1 - convInfo . padInfo . top ;
1699
1710
const leftPad = filterWidth - 1 - convInfo . padInfo . left ;
1700
1711
1712
+ const isChannelsLast = dataFormat === 'channelsLast' ;
1713
+ const xBatchStride = dx . strides [ 0 ] ;
1714
+ const xRowStride = isChannelsLast ? dx . strides [ 1 ] : dx . strides [ 2 ] ;
1715
+ const xColStride = isChannelsLast ? dx . strides [ 2 ] : 1 ;
1716
+ const xChannelStride = isChannelsLast ? 1 : dx . strides [ 1 ] ;
1717
+ const yBatchStride = dy . strides [ 0 ] ;
1718
+ const yRowStride = isChannelsLast ? dy . strides [ 1 ] : dy . strides [ 2 ] ;
1719
+ const yColStride = isChannelsLast ? dy . strides [ 2 ] : 1 ;
1720
+ const yChannelStride = isChannelsLast ? 1 : dy . strides [ 1 ] ;
1721
+
1701
1722
for ( let b = 0 ; b < batchSize ; ++ b ) {
1702
1723
for ( let d1 = 0 ; d1 < inChannels ; ++ d1 ) {
1703
1724
for ( let xR = 0 ; xR < inHeight ; ++ xR ) {
@@ -1718,18 +1739,21 @@ export class MathBackendCPU implements KernelBackend {
1718
1739
1719
1740
for ( let yC = xCMin ; yC < yCMax ; ++ yC ) {
1720
1741
const wC = yC * strideWidth - xCCorner ;
1721
- const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC ;
1742
+ const dyOffset =
1743
+ yBatchStride * b + yRowStride * yR + yColStride * yC ;
1722
1744
const fltOffset = fltS0 * ( filterHeight - 1 - wR ) +
1723
1745
fltS1 * ( filterWidth - 1 - wC ) + fltS2 * d1 ;
1724
1746
1725
1747
for ( let d2 = 0 ; d2 < outChannels ; ++ d2 ) {
1726
- const pixel = dyValues [ dyOffset + d2 ] ;
1748
+ const pixel = dyValues [ dyOffset + yChannelStride * d2 ] ;
1727
1749
const weight = fltValues [ fltOffset + d2 ] ;
1728
1750
dotProd += pixel * weight ;
1729
1751
}
1730
1752
}
1731
1753
}
1732
- dxValues [ dxS0 * b + dxS1 * xR + dxS2 * xC + d1 ] = dotProd ;
1754
+ const dxOffset = xBatchStride * b + xRowStride * xR +
1755
+ xColStride * xC + xChannelStride * d1 ;
1756
+ dxValues [ dxOffset ] = dotProd ;
1733
1757
}
1734
1758
}
1735
1759
}
@@ -1829,6 +1853,7 @@ export class MathBackendCPU implements KernelBackend {
1829
1853
const strideWidth = convInfo . strideWidth ;
1830
1854
const filterHeight = convInfo . filterHeight ;
1831
1855
const filterWidth = convInfo . filterWidth ;
1856
+ const isChannelsLast = convInfo . dataFormat === 'channelsLast' ;
1832
1857
const dW = ops . buffer < Rank . R4 > ( convInfo . filterShape , 'float32' ) ;
1833
1858
1834
1859
const leftPad = convInfo . padInfo . left ;
@@ -1854,7 +1879,13 @@ export class MathBackendCPU implements KernelBackend {
1854
1879
const xR = wR + yR * strideHeight - topPad ;
1855
1880
for ( let yC = yCMin ; yC < yCMax ; ++ yC ) {
1856
1881
const xC = wC + yC * strideWidth - leftPad ;
1857
- dotProd += xBuf . get ( b , xR , xC , d1 ) * dyBuf . get ( b , yR , yC , d2 ) ;
1882
+ if ( isChannelsLast ) {
1883
+ dotProd +=
1884
+ xBuf . get ( b , xR , xC , d1 ) * dyBuf . get ( b , yR , yC , d2 ) ;
1885
+ } else {
1886
+ dotProd +=
1887
+ xBuf . get ( b , d1 , xR , xC ) * dyBuf . get ( b , d2 , yR , yC ) ;
1888
+ }
1858
1889
}
1859
1890
}
1860
1891
}
0 commit comments