Skip to content

Commit 6625e66

Browse files
authored
Refactor IO-related code to use the new LoadOptions interface (tensorflow#1554)
- Refactor several IO-related functions to use the new `LoadOptions` interface. - Move `monitorProgress` to the io-specific source directory. DEV
1 parent e258761 commit 6625e66

7 files changed

+235
-170
lines changed

src/io/browser_http.ts

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {assert} from '../util';
2525

2626
import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
2727
import {IORouter, IORouterRegistry} from './router_registry';
28-
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
28+
import {IOHandler, LoadOptions, ModelArtifacts, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2929
import {loadWeightsAsArrayBuffer} from './weights_loader';
3030

3131
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
@@ -41,30 +41,37 @@ export class BrowserHTTPRequest implements IOHandler {
4141

4242
static readonly URL_SCHEME_REGEX = /^https?:\/\//;
4343

44-
constructor(
45-
path: string|string[], requestInit?: RequestInit,
46-
private readonly weightPathPrefix?: string, fetchFunc?: Function,
47-
private readonly onProgress?: Function) {
48-
if (fetchFunc == null) {
44+
private readonly weightPathPrefix: string;
45+
private readonly onProgress: OnProgressCallback;
46+
47+
constructor(path: string|string[], loadOptions?: LoadOptions) {
48+
if (loadOptions == null) {
49+
loadOptions = {};
50+
}
51+
this.weightPathPrefix = loadOptions.weightPathPrefix;
52+
this.onProgress = loadOptions.onProgress;
53+
54+
if (loadOptions.fetchFunc == null) {
4955
if (typeof fetch === 'undefined') {
5056
throw new Error(
5157
'browserHTTPRequest is not supported outside the web browser ' +
5258
'without a fetch polyfill.');
5359
}
5460
// Make sure fetch is always bound to window (the
5561
// original object) when available.
56-
fetchFunc = fetch.bind(typeof window === 'undefined' ? null : window);
62+
loadOptions.fetchFunc =
63+
fetch.bind(typeof window === 'undefined' ? null : window);
5764
} else {
5865
assert(
59-
typeof fetchFunc === 'function',
66+
typeof loadOptions.fetchFunc === 'function',
6067
() => 'Must pass a function that matches the signature of ' +
6168
'`fetch` (see ' +
6269
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
6370
}
6471

6572
this.fetchFunc = (path: string, requestInits: RequestInit) => {
6673
// tslint:disable-next-line:no-any
67-
return fetchFunc(path, requestInits).catch((error: any) => {
74+
return loadOptions.fetchFunc(path, requestInits).catch((error: any) => {
6875
throw new Error(`Request for ${path} failed due to error: ${error}`);
6976
});
7077
};
@@ -83,11 +90,12 @@ export class BrowserHTTPRequest implements IOHandler {
8390
}
8491
this.path = path;
8592

86-
if (requestInit != null && requestInit.body != null) {
93+
if (loadOptions.requestInit != null &&
94+
loadOptions.requestInit.body != null) {
8795
throw new Error(
8896
'requestInit is expected to have no pre-existing body, but has one.');
8997
}
90-
this.requestInit = requestInit || {};
98+
this.requestInit = loadOptions.requestInit || {};
9199
}
92100

93101
async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
@@ -236,8 +244,11 @@ export class BrowserHTTPRequest implements IOHandler {
236244
fetchURLs.push(pathPrefix + path + suffix);
237245
});
238246
});
239-
const buffers = await loadWeightsAsArrayBuffer(
240-
fetchURLs, this.requestInit, this.getFetchFunc(), this.onProgress);
247+
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
248+
requestInit: this.requestInit,
249+
fetchFunc: this.getFetchFunc(),
250+
onProgress: this.onProgress
251+
});
241252
return [weightSpecs, concatenateArrayBuffers(buffers)];
242253
}
243254

@@ -278,7 +289,7 @@ export function isHTTPScheme(url: string): boolean {
278289
}
279290

280291
export const httpRequestRouter: IORouter =
281-
(url: string|string[], onProgress?: Function) => {
292+
(url: string|string[], onProgress?: OnProgressCallback) => {
282293
if (typeof fetch === 'undefined') {
283294
// browserHTTPRequest uses `fetch`, if one wants to use it in node.js
284295
// they have to setup a global fetch polyfill.
@@ -291,7 +302,7 @@ export const httpRequestRouter: IORouter =
291302
isHTTP = isHTTPScheme(url);
292303
}
293304
if (isHTTP) {
294-
return browserHTTPRequest(url, null, null, null, onProgress);
305+
return browserHTTPRequest(url, {onProgress});
295306
}
296307
}
297308
return null;
@@ -434,17 +445,17 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
434445
* topology (filename: 'model.json') and the weights of the model (filename:
435446
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
436447
* `body`, an Error will be thrown.
437-
* @param weightPathPrefix Optional, this specifies the path prefix for weight
438-
* files, by default this is calculated from the path param.
439-
* @param fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
440-
* the `fetch` from node-fetch can be used here.
441-
* @param onProgress Optional, progress callback function, fired periodically
442-
* before the load is completed.
448+
* @param loadOptions Optional configuration for the loading. It includes the
449+
* following fields:
450+
* - weightPathPrefix Optional, this specifies the path prefix for weight
451+
* files, by default this is calculated from the path param.
452+
* - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
453+
* the `fetch` from node-fetch can be used here.
454+
* - onProgress Optional, progress callback function, fired periodically
455+
* before the load is completed.
443456
* @returns An instance of `IOHandler`.
444457
*/
445458
export function browserHTTPRequest(
446-
path: string|string[], requestInit?: RequestInit, weightPathPrefix?: string,
447-
fetchFunc?: Function, onProgress?: Function): IOHandler {
448-
return new BrowserHTTPRequest(
449-
path, requestInit, weightPathPrefix, fetchFunc, onProgress);
459+
path: string|string[], loadOptions?: LoadOptions): IOHandler {
460+
return new BrowserHTTPRequest(path, loadOptions);
450461
}

src/io/browser_http_test.ts

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,11 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {
303303
it('Save topology and weights, PUT method, extra headers', (done) => {
304304
const testStartDate = new Date();
305305
const handler = tf.io.browserHTTPRequest('model-upload-test', {
306-
method: 'PUT',
307-
headers:
308-
{'header_key_1': 'header_value_1', 'header_key_2': 'header_value_2'}
306+
requestInit: {
307+
method: 'PUT',
308+
headers:
309+
{'header_key_1': 'header_value_1', 'header_key_2': 'header_value_2'}
310+
}
309311
});
310312
handler.save(artifacts1)
311313
.then(saveResult => {
@@ -392,7 +394,9 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {
392394

393395
it('Existing body leads to Error', () => {
394396
expect(() => tf.io.browserHTTPRequest('model-upload-test', {
395-
body: 'existing body'
397+
requestInit: {
398+
body: 'existing body'
399+
}
396400
})).toThrowError(/requestInit is expected to have no pre-existing body/);
397401
});
398402

@@ -518,8 +522,11 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
518522
},
519523
requestInits);
520524

521-
const handler = tf.io.browserHTTPRequest(
522-
'./model.json', {headers: {'header_key_1': 'header_value_1'}});
525+
const handler = tf.io.browserHTTPRequest('./model.json', {
526+
requestInit: {
527+
headers: {'header_key_1': 'header_value_1'}
528+
}
529+
});
523530
const modelArtifacts = await handler.load();
524531
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
525532
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
@@ -899,8 +906,8 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
899906
requestInits);
900907

901908
const handler = tf.io.browserHTTPRequest(
902-
['./model.pb', './weights_manifest.json'],
903-
{headers: {'header_key_1': 'header_value_1'}});
909+
['./model.pb', './weights_manifest.json'], {
910+
requestInit: {headers: {'header_key_1': 'header_value_1'}}});
904911
const modelArtifacts = await handler.load();
905912
expect(modelArtifacts.modelTopology).toEqual(modelData);
906913
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
@@ -1089,7 +1096,8 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
10891096
},
10901097
requestInits);
10911098
const handler = tf.io.browserHTTPRequest(
1092-
['path1/model.pb', 'path2/weights_manifest.json'], {}, 'path3/');
1099+
['path1/model.pb', 'path2/weights_manifest.json'],
1100+
{weightPathPrefix: 'path3/'});
10931101
const modelArtifacts = await handler.load();
10941102
expect(modelArtifacts.modelTopology).toEqual(modelData);
10951103
expect(modelArtifacts.weightSpecs)
@@ -1171,7 +1179,10 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
11711179
}
11721180

11731181
const handler = tf.io.browserHTTPRequest(
1174-
'./model.json', {credentials: 'include'}, null, customFetch);
1182+
'./model.json', {
1183+
requestInit: {credentials: 'include'},
1184+
fetchFunc: customFetch
1185+
});
11751186
const modelArtifacts = await handler.load();
11761187
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
11771188
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);

src/io/progress.ts

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {assert} from '../util';
19+
20+
import {OnProgressCallback} from './types';
21+
22+
/**
23+
* Monitor Promise.all progress, fire onProgress callback function.
24+
*
25+
* @param promises Promise list going to be monitored
26+
* @param onProgress Callback function. Fired when a promise resolved.
27+
* @param startFraction Optional fraction start. Default to 0.
28+
* @param endFraction Optional fraction end. Default to 1.
29+
*/
30+
export function monitorPromisesProgress(
31+
promises: Array<Promise<{}|void>>, onProgress: OnProgressCallback,
32+
startFraction?: number, endFraction?: number) {
33+
checkPromises(promises);
34+
startFraction = startFraction == null ? 0 : startFraction;
35+
endFraction = endFraction == null ? 1 : endFraction;
36+
checkFraction(startFraction, endFraction);
37+
let resolvedPromise = 0;
38+
39+
const registerMonitor = (promise: Promise<{}>) => {
40+
promise.then(value => {
41+
const fraction = startFraction +
42+
++resolvedPromise / promises.length * (endFraction - startFraction);
43+
// pass fraction as parameter to callback function.
44+
onProgress(fraction);
45+
return value;
46+
});
47+
return promise;
48+
};
49+
50+
function checkPromises(promises: Array<Promise<{}|void>>): void {
51+
assert(
52+
promises != null && Array.isArray(promises) && promises.length > 0,
53+
() => 'promises must be a none empty array');
54+
}
55+
56+
function checkFraction(startFraction: number, endFraction: number): void {
57+
assert(
58+
startFraction >= 0 && startFraction <= 1,
59+
() => `Progress fraction must be in range [0, 1], but ` +
60+
`got startFraction ${startFraction}`);
61+
assert(
62+
endFraction >= 0 && endFraction <= 1,
63+
() => `Progress fraction must be in range [0, 1], but ` +
64+
`got endFraction ${endFraction}`);
65+
assert(
66+
endFraction >= startFraction,
67+
() => `startFraction must be no more than endFraction, but ` +
68+
`got startFraction ${startFraction} and endFraction ` +
69+
`${endFraction}`);
70+
}
71+
72+
return Promise.all(promises.map(registerMonitor));
73+
}

src/io/progress_test.ts

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {monitorPromisesProgress} from './progress';
19+
20+
describe('util.monitorPromisesProgress', () => {
21+
it('Default progress from 0 to 1', (done) => {
22+
const expectFractions: number[] = [0.25, 0.50, 0.75, 1.00];
23+
const fractionList: number[] = [];
24+
const tasks = Array(4).fill(0).map(() => {
25+
return Promise.resolve();
26+
});
27+
monitorPromisesProgress(tasks, (progress: number) => {
28+
fractionList.push(parseFloat(progress.toFixed(2)));
29+
}).then(() => {
30+
expect(fractionList).toEqual(expectFractions);
31+
done();
32+
});
33+
});
34+
35+
it('Progress with pre-defined range', (done) => {
36+
const startFraction = 0.2;
37+
const endFraction = 0.8;
38+
const expectFractions: number[] = [0.35, 0.50, 0.65, 0.80];
39+
const fractionList: number[] = [];
40+
const tasks = Array(4).fill(0).map(() => {
41+
return Promise.resolve();
42+
});
43+
monitorPromisesProgress(tasks, (progress: number) => {
44+
fractionList.push(parseFloat(progress.toFixed(2)));
45+
}, startFraction, endFraction).then(() => {
46+
expect(fractionList).toEqual(expectFractions);
47+
done();
48+
});
49+
});
50+
51+
it('throws error when progress fraction is out of range', () => {
52+
expect(() => {
53+
const startFraction = -1;
54+
const endFraction = 1;
55+
const tasks = Array(4).fill(0).map(() => {
56+
return Promise.resolve();
57+
});
58+
monitorPromisesProgress(
59+
tasks, (progress: number) => {}, startFraction, endFraction);
60+
}).toThrowError();
61+
});
62+
63+
it('throws error when startFraction more than endFraction', () => {
64+
expect(() => {
65+
const startFraction = 0.8;
66+
const endFraction = 0.2;
67+
const tasks = Array(4).fill(0).map(() => {
68+
return Promise.resolve();
69+
});
70+
monitorPromisesProgress(
71+
tasks, (progress: number) => {}, startFraction, endFraction);
72+
}).toThrowError();
73+
});
74+
75+
it('throws error when promises is null', () => {
76+
expect(() => {
77+
monitorPromisesProgress(null, (progress: number) => {});
78+
}).toThrowError();
79+
});
80+
81+
it('throws error when promises is empty array', () => {
82+
expect(() => {
83+
monitorPromisesProgress([], (progress: number) => {});
84+
}).toThrowError();
85+
});
86+
});

0 commit comments

Comments
 (0)