Skip to content
This repository was archived by the owner on Oct 17, 2021. It is now read-only.

Make tfjs-layers compatible with the closure compiler #442

Merged
merged 16 commits into from
Jan 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/activations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// Layer activation functions
import * as tfc from '@tensorflow/tfjs-core';
import {serialization, Tensor, tidy} from '@tensorflow/tfjs-core';

import {getScalar} from './backend/state';
import * as K from './backend/tfjs_backend';
import {ActivationIdentifier} from './keras_format/activation_config';
Expand All @@ -36,6 +35,7 @@ export abstract class Activation extends serialization.Serializable {
* Reference: https://arxiv.org/abs/1511.07289
*/
export class Elu extends Activation {
/** @nocollapse */
static readonly className = 'elu';
/**
* Calculate the activation function.
Expand All @@ -58,6 +58,7 @@ serialization.registerClass(Elu);
* - To be used together with the dropout variant "AlphaDropout".
*/
export class Selu extends Activation {
/** @nocollapse */
static readonly className = 'selu';
apply(x: Tensor): Tensor {
return tfc.selu(x);
Expand All @@ -69,6 +70,7 @@ serialization.registerClass(Selu);
* Rectified linear unit
*/
export class Relu extends Activation {
/** @nocollapse */
static readonly className = 'relu';
apply(x: Tensor): Tensor {
return tfc.relu(x);
Expand All @@ -80,6 +82,7 @@ serialization.registerClass(Relu);
* Rectified linear unit activation maxing out at 6.0.
*/
export class Relu6 extends Activation {
/** @nocollapse */
static readonly className = 'relu6';
apply(x: Tensor): Tensor {
return tidy(() => tfc.minimum(getScalar(6.0), tfc.relu(x)));
Expand All @@ -89,6 +92,7 @@ serialization.registerClass(Relu6);

//* Linear activation (no-op) */
export class Linear extends Activation {
/** @nocollapse */
static readonly className = 'linear';
apply(x: Tensor): Tensor {
return x;
Expand All @@ -100,6 +104,7 @@ serialization.registerClass(Linear);
* Sigmoid activation function.
*/
export class Sigmoid extends Activation {
/** @nocollapse */
static readonly className = 'sigmoid';
apply(x: Tensor): Tensor {
return tfc.sigmoid(x);
Expand All @@ -111,6 +116,7 @@ serialization.registerClass(Sigmoid);
* Segment-wise linear approximation of sigmoid.
*/
export class HardSigmoid extends Activation {
/** @nocollapse */
static readonly className = 'hardSigmoid';
apply(x: Tensor): Tensor {
return K.hardSigmoid(x);
Expand All @@ -122,6 +128,7 @@ serialization.registerClass(HardSigmoid);
* Softplus activation function.
*/
export class Softplus extends Activation {
/** @nocollapse */
static readonly className = 'softplus';
apply(x: Tensor): Tensor {
return tfc.softplus(x);
Expand All @@ -133,6 +140,7 @@ serialization.registerClass(Softplus);
* Softsign activation function.
*/
export class Softsign extends Activation {
/** @nocollapse */
static readonly className = 'softsign';
apply(x: Tensor): Tensor {
return K.softsign(x);
Expand All @@ -144,6 +152,7 @@ serialization.registerClass(Softsign);
* Hyperbolic tangent function.
*/
export class Tanh extends Activation {
/** @nocollapse */
static readonly className = 'tanh';
apply(x: Tensor): Tensor {
return tfc.tanh(x);
Expand All @@ -155,6 +164,7 @@ serialization.registerClass(Tanh);
* Softmax activation function
*/
export class Softmax extends Activation {
/** @nocollapse */
static readonly className = 'softmax';
/**
* Calculate the activation function.
Expand Down Expand Up @@ -189,11 +199,15 @@ export function deserializeActivation(
export function getActivation(identifier: ActivationIdentifier|
serialization.ConfigDict|Activation): Activation {
if (identifier == null) {
const config = {className: 'linear', config: {}};
const config: serialization.ConfigDict = {};
config.className = 'linear';
config.config = {};
return deserializeActivation(config);
}
if (typeof identifier === 'string') {
const config = {className: identifier, config: {}};
const config: serialization.ConfigDict = {};
config.className = identifier;
config.config = {};
return deserializeActivation(config);
} else if (identifier instanceof Activation) {
return identifier;
Expand Down
8 changes: 4 additions & 4 deletions src/backend/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ export function getUid(prefix = ''): string {
return prefix + _uidPrefixes[prefix].toString();
}

const scalarCache: {[typeKey: string]: {[key: number]: Scalar}} = {
float32: {},
int32: {}
};
const scalarCache: {[typeKey: string]: {[key: number]: Scalar}} = {};

const DEFAULT_DTYPE: DataType = 'float32';

Expand All @@ -54,6 +51,9 @@ export function getScalar(value: number, dtype?: DataType): Scalar {
if (dtype === undefined) {
dtype = DEFAULT_DTYPE;
}
if (scalarCache[dtype] == null) {
scalarCache[dtype] = {};
}
if (scalarCache[dtype][value] == null) {
scalarCache[dtype][value] = scalar(value, dtype);
keep(scalarCache[dtype][value]);
Expand Down
4 changes: 4 additions & 0 deletions src/constraints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ export interface MaxNormArgs {
* 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
*/
export class MaxNorm extends Constraint {
/** @nocollapse */
static readonly className = 'MaxNorm';
private maxValue: number;
private axis: number;
Expand Down Expand Up @@ -125,6 +126,7 @@ export interface UnitNormArgs {
* Constrains the weights incident to each hidden unit to have unit norm.
*/
export class UnitNorm extends Constraint {
/** @nocollapse */
static readonly className = 'UnitNorm';
private axis: number;
private readonly defaultAxis = 0;
Expand All @@ -149,6 +151,7 @@ serialization.registerClass(UnitNorm);
* Constains the weight to be non-negative.
*/
export class NonNeg extends Constraint {
/** @nocollapse */
static readonly className = 'NonNeg';

apply(w: Tensor): Tensor {
Expand Down Expand Up @@ -192,6 +195,7 @@ export interface MinMaxNormArgs {
}

export class MinMaxNorm extends Constraint {
/** @nocollapse */
static readonly className = 'MinMaxNorm';
private minValue: number;
private maxValue: number;
Expand Down
29 changes: 14 additions & 15 deletions src/engine/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
/* Original source: keras/engine/topology.py */

import {Scalar, serialization, Tensor, tidy, util} from '@tensorflow/tfjs-core';

import {getUid} from '../backend/state';
import {NotImplementedError, RuntimeError, ValueError} from '../errors';
import {Shape} from '../keras_format/common';
Expand Down Expand Up @@ -772,14 +771,13 @@ export abstract class Container extends Layer {
*/
private updatedConfig(): serialization.ConfigDict {
const theConfig = this.getConfig();
const modelConfig: serialization.ConfigDict = {
className: this.getClassName(),
config: theConfig,
kerasVersion: `tfjs-layers ${layersVersion}`,
// TODO(nielsene): Replace something like K.backend() once
// possible.
backend: 'TensorFlow.js'
};
const modelConfig: serialization.ConfigDict = {};
modelConfig.className = this.getClassName();
modelConfig.config = theConfig;
modelConfig.kerasVersion = `tfjs-layers ${layersVersion}`;
// TODO(nielsene): Replace something like K.backend() once
// possible.
modelConfig.backend = 'TensorFlow.js';
return modelConfig;
}

Expand Down Expand Up @@ -1202,12 +1200,12 @@ export abstract class Container extends Layer {
}
}
}
layerConfigs.push({
name: layer.name,
className: layerClassName,
config: layerConfig,
inboundNodes: filteredInboundNodes
});
const dict: serialization.ConfigDict = {};
dict.name = layer.name;
dict.className = layerClassName;
dict.config = layerConfig;
dict.inboundNodes = filteredInboundNodes;
layerConfigs.push(dict);
}
config['layers'] = layerConfigs;
// Gather info about inputs and outputs
Expand Down Expand Up @@ -1261,6 +1259,7 @@ export abstract class Container extends Layer {
* @returns A model instance.
* @throws ValueError: In case of improperly formatted config dict.
*/
/** @nocollapse */
static fromConfig<T extends serialization.Serializable>(
cls: serialization.SerializableConstructor<T>,
config: serialization.ConfigDict,
Expand Down
3 changes: 2 additions & 1 deletion src/engine/input_layer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {DisposeResult, Layer, Node, SymbolicTensor} from './topology';
* If only inputShape is provided, then the batchInputShape is determined by
* the batchSize argument and the inputShape: [batchSize].concat(inputShape).
*/
export interface InputLayerArgs {
export declare interface InputLayerArgs {
/** Input shape, not including the batch axis. */
inputShape?: Shape;
/** Optional input batch size (integer or null). */
Expand Down Expand Up @@ -75,6 +75,7 @@ export interface InputLayerArgs {
* ```
*/
export class InputLayer extends Layer {
/** @nocollapse */
static readonly className = 'InputLayer';
sparse: boolean;
constructor(args: InputLayerArgs) {
Expand Down
2 changes: 1 addition & 1 deletion src/engine/topology.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ export class Node {
}

/** Constructor arguments for Layer. */
export interface LayerArgs {
export declare interface LayerArgs {
/**
* If defined, will be used to create an input layer to insert before this
* layer. If both `inputShape` and `batchInputShape` are defined,
Expand Down
4 changes: 3 additions & 1 deletion src/engine/training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ export interface ModelCompileArgs {
*/
/** @doc {heading: 'Models', subheading: 'Classes'} */
export class Model extends Container implements tfc.InferenceModel {
/** @nocollapse */
static className = 'Model';
optimizer: Optimizer;
loss: string|string[]|{[outputName: string]: string}|LossOrMetricFn|
Expand Down Expand Up @@ -1455,7 +1456,8 @@ export class Model extends Container implements tfc.InferenceModel {
const losses = trainFunction(inputs.concat(targets));
const lossValues: number[] = [];
for (const loss of losses) {
lossValues.push((await loss.data())[0]);
const v = await loss.data();
lossValues.push(v[0]);
}
tfc.dispose(losses);
return singletonOrArray(lossValues);
Expand Down
Loading