22
22
*/
23
23
24
24
import { assert } from '../util' ;
25
-
26
25
import { concatenateArrayBuffers , getModelArtifactsInfoForJSON } from './io_utils' ;
27
26
import { IORouter , IORouterRegistry } from './router_registry' ;
28
- import { IOHandler , ModelArtifacts , SaveResult , WeightsManifestConfig , WeightsManifestEntry } from './types' ;
27
+ import { IOHandler , LoadOptions , ModelArtifacts , ModelJSON , OnProgressCallback , SaveResult , WeightsManifestConfig , WeightsManifestEntry } from './types' ;
29
28
import { loadWeightsAsArrayBuffer } from './weights_loader' ;
30
29
31
30
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream' ;
32
31
const JSON_TYPE = 'application/json' ;
33
32
34
33
export class BrowserHTTPRequest implements IOHandler {
35
- protected readonly path : string | string [ ] ;
34
+ protected readonly path : string ;
36
35
protected readonly requestInit : RequestInit ;
37
36
38
- private readonly fetchFunc : Function ;
37
+ private readonly fetchFunc : ( path : string , init ?: RequestInit ) => Response ;
39
38
40
39
readonly DEFAULT_METHOD = 'POST' ;
41
40
42
41
static readonly URL_SCHEME_REGEX = / ^ h t t p s ? : \/ \/ / ;
43
42
44
- constructor (
45
- path : string | string [ ] , requestInit ?: RequestInit ,
46
- private readonly weightPathPrefix ?: string , fetchFunc ?: Function ,
47
- private readonly onProgress ?: Function ) {
48
- if ( fetchFunc == null ) {
43
+ private readonly weightPathPrefix : string ;
44
+ private readonly onProgress : OnProgressCallback ;
45
+
46
+ constructor ( path : string , loadOptions ?: LoadOptions ) {
47
+ if ( loadOptions == null ) {
48
+ loadOptions = { } ;
49
+ }
50
+ this . weightPathPrefix = loadOptions . weightPathPrefix ;
51
+ this . onProgress = loadOptions . onProgress ;
52
+
53
+ if ( loadOptions . fetchFunc == null ) {
49
54
if ( typeof fetch === 'undefined' ) {
50
55
throw new Error (
51
56
'browserHTTPRequest is not supported outside the web browser ' +
52
57
'without a fetch polyfill.' ) ;
53
58
}
54
59
// Make sure fetch is always bound to window (the
55
60
// original object) when available.
56
- fetchFunc = fetch . bind ( typeof window === 'undefined' ? null : window ) ;
61
+ loadOptions . fetchFunc =
62
+ fetch . bind ( typeof window === 'undefined' ? null : window ) ;
57
63
} else {
58
64
assert (
59
- typeof fetchFunc === 'function' ,
65
+ typeof loadOptions . fetchFunc === 'function' ,
60
66
( ) => 'Must pass a function that matches the signature of ' +
61
67
'`fetch` (see ' +
62
68
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)' ) ;
63
69
}
64
70
65
71
this . fetchFunc = ( path : string , requestInits : RequestInit ) => {
66
72
// tslint:disable-next-line:no-any
67
- return fetchFunc ( path , requestInits ) . catch ( ( error : any ) => {
73
+ return loadOptions . fetchFunc ( path , requestInits ) . catch ( ( error : any ) => {
68
74
throw new Error ( `Request for ${ path } failed due to error: ${ error } ` ) ;
69
75
} ) ;
70
76
} ;
@@ -83,11 +89,12 @@ export class BrowserHTTPRequest implements IOHandler {
83
89
}
84
90
this . path = path ;
85
91
86
- if ( requestInit != null && requestInit . body != null ) {
92
+ if ( loadOptions . requestInit != null &&
93
+ loadOptions . requestInit . body != null ) {
87
94
throw new Error (
88
95
'requestInit is expected to have no pre-existing body, but has one.' ) ;
89
96
}
90
- this . requestInit = requestInit || { } ;
97
+ this . requestInit = loadOptions . requestInit || { } ;
91
98
}
92
99
93
100
async save ( modelArtifacts : ModelArtifacts ) : Promise < SaveResult > {
@@ -104,8 +111,11 @@ export class BrowserHTTPRequest implements IOHandler {
104
111
paths : [ './model.weights.bin' ] ,
105
112
weights : modelArtifacts . weightSpecs ,
106
113
} ] ;
107
- const modelTopologyAndWeightManifest = {
114
+ const modelTopologyAndWeightManifest : ModelJSON = {
108
115
modelTopology : modelArtifacts . modelTopology ,
116
+ format : modelArtifacts . format ,
117
+ generatedBy : modelArtifacts . generatedBy ,
118
+ convertedBy : modelArtifacts . convertedBy ,
109
119
weightsManifest
110
120
} ;
111
121
@@ -123,7 +133,7 @@ export class BrowserHTTPRequest implements IOHandler {
123
133
'model.weights.bin' ) ;
124
134
}
125
135
126
- const response = await this . getFetchFunc ( ) ( this . path as string , init ) ;
136
+ const response = await this . getFetchFunc ( ) ( this . path , init ) ;
127
137
128
138
if ( response . ok ) {
129
139
return {
@@ -146,59 +156,37 @@ export class BrowserHTTPRequest implements IOHandler {
146
156
* @returns The loaded model artifacts (if loading succeeds).
147
157
*/
148
158
async load ( ) : Promise < ModelArtifacts > {
149
- return Array . isArray ( this . path ) ? this . loadBinaryModel ( ) :
150
- this . loadJSONModel ( ) ;
151
- }
152
-
153
- /**
154
- * Loads the model topology file and build the in memory graph of the model.
155
- */
156
- private async loadBinaryTopology ( ) : Promise < ArrayBuffer > {
157
- const response = await this . getFetchFunc ( ) ( this . path [ 0 ] , this . requestInit ) ;
158
-
159
- if ( ! response . ok ) {
160
- throw new Error ( `Request to ${ this . path [ 0 ] } failed with error: ${
161
- response . statusText } `) ;
162
- }
163
- return await response . arrayBuffer ( ) ;
164
- }
165
-
166
- protected async loadBinaryModel ( ) : Promise < ModelArtifacts > {
167
- const graphPromise = this . loadBinaryTopology ( ) ;
168
- const manifestPromise =
169
- await this . getFetchFunc ( ) ( this . path [ 1 ] , this . requestInit ) ;
170
- if ( ! manifestPromise . ok ) {
171
- throw new Error ( `Request to ${ this . path [ 1 ] } failed with error: ${
172
- manifestPromise . statusText } `) ;
173
- }
174
-
175
- const results = await Promise . all ( [ graphPromise , manifestPromise ] ) ;
176
- const [ modelTopology , weightsManifestResponse ] = results ;
177
-
178
- const weightsManifest =
179
- await weightsManifestResponse . json ( ) as WeightsManifestConfig ;
180
-
181
- let weightSpecs : WeightsManifestEntry [ ] ;
182
- let weightData : ArrayBuffer ;
183
- if ( weightsManifest != null ) {
184
- const results = await this . loadWeights ( weightsManifest ) ;
185
- [ weightSpecs , weightData ] = results ;
186
- }
187
-
188
- return { modelTopology, weightSpecs, weightData} ;
189
- }
190
-
191
- protected async loadJSONModel ( ) : Promise < ModelArtifacts > {
192
159
const modelConfigRequest =
193
- await this . getFetchFunc ( ) ( this . path as string , this . requestInit ) ;
160
+ await this . getFetchFunc ( ) ( this . path , this . requestInit ) ;
194
161
195
162
if ( ! modelConfigRequest . ok ) {
196
- throw new Error ( `Request to ${ this . path } failed with error: ${
197
- modelConfigRequest . statusText } `) ;
163
+ throw new Error (
164
+ `Request to ${ this . path } failed with status code ` +
165
+ `${ modelConfigRequest . status } . Please verify this URL points to ` +
166
+ `the model JSON of the model to load.` ) ;
198
167
}
199
- const modelConfig = await modelConfigRequest . json ( ) ;
200
- const modelTopology = modelConfig [ 'modelTopology' ] ;
201
- const weightsManifest = modelConfig [ 'weightsManifest' ] ;
168
+ let modelConfig : ModelJSON ;
169
+ try {
170
+ modelConfig = await modelConfigRequest . json ( ) ;
171
+ } catch ( e ) {
172
+ let message = `Failed to parse model JSON of response from ${ this . path } .` ;
173
+ // TODO(nsthorat): Remove this after some time when we're comfortable that
174
+ // .pb files are mostly gone.
175
+ if ( this . path . endsWith ( '.pb' ) ) {
176
+ message += ' Your path contains a .pb file extension. ' +
177
+ 'Support for .pb models have been removed in TensorFlow.js 1.0 ' +
178
+ 'in favor of .json models. You can re-convert your Python ' +
179
+ 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' +
180
+ 'or you can convert your.pb models with the \'pb2json\'' +
181
+ 'NPM script in the tensorflow/tfjs-converter repository.' ;
182
+ } else {
183
+ message += ' Please make sure the server is serving valid ' +
184
+ 'JSON for this request.' ;
185
+ }
186
+ throw new Error ( message ) ;
187
+ }
188
+ const modelTopology = modelConfig . modelTopology ;
189
+ const weightsManifest = modelConfig . weightsManifest ;
202
190
203
191
// We do not allow both modelTopology and weightsManifest to be missing.
204
192
if ( modelTopology == null && weightsManifest == null ) {
@@ -210,8 +198,6 @@ export class BrowserHTTPRequest implements IOHandler {
210
198
let weightSpecs : WeightsManifestEntry [ ] ;
211
199
let weightData : ArrayBuffer ;
212
200
if ( weightsManifest != null ) {
213
- const weightsManifest =
214
- modelConfig [ 'weightsManifest' ] as WeightsManifestConfig ;
215
201
const results = await this . loadWeights ( weightsManifest ) ;
216
202
[ weightSpecs , weightData ] = results ;
217
203
}
@@ -236,8 +222,11 @@ export class BrowserHTTPRequest implements IOHandler {
236
222
fetchURLs . push ( pathPrefix + path + suffix ) ;
237
223
} ) ;
238
224
} ) ;
239
- const buffers = await loadWeightsAsArrayBuffer (
240
- fetchURLs , this . requestInit , this . getFetchFunc ( ) , this . onProgress ) ;
225
+ const buffers = await loadWeightsAsArrayBuffer ( fetchURLs , {
226
+ requestInit : this . requestInit ,
227
+ fetchFunc : this . getFetchFunc ( ) ,
228
+ onProgress : this . onProgress
229
+ } ) ;
241
230
return [ weightSpecs , concatenateArrayBuffers ( buffers ) ] ;
242
231
}
243
232
@@ -278,7 +267,7 @@ export function isHTTPScheme(url: string): boolean {
278
267
}
279
268
280
269
export const httpRequestRouter : IORouter =
281
- ( url : string | string [ ] , onProgress ?: Function ) => {
270
+ ( url : string , onProgress ?: OnProgressCallback ) => {
282
271
if ( typeof fetch === 'undefined' ) {
283
272
// browserHTTPRequest uses `fetch`, if one wants to use it in node.js
284
273
// they have to setup a global fetch polyfill.
@@ -291,7 +280,7 @@ export const httpRequestRouter: IORouter =
291
280
isHTTP = isHTTPScheme ( url ) ;
292
281
}
293
282
if ( isHTTP ) {
294
- return browserHTTPRequest ( url , null , null , null , onProgress ) ;
283
+ return browserHTTPRequest ( url , { onProgress} ) ;
295
284
}
296
285
}
297
286
return null ;
@@ -419,9 +408,7 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
419
408
* main()
420
409
* ```
421
410
*
422
- * @param path A single URL path or an Array of URL paths.
423
- * Currently, only a single URL path is supported. Array input is reserved
424
- * for future development.
411
+ * @param path A URL path to the model.
425
412
* Can be an absolute HTTP path (e.g.,
426
413
* 'http://localhost:8000/model-upload)') or a relative path (e.g.,
427
414
* './model-upload').
@@ -434,17 +421,17 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
434
421
* topology (filename: 'model.json') and the weights of the model (filename:
435
422
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
436
423
* `body`, an Error will be thrown.
437
- * @param weightPathPrefix Optional, this specifies the path prefix for weight
438
- * files, by default this is calculated from the path param.
439
- * @param fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
440
- * the `fetch` from node-fetch can be used here.
441
- * @param onProgress Optional, progress callback function, fired periodically
442
- * before the load is completed.
424
+ * @param loadOptions Optional configuration for the loading. It includes the
425
+ * following fields:
426
+ * - weightPathPrefix Optional, this specifies the path prefix for weight
427
+ * files, by default this is calculated from the path param.
428
+ * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
429
+ * the `fetch` from node-fetch can be used here.
430
+ * - onProgress Optional, progress callback function, fired periodically
431
+ * before the load is completed.
443
432
* @returns An instance of `IOHandler`.
444
433
*/
445
434
export function browserHTTPRequest (
446
- path : string | string [ ] , requestInit ?: RequestInit , weightPathPrefix ?: string ,
447
- fetchFunc ?: Function , onProgress ?: Function ) : IOHandler {
448
- return new BrowserHTTPRequest (
449
- path , requestInit , weightPathPrefix , fetchFunc , onProgress ) ;
435
+ path : string , loadOptions ?: LoadOptions ) : IOHandler {
436
+ return new BrowserHTTPRequest ( path , loadOptions ) ;
450
437
}
0 commit comments