@@ -49,17 +49,22 @@ export class BrowserHTTPRequest implements IOHandler {
49
49
}
50
50
// Make sure fetch is always bound to window (the
51
51
// original object) when available.
52
- this . fetchFunc =
53
- fetch . bind ( typeof window === 'undefined' ? null : window ) ;
52
+ fetchFunc = fetch . bind ( typeof window === 'undefined' ? null : window ) ;
54
53
} else {
55
54
assert (
56
55
typeof fetchFunc === 'function' ,
57
56
'Must pass a function that matches the signature of ' +
58
57
'`fetch` (see ' +
59
58
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)' ) ;
60
- this . fetchFunc = fetchFunc ;
61
59
}
62
60
61
+ this . fetchFunc = ( path : string , requestInits : RequestInit ) => {
62
+ // tslint:disable-next-line:no-any
63
+ return fetchFunc ( path , requestInits ) . catch ( ( error : any ) => {
64
+ throw new Error ( `Request for ${ path } failed due to error: ${ error } ` ) ;
65
+ } ) ;
66
+ } ;
67
+
63
68
assert (
64
69
path != null && path . length > 0 ,
65
70
'URL path for browserHTTPRequest must not be null, undefined or ' +
@@ -145,26 +150,44 @@ export class BrowserHTTPRequest implements IOHandler {
145
150
* Loads the model topology file and build the in memory graph of the model.
146
151
*/
147
152
private async loadBinaryTopology ( ) : Promise < ArrayBuffer > {
148
- try {
149
- const response =
150
- await this . getFetchFunc ( ) ( this . path [ 0 ] , this . requestInit ) ;
151
- if ( ! response . ok ) {
152
- throw new Error (
153
- `BrowserHTTPRequest.load() failed due to HTTP response: ${
154
- response . statusText } `) ;
155
- }
156
- return await response . arrayBuffer ( ) ;
157
- } catch ( error ) {
158
- throw new Error ( `${ this . path [ 0 ] } not found. ${ error } ` ) ;
153
+ const response = await this . getFetchFunc ( ) (
154
+ this . path [ 0 ] , this . addAcceptHeader ( 'application/octet-stream' ) ) ;
155
+ this . verifyContentType (
156
+ response , 'model topology' , 'application/octet-stream' ) ;
157
+
158
+ if ( ! response . ok ) {
159
+ throw new Error ( `Request to ${ this . path [ 0 ] } failed with error: ${
160
+ response . statusText } `) ;
161
+ }
162
+ return await response . arrayBuffer ( ) ;
163
+ }
164
+
165
+ private addAcceptHeader ( mimeType : string ) : RequestInit {
166
+ const requestOptions = Object . assign ( { } , this . requestInit || { } ) ;
167
+ const headers = Object . assign ( { } , requestOptions . headers || { } ) ;
168
+ // tslint:disable-next-line:no-any
169
+ ( headers as any ) [ 'Accept' ] = mimeType ;
170
+ requestOptions . headers = headers ;
171
+ return requestOptions ;
172
+ }
173
+
174
+ private verifyContentType ( response : Response , target : string , type : string ) {
175
+ const contentType = response . headers . get ( 'content-type' ) ;
176
+ if ( ! contentType || contentType . indexOf ( type ) === - 1 ) {
177
+ throw new Error ( `Request to ${ response . url } for ${
178
+ target } failed. Expected content type ${ type } but got ${
179
+ contentType } .`) ;
159
180
}
160
181
}
161
182
162
183
protected async loadBinaryModel ( ) : Promise < ModelArtifacts > {
163
184
const graphPromise = this . loadBinaryTopology ( ) ;
164
- const manifestPromise =
165
- await this . getFetchFunc ( ) ( this . path [ 1 ] , this . requestInit ) ;
185
+ const manifestPromise = await this . getFetchFunc ( ) (
186
+ this . path [ 1 ] , this . addAcceptHeader ( 'application/json' ) ) ;
187
+ this . verifyContentType (
188
+ manifestPromise , 'weights manifest' , 'application/json' ) ;
166
189
if ( ! manifestPromise . ok ) {
167
- throw new Error ( `BrowserHTTPRequest.load() failed due to HTTP response : ${
190
+ throw new Error ( `Request to ${ this . path [ 1 ] } failed with error : ${
168
191
manifestPromise . statusText } `) ;
169
192
}
170
193
@@ -185,10 +208,13 @@ export class BrowserHTTPRequest implements IOHandler {
185
208
}
186
209
187
210
protected async loadJSONModel ( ) : Promise < ModelArtifacts > {
188
- const modelConfigRequest =
189
- await this . getFetchFunc ( ) ( this . path as string , this . requestInit ) ;
211
+ const modelConfigRequest = await this . getFetchFunc ( ) (
212
+ this . path as string , this . addAcceptHeader ( 'application/json' ) ) ;
213
+ this . verifyContentType (
214
+ modelConfigRequest , 'model topology' , 'application/json' ) ;
215
+
190
216
if ( ! modelConfigRequest . ok ) {
191
- throw new Error ( `BrowserHTTPRequest.load() failed due to HTTP response : ${
217
+ throw new Error ( `Request to ${ this . path } failed with error : ${
192
218
modelConfigRequest . statusText } `) ;
193
219
}
194
220
const modelConfig = await modelConfigRequest . json ( ) ;
@@ -231,12 +257,10 @@ export class BrowserHTTPRequest implements IOHandler {
231
257
fetchURLs . push ( pathPrefix + path + suffix ) ;
232
258
} ) ;
233
259
} ) ;
234
-
235
260
return [
236
261
weightSpecs ,
237
262
concatenateArrayBuffers ( await loadWeightsAsArrayBuffer (
238
- fetchURLs , this . requestInit , this . getFetchFunc ( ) ,
239
- this . onProgress ) )
263
+ fetchURLs , this . requestInit , this . getFetchFunc ( ) , this . onProgress ) )
240
264
] ;
241
265
}
242
266
@@ -276,24 +300,25 @@ export function isHTTPScheme(url: string): boolean {
276
300
return url . match ( BrowserHTTPRequest . URL_SCHEME_REGEX ) != null ;
277
301
}
278
302
279
- export const httpRequestRouter : IORouter = ( url : string | string [ ] , onProgress ?: Function ) => {
280
- if ( typeof fetch === 'undefined' ) {
281
- // browserHTTPRequest uses `fetch`, if one wants to use it in node.js
282
- // they have to setup a global fetch polyfill.
283
- return null ;
284
- } else {
285
- let isHTTP = true ;
286
- if ( Array . isArray ( url ) ) {
287
- isHTTP = url . every ( urlItem => isHTTPScheme ( urlItem ) ) ;
288
- } else {
289
- isHTTP = isHTTPScheme ( url ) ;
290
- }
291
- if ( isHTTP ) {
292
- return browserHTTPRequest ( url , null , null , null , onProgress ) ;
293
- }
294
- }
295
- return null ;
296
- } ;
303
+ export const httpRequestRouter : IORouter =
304
+ ( url : string | string [ ] , onProgress ?: Function ) => {
305
+ if ( typeof fetch === 'undefined' ) {
306
+ // browserHTTPRequest uses `fetch`, if one wants to use it in node.js
307
+ // they have to setup a global fetch polyfill.
308
+ return null ;
309
+ } else {
310
+ let isHTTP = true ;
311
+ if ( Array . isArray ( url ) ) {
312
+ isHTTP = url . every ( urlItem => isHTTPScheme ( urlItem ) ) ;
313
+ } else {
314
+ isHTTP = isHTTPScheme ( url ) ;
315
+ }
316
+ if ( isHTTP ) {
317
+ return browserHTTPRequest ( url , null , null , null , onProgress ) ;
318
+ }
319
+ }
320
+ return null ;
321
+ } ;
297
322
IORouterRegistry . registerSaveRouter ( httpRequestRouter ) ;
298
323
IORouterRegistry . registerLoadRouter ( httpRequestRouter ) ;
299
324
@@ -444,6 +469,6 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
444
469
export function browserHTTPRequest (
445
470
path : string | string [ ] , requestInit ?: RequestInit , weightPathPrefix ?: string ,
446
471
fetchFunc ?: Function , onProgress ?: Function ) : IOHandler {
447
- return new BrowserHTTPRequest ( path , requestInit , weightPathPrefix , fetchFunc ,
448
- onProgress ) ;
472
+ return new BrowserHTTPRequest (
473
+ path , requestInit , weightPathPrefix , fetchFunc , onProgress ) ;
449
474
}
0 commit comments