Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 5736936

Browse files
authored
Merge branch 'master' into use_precomputed_tensor_size
2 parents 1f762f4 + b072cae commit 5736936

24 files changed

+828
-1939
lines changed

package.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@tensorflow/tfjs-core",
3-
"version": "1.0.0-alpha3",
3+
"version": "1.0.0",
44
"description": "Hardware-accelerated JavaScript library for machine intelligence",
55
"private": false,
66
"main": "dist/index.js",
@@ -23,11 +23,11 @@
2323
"clang-format": "~1.2.4",
2424
"jasmine": "~3.1.0",
2525
"jasmine-core": "~3.1.0",
26-
"karma": "~2.0.2",
27-
"karma-browserstack-launcher": "~1.3.0",
26+
"karma": "~4.0.0",
27+
"karma-browserstack-launcher": "~1.4.0",
2828
"karma-chrome-launcher": "~2.2.0",
2929
"karma-jasmine": "~1.1.0",
30-
"karma-typescript": "~3.0.12",
30+
"karma-typescript": "~4.0.0",
3131
"npm-run-all": "~4.1.3",
3232
"rimraf": "~2.6.2",
3333
"rollup": "~0.58.2",

scripts/publish-npm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ fi
5050
yarn build-npm
5151
./scripts/make-version # This is for safety in case you forgot to do 2).
5252
./scripts/tag-version
53-
npm publish --tag next # Remove --tag net when prereleases are done.
53+
npm publish
5454
echo 'Yay! Published a new package to npm.'

src/io/browser_files.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import {ENV} from '../environment';
2424
import {basename, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
2525
import {IORouter, IORouterRegistry} from './router_registry';
26-
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
26+
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2727

2828
const DEFAULT_FILE_NAME_PREFIX = 'model';
2929
const DEFAULT_JSON_EXTENSION_NAME = '.json';
@@ -71,8 +71,11 @@ export class BrowserDownloads implements IOHandler {
7171
paths: ['./' + this.weightDataFileName],
7272
weights: modelArtifacts.weightSpecs
7373
}];
74-
const modelTopologyAndWeightManifest = {
74+
const modelTopologyAndWeightManifest: ModelJSON = {
7575
modelTopology: modelArtifacts.modelTopology,
76+
format: modelArtifacts.format,
77+
generatedBy: modelArtifacts.generatedBy,
78+
convertedBy: modelArtifacts.convertedBy,
7679
weightsManifest
7780
};
7881
const modelTopologyAndWeightManifestURL =
@@ -124,8 +127,8 @@ class BrowserFiles implements IOHandler {
124127
const jsonReader = new FileReader();
125128
jsonReader.onload = (event: Event) => {
126129
// tslint:disable-next-line:no-any
127-
const modelJSON = JSON.parse((event.target as any).result);
128-
const modelTopology = modelJSON.modelTopology as {};
130+
const modelJSON = JSON.parse((event.target as any).result) as ModelJSON;
131+
const modelTopology = modelJSON.modelTopology;
129132
if (modelTopology == null) {
130133
reject(new Error(
131134
`modelTopology field is missing from file ${jsonFile.name}`));
@@ -136,8 +139,7 @@ class BrowserFiles implements IOHandler {
136139
resolve({modelTopology});
137140
}
138141

139-
const weightsManifest =
140-
modelJSON.weightsManifest as WeightsManifestConfig;
142+
const weightsManifest = modelJSON.weightsManifest;
141143
if (weightsManifest == null) {
142144
reject(new Error(
143145
`weightManifest field is missing from file ${jsonFile.name}`));

src/io/browser_files_test.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ const artifacts1: tf.io.ModelArtifacts = {
7474
modelTopology: modelTopology1,
7575
weightSpecs: weightSpecs1,
7676
weightData: weightData1,
77+
format: 'layers-model',
78+
generatedBy: 'TensorFlow.js v0.0.0',
79+
convertedBy: null
7780
};
7881

7982
describeWithFlags('browserDownloads', BROWSER_ENVS, () => {
@@ -127,6 +130,10 @@ describeWithFlags('browserDownloads', BROWSER_ENVS, () => {
127130
JSON.parse(await jsonContent.text());
128131
expect(modelTopologyAndWeightsManifest.modelTopology)
129132
.toEqual(modelTopology1);
133+
expect(modelTopologyAndWeightsManifest.format).toEqual('layers-model');
134+
expect(modelTopologyAndWeightsManifest.generatedBy)
135+
.toEqual('TensorFlow.js v0.0.0');
136+
expect(modelTopologyAndWeightsManifest.convertedBy).toEqual(null);
130137
const weightsManifest = modelTopologyAndWeightsManifest.weightsManifest as
131138
WeightsManifestConfig;
132139
expect(weightsManifest.length).toEqual(1);

src/io/browser_http.ts

Lines changed: 71 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -22,49 +22,55 @@
2222
*/
2323

2424
import {assert} from '../util';
25-
2625
import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
2726
import {IORouter, IORouterRegistry} from './router_registry';
28-
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
27+
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2928
import {loadWeightsAsArrayBuffer} from './weights_loader';
3029

3130
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
3231
const JSON_TYPE = 'application/json';
3332

3433
export class BrowserHTTPRequest implements IOHandler {
35-
protected readonly path: string|string[];
34+
protected readonly path: string;
3635
protected readonly requestInit: RequestInit;
3736

38-
private readonly fetchFunc: Function;
37+
private readonly fetchFunc: (path: string, init?: RequestInit) => Response;
3938

4039
readonly DEFAULT_METHOD = 'POST';
4140

4241
static readonly URL_SCHEME_REGEX = /^https?:\/\//;
4342

44-
constructor(
45-
path: string|string[], requestInit?: RequestInit,
46-
private readonly weightPathPrefix?: string, fetchFunc?: Function,
47-
private readonly onProgress?: Function) {
48-
if (fetchFunc == null) {
43+
private readonly weightPathPrefix: string;
44+
private readonly onProgress: OnProgressCallback;
45+
46+
constructor(path: string, loadOptions?: LoadOptions) {
47+
if (loadOptions == null) {
48+
loadOptions = {};
49+
}
50+
this.weightPathPrefix = loadOptions.weightPathPrefix;
51+
this.onProgress = loadOptions.onProgress;
52+
53+
if (loadOptions.fetchFunc == null) {
4954
if (typeof fetch === 'undefined') {
5055
throw new Error(
5156
'browserHTTPRequest is not supported outside the web browser ' +
5257
'without a fetch polyfill.');
5358
}
5459
// Make sure fetch is always bound to window (the
5560
// original object) when available.
56-
fetchFunc = fetch.bind(typeof window === 'undefined' ? null : window);
61+
loadOptions.fetchFunc =
62+
fetch.bind(typeof window === 'undefined' ? null : window);
5763
} else {
5864
assert(
59-
typeof fetchFunc === 'function',
65+
typeof loadOptions.fetchFunc === 'function',
6066
() => 'Must pass a function that matches the signature of ' +
6167
'`fetch` (see ' +
6268
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
6369
}
6470

6571
this.fetchFunc = (path: string, requestInits: RequestInit) => {
6672
// tslint:disable-next-line:no-any
67-
return fetchFunc(path, requestInits).catch((error: any) => {
73+
return loadOptions.fetchFunc(path, requestInits).catch((error: any) => {
6874
throw new Error(`Request for ${path} failed due to error: ${error}`);
6975
});
7076
};
@@ -83,11 +89,12 @@ export class BrowserHTTPRequest implements IOHandler {
8389
}
8490
this.path = path;
8591

86-
if (requestInit != null && requestInit.body != null) {
92+
if (loadOptions.requestInit != null &&
93+
loadOptions.requestInit.body != null) {
8794
throw new Error(
8895
'requestInit is expected to have no pre-existing body, but has one.');
8996
}
90-
this.requestInit = requestInit || {};
97+
this.requestInit = loadOptions.requestInit || {};
9198
}
9299

93100
async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
@@ -104,8 +111,11 @@ export class BrowserHTTPRequest implements IOHandler {
104111
paths: ['./model.weights.bin'],
105112
weights: modelArtifacts.weightSpecs,
106113
}];
107-
const modelTopologyAndWeightManifest = {
114+
const modelTopologyAndWeightManifest: ModelJSON = {
108115
modelTopology: modelArtifacts.modelTopology,
116+
format: modelArtifacts.format,
117+
generatedBy: modelArtifacts.generatedBy,
118+
convertedBy: modelArtifacts.convertedBy,
109119
weightsManifest
110120
};
111121

@@ -123,7 +133,7 @@ export class BrowserHTTPRequest implements IOHandler {
123133
'model.weights.bin');
124134
}
125135

126-
const response = await this.getFetchFunc()(this.path as string, init);
136+
const response = await this.getFetchFunc()(this.path, init);
127137

128138
if (response.ok) {
129139
return {
@@ -146,59 +156,37 @@ export class BrowserHTTPRequest implements IOHandler {
146156
* @returns The loaded model artifacts (if loading succeeds).
147157
*/
148158
async load(): Promise<ModelArtifacts> {
149-
return Array.isArray(this.path) ? this.loadBinaryModel() :
150-
this.loadJSONModel();
151-
}
152-
153-
/**
154-
* Loads the model topology file and build the in memory graph of the model.
155-
*/
156-
private async loadBinaryTopology(): Promise<ArrayBuffer> {
157-
const response = await this.getFetchFunc()(this.path[0], this.requestInit);
158-
159-
if (!response.ok) {
160-
throw new Error(`Request to ${this.path[0]} failed with error: ${
161-
response.statusText}`);
162-
}
163-
return await response.arrayBuffer();
164-
}
165-
166-
protected async loadBinaryModel(): Promise<ModelArtifacts> {
167-
const graphPromise = this.loadBinaryTopology();
168-
const manifestPromise =
169-
await this.getFetchFunc()(this.path[1], this.requestInit);
170-
if (!manifestPromise.ok) {
171-
throw new Error(`Request to ${this.path[1]} failed with error: ${
172-
manifestPromise.statusText}`);
173-
}
174-
175-
const results = await Promise.all([graphPromise, manifestPromise]);
176-
const [modelTopology, weightsManifestResponse] = results;
177-
178-
const weightsManifest =
179-
await weightsManifestResponse.json() as WeightsManifestConfig;
180-
181-
let weightSpecs: WeightsManifestEntry[];
182-
let weightData: ArrayBuffer;
183-
if (weightsManifest != null) {
184-
const results = await this.loadWeights(weightsManifest);
185-
[weightSpecs, weightData] = results;
186-
}
187-
188-
return {modelTopology, weightSpecs, weightData};
189-
}
190-
191-
protected async loadJSONModel(): Promise<ModelArtifacts> {
192159
const modelConfigRequest =
193-
await this.getFetchFunc()(this.path as string, this.requestInit);
160+
await this.getFetchFunc()(this.path, this.requestInit);
194161

195162
if (!modelConfigRequest.ok) {
196-
throw new Error(`Request to ${this.path} failed with error: ${
197-
modelConfigRequest.statusText}`);
163+
throw new Error(
164+
`Request to ${this.path} failed with status code ` +
165+
`${modelConfigRequest.status}. Please verify this URL points to ` +
166+
`the model JSON of the model to load.`);
198167
}
199-
const modelConfig = await modelConfigRequest.json();
200-
const modelTopology = modelConfig['modelTopology'];
201-
const weightsManifest = modelConfig['weightsManifest'];
168+
let modelConfig: ModelJSON;
169+
try {
170+
modelConfig = await modelConfigRequest.json();
171+
} catch (e) {
172+
let message = `Failed to parse model JSON of response from ${this.path}.`;
173+
// TODO(nsthorat): Remove this after some time when we're comfortable that
174+
// .pb files are mostly gone.
175+
if (this.path.endsWith('.pb')) {
176+
message += ' Your path contains a .pb file extension. ' +
177+
'Support for .pb models have been removed in TensorFlow.js 1.0 ' +
178+
'in favor of .json models. You can re-convert your Python ' +
179+
'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' +
180+
'or you can convert your.pb models with the \'pb2json\'' +
181+
'NPM script in the tensorflow/tfjs-converter repository.';
182+
} else {
183+
message += ' Please make sure the server is serving valid ' +
184+
'JSON for this request.';
185+
}
186+
throw new Error(message);
187+
}
188+
const modelTopology = modelConfig.modelTopology;
189+
const weightsManifest = modelConfig.weightsManifest;
202190

203191
// We do not allow both modelTopology and weightsManifest to be missing.
204192
if (modelTopology == null && weightsManifest == null) {
@@ -210,8 +198,6 @@ export class BrowserHTTPRequest implements IOHandler {
210198
let weightSpecs: WeightsManifestEntry[];
211199
let weightData: ArrayBuffer;
212200
if (weightsManifest != null) {
213-
const weightsManifest =
214-
modelConfig['weightsManifest'] as WeightsManifestConfig;
215201
const results = await this.loadWeights(weightsManifest);
216202
[weightSpecs, weightData] = results;
217203
}
@@ -236,8 +222,11 @@ export class BrowserHTTPRequest implements IOHandler {
236222
fetchURLs.push(pathPrefix + path + suffix);
237223
});
238224
});
239-
const buffers = await loadWeightsAsArrayBuffer(
240-
fetchURLs, this.requestInit, this.getFetchFunc(), this.onProgress);
225+
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
226+
requestInit: this.requestInit,
227+
fetchFunc: this.getFetchFunc(),
228+
onProgress: this.onProgress
229+
});
241230
return [weightSpecs, concatenateArrayBuffers(buffers)];
242231
}
243232

@@ -278,7 +267,7 @@ export function isHTTPScheme(url: string): boolean {
278267
}
279268

280269
export const httpRequestRouter: IORouter =
281-
(url: string|string[], onProgress?: Function) => {
270+
(url: string, onProgress?: OnProgressCallback) => {
282271
if (typeof fetch === 'undefined') {
283272
// browserHTTPRequest uses `fetch`, if one wants to use it in node.js
284273
// they have to setup a global fetch polyfill.
@@ -291,7 +280,7 @@ export const httpRequestRouter: IORouter =
291280
isHTTP = isHTTPScheme(url);
292281
}
293282
if (isHTTP) {
294-
return browserHTTPRequest(url, null, null, null, onProgress);
283+
return browserHTTPRequest(url, {onProgress});
295284
}
296285
}
297286
return null;
@@ -419,9 +408,7 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
419408
* main()
420409
* ```
421410
*
422-
* @param path A single URL path or an Array of URL paths.
423-
* Currently, only a single URL path is supported. Array input is reserved
424-
* for future development.
411+
* @param path A URL path to the model.
425412
* Can be an absolute HTTP path (e.g.,
426413
* 'http://localhost:8000/model-upload)') or a relative path (e.g.,
427414
* './model-upload').
@@ -434,17 +421,17 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
434421
* topology (filename: 'model.json') and the weights of the model (filename:
435422
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
436423
* `body`, an Error will be thrown.
437-
* @param weightPathPrefix Optional, this specifies the path prefix for weight
438-
* files, by default this is calculated from the path param.
439-
* @param fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
440-
* the `fetch` from node-fetch can be used here.
441-
* @param onProgress Optional, progress callback function, fired periodically
442-
* before the load is completed.
424+
* @param loadOptions Optional configuration for the loading. It includes the
425+
* following fields:
426+
* - weightPathPrefix Optional, this specifies the path prefix for weight
427+
* files, by default this is calculated from the path param.
428+
* - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
429+
* the `fetch` from node-fetch can be used here.
430+
* - onProgress Optional, progress callback function, fired periodically
431+
* before the load is completed.
443432
* @returns An instance of `IOHandler`.
444433
*/
445434
export function browserHTTPRequest(
446-
path: string|string[], requestInit?: RequestInit, weightPathPrefix?: string,
447-
fetchFunc?: Function, onProgress?: Function): IOHandler {
448-
return new BrowserHTTPRequest(
449-
path, requestInit, weightPathPrefix, fetchFunc, onProgress);
435+
path: string, loadOptions?: LoadOptions): IOHandler {
436+
return new BrowserHTTPRequest(path, loadOptions);
450437
}

0 commit comments

Comments
 (0)