Skip to content

Commit ef29cc1

Browse files
authored
Simplify tf.ENV.registerBackend (remove setTensorTracker param) (tensorflow#1517)
INTERNAL There is no need for custom backends to provide their own tensor trackers since the tracker is the engine (a global). Tested that double tfjs import works: ```html <script src="http://github.com/tf-core.min.js"></script> <script src="http://github.com/tf-core.min.js"></script> <script> tf.square(3).print(); </script> ```
1 parent 58555ae commit ef29cc1

File tree

4 files changed

+14
-20
lines changed

4 files changed

+14
-20
lines changed

src/canvas_util.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ const WEBGL_ATTRIBUTES: WebGLContextAttributes = {
2929

3030
export function getWebGLContext(webGLVersion: number): WebGLRenderingContext {
3131
if (!(webGLVersion in contexts)) {
32-
const canvas = document.createElement('canvas');
33-
canvas.addEventListener('webglcontextlost', ev => {
34-
ev.preventDefault();
35-
delete contexts[webGLVersion];
36-
}, false);
3732
contexts[webGLVersion] = getWebGLRenderingContext(webGLVersion);
3833
}
3934
const gl = contexts[webGLVersion];
@@ -61,6 +56,10 @@ function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext {
6156
}
6257

6358
const canvas = document.createElement('canvas');
59+
canvas.addEventListener('webglcontextlost', ev => {
60+
ev.preventDefault();
61+
delete contexts[webGLVersion];
62+
}, false);
6463
if (webGLVersion === 1) {
6564
return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
6665
canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES)) as

src/environment.ts

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import * as device_util from './device_util';
1919
import {Engine, MemoryInfo, ProfileInfo, ScopeFn, TimingInfo} from './engine';
2020
import {Features, getFeaturesFromURL, getMaxTexturesInShader, getNumMBBeforePaging, getWebGLDisjointQueryTimerVersion, getWebGLMaxTextureSize, isChrome, isDownloadFloatTextureEnabled, isRenderToFloatTextureEnabled, isWebGLFenceEnabled, isWebGLVersionEnabled} from './environment_util';
2121
import {KernelBackend} from './kernels/backend';
22-
import {DataId, setTensorTracker, Tensor, TensorTracker} from './tensor';
22+
import {DataId, setTensorTracker, Tensor} from './tensor';
2323
import {TensorContainer} from './tensor_types';
2424
import {getTensorsInContainer} from './tensor_util';
2525

@@ -416,15 +416,11 @@ export class Environment {
416416
* the best backend. Defaults to 1.
417417
* @return False if the creation/registration failed. True otherwise.
418418
*/
419-
registerBackend(
420-
name: string, factory: () => KernelBackend, priority = 1,
421-
setTensorTrackerFn?: (f: () => TensorTracker) => void): boolean {
419+
registerBackend(name: string, factory: () => KernelBackend, priority = 1):
420+
boolean {
422421
if (name in this.registry) {
423422
console.warn(
424423
`${name} backend was already registered. Reusing existing backend`);
425-
if (setTensorTrackerFn != null) {
426-
setTensorTrackerFn(() => this.engine);
427-
}
428424
return false;
429425
}
430426
try {
@@ -480,8 +476,10 @@ function getOrMakeEnvironment(): Environment {
480476
const ns = getGlobalNamespace();
481477
if (ns.ENV == null) {
482478
ns.ENV = new Environment(getFeaturesFromURL());
483-
setTensorTracker(() => ns.ENV.engine);
484479
}
480+
// Tell the current tensor interface that the global engine is responsible for
481+
// tracking.
482+
setTensorTracker(() => ns.ENV.engine);
485483
return ns.ENV;
486484
}
487485

src/kernels/backend_cpu.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import {buffer, scalar, tensor, tensor3d, tensor4d} from '../ops/ops';
3232
import * as scatter_nd_util from '../ops/scatter_nd_util';
3333
import * as selu_util from '../ops/selu_util';
3434
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../ops/slice_util';
35-
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../tensor';
35+
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../tensor';
3636
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../types';
3737
import * as util from '../util';
3838
import {now} from '../util';
@@ -3363,5 +3363,4 @@ export class MathBackendCPU implements KernelBackend {
33633363
}
33643364
}
33653365

3366-
ENV.registerBackend(
3367-
'cpu', () => new MathBackendCPU(), 1 /* priority */, setTensorTracker);
3366+
ENV.registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */);

src/kernels/backend_webgl.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import * as segment_util from '../ops/segment_util';
3232
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../ops/slice_util';
3333
import {softmax} from '../ops/softmax';
3434
import {range, scalar, tensor} from '../ops/tensor_ops';
35-
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
35+
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
3636
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../types';
3737
import * as util from '../util';
3838
import {getTypedArrayFromDType, sizeFromShape} from '../util';
@@ -2304,9 +2304,7 @@ export class MathBackendWebGL implements KernelBackend {
23042304
}
23052305

23062306
if (ENV.get('IS_BROWSER')) {
2307-
ENV.registerBackend(
2308-
'webgl', () => new MathBackendWebGL(), 2 /* priority */,
2309-
setTensorTracker);
2307+
ENV.registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */);
23102308
}
23112309

23122310
function float32ToTypedArray<D extends NumericDataType>(

0 commit comments

Comments
 (0)