Skip to content

Commit 0f647e2

Browse files
davidsoergelNikhil Thorat
authored andcommitted
Provide types to enforce the structure of model.json files. (tensorflow#1597)
We were reminded of this by discussion on tensorflow#1596 (and, related, tensorflow/tfjs-converter#320). My view is that those two PRs should make it into 1.0. This one is less urgent because it's a non-breaking change that just helps enforce our intentions. Obviously, this PR can't enforce anything on the Python side. Ultimately we should probably do all this via a language-independent JSON Schema. INTERNAL
1 parent b8d937e commit 0f647e2

File tree

5 files changed

+83
-39
lines changed

5 files changed

+83
-39
lines changed

src/io/browser_files.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import {ENV} from '../environment';
2424
import {basename, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
2525
import {IORouter, IORouterRegistry} from './router_registry';
26-
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
26+
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2727

2828
const DEFAULT_FILE_NAME_PREFIX = 'model';
2929
const DEFAULT_JSON_EXTENSION_NAME = '.json';
@@ -71,7 +71,7 @@ export class BrowserDownloads implements IOHandler {
7171
paths: ['./' + this.weightDataFileName],
7272
weights: modelArtifacts.weightSpecs
7373
}];
74-
const modelTopologyAndWeightManifest = {
74+
const modelTopologyAndWeightManifest: ModelJSON = {
7575
modelTopology: modelArtifacts.modelTopology,
7676
format: modelArtifacts.format,
7777
generatedBy: modelArtifacts.generatedBy,
@@ -127,8 +127,8 @@ class BrowserFiles implements IOHandler {
127127
const jsonReader = new FileReader();
128128
jsonReader.onload = (event: Event) => {
129129
// tslint:disable-next-line:no-any
130-
const modelJSON = JSON.parse((event.target as any).result);
131-
const modelTopology = modelJSON.modelTopology as {};
130+
const modelJSON = JSON.parse((event.target as any).result) as ModelJSON;
131+
const modelTopology = modelJSON.modelTopology;
132132
if (modelTopology == null) {
133133
reject(new Error(
134134
`modelTopology field is missing from file ${jsonFile.name}`));
@@ -139,8 +139,7 @@ class BrowserFiles implements IOHandler {
139139
resolve({modelTopology});
140140
}
141141

142-
const weightsManifest =
143-
modelJSON.weightsManifest as WeightsManifestConfig;
142+
const weightsManifest = modelJSON.weightsManifest;
144143
if (weightsManifest == null) {
145144
reject(new Error(
146145
`weightManifest field is missing from file ${jsonFile.name}`));

src/io/browser_http.ts

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
*/
2323

2424
import {assert} from '../util';
25-
2625
import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
2726
import {IORouter, IORouterRegistry} from './router_registry';
28-
import {IOHandler, LoadOptions, ModelArtifacts, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
27+
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2928
import {loadWeightsAsArrayBuffer} from './weights_loader';
3029

3130
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
@@ -112,7 +111,7 @@ export class BrowserHTTPRequest implements IOHandler {
112111
paths: ['./model.weights.bin'],
113112
weights: modelArtifacts.weightSpecs,
114113
}];
115-
const modelTopologyAndWeightManifest = {
114+
const modelTopologyAndWeightManifest: ModelJSON = {
116115
modelTopology: modelArtifacts.modelTopology,
117116
format: modelArtifacts.format,
118117
generatedBy: modelArtifacts.generatedBy,
@@ -166,7 +165,7 @@ export class BrowserHTTPRequest implements IOHandler {
166165
`${modelConfigRequest.status}. Please verify this URL points to ` +
167166
`the model JSON of the model to load.`);
168167
}
169-
let modelConfig;
168+
let modelConfig: ModelJSON;
170169
try {
171170
modelConfig = await modelConfigRequest.json();
172171
} catch (e) {
@@ -186,9 +185,8 @@ export class BrowserHTTPRequest implements IOHandler {
186185
}
187186
throw new Error(message);
188187
}
189-
190-
const modelTopology = modelConfig['modelTopology'];
191-
const weightsManifest = modelConfig['weightsManifest'];
188+
const modelTopology = modelConfig.modelTopology;
189+
const weightsManifest = modelConfig.weightsManifest;
192190

193191
// We do not allow both modelTopology and weightsManifest to be missing.
194192
if (modelTopology == null && weightsManifest == null) {
@@ -200,8 +198,6 @@ export class BrowserHTTPRequest implements IOHandler {
200198
let weightSpecs: WeightsManifestEntry[];
201199
let weightData: ArrayBuffer;
202200
if (weightsManifest != null) {
203-
const weightsManifest =
204-
modelConfig['weightsManifest'] as WeightsManifestConfig;
205201
const results = await this.loadWeights(weightsManifest);
206202
[weightSpecs, weightData] = results;
207203
}

src/io/io.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {browserHTTPRequest, isHTTPScheme} from './browser_http';
2525
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
2626
import {fromMemory, withSaveHandler} from './passthrough';
2727
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
28-
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelStoreManager, ModelFormat, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
28+
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2929
import {loadWeights, weightsLoaderFactory} from './weights_loader';
3030

3131
export {copyModel, listModels, moveModel, removeModel} from './model_management';
@@ -45,7 +45,7 @@ export {
4545
LoadOptions,
4646
loadWeights,
4747
ModelArtifacts,
48-
ModelFormat,
48+
ModelJSON,
4949
ModelStoreManager,
5050
OnProgressCallback,
5151
registerLoadRouter,

src/io/local_storage.ts

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {assert} from '../util';
2020
import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils';
2121
import {ModelStoreManagerRegistry} from './model_management';
2222
import {IORouter, IORouterRegistry} from './router_registry';
23-
import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, ModelFormat, SaveResult} from './types';
23+
import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, SaveResult} from './types';
2424

2525
const PATH_SEPARATOR = '/';
2626
const PATH_PREFIX = 'tensorflowjs_models';
@@ -57,16 +57,20 @@ export function purgeLocalStorageArtifacts(): string[] {
5757
return purgedModelPaths;
5858
}
5959

60-
function getModelKeys(path: string):
61-
{info: string, topology: string, weightSpecs: string, weightData: string,
62-
modelMetadata: string} {
60+
function getModelKeys(path: string): {
61+
info: string,
62+
topology: string,
63+
weightSpecs: string,
64+
weightData: string,
65+
modelMetadata: string
66+
} {
6367
return {
6468
info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
6569
topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
6670
weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
6771
weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
68-
modelMetadata: [
69-
PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
72+
modelMetadata:
73+
[PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
7074
};
7175
}
7276

@@ -218,12 +222,9 @@ export class BrowserLocalStorage implements IOHandler {
218222
// Load meta-data fields.
219223
const metadataString = this.LS.getItem(this.keys.modelMetadata);
220224
if (metadataString != null) {
221-
const metadata = JSON.parse(metadataString) as {
222-
format: string,
223-
generatedBy: string,
224-
convertedBy: string
225-
};
226-
out.format = metadata.format as ModelFormat;
225+
const metadata = JSON.parse(metadataString) as
226+
{format: string, generatedBy: string, convertedBy: string};
227+
out.format = metadata['format'];
227228
out.generatedBy = metadata['generatedBy'];
228229
out.convertedBy = metadata['convertedBy'];
229230
}

src/io/types.ts

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,22 +156,23 @@ export declare interface ModelArtifactsInfo {
156156
weightDataBytes?: number;
157157
}
158158

159-
export declare type ModelFormat = 'graph-model'|'layers-model';
160-
161159
/**
162160
* The serialized artifacts of a model, including topology and weights.
163161
*
164162
* The `modelTopology`, `weightSpecs` and `weightData` fields of this interface
165163
* are optional, in order to support topology- or weights-only saving and
166164
* loading.
165+
*
166+
* Note this interface is used internally in IOHandlers. For the file format
167+
* written to disk as `model.json`, see `ModelJSON`.
167168
*/
168169
export declare interface ModelArtifacts {
169170
/**
170171
* Model topology.
171172
*
172173
* For Keras-style `tf.Model`s, this is a JSON object.
173-
* For TensorFlow-style models (e.g., `FrozenModel`), this is a binary buffer
174-
* carrying the `GraphDef` protocol buffer.
174+
* For TensorFlow-style models (e.g., `SavedModel`), this is the JSON
175+
* encoding of the `GraphDef` protocol buffer.
175176
*/
176177
modelTopology?: {}|ArrayBuffer;
177178

@@ -192,7 +193,7 @@ export declare interface ModelArtifacts {
192193
* Hard-coded format name for models saved from TensorFlow.js or converted
193194
* by TensorFlow.js Converter.
194195
*/
195-
format?: ModelFormat;
196+
format?: string;
196197

197198
/**
198199
* What library is responsible for originally generating this artifact.
@@ -209,15 +210,62 @@ export declare interface ModelArtifacts {
209210
*
210211
* A value of `null` means the model artifacts are generated without any
211212
* conversion process (e.g., saved directly from a TensorFlow.js
212-
* `LayersModel` instance.)
213+
* `tf.LayersModel` instance.)
213214
*/
214215
convertedBy?: string|null;
215216
}
216217

217-
// TODO(cais): Create interface spec for JSON model artifact, e.g., the format
218-
// for the content in a model.json file saved by TensorFlow.js. It should be the
219-
// same as the `ModelArtifacts` interface, but without the binary `weightData`
220-
// field. The `weightSpec` field should be replaced with a `weightManifest` one.
218+
/**
219+
* The on-disk format of the `model.json` file.
220+
*
221+
* TF.js 1.0 always populates the optional fields when writing model.json.
222+
* Prior versions did not provide those fields.
223+
*/
224+
export declare interface ModelJSON {
225+
/**
226+
* Model topology.
227+
*
228+
* For Keras-style `tf.Model`s, this is a JSON object.
229+
* For TensorFlow-style models (e.g., `SavedModel`), this is the JSON
230+
* encoding of the `GraphDef` protocol buffer.
231+
*/
232+
modelTopology: {};
233+
234+
/**
235+
* Weights manifest.
236+
*
237+
* The weights manifest consists of an ordered list of weight-manifest
238+
* groups. Each weight-manifest group consists of a number of weight values
239+
* stored in a number of paths. See the documentation of
240+
* `WeightsManifestConfig` for more details.
241+
*/
242+
weightsManifest: WeightsManifestConfig;
243+
244+
/**
245+
* Hard-coded format name for models saved from TensorFlow.js or converted
246+
* by TensorFlow.js Converter.
247+
*/
248+
format?: string;
249+
250+
/**
251+
* What library is responsible for originally generating this artifact.
252+
*
253+
* Used for debugging purposes. E.g., 'TensorFlow.js v1.0.0'.
254+
*/
255+
generatedBy?: string;
256+
257+
/**
258+
* What library or tool is responsible for converting the original model
259+
* to this format, applicable only if the model is output by a converter.
260+
*
261+
* Used for debugging purposes. E.g., 'TensorFlow.js Converter v1.0.0'.
262+
*
263+
* A value of `null` means the model artifacts are generated without any
264+
* conversion process (e.g., saved directly from a TensorFlow.js
265+
* `tf.LayersModel` instance.)
266+
*/
267+
convertedBy?: string|null;
268+
}
221269

222270
/**
223271
* Type definition for handlers of loading operations.

0 commit comments

Comments
 (0)