@@ -28,7 +28,7 @@ import * as array_ops_util from '../../ops/array_ops_util';
28
28
import * as axis_util from '../../ops/axis_util' ;
29
29
import { computeOutShape } from '../../ops/concat_util' ;
30
30
import { Conv2DInfo , Conv3DInfo } from '../../ops/conv_util' ;
31
- import { Activation } from '../../ops/fused_util' ;
31
+ import { Activation , FusedBatchMatMulConfig } from '../../ops/fused_util' ;
32
32
import * as gather_nd_util from '../../ops/gather_nd_util' ;
33
33
import * as reduce_util from '../../ops/reduce_util' ;
34
34
import * as scatter_nd_util from '../../ops/scatter_nd_util' ;
@@ -174,6 +174,11 @@ function mapActivationToShaderProgram(
174
174
return unary_packed_op . RELU ;
175
175
}
176
176
return unary_op . RELU ;
177
+ } else if ( activation === 'prelu' ) {
178
+ if ( packed ) {
179
+ return binaryop_packed_gpu . PRELU ;
180
+ }
181
+ return binaryop_gpu . PRELU ;
177
182
}
178
183
throw new Error ( `Activation ${
179
184
activation } has not been implemented for the WebGL backend.`) ;
@@ -865,26 +870,30 @@ export class MathBackendWebGL implements KernelBackend {
865
870
}
866
871
867
872
fusedBatchMatMul (
868
- a : Tensor3D , b : Tensor3D , transposeA : boolean , transposeB : boolean ,
869
- bias ?: Tensor , activation ?: Activation ) : Tensor3D {
873
+ { a , b, transposeA, transposeB, bias , activation , preluActivationWeights } :
874
+ FusedBatchMatMulConfig ) : Tensor3D {
870
875
const outerShapeA = transposeA ? a . shape [ 2 ] : a . shape [ 1 ] ;
871
876
const outerShapeB = transposeB ? b . shape [ 1 ] : b . shape [ 2 ] ;
872
877
const [ batch , , ] = a . shape ;
873
878
874
879
const dtype = upcastType ( a . dtype , b . dtype ) ;
875
880
876
881
const hasBias = bias != null ;
882
+ const hasPreluActivationWeights = preluActivationWeights != null ;
877
883
const fusedActivation =
878
884
activation ? mapActivationToShaderProgram ( activation , true ) : null ;
879
885
const program = new MatMulPackedProgram (
880
886
a . shape , [ batch , outerShapeA , outerShapeB ] , transposeA , transposeB ,
881
- hasBias , fusedActivation ) ;
887
+ hasBias , fusedActivation , hasPreluActivationWeights ) ;
882
888
const output =
883
889
this . makePackedTensor ( program . outputShape , dtype ) as Tensor3D ;
884
890
const inputs : TensorHandle [ ] = [ a , b ] ;
885
891
if ( bias ) {
886
892
inputs . push ( bias ) ;
887
893
}
894
+ if ( preluActivationWeights ) {
895
+ inputs . push ( preluActivationWeights ) ;
896
+ }
888
897
return this . compileAndRun < Tensor3D > ( program , inputs , output ) ;
889
898
}
890
899
@@ -1819,7 +1828,7 @@ export class MathBackendWebGL implements KernelBackend {
1819
1828
1820
1829
private conv2dByMatMul (
1821
1830
x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1822
- activation ?: Activation ) : Tensor4D {
1831
+ activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
1823
1832
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
1824
1833
// result from 2D to 4D.
1825
1834
const xShape = x . shape ;
@@ -1850,9 +1859,15 @@ export class MathBackendWebGL implements KernelBackend {
1850
1859
Tensor3D ;
1851
1860
1852
1861
return this . reshape < Rank . R4 > (
1853
- this . fusedBatchMatMul (
1854
- xReshaped , filterReshaped , transposeA , transposeB , bias ,
1855
- activation ) ,
1862
+ this . fusedBatchMatMul ( {
1863
+ a : xReshaped ,
1864
+ b : filterReshaped ,
1865
+ transposeA,
1866
+ transposeB,
1867
+ bias,
1868
+ activation,
1869
+ preluActivationWeights
1870
+ } ) ,
1856
1871
convInfo . outShape ) ;
1857
1872
}
1858
1873
@@ -1888,8 +1903,15 @@ export class MathBackendWebGL implements KernelBackend {
1888
1903
this . reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) as
1889
1904
Tensor3D ;
1890
1905
1891
- const pointwiseConv = this . fusedBatchMatMul (
1892
- xReshaped , filterReshaped , transposeA , transposeB , bias , activation ) ;
1906
+ const pointwiseConv = this . fusedBatchMatMul ( {
1907
+ a : xReshaped ,
1908
+ b : filterReshaped ,
1909
+ transposeA,
1910
+ transposeB,
1911
+ bias,
1912
+ activation,
1913
+ preluActivationWeights
1914
+ } ) ;
1893
1915
const pointwiseConvTexData = this . texData . get ( pointwiseConv . dataId ) ;
1894
1916
util . assert (
1895
1917
pointwiseConvTexData . isPacked ,
@@ -1906,7 +1928,7 @@ export class MathBackendWebGL implements KernelBackend {
1906
1928
1907
1929
private conv2dWithIm2Row (
1908
1930
x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1909
- activation ?: Activation ) : Tensor4D {
1931
+ activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
1910
1932
// Rearranges conv2d input so each block to be convolved over forms the
1911
1933
// column of a new matrix with shape [filterWidth * filterHeight *
1912
1934
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
@@ -1938,42 +1960,53 @@ export class MathBackendWebGL implements KernelBackend {
1938
1960
] ) as Tensor3D ;
1939
1961
1940
1962
const hasBias = bias != null ;
1963
+ const hasPreluActivationWeights = preluActivationWeights != null ;
1941
1964
const fusedActivation =
1942
1965
activation ? mapActivationToShaderProgram ( activation , true ) : null ;
1943
1966
const matmulProgram = new MatMulPackedProgram (
1944
1967
im2Col . shape , [ 1 , numCols , convInfo . outChannels ] , transposeA ,
1945
- transposeB , hasBias , fusedActivation ) ;
1968
+ transposeB , hasBias , fusedActivation , hasPreluActivationWeights ) ;
1946
1969
const inputs : TensorHandle [ ] = [ im2Col , w2Row ] ;
1947
1970
if ( bias ) {
1948
1971
inputs . push ( bias ) ;
1949
1972
}
1973
+ if ( hasPreluActivationWeights ) {
1974
+ inputs . push ( preluActivationWeights ) ;
1975
+ }
1950
1976
const product = this . compileAndRun < Tensor4D > ( matmulProgram , inputs ) ;
1951
1977
1952
1978
return product . reshape ( [ 1 , outHeight , outWidth , convInfo . outChannels ] ) ;
1953
1979
}
1954
1980
1955
1981
fusedConv2d (
1956
1982
x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1957
- activation ?: Activation ) : Tensor4D {
1983
+ activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
1958
1984
if ( convInfo . filterHeight === 1 && convInfo . filterWidth === 1 &&
1959
1985
convInfo . dilationHeight === 1 && convInfo . dilationWidth === 1 &&
1960
1986
convInfo . strideHeight === 1 && convInfo . strideWidth === 1 &&
1961
1987
( convInfo . padInfo . type === 'SAME' ||
1962
1988
convInfo . padInfo . type === 'VALID' ) ) {
1963
- return this . conv2dByMatMul ( x , filter , convInfo , bias , activation ) ;
1989
+ return this . conv2dByMatMul (
1990
+ x , filter , convInfo , bias , activation , preluActivationWeights ) ;
1964
1991
}
1965
1992
if ( ENV . getBool ( 'WEBGL_CONV_IM2COL' ) && x . shape [ 0 ] === 1 ) {
1966
- return this . conv2dWithIm2Row ( x , filter , convInfo , bias , activation ) ;
1993
+ return this . conv2dWithIm2Row (
1994
+ x , filter , convInfo , bias , activation , preluActivationWeights ) ;
1967
1995
}
1968
1996
1969
1997
const hasBias = bias != null ;
1998
+ const hasPreluActivationWeights = preluActivationWeights != null ;
1970
1999
const fusedActivation =
1971
2000
activation ? mapActivationToShaderProgram ( activation , false ) : null ;
1972
- const program = new Conv2DProgram ( convInfo , hasBias , fusedActivation ) ;
2001
+ const program = new Conv2DProgram (
2002
+ convInfo , hasBias , fusedActivation , hasPreluActivationWeights ) ;
1973
2003
const inputs : TensorHandle [ ] = [ x , filter ] ;
1974
2004
if ( bias ) {
1975
2005
inputs . push ( bias ) ;
1976
2006
}
2007
+ if ( preluActivationWeights ) {
2008
+ inputs . push ( preluActivationWeights ) ;
2009
+ }
1977
2010
return this . compileAndRun ( program , inputs ) ;
1978
2011
}
1979
2012
0 commit comments