@@ -873,10 +873,12 @@ export class MathBackendWebGL implements KernelBackend {
873
873
874
874
const dtype = upcastType ( a . dtype , b . dtype ) ;
875
875
876
+ const hasBias = bias != null ;
877
+ const fusedActivation =
878
+ activation ? mapActivationToShaderProgram ( activation , true ) : null ;
876
879
const program = new MatMulPackedProgram (
877
880
a . shape , [ batch , outerShapeA , outerShapeB ] , transposeA , transposeB ,
878
- ! ! bias ,
879
- activation ? mapActivationToShaderProgram ( activation , true ) : null ) ;
881
+ hasBias , fusedActivation ) ;
880
882
const output =
881
883
this . makePackedTensor ( program . outputShape , dtype ) as Tensor3D ;
882
884
const inputs : TensorHandle [ ] = [ a , b ] ;
@@ -1815,15 +1817,18 @@ export class MathBackendWebGL implements KernelBackend {
1815
1817
return this . compileAndRun ( program , [ x ] ) as T ;
1816
1818
}
1817
1819
1818
- conv2dByMatMul ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) :
1819
- Tensor4D {
1820
+ private conv2dByMatMul (
1821
+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1822
+ activation ?: Activation ) : Tensor4D {
1820
1823
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
1821
1824
// result from 2D to 4D.
1822
1825
const xShape = x . shape ;
1823
1826
const xTexData = this . texData . get ( x . dataId ) ;
1824
1827
const sharedMatMulDim = convInfo . inChannels ;
1825
1828
const outerShapeX = xShape [ 0 ] * xShape [ 1 ] * xShape [ 2 ] ;
1826
1829
const outerShapeFilter = convInfo . outChannels ;
1830
+ const transposeA = false ;
1831
+ const transposeB = false ;
1827
1832
1828
1833
// TODO: Once reduction ops are packed, batchMatMul will always be packed
1829
1834
// and we can remove this condition.
@@ -1843,8 +1848,11 @@ export class MathBackendWebGL implements KernelBackend {
1843
1848
this . reshape (
1844
1849
filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) as
1845
1850
Tensor3D ;
1851
+
1846
1852
return this . reshape < Rank . R4 > (
1847
- this . batchMatMul ( xReshaped , filterReshaped , false , false ) ,
1853
+ this . fusedBatchMatMul (
1854
+ xReshaped , filterReshaped , transposeA , transposeB , bias ,
1855
+ activation ) ,
1848
1856
convInfo . outShape ) ;
1849
1857
}
1850
1858
@@ -1880,8 +1888,8 @@ export class MathBackendWebGL implements KernelBackend {
1880
1888
this . reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) as
1881
1889
Tensor3D ;
1882
1890
1883
- const pointwiseConv =
1884
- this . batchMatMul ( xReshaped , filterReshaped , false , false ) ;
1891
+ const pointwiseConv = this . fusedBatchMatMul (
1892
+ xReshaped , filterReshaped , transposeA , transposeB , bias , activation ) ;
1885
1893
const pointwiseConvTexData = this . texData . get ( pointwiseConv . dataId ) ;
1886
1894
util . assert (
1887
1895
pointwiseConvTexData . isPacked ,
@@ -1896,8 +1904,9 @@ export class MathBackendWebGL implements KernelBackend {
1896
1904
pointwiseConv . dtype , this ) as Tensor4D ;
1897
1905
}
1898
1906
1899
- conv2dWithIm2Row ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) :
1900
- Tensor4D {
1907
+ private conv2dWithIm2Row (
1908
+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1909
+ activation ?: Activation ) : Tensor4D {
1901
1910
// Rearranges conv2d input so each block to be convolved over forms the
1902
1911
// column of a new matrix with shape [filterWidth * filterHeight *
1903
1912
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
@@ -1915,6 +1924,8 @@ export class MathBackendWebGL implements KernelBackend {
1915
1924
const sharedDim = filterWidth * filterHeight * inChannels ;
1916
1925
const numCols = outHeight * outWidth ;
1917
1926
const x2ColShape = [ sharedDim , numCols ] ;
1927
+ const transposeA = true ;
1928
+ const transposeB = false ;
1918
1929
1919
1930
const xSqueezed = x . squeeze ( [ 0 ] ) ;
1920
1931
const w2Row = filter . reshape ( [ 1 , sharedDim , - 1 ] ) as Tensor3D ;
@@ -1926,14 +1937,46 @@ export class MathBackendWebGL implements KernelBackend {
1926
1937
1 , x2ColShape [ 0 ] , x2ColShape [ 1 ]
1927
1938
] ) as Tensor3D ;
1928
1939
1940
+ const hasBias = bias != null ;
1941
+ const fusedActivation =
1942
+ activation ? mapActivationToShaderProgram ( activation , true ) : null ;
1929
1943
const matmulProgram = new MatMulPackedProgram (
1930
- im2Col . shape , [ 1 , numCols , convInfo . outChannels ] , true , false ) ;
1931
- const product =
1932
- this . compileAndRun < Tensor4D > ( matmulProgram , [ im2Col , w2Row ] ) ;
1944
+ im2Col . shape , [ 1 , numCols , convInfo . outChannels ] , transposeA ,
1945
+ transposeB , hasBias , fusedActivation ) ;
1946
+ const inputs : TensorHandle [ ] = [ im2Col , w2Row ] ;
1947
+ if ( bias ) {
1948
+ inputs . push ( bias ) ;
1949
+ }
1950
+ const product = this . compileAndRun < Tensor4D > ( matmulProgram , inputs ) ;
1933
1951
1934
1952
return product . reshape ( [ 1 , outHeight , outWidth , convInfo . outChannels ] ) ;
1935
1953
}
1936
1954
1955
+ fusedConv2d (
1956
+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1957
+ activation ?: Activation ) : Tensor4D {
1958
+ if ( convInfo . filterHeight === 1 && convInfo . filterWidth === 1 &&
1959
+ convInfo . dilationHeight === 1 && convInfo . dilationWidth === 1 &&
1960
+ convInfo . strideHeight === 1 && convInfo . strideWidth === 1 &&
1961
+ ( convInfo . padInfo . type === 'SAME' ||
1962
+ convInfo . padInfo . type === 'VALID' ) ) {
1963
+ return this . conv2dByMatMul ( x , filter , convInfo , bias , activation ) ;
1964
+ }
1965
+ if ( ENV . getBool ( 'WEBGL_CONV_IM2COL' ) && x . shape [ 0 ] === 1 ) {
1966
+ return this . conv2dWithIm2Row ( x , filter , convInfo , bias , activation ) ;
1967
+ }
1968
+
1969
+ const hasBias = bias != null ;
1970
+ const fusedActivation =
1971
+ activation ? mapActivationToShaderProgram ( activation , false ) : null ;
1972
+ const program = new Conv2DProgram ( convInfo , hasBias , fusedActivation ) ;
1973
+ const inputs : TensorHandle [ ] = [ x , filter ] ;
1974
+ if ( bias ) {
1975
+ inputs . push ( bias ) ;
1976
+ }
1977
+ return this . compileAndRun ( program , inputs ) ;
1978
+ }
1979
+
1937
1980
conv2d ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) : Tensor4D {
1938
1981
if ( convInfo . filterHeight === 1 && convInfo . filterWidth === 1 &&
1939
1982
convInfo . dilationHeight === 1 && convInfo . dilationWidth === 1 &&
0 commit comments