Skip to content
This repository was archived by the owner on Oct 17, 2021. It is now read-only.

Commit d90401e

Browse files
authored
Make tfjs-layers compatible with the closure compiler (#442)
- Add `/** @nocollapse */` on static properties to tell closure not to remove them. - Prepend `declare` in front of interfaces and types that are external (used in json deserialization/serialization) so closure doesn't rename their properties. - Instead of object literals, explicitly assign each field separately to a config object and mark it as `serialization.ConfigDict` so closure doesn't rename its properties. Similar changes in tfjs-core: tensorflow/tfjs-core#1521
1 parent eca26e6 commit d90401e

24 files changed

+190
-84
lines changed

src/activations.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
// Layer activation functions
1212
import * as tfc from '@tensorflow/tfjs-core';
1313
import {serialization, Tensor, tidy} from '@tensorflow/tfjs-core';
14-
1514
import {getScalar} from './backend/state';
1615
import * as K from './backend/tfjs_backend';
1716
import {ActivationIdentifier} from './keras_format/activation_config';
@@ -36,6 +35,7 @@ export abstract class Activation extends serialization.Serializable {
3635
* Reference: https://arxiv.org/abs/1511.07289
3736
*/
3837
export class Elu extends Activation {
38+
/** @nocollapse */
3939
static readonly className = 'elu';
4040
/**
4141
* Calculate the activation function.
@@ -58,6 +58,7 @@ serialization.registerClass(Elu);
5858
* - To be used together with the dropout variant "AlphaDropout".
5959
*/
6060
export class Selu extends Activation {
61+
/** @nocollapse */
6162
static readonly className = 'selu';
6263
apply(x: Tensor): Tensor {
6364
return tfc.selu(x);
@@ -69,6 +70,7 @@ serialization.registerClass(Selu);
6970
* Rectified linear unit
7071
*/
7172
export class Relu extends Activation {
73+
/** @nocollapse */
7274
static readonly className = 'relu';
7375
apply(x: Tensor): Tensor {
7476
return tfc.relu(x);
@@ -80,6 +82,7 @@ serialization.registerClass(Relu);
8082
* Rectified linear unit activation maxing out at 6.0.
8183
*/
8284
export class Relu6 extends Activation {
85+
/** @nocollapse */
8386
static readonly className = 'relu6';
8487
apply(x: Tensor): Tensor {
8588
return tidy(() => tfc.minimum(getScalar(6.0), tfc.relu(x)));
@@ -89,6 +92,7 @@ serialization.registerClass(Relu6);
8992

9093
//* Linear activation (no-op) */
9194
export class Linear extends Activation {
95+
/** @nocollapse */
9296
static readonly className = 'linear';
9397
apply(x: Tensor): Tensor {
9498
return x;
@@ -100,6 +104,7 @@ serialization.registerClass(Linear);
100104
* Sigmoid activation function.
101105
*/
102106
export class Sigmoid extends Activation {
107+
/** @nocollapse */
103108
static readonly className = 'sigmoid';
104109
apply(x: Tensor): Tensor {
105110
return tfc.sigmoid(x);
@@ -111,6 +116,7 @@ serialization.registerClass(Sigmoid);
111116
* Segment-wise linear approximation of sigmoid.
112117
*/
113118
export class HardSigmoid extends Activation {
119+
/** @nocollapse */
114120
static readonly className = 'hardSigmoid';
115121
apply(x: Tensor): Tensor {
116122
return K.hardSigmoid(x);
@@ -122,6 +128,7 @@ serialization.registerClass(HardSigmoid);
122128
* Softplus activation function.
123129
*/
124130
export class Softplus extends Activation {
131+
/** @nocollapse */
125132
static readonly className = 'softplus';
126133
apply(x: Tensor): Tensor {
127134
return tfc.softplus(x);
@@ -133,6 +140,7 @@ serialization.registerClass(Softplus);
133140
* Softsign activation function.
134141
*/
135142
export class Softsign extends Activation {
143+
/** @nocollapse */
136144
static readonly className = 'softsign';
137145
apply(x: Tensor): Tensor {
138146
return K.softsign(x);
@@ -144,6 +152,7 @@ serialization.registerClass(Softsign);
144152
* Hyperbolic tangent function.
145153
*/
146154
export class Tanh extends Activation {
155+
/** @nocollapse */
147156
static readonly className = 'tanh';
148157
apply(x: Tensor): Tensor {
149158
return tfc.tanh(x);
@@ -155,6 +164,7 @@ serialization.registerClass(Tanh);
155164
* Softmax activation function
156165
*/
157166
export class Softmax extends Activation {
167+
/** @nocollapse */
158168
static readonly className = 'softmax';
159169
/**
160170
* Calculate the activation function.
@@ -189,11 +199,15 @@ export function deserializeActivation(
189199
export function getActivation(identifier: ActivationIdentifier|
190200
serialization.ConfigDict|Activation): Activation {
191201
if (identifier == null) {
192-
const config = {className: 'linear', config: {}};
202+
const config: serialization.ConfigDict = {};
203+
config.className = 'linear';
204+
config.config = {};
193205
return deserializeActivation(config);
194206
}
195207
if (typeof identifier === 'string') {
196-
const config = {className: identifier, config: {}};
208+
const config: serialization.ConfigDict = {};
209+
config.className = identifier;
210+
config.config = {};
197211
return deserializeActivation(config);
198212
} else if (identifier instanceof Activation) {
199213
return identifier;

src/backend/state.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ export function getUid(prefix = ''): string {
4040
return prefix + _uidPrefixes[prefix].toString();
4141
}
4242

43-
const scalarCache: {[typeKey: string]: {[key: number]: Scalar}} = {
44-
float32: {},
45-
int32: {}
46-
};
43+
const scalarCache: {[typeKey: string]: {[key: number]: Scalar}} = {};
4744

4845
const DEFAULT_DTYPE: DataType = 'float32';
4946

@@ -54,6 +51,9 @@ export function getScalar(value: number, dtype?: DataType): Scalar {
5451
if (dtype === undefined) {
5552
dtype = DEFAULT_DTYPE;
5653
}
54+
if (scalarCache[dtype] == null) {
55+
scalarCache[dtype] = {};
56+
}
5757
if (scalarCache[dtype][value] == null) {
5858
scalarCache[dtype][value] = scalar(value, dtype);
5959
keep(scalarCache[dtype][value]);

src/constraints.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ export interface MaxNormArgs {
7676
* 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
7777
*/
7878
export class MaxNorm extends Constraint {
79+
/** @nocollapse */
7980
static readonly className = 'MaxNorm';
8081
private maxValue: number;
8182
private axis: number;
@@ -125,6 +126,7 @@ export interface UnitNormArgs {
125126
* Constrains the weights incident to each hidden unit to have unit norm.
126127
*/
127128
export class UnitNorm extends Constraint {
129+
/** @nocollapse */
128130
static readonly className = 'UnitNorm';
129131
private axis: number;
130132
private readonly defaultAxis = 0;
@@ -149,6 +151,7 @@ serialization.registerClass(UnitNorm);
149151
* Constains the weight to be non-negative.
150152
*/
151153
export class NonNeg extends Constraint {
154+
/** @nocollapse */
152155
static readonly className = 'NonNeg';
153156

154157
apply(w: Tensor): Tensor {
@@ -192,6 +195,7 @@ export interface MinMaxNormArgs {
192195
}
193196

194197
export class MinMaxNorm extends Constraint {
198+
/** @nocollapse */
195199
static readonly className = 'MinMaxNorm';
196200
private minValue: number;
197201
private maxValue: number;

src/engine/container.ts

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
/* Original source: keras/engine/topology.py */
1212

1313
import {Scalar, serialization, Tensor, tidy, util} from '@tensorflow/tfjs-core';
14-
1514
import {getUid} from '../backend/state';
1615
import {NotImplementedError, RuntimeError, ValueError} from '../errors';
1716
import {Shape} from '../keras_format/common';
@@ -772,14 +771,13 @@ export abstract class Container extends Layer {
772771
*/
773772
private updatedConfig(): serialization.ConfigDict {
774773
const theConfig = this.getConfig();
775-
const modelConfig: serialization.ConfigDict = {
776-
className: this.getClassName(),
777-
config: theConfig,
778-
kerasVersion: `tfjs-layers ${layersVersion}`,
779-
// TODO(nielsene): Replace something like K.backend() once
780-
// possible.
781-
backend: 'TensorFlow.js'
782-
};
774+
const modelConfig: serialization.ConfigDict = {};
775+
modelConfig.className = this.getClassName();
776+
modelConfig.config = theConfig;
777+
modelConfig.kerasVersion = `tfjs-layers ${layersVersion}`;
778+
// TODO(nielsene): Replace something like K.backend() once
779+
// possible.
780+
modelConfig.backend = 'TensorFlow.js';
783781
return modelConfig;
784782
}
785783

@@ -1202,12 +1200,12 @@ export abstract class Container extends Layer {
12021200
}
12031201
}
12041202
}
1205-
layerConfigs.push({
1206-
name: layer.name,
1207-
className: layerClassName,
1208-
config: layerConfig,
1209-
inboundNodes: filteredInboundNodes
1210-
});
1203+
const dict: serialization.ConfigDict = {};
1204+
dict.name = layer.name;
1205+
dict.className = layerClassName;
1206+
dict.config = layerConfig;
1207+
dict.inboundNodes = filteredInboundNodes;
1208+
layerConfigs.push(dict);
12111209
}
12121210
config['layers'] = layerConfigs;
12131211
// Gather info about inputs and outputs
@@ -1261,6 +1259,7 @@ export abstract class Container extends Layer {
12611259
* @returns A model instance.
12621260
* @throws ValueError: In case of improperly formatted config dict.
12631261
*/
1262+
/** @nocollapse */
12641263
static fromConfig<T extends serialization.Serializable>(
12651264
cls: serialization.SerializableConstructor<T>,
12661265
config: serialization.ConfigDict,

src/engine/input_layer.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {DisposeResult, Layer, Node, SymbolicTensor} from './topology';
2424
* If only inputShape is provided, then the batchInputShape is determined by
2525
* the batchSize argument and the inputShape: [batchSize].concat(inputShape).
2626
*/
27-
export interface InputLayerArgs {
27+
export declare interface InputLayerArgs {
2828
/** Input shape, not including the batch axis. */
2929
inputShape?: Shape;
3030
/** Optional input batch size (integer or null). */
@@ -75,6 +75,7 @@ export interface InputLayerArgs {
7575
* ```
7676
*/
7777
export class InputLayer extends Layer {
78+
/** @nocollapse */
7879
static readonly className = 'InputLayer';
7980
sparse: boolean;
8081
constructor(args: InputLayerArgs) {

src/engine/topology.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ export class Node {
343343
}
344344

345345
/** Constructor arguments for Layer. */
346-
export interface LayerArgs {
346+
export declare interface LayerArgs {
347347
/**
348348
* If defined, will be used to create an input layer to insert before this
349349
* layer. If both `inputShape` and `batchInputShape` are defined,

src/engine/training.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ export interface ModelCompileArgs {
441441
*/
442442
/** @doc {heading: 'Models', subheading: 'Classes'} */
443443
export class Model extends Container implements tfc.InferenceModel {
444+
/** @nocollapse */
444445
static className = 'Model';
445446
optimizer: Optimizer;
446447
loss: string|string[]|{[outputName: string]: string}|LossOrMetricFn|
@@ -1455,7 +1456,8 @@ export class Model extends Container implements tfc.InferenceModel {
14551456
const losses = trainFunction(inputs.concat(targets));
14561457
const lossValues: number[] = [];
14571458
for (const loss of losses) {
1458-
lossValues.push((await loss.data())[0]);
1459+
const v = await loss.data();
1460+
lossValues.push(v[0]);
14591461
}
14601462
tfc.dispose(losses);
14611463
return singletonOrArray(lossValues);

0 commit comments

Comments
 (0)