Skip to content

Commit 33146d2

Browse files
authored
Remove Tensor.get(). Make tensor.buffer() async. Add tensor.array() (tensorflow#1537)
FEATURE BREAKING - Remove `tensor.get(...)` and remove all usage in tfjs-core - Make `tensor.buffer()` async and introduce `tensor.bufferSync()` - Introduce `tensor.array()` and `tensor.arraySync()` which return the data as a nested array. Fixes tensorflow/tfjs#979 Fixes tensorflow/tfjs#928 Fixes tensorflow/tfjs#1124
1 parent 8f1f8b1 commit 33146d2

21 files changed

+682
-659
lines changed

src/engine_test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {describeWithFlags} from './jasmine_util';
2020
import {MathBackendCPU} from './kernels/backend_cpu';
2121
import {MathBackendWebGL} from './kernels/backend_webgl';
2222
import {Tensor} from './tensor';
23-
import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, expectNumbersClose, WEBGL_ENVS} from './test_util';
23+
import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from './test_util';
2424

2525
describeWithFlags('fromPixels + regular math op', WEBGL_ENVS, () => {
2626
it('debug mode does not error when no nans', () => {
@@ -200,7 +200,7 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
200200
return tf.sum(y);
201201
})([a, b]);
202202

203-
expectNumbersClose(value.get(), 10);
203+
expectArraysClose(value, 10);
204204

205205
// de/dy = 1
206206
// dy/dm = step(m)
@@ -235,7 +235,7 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
235235
});
236236
})([a, b]);
237237

238-
expectNumbersClose(value.get(), 10);
238+
expectArraysClose(value, 10);
239239

240240
// de/dy = 1
241241
// dy/dm = step(m)

src/environment_test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ describeWithFlags('epsilon', {}, () => {
187187
});
188188

189189
it('abs(epsilon) > 0', () => {
190-
expect(tf.abs(ENV.get('EPSILON')).get()).toBeGreaterThan(0);
190+
expect(tf.abs(ENV.get('EPSILON')).arraySync()).toBeGreaterThan(0);
191191
});
192192
});
193193

src/kernels/backend_cpu.ts

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,11 @@ export class MathBackendCPU implements KernelBackend {
261261
}
262262

263263
const buffer = ops.buffer(size, x.dtype);
264+
const xBuf = x.bufferSync();
264265
for (let i = 0; i < buffer.size; ++i) {
265266
const loc = buffer.indexToLoc(i);
266267
const xLoc = loc.map((idx, j) => idx + begin[j]);
267-
buffer.values[i] = x.get(...xLoc);
268+
buffer.values[i] = xBuf.get(...xLoc);
268269
}
269270
return buffer.toTensor() as T;
270271
}
@@ -286,15 +287,15 @@ export class MathBackendCPU implements KernelBackend {
286287
}
287288

288289
const buffer = ops.buffer(size, x.dtype);
289-
290+
const xBuf = x.bufferSync();
290291
for (let i = 0; i < buffer.size; i++) {
291292
const loc = buffer.indexToLoc(i);
292293

293294
const newLoc: number[] = new Array(loc.length);
294295
for (let j = 0; j < newLoc.length; j++) {
295296
newLoc[j] = loc[j] * strides[j] + beginIndex[j];
296297
}
297-
buffer.set(x.get(...newLoc), ...loc);
298+
buffer.set(xBuf.get(...newLoc), ...loc);
298299
}
299300

300301
return buffer.toTensor().reshape(shape) as T;
@@ -325,13 +326,13 @@ export class MathBackendCPU implements KernelBackend {
325326
this.assertNotComplex(x, 'reverse');
326327

327328
const buffer = ops.buffer(x.shape, x.dtype);
328-
const xBuffer = x.buffer();
329+
const xBuf = x.bufferSync();
329330

330331
for (let i = 0; i < buffer.size; i++) {
331332
const outLoc = buffer.indexToLoc(i);
332333
const inLoc = outLoc.slice();
333334
axis.forEach(ax => inLoc[ax] = x.shape[ax] - 1 - inLoc[ax]);
334-
buffer.set(xBuffer.get(...inLoc), ...outLoc);
335+
buffer.set(xBuf.get(...inLoc), ...outLoc);
335336
}
336337

337338
return buffer.toTensor() as T;
@@ -1719,7 +1720,8 @@ export class MathBackendCPU implements KernelBackend {
17191720

17201721
const leftPad = convInfo.padInfo.left;
17211722
const topPad = convInfo.padInfo.top;
1722-
1723+
const xBuf = x.bufferSync();
1724+
const dyBuf = dy.bufferSync();
17231725
for (let wR = 0; wR < filterHeight; ++wR) {
17241726
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
17251727
const yRMax = Math.min(
@@ -1739,7 +1741,7 @@ export class MathBackendCPU implements KernelBackend {
17391741
const xR = wR + yR * strideHeight - topPad;
17401742
for (let yC = yCMin; yC < yCMax; ++yC) {
17411743
const xC = wC + yC * strideWidth - leftPad;
1742-
dotProd += x.get(b, xR, xC, d1) * dy.get(b, yR, yC, d2);
1744+
dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
17431745
}
17441746
}
17451747
}
@@ -1970,6 +1972,8 @@ export class MathBackendCPU implements KernelBackend {
19701972
const topPad = convInfo.padInfo.top;
19711973
const chMul = convInfo.outChannels / convInfo.inChannels;
19721974

1975+
const xBuf = x.bufferSync();
1976+
const dyBuf = dy.bufferSync();
19731977
for (let wR = 0; wR < filterHeight; ++wR) {
19741978
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
19751979
const yRMax = Math.min(
@@ -1990,7 +1994,7 @@ export class MathBackendCPU implements KernelBackend {
19901994
const xR = wR + yR * strideHeight - topPad;
19911995
for (let yC = yCMin; yC < yCMax; ++yC) {
19921996
const xC = wC + yC * strideWidth - leftPad;
1993-
dotProd += x.get(b, xR, xC, d1) * dy.get(b, yR, yC, d2);
1997+
dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
19941998
}
19951999
}
19962000
}
@@ -2009,7 +2013,7 @@ export class MathBackendCPU implements KernelBackend {
20092013
newShape[i] = x.shape[i] * reps[i];
20102014
}
20112015
const result = ops.buffer(newShape, x.dtype);
2012-
const xBuf = x.buffer();
2016+
const xBuf = x.bufferSync();
20132017
for (let i = 0; i < result.values.length; ++i) {
20142018
const newLoc = result.indexToLoc(i);
20152019

@@ -2032,7 +2036,7 @@ export class MathBackendCPU implements KernelBackend {
20322036
const outShape = paddings.map(
20332037
(p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
20342038
const start = paddings.map(p => p[0]);
2035-
const xBuffer = x.buffer();
2039+
const xBuffer = x.bufferSync();
20362040
const buffer = ops.buffer(outShape, x.dtype as 'float32');
20372041
if (constantValue !== 0) {
20382042
buffer.values.fill(constantValue);
@@ -2041,7 +2045,7 @@ export class MathBackendCPU implements KernelBackend {
20412045
for (let i = 0; i < x.size; i++) {
20422046
const coords = xBuffer.indexToLoc(i);
20432047
const outCoords = coords.map((c, i) => c + start[i]);
2044-
buffer.set(x.get(...coords), ...outCoords);
2048+
buffer.set(xBuffer.get(...coords), ...outCoords);
20452049
}
20462050
return buffer.toTensor() as T;
20472051
}
@@ -2056,7 +2060,7 @@ export class MathBackendCPU implements KernelBackend {
20562060
const values = x.dataSync();
20572061
const result = buffer(newShape, x.dtype);
20582062

2059-
const xBuf = x.buffer();
2063+
const xBuf = x.bufferSync();
20602064
for (let i = 0; i < x.size; ++i) {
20612065
const loc = xBuf.indexToLoc(i);
20622066

@@ -2079,7 +2083,7 @@ export class MathBackendCPU implements KernelBackend {
20792083
const indicesValues = indices.dataSync();
20802084
newShape[axis] = indicesValues.length;
20812085
const result = buffer(newShape, x.dtype);
2082-
const xBuf = x.buffer();
2086+
const xBuf = x.bufferSync();
20832087

20842088
for (let i = 0; i < result.size; ++i) {
20852089
const newLoc = result.indexToLoc(i);
@@ -2226,6 +2230,7 @@ export class MathBackendCPU implements KernelBackend {
22262230
const padTop = convInfo.padInfo.top;
22272231
const padLeft = convInfo.padInfo.left;
22282232

2233+
const xBuf = x.bufferSync();
22292234
for (let b = 0; b < convInfo.batchSize; ++b) {
22302235
for (let d = 0; d < convInfo.inChannels; ++d) {
22312236
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
@@ -2252,7 +2257,7 @@ export class MathBackendCPU implements KernelBackend {
22522257
const wR = xR - xRCorner;
22532258
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
22542259
const wC = xC - xCCorner;
2255-
const pixel = x.get(b, xR, xC, d);
2260+
const pixel = xBuf.get(b, xR, xC, d);
22562261
if (pixel > maxValue) {
22572262
maxValue = pixel;
22582263
maxPosition = wR * effectiveFilterWidth + wC;
@@ -2282,6 +2287,9 @@ export class MathBackendCPU implements KernelBackend {
22822287
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
22832288
const dx = ops.buffer<Rank.R4>(x.shape, 'float32');
22842289

2290+
const maxPosBuf = maxPositions.bufferSync();
2291+
const dyBuf = dy.bufferSync();
2292+
22852293
for (let b = 0; b < convInfo.batchSize; ++b) {
22862294
for (let d = 0; d < convInfo.inChannels; ++d) {
22872295
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
@@ -2303,15 +2311,15 @@ export class MathBackendCPU implements KernelBackend {
23032311
continue;
23042312
}
23052313
const maxPos = effectiveFilterHeight * effectiveFilterWidth -
2306-
1 - maxPositions.get(b, dyR, dyC, d);
2314+
1 - maxPosBuf.get(b, dyR, dyC, d);
23072315
const curPos = wR * effectiveFilterWidth + wC;
23082316

23092317
const mask = maxPos === curPos ? 1 : 0;
23102318
if (mask === 0) {
23112319
continue;
23122320
}
23132321

2314-
const pixel = dy.get(b, dyR, dyC, d);
2322+
const pixel = dyBuf.get(b, dyR, dyC, d);
23152323
dotProd += pixel * mask;
23162324
}
23172325
}
@@ -2340,6 +2348,8 @@ export class MathBackendCPU implements KernelBackend {
23402348

23412349
const avgMultiplier = 1 / (filterHeight * filterWidth);
23422350

2351+
const dyBuf = dy.bufferSync();
2352+
23432353
for (let b = 0; b < convInfo.batchSize; ++b) {
23442354
for (let d = 0; d < convInfo.inChannels; ++d) {
23452355
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
@@ -2361,7 +2371,7 @@ export class MathBackendCPU implements KernelBackend {
23612371
continue;
23622372
}
23632373

2364-
const pixel = dy.get(b, dyR, dyC, d);
2374+
const pixel = dyBuf.get(b, dyR, dyC, d);
23652375
dotProd += pixel;
23662376
}
23672377
}
@@ -2842,10 +2852,11 @@ export class MathBackendCPU implements KernelBackend {
28422852

28432853
const res = new Float32Array(indices.size * depth);
28442854
res.fill(offValue);
2855+
const indicesVal = indices.dataSync();
28452856

28462857
for (let event = 0; event < indices.size; ++event) {
2847-
if (indices.get(event) >= 0 && indices.get(event) < depth) {
2848-
res[event * depth + indices.get(event)] = onValue;
2858+
if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
2859+
res[event * depth + indicesVal[event]] = onValue;
28492860
}
28502861
}
28512862
return ops.tensor2d(res, [indices.size, depth], 'int32');
@@ -3040,8 +3051,8 @@ export class MathBackendCPU implements KernelBackend {
30403051
resVals[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
30413052
}
30423053
} else {
3043-
const aBuf = a.buffer();
3044-
const bBuf = b.buffer();
3054+
const aBuf = a.bufferSync();
3055+
const bBuf = b.bufferSync();
30453056
for (let i = 0; i < resVals.length; ++i) {
30463057
const loc = result.indexToLoc(i);
30473058

@@ -3090,8 +3101,8 @@ export class MathBackendCPU implements KernelBackend {
30903101
imagVals[i] = result.imag;
30913102
}
30923103
} else {
3093-
const aRealBuf = this.data.get(a.dataId).complexTensors.real.buffer();
3094-
const bRealBuf = this.data.get(b.dataId).complexTensors.real.buffer();
3104+
const aRealBuf = this.data.get(a.dataId).complexTensors.real.bufferSync();
3105+
const bRealBuf = this.data.get(b.dataId).complexTensors.real.bufferSync();
30953106
for (let i = 0; i < realVals.length; i++) {
30963107
const loc = realResult.indexToLoc(i);
30973108

src/kernels/backend_webgl.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,10 +2267,10 @@ export class MathBackendWebGL implements KernelBackend {
22672267
// trying to upload a small value.
22682268
const debugFlag = ENV.get('DEBUG');
22692269
ENV.set('DEBUG', false);
2270-
const underflowCheckVluae = this.abs(scalar(1e-8)).get();
2270+
const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0];
22712271
ENV.set('DEBUG', debugFlag);
22722272

2273-
if (underflowCheckVluae > 0) {
2273+
if (underflowCheckValue > 0) {
22742274
return 32;
22752275
}
22762276
return 16;

src/ops/arithmetic_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,11 +457,11 @@ describeWithFlags('mul', ALL_ENVS, () => {
457457

458458
expect(da.shape).toEqual(a.shape);
459459
expect(da.dtype).toEqual('float32');
460-
expectArraysClose(da, [b.get() * dy.get()]);
460+
expectArraysClose(da, b.mul(dy));
461461

462462
expect(db.shape).toEqual(b.shape);
463463
expect(db.dtype).toEqual('float32');
464-
expectArraysClose(db, [a.get() * dy.get()]);
464+
expectArraysClose(db, a.mul(dy));
465465
});
466466

467467
it('gradient: Tensor1D', () => {

0 commit comments

Comments
 (0)