Skip to content

Commit c157cc5

Browse files
authored
Make tfjs-core compatible with the closure compiler (tensorflow#1521)
- Add `/** @nocollapse */` on static properties to tell closure not to remove them. - Prepend `declare` in front of interfaces and types that are external (used in json deserialization/serialization) so closure doesn't rename its properties. - Fix issues where the promise was unused. Similar changes in tfjs-layers: tensorflow/tfjs-layers#442
1 parent 7179d24 commit c157cc5

13 files changed

+78
-72
lines changed

src/io/router_registry.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ export class IORouterRegistry {
8989
url: string|string[], handlerType: 'save'|'load',
9090
onProgress?: Function): IOHandler[] {
9191
const validHandlers: IOHandler[] = [];
92-
const routers = handlerType === 'load' ? this.getInstance().loadRouters :
93-
this.getInstance().saveRouters;
92+
const routers = handlerType === 'load' ?
93+
IORouterRegistry.getInstance().loadRouters :
94+
IORouterRegistry.getInstance().saveRouters;
9495
routers.forEach(router => {
9596
const handler = router(url, onProgress);
9697
if (handler !== null) {

src/io/weights_loader.ts

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,11 @@ export async function loadWeightsAsArrayBuffer(
5858
const fetchStartFraction = 0;
5959
const fetchEndFraction = 0.5;
6060

61-
if (onProgress != null) {
62-
util.monitorPromisesProgress(
63-
requests, onProgress, fetchStartFraction, fetchEndFraction);
64-
}
61+
const responses = onProgress == null ?
62+
await Promise.all(requests) :
63+
await util.monitorPromisesProgress(
64+
requests, onProgress, fetchStartFraction, fetchEndFraction);
6565

66-
const responses = await Promise.all(requests);
6766
const badContentType = responses.filter(response => {
6867
const contentType = response.headers.get(CONTENT_TYPE);
6968
return !contentType || contentType.indexOf(OCTET_STREAM_TYPE) === -1;
@@ -82,12 +81,10 @@ export async function loadWeightsAsArrayBuffer(
8281
const bufferStartFraction = 0.5;
8382
const bufferEndFraction = 1;
8483

85-
if (onProgress != null) {
86-
util.monitorPromisesProgress(
87-
bufferPromises, onProgress, bufferStartFraction, bufferEndFraction);
88-
}
89-
90-
const buffers = await Promise.all(bufferPromises);
84+
const buffers = onProgress == null ?
85+
await Promise.all(bufferPromises) :
86+
await util.monitorPromisesProgress(
87+
bufferPromises, onProgress, bufferStartFraction, bufferEndFraction);
9188
return buffers;
9289
}
9390

src/kernels/webgl/shader_compiler.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ function getFloatTextureSetRGBASnippet(glsl: GLSL): string {
216216
}
217217

218218
function getShaderPrefix(glsl: GLSL): string {
219-
let NAN_CHECKS = '';
219+
let nanChecks = '';
220220
if (ENV.get('PROD')) {
221-
NAN_CHECKS = `
221+
nanChecks = `
222222
bool isNaN(float val) {
223223
return false;
224224
}
@@ -232,7 +232,7 @@ function getShaderPrefix(glsl: GLSL): string {
232232
* Previous NaN check '(val < 0.0 || 0.0 < val || val == 0.0) ? false :
233233
* true' does not work on iOS 12
234234
*/
235-
NAN_CHECKS = `
235+
nanChecks = `
236236
bool isNaN(float val) {
237237
return (val < 1.0 || 0.0 < val || val == 0.0) ? false : true;
238238
}
@@ -275,7 +275,7 @@ function getShaderPrefix(glsl: GLSL): string {
275275
int v;
276276
};
277277
278-
${NAN_CHECKS}
278+
${nanChecks}
279279
280280
float getNaN(vec4 values) {
281281
return dot(vec4(1), values);

src/optimizers/adadelta_optimizer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ export class AdadeltaOptimizer extends Optimizer {
115115
epsilon: this.epsilon
116116
};
117117
}
118+
119+
/** @nocollapse */
118120
static fromConfig<T extends Serializable>(
119121
cls: SerializableConstructor<T>, config: ConfigDict): T {
120122
return new cls(config.learningRate, config.rho, config.epsilon);

src/optimizers/adagrad_optimizer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ export class AdagradOptimizer extends Optimizer {
8282
initialAccumulatorValue: this.initialAccumulatorValue,
8383
};
8484
}
85+
86+
/** @nocollapse */
8587
static fromConfig<T extends Serializable>(
8688
cls: SerializableConstructor<T>, config: ConfigDict): T {
8789
return new cls(config.learningRate, config.initialAccumulatorValue);

src/optimizers/adam_optimizer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ export class AdamOptimizer extends Optimizer {
139139
epsilon: this.epsilon,
140140
};
141141
}
142+
143+
/** @nocollapse */
142144
static fromConfig<T extends Serializable>(
143145
cls: SerializableConstructor<T>, config: ConfigDict): T {
144146
return new cls(

src/optimizers/adamax_optimizer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ export class AdamaxOptimizer extends Optimizer {
146146
decay: this.decay
147147
};
148148
}
149+
150+
/** @nocollapse */
149151
static fromConfig<T extends Serializable>(
150152
cls: SerializableConstructor<T>, config: ConfigDict): T {
151153
return new cls(

src/optimizers/momentum_optimizer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ export class MomentumOptimizer extends SGDOptimizer {
9393
useNesterov: this.useNesterov
9494
};
9595
}
96+
97+
/** @nocollapse */
9698
static fromConfig<T extends Serializable>(
9799
cls: SerializableConstructor<T>, config: ConfigDict): T {
98100
return new cls(config.learningRate, config.momentum, config.useNesterov);

src/optimizers/rmsprop_optimizer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ export class RMSPropOptimizer extends Optimizer {
165165
centered: this.centered
166166
};
167167
}
168+
169+
/** @nocollapse */
168170
static fromConfig<T extends Serializable>(
169171
cls: SerializableConstructor<T>, config: ConfigDict): T {
170172
return new cls(

src/optimizers/sgd_optimizer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ export class SGDOptimizer extends Optimizer {
6565
getConfig(): ConfigDict {
6666
return {learningRate: this.learningRate};
6767
}
68+
69+
/** @nocollapse */
6870
static fromConfig<T extends Serializable>(
6971
cls: SerializableConstructor<T>, config: ConfigDict): T {
7072
return new cls(config.learningRate);

src/serialization.ts

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ import {assert} from './util';
3232
* convertTsToPythonic from serialization_utils in -Layers.
3333
*
3434
*/
35-
export type ConfigDictValue =
36-
boolean|number|string|null|ConfigDictArray|ConfigDict;
37-
export interface ConfigDict {
35+
export declare type ConfigDictValue =
36+
boolean | number | string | null | ConfigDictArray | ConfigDict;
37+
export declare interface ConfigDict {
3838
[key: string]: ConfigDictValue;
3939
}
40-
export interface ConfigDictArray extends Array<ConfigDictValue> {}
40+
export declare interface ConfigDictArray extends Array<ConfigDictValue> {}
4141

4242
/**
4343
* Type to represent the class-type of Serializable objects.
@@ -47,11 +47,11 @@ export interface ConfigDictArray extends Array<ConfigDictValue> {}
4747
*
4848
* Source for this idea: https://stackoverflow.com/a/43607255
4949
*/
50-
export type SerializableConstructor<T extends Serializable> = {
50+
export declare type SerializableConstructor<T extends Serializable> = {
5151
// tslint:disable-next-line:no-any
5252
new (...args: any[]): T; className: string; fromConfig: FromConfigMethod<T>;
5353
};
54-
export type FromConfigMethod<T extends Serializable> =
54+
export declare type FromConfigMethod<T extends Serializable> =
5555
(cls: SerializableConstructor<T>, config: ConfigDict) => T;
5656

5757
/**
@@ -90,6 +90,7 @@ export abstract class Serializable {
9090
* @param cls A Constructor for the class to instantiate.
9191
* @param config The Configuration for the object.
9292
*/
93+
/** @nocollapse */
9394
static fromConfig<T extends Serializable>(
9495
cls: SerializableConstructor<T>, config: ConfigDict): T {
9596
return new cls(config);

src/util.ts

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -601,57 +601,51 @@ export function now(): number {
601601
/**
602602
* Monitor Promise.all progress, fire onProgress callback function.
603603
*
604-
* @param {Array<Promise<D | Function | {} | void>>} promises,
605-
* Promise list going to be monitored
606-
* @param {Function} onProgress, callback function.
607-
* Fired when a promise resolved.
608-
* @param {number} startFraction, Optional fraction start. Default to 0.
609-
* @param {number} endFraction, Optional fraction end. Default to 1.
604+
* @param promises Promise list going to be monitored
605+
* @param onProgress Callback function. Fired when a promise resolved.
606+
* @param startFraction Optional fraction start. Default to 0.
607+
* @param endFraction Optional fraction end. Default to 1.
610608
*/
611-
export function monitorPromisesProgress<D extends DataType>(
612-
promises: Array<Promise<D | Function | {} | void>>, onProgress: Function,
609+
export function monitorPromisesProgress(
610+
promises: Array<Promise<{}|void>>, onProgress: Function,
613611
startFraction?: number, endFraction?: number) {
614612
checkPromises(promises);
615613
startFraction = startFraction == null ? 0 : startFraction;
616614
endFraction = endFraction == null ? 1 : endFraction;
617615
checkFraction(startFraction, endFraction);
618616
let resolvedPromise = 0;
619617

620-
function registerMonitor(promise: Promise<D | Function | {} | void>) {
621-
promise.then((value: D | Function | {} | void) => {
622-
const fraction = startFraction + ++resolvedPromise / promises.length *
623-
(endFraction - startFraction);
618+
function registerMonitor(promise: Promise<{}>) {
619+
promise.then(value => {
620+
const fraction = startFraction +
621+
++resolvedPromise / promises.length * (endFraction - startFraction);
624622
// pass fraction as parameter to callback function.
625623
onProgress(fraction);
626624
return value;
627625
});
628626
return promise;
629627
}
630628

631-
function checkPromises(
632-
promises: Array<Promise<D | Function | {} | void>>): void {
629+
function checkPromises(promises: Array<Promise<{}|void>>): void {
633630
assert(
634631
promises != null && Array.isArray(promises) && promises.length > 0,
635-
'promises must be a none empty array'
636-
);
632+
'promises must be a none empty array');
637633
}
638634

639635
function checkFraction(startFraction: number, endFraction: number): void {
640636
assert(
641637
startFraction >= 0 && startFraction <= 1,
642638
`Progress fraction must be in range [0, 1], but ` +
643-
`got startFraction ${startFraction}`
644-
);
639+
`got startFraction ${startFraction}`);
645640
assert(
646641
endFraction >= 0 && endFraction <= 1,
647642
`Progress fraction must be in range [0, 1], but ` +
648-
`got endFraction ${endFraction}`
649-
);
643+
`got endFraction ${endFraction}`);
650644
assert(
651645
endFraction >= startFraction,
652646
`startFraction must be no more than endFraction, but ` +
653-
`got startFraction ${startFraction} and endFraction ${endFraction}`
654-
);
647+
`got startFraction ${startFraction} and endFraction ${
648+
endFraction}`);
655649
}
656650

657651
return Promise.all(promises.map(registerMonitor));

src/util_test.ts

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
* =============================================================================
1616
*/
1717

18+
import {scalar, tensor2d} from './ops/ops';
1819
import {inferShape} from './tensor_util_env';
1920
import * as util from './util';
20-
import {scalar, tensor2d} from './ops/ops';
2121

2222
describe('Util', () => {
2323
it('Correctly gets size from shape', () => {
@@ -76,8 +76,7 @@ describe('Util', () => {
7676

7777
it('infer shape 4d array', () => {
7878
const a = [
79-
[[[1], [2]], [[2], [3]], [[5], [6]]],
80-
[[[5], [6]], [[4], [5]], [[1], [2]]]
79+
[[[1], [2]], [[2], [3]], [[5], [6]]], [[[5], [6]], [[4], [5]], [[1], [2]]]
8180
];
8281
expect(inferShape(a)).toEqual([2, 3, 2, 1]);
8382
});
@@ -459,20 +458,21 @@ describe('util.hasEncodingLoss', () => {
459458
describe('util.toNestedArray', () => {
460459
it('2 dimensions', () => {
461460
const a = new Float32Array([1, 2, 3, 4, 5, 6]);
462-
expect(util.toNestedArray([2, 3], a))
463-
.toEqual([[1,2,3], [4,5,6]]);
461+
expect(util.toNestedArray([2, 3], a)).toEqual([[1, 2, 3], [4, 5, 6]]);
464462
});
465463

466464
it('3 dimensions (2x2x3)', () => {
467465
const a = new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
468-
expect(util.toNestedArray([2, 2, 3], a))
469-
.toEqual([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]);
466+
expect(util.toNestedArray([2, 2, 3], a)).toEqual([
467+
[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]
468+
]);
470469
});
471470

472471
it('3 dimensions (3x2x2)', () => {
473472
const a = new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
474-
expect(util.toNestedArray([3, 2, 2], a))
475-
.toEqual([[[0, 1],[2, 3]],[[4, 5],[6, 7]],[[8, 9],[10, 11]]]);
473+
expect(util.toNestedArray([3, 2, 2], a)).toEqual([
474+
[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]
475+
]);
476476
});
477477

478478
it('invalid dimension', () => {
@@ -482,8 +482,7 @@ describe('util.toNestedArray', () => {
482482

483483
it('tensor to nested array', () => {
484484
const x = tensor2d([1, 2, 3, 4], [2, 2]);
485-
expect(util.toNestedArray(x.shape, x.dataSync()))
486-
.toEqual([[1, 2], [3, 4]]);
485+
expect(util.toNestedArray(x.shape, x.dataSync())).toEqual([[1, 2], [3, 4]]);
487486
});
488487

489488
it('scalar to nested array', () => {
@@ -501,12 +500,12 @@ describe('util.monitorPromisesProgress', () => {
501500
it('Default progress from 0 to 1', (done) => {
502501
const expectFractions: number[] = [0.25, 0.50, 0.75, 1.00];
503502
const fractionList: number[] = [];
504-
const tasks = Array(4).fill(0).map(()=>{
503+
const tasks = Array(4).fill(0).map(() => {
505504
return Promise.resolve();
506505
});
507-
util.monitorPromisesProgress(tasks, (progress: number)=>{
508-
fractionList.push(parseFloat(progress.toFixed(2)));
509-
}).then(()=>{
506+
util.monitorPromisesProgress(tasks, (progress: number) => {
507+
fractionList.push(parseFloat(progress.toFixed(2)));
508+
}).then(() => {
510509
expect(fractionList).toEqual(expectFractions);
511510
done();
512511
});
@@ -517,12 +516,12 @@ describe('util.monitorPromisesProgress', () => {
517516
const endFraction = 0.8;
518517
const expectFractions: number[] = [0.35, 0.50, 0.65, 0.80];
519518
const fractionList: number[] = [];
520-
const tasks = Array(4).fill(0).map(()=>{
519+
const tasks = Array(4).fill(0).map(() => {
521520
return Promise.resolve();
522521
});
523-
util.monitorPromisesProgress(tasks, (progress: number)=>{
524-
fractionList.push(parseFloat(progress.toFixed(2)));
525-
}, startFraction, endFraction).then(()=>{
522+
util.monitorPromisesProgress(tasks, (progress: number) => {
523+
fractionList.push(parseFloat(progress.toFixed(2)));
524+
}, startFraction, endFraction).then(() => {
526525
expect(fractionList).toEqual(expectFractions);
527526
done();
528527
});
@@ -532,35 +531,35 @@ describe('util.monitorPromisesProgress', () => {
532531
expect(() => {
533532
const startFraction = -1;
534533
const endFraction = 1;
535-
const tasks = Array(4).fill(0).map(()=>{
534+
const tasks = Array(4).fill(0).map(() => {
536535
return Promise.resolve();
537536
});
538-
util.monitorPromisesProgress(tasks, (progress: number)=>{},
539-
startFraction, endFraction);
537+
util.monitorPromisesProgress(
538+
tasks, (progress: number) => {}, startFraction, endFraction);
540539
}).toThrowError();
541540
});
542541

543542
it('throws error when startFraction more than endFraction', () => {
544543
expect(() => {
545544
const startFraction = 0.8;
546545
const endFraction = 0.2;
547-
const tasks = Array(4).fill(0).map(()=>{
546+
const tasks = Array(4).fill(0).map(() => {
548547
return Promise.resolve();
549548
});
550-
util.monitorPromisesProgress(tasks, (progress: number)=>{},
551-
startFraction, endFraction);
549+
util.monitorPromisesProgress(
550+
tasks, (progress: number) => {}, startFraction, endFraction);
552551
}).toThrowError();
553552
});
554553

555554
it('throws error when promises is null', () => {
556555
expect(() => {
557-
util.monitorPromisesProgress(null, (progress: number)=>{});
556+
util.monitorPromisesProgress(null, (progress: number) => {});
558557
}).toThrowError();
559558
});
560559

561560
it('throws error when promises is empty array', () => {
562561
expect(() => {
563-
util.monitorPromisesProgress([], (progress: number)=>{});
562+
util.monitorPromisesProgress([], (progress: number) => {});
564563
}).toThrowError();
565564
});
566565
});

0 commit comments

Comments
 (0)