Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Simplify tf.ENV.registerBackend (remove setTensorTracker param) #1517

Merged
merged 2 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/canvas_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ const WEBGL_ATTRIBUTES: WebGLContextAttributes = {

export function getWebGLContext(webGLVersion: number): WebGLRenderingContext {
if (!(webGLVersion in contexts)) {
const canvas = document.createElement('canvas');
canvas.addEventListener('webglcontextlost', ev => {
ev.preventDefault();
delete contexts[webGLVersion];
}, false);
contexts[webGLVersion] = getWebGLRenderingContext(webGLVersion);
}
const gl = contexts[webGLVersion];
Expand Down Expand Up @@ -61,6 +56,10 @@ function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext {
}

const canvas = document.createElement('canvas');
canvas.addEventListener('webglcontextlost', ev => {
ev.preventDefault();
delete contexts[webGLVersion];
}, false);
if (webGLVersion === 1) {
return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES)) as
Expand Down
14 changes: 6 additions & 8 deletions src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import * as device_util from './device_util';
import {Engine, MemoryInfo, ProfileInfo, ScopeFn, TimingInfo} from './engine';
import {Features, getFeaturesFromURL, getMaxTexturesInShader, getNumMBBeforePaging, getWebGLDisjointQueryTimerVersion, getWebGLMaxTextureSize, isChrome, isDownloadFloatTextureEnabled, isRenderToFloatTextureEnabled, isWebGLFenceEnabled, isWebGLVersionEnabled} from './environment_util';
import {KernelBackend} from './kernels/backend';
import {DataId, setTensorTracker, Tensor, TensorTracker} from './tensor';
import {DataId, setTensorTracker, Tensor} from './tensor';
import {TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';

Expand Down Expand Up @@ -416,15 +416,11 @@ export class Environment {
* the best backend. Defaults to 1.
* @return False if the creation/registration failed. True otherwise.
*/
registerBackend(
name: string, factory: () => KernelBackend, priority = 1,
setTensorTrackerFn?: (f: () => TensorTracker) => void): boolean {
registerBackend(name: string, factory: () => KernelBackend, priority = 1):
boolean {
if (name in this.registry) {
console.warn(
`${name} backend was already registered. Reusing existing backend`);
if (setTensorTrackerFn != null) {
setTensorTrackerFn(() => this.engine);
}
return false;
}
try {
Expand Down Expand Up @@ -480,8 +476,10 @@ function getOrMakeEnvironment(): Environment {
const ns = getGlobalNamespace();
if (ns.ENV == null) {
ns.ENV = new Environment(getFeaturesFromURL());
setTensorTracker(() => ns.ENV.engine);
}
// Tell the current tensor interface that the global engine is responsible for
// tracking.
setTensorTracker(() => ns.ENV.engine);
return ns.ENV;
}

Expand Down
5 changes: 2 additions & 3 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {buffer, scalar, tensor, tensor3d, tensor4d} from '../ops/ops';
import * as scatter_nd_util from '../ops/scatter_nd_util';
import * as selu_util from '../ops/selu_util';
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../ops/slice_util';
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../tensor';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../tensor';
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../types';
import * as util from '../util';
import {now} from '../util';
Expand Down Expand Up @@ -3363,5 +3363,4 @@ export class MathBackendCPU implements KernelBackend {
}
}

ENV.registerBackend(
'cpu', () => new MathBackendCPU(), 1 /* priority */, setTensorTracker);
ENV.registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */);
6 changes: 2 additions & 4 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import * as segment_util from '../ops/segment_util';
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../ops/slice_util';
import {softmax} from '../ops/softmax';
import {range, scalar, tensor} from '../ops/tensor_ops';
import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../types';
import * as util from '../util';
import {getTypedArrayFromDType, sizeFromShape} from '../util';
Expand Down Expand Up @@ -2298,9 +2298,7 @@ export class MathBackendWebGL implements KernelBackend {
}

if (ENV.get('IS_BROWSER')) {
ENV.registerBackend(
'webgl', () => new MathBackendWebGL(), 2 /* priority */,
setTensorTracker);
ENV.registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */);
}

function float32ToTypedArray<D extends NumericDataType>(
Expand Down