@@ -77,7 +77,6 @@ import {Im2ColProgram} from './webgl/im2col_gpu';
77
77
import { LRNProgram } from './webgl/lrn_gpu' ;
78
78
import { LRNGradProgram } from './webgl/lrn_grad_gpu' ;
79
79
import { MaxPool2DBackpropProgram } from './webgl/max_pool_backprop_gpu' ;
80
- import { MatMulProgram } from './webgl/mulmat_gpu' ;
81
80
import { MatMulPackedProgram } from './webgl/mulmat_packed_gpu' ;
82
81
import { MultinomialProgram } from './webgl/multinomial_gpu' ;
83
82
import { OneHotProgram } from './webgl/onehot_gpu' ;
@@ -769,26 +768,11 @@ export class MathBackendWebGL implements KernelBackend {
769
768
770
769
const dtype = upcastType ( a . dtype , b . dtype ) ;
771
770
772
- // TODO(https://github.com/tensorflow/tfjs/issues/693): Support 3D tensors
773
- if ( batch === 1 ) {
774
- const aSqueezed = a . as2D ( a . shape [ 1 ] , a . shape [ 2 ] ) ;
775
- const bSqueezed = b . as2D ( b . shape [ 1 ] , b . shape [ 2 ] ) ;
776
-
777
- const program = new MatMulPackedProgram (
778
- aSqueezed . shape , bSqueezed . shape , [ outerShapeA , outerShapeB ] ,
779
- transposeA , transposeB ) ;
780
- const output =
781
- this . makePackedTensor ( program . outputShape , dtype ) as Tensor2D ;
782
- const result =
783
- this . compileAndRun < Tensor2D > ( program , [ aSqueezed , bSqueezed ] , output ) ;
784
- return result . reshape ( [ 1 , result . shape [ 0 ] , result . shape [ 1 ] ] ) ;
785
- } else {
786
- const program =
787
- new MatMulProgram ( a . shape , b . shape , transposeA , transposeB ) ;
788
- const output =
789
- this . makeOutputArray ( program . outputShape , dtype ) as Tensor3D ;
790
- return this . compileAndRun ( program , [ a , b ] , output ) ;
791
- }
771
+ const program = new MatMulPackedProgram ( a . shape ,
772
+ [ batch , outerShapeA , outerShapeB ] , transposeA , transposeB ) ;
773
+ const output =
774
+ this . makePackedTensor ( program . outputShape , dtype ) as Tensor3D ;
775
+ return this . compileAndRun < Tensor3D > ( program , [ a , b ] , output ) ;
792
776
}
793
777
794
778
fusedBatchMatMul (
@@ -800,35 +784,16 @@ export class MathBackendWebGL implements KernelBackend {
800
784
801
785
const dtype = upcastType ( a . dtype , b . dtype ) ;
802
786
803
- // TODO(https://github.com/tensorflow/tfjs/issues/693): Support 3D tensors
804
- if ( batch === 1 ) {
805
- const aSqueezed = a . as2D ( a . shape [ 1 ] , a . shape [ 2 ] ) ;
806
- const bSqueezed = b . as2D ( b . shape [ 1 ] , b . shape [ 2 ] ) ;
807
-
808
- const program = new MatMulPackedProgram (
809
- aSqueezed . shape , bSqueezed . shape , [ outerShapeA , outerShapeB ] ,
810
- transposeA , transposeB , ! ! bias ,
811
- activation ? mapActivationToShaderProgram ( activation , true ) : null ) ;
812
- const output =
813
- this . makePackedTensor ( program . outputShape , dtype ) as Tensor2D ;
814
- const inputs : TensorHandle [ ] = [ aSqueezed , bSqueezed ] ;
815
- if ( bias ) {
816
- inputs . push ( bias ) ;
817
- }
818
- const result = this . compileAndRun < Tensor2D > ( program , inputs , output ) ;
819
- return result . reshape ( [ 1 , result . shape [ 0 ] , result . shape [ 1 ] ] ) ;
820
- } else {
821
- const program = new MatMulProgram (
822
- a . shape , b . shape , transposeA , transposeB , ! ! bias ,
823
- activation ? mapActivationToShaderProgram ( activation ) : null ) ;
824
- const inputs : TensorHandle [ ] = [ a , b ] ;
825
- if ( bias ) {
826
- inputs . push ( bias ) ;
827
- }
828
- const output =
829
- this . makeOutputArray ( program . outputShape , dtype ) as Tensor3D ;
830
- return this . compileAndRun ( program , inputs , output ) ;
787
+ const program = new MatMulPackedProgram ( a . shape ,
788
+ [ batch , outerShapeA , outerShapeB ] , transposeA , transposeB , ! ! bias ,
789
+ activation ? mapActivationToShaderProgram ( activation , true ) : null ) ;
790
+ const output =
791
+ this . makePackedTensor ( program . outputShape , dtype ) as Tensor3D ;
792
+ const inputs : TensorHandle [ ] = [ a , b ] ;
793
+ if ( bias ) {
794
+ inputs . push ( bias ) ;
831
795
}
796
+ return this . compileAndRun < Tensor3D > ( program , inputs , output ) ;
832
797
}
833
798
834
799
multiply ( a : Tensor , b : Tensor ) : Tensor {
@@ -1711,14 +1676,15 @@ export class MathBackendWebGL implements KernelBackend {
1711
1676
const x2ColShape = [ sharedDim , numCols ] ;
1712
1677
1713
1678
const xSqueezed = x . squeeze ( [ 0 ] ) ;
1714
- const w2Row = filter . reshape ( [ sharedDim , - 1 ] ) as Tensor2D ;
1679
+ const w2Row = filter . reshape ( [ 1 , sharedDim , - 1 ] ) as Tensor3D ;
1715
1680
1716
1681
const im2ColProgram =
1717
1682
new Im2ColProgram ( x2ColShape , xSqueezed . shape , convInfo ) ;
1718
- const im2Col = this . compileAndRun < Tensor2D > ( im2ColProgram , [ xSqueezed ] ) ;
1683
+ const im2Col = this . compileAndRun < Tensor2D > ( im2ColProgram , [ xSqueezed ] ) .
1684
+ reshape ( [ 1 , x2ColShape [ 0 ] , x2ColShape [ 1 ] ] ) as Tensor3D ;
1719
1685
1720
1686
const matmulProgram = new MatMulPackedProgram (
1721
- im2Col . shape , w2Row . shape , [ numCols , convInfo . outChannels ] , true ,
1687
+ im2Col . shape , [ 1 , numCols , convInfo . outChannels ] , true ,
1722
1688
false ) ;
1723
1689
const product =
1724
1690
this . compileAndRun < Tensor4D > ( matmulProgram , [ im2Col , w2Row ] ) ;
0 commit comments