Skip to content

Commit 4beeddf

Browse files
tedhtchangdsmilkov
authored andcommitted
use precomputed tensor size (tensorflow#1608)
INTERNAL PR resolves tensorflow/tfjs#873 and responds to @dsmilkov 's comment tensorflow#1371 (review)
1 parent f782fc8 commit 4beeddf

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

src/engine.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
259259
// string tensors are counted when writing values.
260260
let bytes = 0;
261261
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
262-
bytes = util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype);
262+
bytes = a.size * util.bytesPerElement(a.dtype);
263263
}
264264
this.tensorInfo.set(a.dataId, {
265265
backend: backend != null ? backend : this.backend,

src/kernels/backend_cpu.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2740,7 +2740,7 @@ export class MathBackendCPU implements KernelBackend {
27402740
const channels = x.shape[3];
27412741
const maxD = channels - 1;
27422742
const xValues = x.dataSync();
2743-
const size = util.sizeFromShape(x.shape);
2743+
const size = x.size;
27442744
const result = new Float32Array(size);
27452745

27462746
function sumAcrossChannels(offset: number) {
@@ -2776,8 +2776,8 @@ export class MathBackendCPU implements KernelBackend {
27762776
const dyValues = dy.dataSync();
27772777
const inputImageValues = inputImage.dataSync();
27782778
const outputImageValues = outputImage.dataSync();
2779-
const result = new Float32Array(util.sizeFromShape(dy.shape));
2780-
const size = util.sizeFromShape(dy.shape);
2779+
const result = new Float32Array(dy.size);
2780+
const size = dy.size;
27812781

27822782
for (let offset = 0; offset < size; offset++) {
27832783
const currentChannel = offset % channels;

src/ops/loss_ops.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {customGrad} from '../globals';
1919
import {Tensor} from '../tensor';
2020
import {convertToTensor} from '../tensor_util_env';
2121
import {TensorLike} from '../types';
22-
import {assertShapesMatch, sizeFromShape} from '../util';
22+
import {assertShapesMatch} from '../util';
2323

2424
import {expandShapeToKeepDim} from './axis_util';
2525

@@ -66,7 +66,7 @@ function computeWeightedLoss_<T extends Tensor, O extends Tensor>(
6666
return weightedLoss.mean();
6767
} else {
6868
const broadcastFactor =
69-
sizeFromShape($losses.shape) / sizeFromShape($weights.shape);
69+
$losses.size / $weights.size;
7070
const result = weightedLoss.sum().div($weights.sum());
7171
return broadcastFactor > 1 ? result.div(scalar(broadcastFactor)) :
7272
result as O;

src/tensor.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ export class TensorBuffer<R extends Rank, D extends DataType = 'float32'> {
6565
`call tf.complex(real, imag).`);
6666
}
6767
this.values =
68-
values || util.getArrayFromDType(dtype, util.sizeFromShape(this.shape));
68+
values || util.getArrayFromDType(dtype, this.size);
6969
this.strides = computeStrides(shape);
7070
}
7171

0 commit comments

Comments
 (0)