Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 5aa35a3

Browse files
authored
[io] Add field user-defined metadata; Tweak API for tf.io.fromMemory() (#1864)
FEATURE - This is the first PR for adding support for user-defined metadata in model artifacts. Design doc has been circulated and discussed. - Add the field `userDefinedMetadata` to `ModelArtifacts` and `ModelJSON`. - Deprecate the old API of `tf.io.fromMemory()` which consisted of multiple arguments. The arguments are consolidated into on in the new API. - Add unit tests. Towards tensorflow/tfjs#1596
1 parent 4432e82 commit 5aa35a3

File tree

3 files changed

+89
-29
lines changed

3 files changed

+89
-29
lines changed

src/io/passthrough.ts

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,10 @@
2222
import {IOHandler, ModelArtifacts, SaveResult, TrainingConfig, WeightsManifestEntry} from './types';
2323

2424
class PassthroughLoader implements IOHandler {
25-
constructor(
26-
private readonly modelTopology?: {}|ArrayBuffer,
27-
private readonly weightSpecs?: WeightsManifestEntry[],
28-
private readonly weightData?: ArrayBuffer,
29-
private readonly trainingConfig?: TrainingConfig) {}
25+
constructor(private readonly modelArtifacts?: ModelArtifacts) {}
3026

3127
async load(): Promise<ModelArtifacts> {
32-
let result = {};
33-
if (this.modelTopology != null) {
34-
result = {modelTopology: this.modelTopology, ...result};
35-
}
36-
if (this.weightSpecs != null && this.weightSpecs.length > 0) {
37-
result = {weightSpecs: this.weightSpecs, ...result};
38-
}
39-
if (this.weightData != null && this.weightData.byteLength > 0) {
40-
result = {weightData: this.weightData, ...result};
41-
}
42-
if (this.trainingConfig != null) {
43-
result = {trainingConfig: this.trainingConfig, ...result};
44-
}
45-
return result;
28+
return this.modelArtifacts;
4629
}
4730
}
4831

@@ -67,7 +50,7 @@ class PassthroughSaver implements IOHandler {
6750
* modelTopology, weightSpecs, weightData));
6851
* ```
6952
*
70-
* @param modelTopology a object containing model topology (i.e., parsed from
53+
* @param modelArtifacts a object containing model topology (i.e., parsed from
7154
* the JSON format).
7255
* @param weightSpecs An array of `WeightsManifestEntry` objects describing the
7356
* names, shapes, types, and quantization of the weight data.
@@ -78,13 +61,39 @@ class PassthroughSaver implements IOHandler {
7861
* @returns A passthrough `IOHandler` that simply loads the provided data.
7962
*/
8063
export function fromMemory(
81-
modelTopology: {}, weightSpecs?: WeightsManifestEntry[],
64+
modelArtifacts: {}|ModelArtifacts, weightSpecs?: WeightsManifestEntry[],
8265
weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandler {
83-
// TODO(cais): The arguments should probably be consolidated into a single
84-
// object, with proper deprecation process. Even though this function isn't
85-
// documented, it is public and being used by some downstream libraries.
86-
return new PassthroughLoader(
87-
modelTopology, weightSpecs, weightData, trainingConfig);
66+
if (arguments.length === 1) {
67+
const isModelArtifacts =
68+
(modelArtifacts as ModelArtifacts).modelTopology != null ||
69+
(modelArtifacts as ModelArtifacts).weightSpecs != null;
70+
if (isModelArtifacts) {
71+
return new PassthroughLoader(modelArtifacts as ModelArtifacts);
72+
} else {
73+
// Legacy support: with only modelTopology.
74+
// TODO(cais): Remove this deprecated API.
75+
console.warn(
76+
'Please call tf.io.fromMemory() with only one argument. ' +
77+
'The argument should be of type ModelArtifacts. ' +
78+
'The multi-argument signature of tf.io.fromMemory() has been ' +
79+
'deprecated and will be removed in a future release.');
80+
return new PassthroughLoader({modelTopology: modelArtifacts as {}});
81+
}
82+
} else {
83+
// Legacy support.
84+
// TODO(cais): Remove this deprecated API.
85+
console.warn(
86+
'Please call tf.io.fromMemory() with only one argument. ' +
87+
'The argument should be of type ModelArtifacts. ' +
88+
'The multi-argument signature of tf.io.fromMemory() has been ' +
89+
'deprecated and will be removed in a future release.');
90+
return new PassthroughLoader({
91+
modelTopology: modelArtifacts as {},
92+
weightSpecs,
93+
weightData,
94+
trainingConfig
95+
});
96+
}
8897
}
8998

9099
/**

src/io/passthrough_test.ts

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,61 @@ describeWithFlags('Passthrough Saver', BROWSER_ENVS, () => {
114114
});
115115

116116
describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
117+
it('load topology and weights: legacy signature', async () => {
118+
const passthroughHandler = tf.io.fromMemory(
119+
modelTopology1, weightSpecs1, weightData1);
120+
const modelArtifacts = await passthroughHandler.load();
121+
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
122+
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
123+
expect(modelArtifacts.weightData).toEqual(weightData1);
124+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
125+
});
126+
117127
it('load topology and weights', async () => {
118-
const passthroughHandler =
119-
tf.io.fromMemory(modelTopology1, weightSpecs1, weightData1);
128+
const passthroughHandler = tf.io.fromMemory({
129+
modelTopology: modelTopology1,
130+
weightSpecs: weightSpecs1,
131+
weightData: weightData1
132+
});
120133
const modelArtifacts = await passthroughHandler.load();
121134
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
122135
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
123136
expect(modelArtifacts.weightData).toEqual(weightData1);
137+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
124138
});
125139

126-
it('load model topology only', async () => {
140+
it('load model topology only: legacy signature', async () => {
127141
const passthroughHandler = tf.io.fromMemory(modelTopology1);
128142
const modelArtifacts = await passthroughHandler.load();
129143
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
130144
expect(modelArtifacts.weightSpecs).toEqual(undefined);
131145
expect(modelArtifacts.weightData).toEqual(undefined);
146+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
147+
});
148+
149+
it('load model topology only', async () => {
150+
const passthroughHandler = tf.io.fromMemory({
151+
modelTopology: modelTopology1
152+
});
153+
const modelArtifacts = await passthroughHandler.load();
154+
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
155+
expect(modelArtifacts.weightSpecs).toEqual(undefined);
156+
expect(modelArtifacts.weightData).toEqual(undefined);
157+
expect(modelArtifacts.userDefinedMetadata).toEqual(undefined);
158+
});
159+
160+
it('load topology, weights, and user-defined metadata', async () => {
161+
const userDefinedMetadata: {} = {'fooField': 'fooValue'};
162+
const passthroughHandler = tf.io.fromMemory({
163+
modelTopology: modelTopology1,
164+
weightSpecs: weightSpecs1,
165+
weightData: weightData1,
166+
userDefinedMetadata
167+
});
168+
const modelArtifacts = await passthroughHandler.load();
169+
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
170+
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
171+
expect(modelArtifacts.weightData).toEqual(weightData1);
172+
expect(modelArtifacts.userDefinedMetadata).toEqual(userDefinedMetadata);
132173
});
133174
});

src/io/types.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ export declare interface ModelArtifacts {
275275
* `tf.LayersModel` instance.)
276276
*/
277277
convertedBy?: string|null;
278+
279+
/**
280+
* User-defined metadata about the model.
281+
*/
282+
userDefinedMetadata?: {};
278283
}
279284

280285
/**
@@ -330,6 +335,11 @@ export declare interface ModelJSON {
330335
* `tf.LayersModel` instance.)
331336
*/
332337
convertedBy?: string|null;
338+
339+
/**
340+
* User-defined metadata about the model.
341+
*/
342+
userDefinedMetadata?: {};
333343
}
334344

335345
/**

0 commit comments

Comments
 (0)