Skip to content

Commit 3be0717

Browse files
authored
Add fields to io.ModelArtifacts and surrounding save/load logic (tensorflow#1596)
- Add new fields to `ModelArtifacts` interface - format - generatedBy - convertedBy - Add the logic to save and load these new fields, with tests, to: - Local Storage - HTTP - Browser file downloads - IndexedDB (test only; no non-test code change needed) Towards: tensorflow/tfjs#1285 Towards: tensorflow/tfjs#1286 FEATURE
1 parent b41d0be commit 3be0717

9 files changed

+111
-4
lines changed

src/io/browser_files.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ export class BrowserDownloads implements IOHandler {
7373
}];
7474
const modelTopologyAndWeightManifest = {
7575
modelTopology: modelArtifacts.modelTopology,
76+
format: modelArtifacts.format,
77+
generatedBy: modelArtifacts.generatedBy,
78+
convertedBy: modelArtifacts.convertedBy,
7679
weightsManifest
7780
};
7881
const modelTopologyAndWeightManifestURL =

src/io/browser_files_test.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ const artifacts1: tf.io.ModelArtifacts = {
7474
modelTopology: modelTopology1,
7575
weightSpecs: weightSpecs1,
7676
weightData: weightData1,
77+
format: 'layers-model',
78+
generatedBy: 'TensorFlow.js v0.0.0',
79+
convertedBy: null
7780
};
7881

7982
describeWithFlags('browserDownloads', BROWSER_ENVS, () => {
@@ -127,6 +130,10 @@ describeWithFlags('browserDownloads', BROWSER_ENVS, () => {
127130
JSON.parse(await jsonContent.text());
128131
expect(modelTopologyAndWeightsManifest.modelTopology)
129132
.toEqual(modelTopology1);
133+
expect(modelTopologyAndWeightsManifest.format).toEqual('layers-model');
134+
expect(modelTopologyAndWeightsManifest.generatedBy)
135+
.toEqual('TensorFlow.js v0.0.0');
136+
expect(modelTopologyAndWeightsManifest.convertedBy).toEqual(null);
130137
const weightsManifest = modelTopologyAndWeightsManifest.weightsManifest as
131138
WeightsManifestConfig;
132139
expect(weightsManifest.length).toEqual(1);

src/io/browser_http.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ export class BrowserHTTPRequest implements IOHandler {
114114
}];
115115
const modelTopologyAndWeightManifest = {
116116
modelTopology: modelArtifacts.modelTopology,
117+
format: modelArtifacts.format,
118+
generatedBy: modelArtifacts.generatedBy,
119+
convertedBy: modelArtifacts.convertedBy,
117120
weightsManifest
118121
};
119122

src/io/browser_http_test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {
192192
modelTopology: modelTopology1,
193193
weightSpecs: weightSpecs1,
194194
weightData: weightData1,
195+
format: 'layers-model',
196+
generatedBy: 'TensorFlow.js v0.0.0',
197+
convertedBy: null
195198
};
196199

197200
let requestInits: RequestInit[] = [];
@@ -338,6 +341,9 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {
338341
jsonFileReader.onload = (event: Event) => {
339342
// tslint:disable-next-line:no-any
340343
const modelJSON = JSON.parse((event.target as any).result);
344+
expect(modelJSON.format).toEqual('layers-model');
345+
expect(modelJSON.generatedBy).toEqual('TensorFlow.js v0.0.0');
346+
expect(modelJSON.convertedBy).toEqual(null);
341347
expect(modelJSON.modelTopology).toEqual(modelTopology1);
342348
expect(modelJSON.weightsManifest.length).toEqual(1);
343349
expect(modelJSON.weightsManifest[0].weights).toEqual(weightSpecs1);

src/io/indexed_db_test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => {
7575
modelTopology: modelTopology1,
7676
weightSpecs: weightSpecs1,
7777
weightData: weightData1,
78+
format: 'layers-model',
79+
generatedBy: 'TensorFlow.js v0.0.0',
80+
convertedBy: null
7881
};
7982

8083
const weightSpecs2: tf.io.WeightsManifestEntry[] = [
@@ -113,6 +116,9 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => {
113116
const loadedArtifacts = await handler.load();
114117
expect(loadedArtifacts.modelTopology).toEqual(modelTopology1);
115118
expect(loadedArtifacts.weightSpecs).toEqual(weightSpecs1);
119+
expect(loadedArtifacts.format).toEqual('layers-model');
120+
expect(loadedArtifacts.generatedBy).toEqual('TensorFlow.js v0.0.0');
121+
expect(loadedArtifacts.convertedBy).toEqual(null);
116122
expectArrayBuffersEqual(loadedArtifacts.weightData, weightData1);
117123
});
118124

src/io/io.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {browserHTTPRequest, isHTTPScheme} from './browser_http';
2525
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
2626
import {fromMemory, withSaveHandler} from './passthrough';
2727
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
28-
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
28+
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelStoreManager, ModelFormat, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2929
import {loadWeights, weightsLoaderFactory} from './weights_loader';
3030

3131
export {copyModel, listModels, moveModel, removeModel} from './model_management';
@@ -45,6 +45,7 @@ export {
4545
LoadOptions,
4646
loadWeights,
4747
ModelArtifacts,
48+
ModelFormat,
4849
ModelStoreManager,
4950
OnProgressCallback,
5051
registerLoadRouter,

src/io/local_storage.ts

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ import {assert} from '../util';
2020
import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils';
2121
import {ModelStoreManagerRegistry} from './model_management';
2222
import {IORouter, IORouterRegistry} from './router_registry';
23-
import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, SaveResult} from './types';
23+
import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, ModelFormat, SaveResult} from './types';
2424

2525
const PATH_SEPARATOR = '/';
2626
const PATH_PREFIX = 'tensorflowjs_models';
2727
const INFO_SUFFIX = 'info';
2828
const MODEL_TOPOLOGY_SUFFIX = 'model_topology';
2929
const WEIGHT_SPECS_SUFFIX = 'weight_specs';
3030
const WEIGHT_DATA_SUFFIX = 'weight_data';
31+
const MODEL_METADATA_SUFFIX = 'model_metadata';
3132

3233
/**
3334
* Purge all tensorflow.js-saved model artifacts from local storage.
@@ -57,12 +58,15 @@ export function purgeLocalStorageArtifacts(): string[] {
5758
}
5859

5960
function getModelKeys(path: string):
60-
{info: string, topology: string, weightSpecs: string, weightData: string} {
61+
{info: string, topology: string, weightSpecs: string, weightData: string,
62+
modelMetadata: string} {
6163
return {
6264
info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
6365
topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
6466
weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
65-
weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR)
67+
weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
68+
modelMetadata: [
69+
PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
6670
};
6771
}
6872

@@ -146,6 +150,11 @@ export class BrowserLocalStorage implements IOHandler {
146150
this.LS.setItem(
147151
this.keys.weightData,
148152
arrayBufferToBase64String(modelArtifacts.weightData));
153+
this.LS.setItem(this.keys.modelMetadata, JSON.stringify({
154+
format: modelArtifacts.format,
155+
generatedBy: modelArtifacts.generatedBy,
156+
convertedBy: modelArtifacts.convertedBy
157+
}));
149158

150159
return {modelArtifactsInfo};
151160
} catch (err) {
@@ -206,6 +215,19 @@ export class BrowserLocalStorage implements IOHandler {
206215
}
207216
out.weightSpecs = weightSpecs;
208217

218+
// Load meta-data fields.
219+
const metadataString = this.LS.getItem(this.keys.modelMetadata);
220+
if (metadataString != null) {
221+
const metadata = JSON.parse(metadataString) as {
222+
format: string,
223+
generatedBy: string,
224+
convertedBy: string
225+
};
226+
out.format = metadata.format as ModelFormat;
227+
out.generatedBy = metadata['generatedBy'];
228+
out.convertedBy = metadata['convertedBy'];
229+
}
230+
209231
// Load weight data.
210232
const weightDataBase64 = this.LS.getItem(this.keys.weightData);
211233
if (weightDataBase64 == null) {

src/io/local_storage_test.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,20 @@ describeWithFlags('LocalStorage', BROWSER_ENVS, () => {
6868
}
6969
];
7070
const weightData1 = new ArrayBuffer(16);
71+
7172
const artifacts1: tf.io.ModelArtifacts = {
7273
modelTopology: modelTopology1,
7374
weightSpecs: weightSpecs1,
7475
weightData: weightData1,
76+
format: 'layers-model',
77+
generatedBy: 'TensorFlow.js v0.0.0',
78+
convertedBy: null
79+
};
80+
81+
const artifactsV0: tf.io.ModelArtifacts = {
82+
modelTopology: modelTopology1,
83+
weightSpecs: weightSpecs1,
84+
weightData: weightData1
7585
};
7686

7787
function findOverflowingByteSize(): number {
@@ -162,6 +172,23 @@ describeWithFlags('LocalStorage', BROWSER_ENVS, () => {
162172
expect(loaded.modelTopology).toEqual(modelTopology1);
163173
expect(loaded.weightSpecs).toEqual(weightSpecs1);
164174
expect(loaded.weightData).toEqual(weightData1);
175+
expect(loaded.format).toEqual('layers-model');
176+
expect(loaded.generatedBy).toEqual('TensorFlow.js v0.0.0');
177+
expect(loaded.convertedBy).toEqual(null);
178+
});
179+
180+
it('Save-load round trip succeeds: v0 format', async () => {
181+
const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0];
182+
183+
await handler1.save(artifactsV0);
184+
const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0];
185+
const loaded = await handler2.load();
186+
expect(loaded.modelTopology).toEqual(modelTopology1);
187+
expect(loaded.weightSpecs).toEqual(weightSpecs1);
188+
expect(loaded.weightData).toEqual(weightData1);
189+
expect(loaded.format).toEqual(undefined);
190+
expect(loaded.generatedBy).toEqual(undefined);
191+
expect(loaded.convertedBy).toEqual(undefined);
165192
});
166193

167194
it('Loading nonexistent model fails.', done => {

src/io/types.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ export declare interface ModelArtifactsInfo {
156156
weightDataBytes?: number;
157157
}
158158

159+
export declare type ModelFormat = 'graph-model'|'layers-model';
160+
159161
/**
160162
* The serialized artifacts of a model, including topology and weights.
161163
*
@@ -185,8 +187,38 @@ export declare interface ModelArtifacts {
185187
* by `weightSpecs`.
186188
*/
187189
weightData?: ArrayBuffer;
190+
191+
/**
192+
* Hard-coded format name for models saved from TensorFlow.js or converted
193+
* by TensorFlow.js Converter.
194+
*/
195+
format?: ModelFormat;
196+
197+
/**
198+
* What library is responsible for originally generating this artifact.
199+
*
200+
* Used for debugging purposes. E.g., 'TensorFlow.js v1.0.0'.
201+
*/
202+
generatedBy?: string;
203+
204+
/**
205+
* What library or tool is responsible for converting the original model
206+
* to this format, applicable only if the model is output by a converter.
207+
*
208+
* Used for debugging purposes. E.g., 'TensorFlow.js Converter v1.0.0'.
209+
*
210+
* A value of `null` means the model artifacts are generated without any
211+
* conversion process (e.g., saved directly from a TensorFlow.js
212+
* `LayersModel` instance.)
213+
*/
214+
convertedBy?: string|null;
188215
}
189216

217+
// TODO(cais): Create interface spec for JSON model artifact, e.g., the format
218+
// for the content in a model.json file saved by TensorFlow.js. It should be the
219+
// same as the `ModelArtifacts` interface, but without the binary `weightData`
220+
// field. The `weightSpec` field should be replaced with a `weightManifest` one.
221+
190222
/**
191223
* Type definition for handlers of loading operations.
192224
*/

0 commit comments

Comments
 (0)