@@ -34,9 +34,8 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../ops/
34
34
import { DataId , Scalar , Tensor , Tensor1D , Tensor2D , Tensor3D , Tensor4D , Tensor5D , TensorBuffer } from '../tensor' ;
35
35
import { DataType , DataTypeMap , DataValues , NumericDataType , Rank , ShapeMap , TypedArray , upcastType } from '../types' ;
36
36
import * as util from '../util' ;
37
- // import * as asm from './asm';
37
+ import * as asm from './asm' ;
38
38
import { getArrayFromDType , inferDtype , now , sizeFromShape } from '../util' ;
39
-
40
39
import { BackendTimingInfo , DataMover , DataStorage , KernelBackend } from './backend' ;
41
40
import * as backend_util from './backend_util' ;
42
41
import * as complex_util from './complex_util' ;
@@ -460,63 +459,60 @@ export class MathBackendCPU implements KernelBackend {
460
459
transposeB : boolean ) : Tensor3D {
461
460
this . assertNotComplex ( [ a , b ] , 'matMul' ) ;
462
461
463
- const sharedDim = transposeA ? a . shape [ 1 ] : a . shape [ 2 ] ;
462
+ // const sharedDim = transposeA ? a.shape[1] : a.shape[2];
464
463
const leftDim = transposeA ? a . shape [ 2 ] : a . shape [ 1 ] ;
465
464
const rightDim = transposeB ? b . shape [ 1 ] : b . shape [ 2 ] ;
466
465
const batchDim = a . shape [ 0 ] ;
467
- // const nWorkers = navigator.hardwareConcurrency || 4;
468
466
const outShape = [ batchDim , leftDim , rightDim ] ;
469
- // if (batchDim === 1 && a.shape[0] >= nWorkers) {
470
- // console.warn('asking for asm');
471
- // const values = asm.matmul(a.squeeze([0]), b.squeeze([0]));
472
- // return Tensor.make(outShape, {values}, a.dtype);
473
- // }
474
-
475
- const compute = async ( ) => {
476
- const [ aValues , bValues ] = await Promise . all ( [ a . data ( ) , b . data ( ) ] ) ;
477
- const [ aOuterStep , aInnerStep ] =
478
- transposeA ? [ 1 , a . strides [ 1 ] ] : [ a . strides [ 1 ] , 1 ] ;
479
- const [ bInnerStep , bOuterStep ] =
480
- transposeB ? [ 1 , b . strides [ 1 ] ] : [ b . strides [ 1 ] , 1 ] ;
481
-
482
- const resVals = util . getTypedArrayFromDType (
483
- a . dtype as 'float32' , sizeFromShape ( outShape ) ) ;
484
- const blockSize = this . blockSize ;
485
-
486
- for ( let batch = 0 ; batch < batchDim ; batch ++ ) {
487
- const aBatch = batch * a . strides [ 0 ] ;
488
- const bBatch = batch * b . strides [ 0 ] ;
489
- for ( let i0 = 0 ; i0 < leftDim ; i0 += blockSize ) {
490
- const iBlock = i0 + blockSize < leftDim ? i0 + blockSize : leftDim ;
491
- for ( let j0 = 0 ; j0 < rightDim ; j0 += blockSize ) {
492
- const jBlock =
493
- j0 + blockSize < rightDim ? j0 + blockSize : rightDim ;
494
- for ( let k0 = 0 ; k0 < sharedDim ; k0 += blockSize ) {
495
- // for when blockSize doesn't evenly divide the input
496
- const kBlock =
497
- k0 + blockSize < sharedDim ? k0 + blockSize : sharedDim ;
498
-
499
- for ( let i = i0 ; i < iBlock ; i ++ ) {
500
- const iDim = i * rightDim ;
501
- const iStep = aBatch + i * aOuterStep ;
502
- for ( let j = j0 ; j < jBlock ; j ++ ) {
503
- const jStep = bBatch + j * bOuterStep ;
504
- let sum = 0.0 ;
505
-
506
- for ( let k = k0 ; k < kBlock ; k ++ ) {
507
- sum += aValues [ k * aInnerStep + iStep ] *
508
- bValues [ k * bInnerStep + jStep ] ;
509
- }
510
- resVals [ iDim + j ] += sum ;
511
- }
512
- }
513
- }
514
- }
515
- }
516
- }
517
- return resVals ;
518
- } ;
519
- return Tensor . make ( outShape , { values : compute ( ) } , a . dtype ) as Tensor3D ;
467
+ const values = asm . matmul ( a , b , transposeA , transposeB ) ;
468
+ return Tensor . make ( outShape , { values} , a . dtype ) ;
469
+
470
+ // const compute = async () => {
471
+ // const [aValues, bValues] = await Promise.all([a.data(), b.data()]);
472
+ // const [aOuterStep, aInnerStep] =
473
+ // transposeA ? [1, a.strides[1]] : [a.strides[1], 1];
474
+ // const [bInnerStep, bOuterStep] =
475
+ // transposeB ? [1, b.strides[1]] : [b.strides[1], 1];
476
+
477
+ // const resVals = util.getTypedArrayFromDType(
478
+ // a.dtype as 'float32', sizeFromShape(outShape));
479
+ // const blockSize = this.blockSize;
480
+
481
+ // for (let batch = 0; batch < batchDim; batch++) {
482
+ // const aBatch = batch * a.strides[0];
483
+ // const bBatch = batch * b.strides[0];
484
+ // const resBatch = batch * leftDim * rightDim;
485
+ // for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
486
+ // const iBlock = i0 + blockSize < leftDim ? i0 + blockSize : leftDim;
487
+ // for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
488
+ // const jBlock =
489
+ // j0 + blockSize < rightDim ? j0 + blockSize : rightDim;
490
+ // for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
491
+ // // for when blockSize doesn't evenly divide the input
492
+ // const kBlock =
493
+ // k0 + blockSize < sharedDim ? k0 + blockSize : sharedDim;
494
+
495
+ // for (let i = i0; i < iBlock; i++) {
496
+ // const iDim = resBatch + i * rightDim;
497
+ // const iStep = aBatch + i * aOuterStep;
498
+ // for (let j = j0; j < jBlock; j++) {
499
+ // const jStep = bBatch + j * bOuterStep;
500
+ // let sum = 0.0;
501
+
502
+ // for (let k = k0; k < kBlock; k++) {
503
+ // sum += aValues[k * aInnerStep + iStep] *
504
+ // bValues[k * bInnerStep + jStep];
505
+ // }
506
+ // resVals[iDim + j] += sum;
507
+ // }
508
+ // }
509
+ // }
510
+ // }
511
+ // }
512
+ // }
513
+ // return resVals;
514
+ // };
515
+ // return Tensor.make(outShape, {values: compute()}, a.dtype) as Tensor3D;
520
516
}
521
517
522
518
fusedBatchMatMul (
0 commit comments