@@ -25,7 +25,7 @@ import {assert} from '../util';
25
25
26
26
import { concatenateArrayBuffers , getModelArtifactsInfoForJSON } from './io_utils' ;
27
27
import { IORouter , IORouterRegistry } from './router_registry' ;
28
- import { IOHandler , ModelArtifacts , SaveResult , WeightsManifestConfig , WeightsManifestEntry } from './types' ;
28
+ import { IOHandler , LoadOptions , ModelArtifacts , OnProgressCallback , SaveResult , WeightsManifestConfig , WeightsManifestEntry } from './types' ;
29
29
import { loadWeightsAsArrayBuffer } from './weights_loader' ;
30
30
31
31
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream' ;
@@ -41,30 +41,37 @@ export class BrowserHTTPRequest implements IOHandler {
41
41
42
42
static readonly URL_SCHEME_REGEX = / ^ h t t p s ? : \/ \/ / ;
43
43
44
- constructor (
45
- path : string | string [ ] , requestInit ?: RequestInit ,
46
- private readonly weightPathPrefix ?: string , fetchFunc ?: Function ,
47
- private readonly onProgress ?: Function ) {
48
- if ( fetchFunc == null ) {
44
+ private readonly weightPathPrefix : string ;
45
+ private readonly onProgress : OnProgressCallback ;
46
+
47
+ constructor ( path : string | string [ ] , loadOptions ?: LoadOptions ) {
48
+ if ( loadOptions == null ) {
49
+ loadOptions = { } ;
50
+ }
51
+ this . weightPathPrefix = loadOptions . weightPathPrefix ;
52
+ this . onProgress = loadOptions . onProgress ;
53
+
54
+ if ( loadOptions . fetchFunc == null ) {
49
55
if ( typeof fetch === 'undefined' ) {
50
56
throw new Error (
51
57
'browserHTTPRequest is not supported outside the web browser ' +
52
58
'without a fetch polyfill.' ) ;
53
59
}
54
60
// Make sure fetch is always bound to window (the
55
61
// original object) when available.
56
- fetchFunc = fetch . bind ( typeof window === 'undefined' ? null : window ) ;
62
+ loadOptions . fetchFunc =
63
+ fetch . bind ( typeof window === 'undefined' ? null : window ) ;
57
64
} else {
58
65
assert (
59
- typeof fetchFunc === 'function' ,
66
+ typeof loadOptions . fetchFunc === 'function' ,
60
67
( ) => 'Must pass a function that matches the signature of ' +
61
68
'`fetch` (see ' +
62
69
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)' ) ;
63
70
}
64
71
65
72
this . fetchFunc = ( path : string , requestInits : RequestInit ) => {
66
73
// tslint:disable-next-line:no-any
67
- return fetchFunc ( path , requestInits ) . catch ( ( error : any ) => {
74
+ return loadOptions . fetchFunc ( path , requestInits ) . catch ( ( error : any ) => {
68
75
throw new Error ( `Request for ${ path } failed due to error: ${ error } ` ) ;
69
76
} ) ;
70
77
} ;
@@ -83,11 +90,12 @@ export class BrowserHTTPRequest implements IOHandler {
83
90
}
84
91
this . path = path ;
85
92
86
- if ( requestInit != null && requestInit . body != null ) {
93
+ if ( loadOptions . requestInit != null &&
94
+ loadOptions . requestInit . body != null ) {
87
95
throw new Error (
88
96
'requestInit is expected to have no pre-existing body, but has one.' ) ;
89
97
}
90
- this . requestInit = requestInit || { } ;
98
+ this . requestInit = loadOptions . requestInit || { } ;
91
99
}
92
100
93
101
async save ( modelArtifacts : ModelArtifacts ) : Promise < SaveResult > {
@@ -236,8 +244,11 @@ export class BrowserHTTPRequest implements IOHandler {
236
244
fetchURLs . push ( pathPrefix + path + suffix ) ;
237
245
} ) ;
238
246
} ) ;
239
- const buffers = await loadWeightsAsArrayBuffer (
240
- fetchURLs , this . requestInit , this . getFetchFunc ( ) , this . onProgress ) ;
247
+ const buffers = await loadWeightsAsArrayBuffer ( fetchURLs , {
248
+ requestInit : this . requestInit ,
249
+ fetchFunc : this . getFetchFunc ( ) ,
250
+ onProgress : this . onProgress
251
+ } ) ;
241
252
return [ weightSpecs , concatenateArrayBuffers ( buffers ) ] ;
242
253
}
243
254
@@ -278,7 +289,7 @@ export function isHTTPScheme(url: string): boolean {
278
289
}
279
290
280
291
export const httpRequestRouter : IORouter =
281
- ( url : string | string [ ] , onProgress ?: Function ) => {
292
+ ( url : string | string [ ] , onProgress ?: OnProgressCallback ) => {
282
293
if ( typeof fetch === 'undefined' ) {
283
294
// browserHTTPRequest uses `fetch`, if one wants to use it in node.js
284
295
// they have to setup a global fetch polyfill.
@@ -291,7 +302,7 @@ export const httpRequestRouter: IORouter =
291
302
isHTTP = isHTTPScheme ( url ) ;
292
303
}
293
304
if ( isHTTP ) {
294
- return browserHTTPRequest ( url , null , null , null , onProgress ) ;
305
+ return browserHTTPRequest ( url , { onProgress} ) ;
295
306
}
296
307
}
297
308
return null ;
@@ -434,17 +445,17 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
434
445
* topology (filename: 'model.json') and the weights of the model (filename:
435
446
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
436
447
* `body`, an Error will be thrown.
437
- * @param weightPathPrefix Optional, this specifies the path prefix for weight
438
- * files, by default this is calculated from the path param.
439
- * @param fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
440
- * the `fetch` from node-fetch can be used here.
441
- * @param onProgress Optional, progress callback function, fired periodically
442
- * before the load is completed.
448
+ * @param loadOptions Optional configuration for the loading. It includes the
449
+ * following fields:
450
+ * - weightPathPrefix Optional, this specifies the path prefix for weight
451
+ * files, by default this is calculated from the path param.
452
+ * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
453
+ * the `fetch` from node-fetch can be used here.
454
+ * - onProgress Optional, progress callback function, fired periodically
455
+ * before the load is completed.
443
456
* @returns An instance of `IOHandler`.
444
457
*/
445
458
export function browserHTTPRequest (
446
- path : string | string [ ] , requestInit ?: RequestInit , weightPathPrefix ?: string ,
447
- fetchFunc ?: Function , onProgress ?: Function ) : IOHandler {
448
- return new BrowserHTTPRequest (
449
- path , requestInit , weightPathPrefix , fetchFunc , onProgress ) ;
459
+ path : string | string [ ] , loadOptions ?: LoadOptions ) : IOHandler {
460
+ return new BrowserHTTPRequest ( path , loadOptions ) ;
450
461
}
0 commit comments