diff --git a/src/engine.ts b/src/engine.ts index 74a0f47028..5af0ac0e8a 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -259,7 +259,7 @@ export class Engine implements TensorManager, TensorTracker, DataMover { // string tensors are counted when writing values. let bytes = 0; if (a.dtype !== 'complex64' && a.dtype !== 'string') { - bytes = util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype); + bytes = a.size * util.bytesPerElement(a.dtype); } this.tensorInfo.set(a.dataId, { backend: backend != null ? backend : this.backend, diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index fd04f2acaf..069e4e7599 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -2740,7 +2740,7 @@ export class MathBackendCPU implements KernelBackend { const channels = x.shape[3]; const maxD = channels - 1; const xValues = x.dataSync(); - const size = util.sizeFromShape(x.shape); + const size = x.size; const result = new Float32Array(size); function sumAcrossChannels(offset: number) { @@ -2776,8 +2776,8 @@ export class MathBackendCPU implements KernelBackend { const dyValues = dy.dataSync(); const inputImageValues = inputImage.dataSync(); const outputImageValues = outputImage.dataSync(); - const result = new Float32Array(util.sizeFromShape(dy.shape)); - const size = util.sizeFromShape(dy.shape); + const result = new Float32Array(dy.size); + const size = dy.size; for (let offset = 0; offset < size; offset++) { const currentChannel = offset % channels; diff --git a/src/ops/loss_ops.ts b/src/ops/loss_ops.ts index a28f4fae9e..0e7f5a48b9 100644 --- a/src/ops/loss_ops.ts +++ b/src/ops/loss_ops.ts @@ -19,7 +19,7 @@ import {customGrad} from '../globals'; import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import {assertShapesMatch, sizeFromShape} from '../util'; +import {assertShapesMatch} from '../util'; import {expandShapeToKeepDim} from './axis_util'; @@ -66,7 +66,7 @@ function computeWeightedLoss_( return weightedLoss.mean(); } else { const broadcastFactor = - sizeFromShape($losses.shape) / sizeFromShape($weights.shape); + $losses.size / $weights.size; const result = weightedLoss.sum().div($weights.sum()); return broadcastFactor > 1 ? result.div(scalar(broadcastFactor)) : result as O; diff --git a/src/tensor.ts b/src/tensor.ts index 02d098c69c..0e1c2d1d20 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -65,7 +65,7 @@ export class TensorBuffer { `call tf.complex(real, imag).`); } this.values = - values || util.getArrayFromDType(dtype, util.sizeFromShape(this.shape)); + values || util.getArrayFromDType(dtype, this.size); this.strides = computeStrides(shape); }