Skip to content

Commit 4d33109

Browse files
committed
save
2 parents b224e99 + 6625e66 commit 4d33109

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1403
-1815
lines changed

package.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@tensorflow/tfjs-core",
3-
"version": "1.0.0-alpha3",
3+
"version": "1.0.0-alpha4",
44
"description": "Hardware-accelerated JavaScript library for machine intelligence",
55
"private": false,
66
"main": "dist/index.js",
@@ -23,11 +23,11 @@
2323
"clang-format": "~1.2.4",
2424
"jasmine": "~3.1.0",
2525
"jasmine-core": "~3.1.0",
26-
"karma": "~2.0.2",
27-
"karma-browserstack-launcher": "~1.3.0",
26+
"karma": "~4.0.0",
27+
"karma-browserstack-launcher": "~1.4.0",
2828
"karma-chrome-launcher": "~2.2.0",
2929
"karma-jasmine": "~1.1.0",
30-
"karma-typescript": "~3.0.12",
30+
"karma-typescript": "~4.0.0",
3131
"npm-run-all": "~4.1.3",
3232
"rimraf": "~2.6.2",
3333
"rollup": "~0.58.2",

src/engine.ts

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,8 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
478478
gradients<T extends Tensor>(
479479
f: () => T, xs: Tensor[], dy?: T,
480480
allowNoGradients = false): {value: T, grads: Tensor[]} {
481-
util.assert(xs.length > 0, 'gradients() received an empty list of xs.');
481+
util.assert(
482+
xs.length > 0, () => 'gradients() received an empty list of xs.');
482483
if (dy != null && dy.dtype !== 'float32') {
483484
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
484485
}
@@ -487,7 +488,7 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
487488
const y = f();
488489
util.assert(
489490
y instanceof Tensor,
490-
'The result y returned by f() must be a tensor.');
491+
() => 'The result y returned by f() must be a tensor.');
491492
// Filter out the nodes that don't connect x => y.
492493
const filteredTape = getFilteredNodesXToY(this.activeTape, xs, y);
493494
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
@@ -512,11 +513,12 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
512513
(...args: Tensor[]) => T {
513514
util.assert(
514515
util.isFunction(f),
515-
'The f passed in customGrad(f) must be a function.');
516+
() => 'The f passed in customGrad(f) must be a function.');
516517
return (...inputs: Tensor[]): T => {
517518
util.assert(
518519
inputs.every(t => t instanceof Tensor),
519-
'The args passed in customGrad(f)(x1, x2,...) must all be tensors');
520+
() => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
521+
'tensors');
520522

521523
let gradientsFunc: (dy: T) => Tensor | Tensor[];
522524
let result: T;
@@ -528,11 +530,13 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
528530
const {value, gradFunc} = f(...inputs);
529531
util.assert(
530532
value instanceof Tensor,
531-
'The function f passed in customGrad(f) must return an ' +
533+
() =>
534+
'The function f passed in customGrad(f) must return an ' +
532535
'object where `obj.value` is a tensor');
533536
util.assert(
534537
util.isFunction(gradFunc),
535-
'The function f passed in customGrad(f) must return an ' +
538+
() =>
539+
'The function f passed in customGrad(f) must return an ' +
536540
'object where `obj.gradFunc` is a function.');
537541
gradientsFunc = gradFunc;
538542
return value;
@@ -545,14 +549,14 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
545549
const grads: Tensor[] = Array.isArray(res) ? res : [res];
546550
util.assert(
547551
grads.length === inputs.length,
548-
'The function f passed in customGrad(f) must return an object ' +
549-
'where `obj.gradFunc` is a function that returns the same ' +
550-
'number of tensors as inputs passed to f(...).');
552+
() => 'The function f passed in customGrad(f) must return an ' +
553+
'object where `obj.gradFunc` is a function that returns ' +
554+
'the same number of tensors as inputs passed to f(...).');
551555
util.assert(
552556
grads.every(t => t instanceof Tensor),
553-
'The function f passed in customGrad(f) must return an object ' +
554-
'where `obj.gradFunc` is a function that returns a list of ' +
555-
'only tensors.');
557+
() => 'The function f passed in customGrad(f) must return an ' +
558+
'object where `obj.gradFunc` is a function that returns a ' +
559+
'list of only tensors.');
556560
return grads;
557561
};
558562
this.addTapeNode(inputs, result, gradFunc);

src/environment.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ export class Environment {
322322
return this.get('WEBGL_PACK');
323323
} else if (feature === 'WEBGL_PACK_IMAGE_OPERATIONS') {
324324
return this.get('WEBGL_PACK');
325+
} else if (feature === 'WEBGL_PACK_REDUCE') {
326+
return this.get('WEBGL_PACK');
325327
} else if (feature === 'WEBGL_LAZILY_UNPACK') {
326328
return this.get('WEBGL_PACK');
327329
} else if (feature === 'WEBGL_CONV_IM2COL') {

src/environment_util.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ export interface Features {
4242
'WEBGL_PACK_ARRAY_OPERATIONS'?: boolean;
4343
// Whether we will pack image operations.
4444
'WEBGL_PACK_IMAGE_OPERATIONS'?: boolean;
45+
// Whether we will pack reduction ops.
46+
'WEBGL_PACK_REDUCE'?: boolean;
4547
// Whether we will use the im2col algorithm to speed up convolutions.
4648
'WEBGL_CONV_IM2COL'?: boolean;
4749
// The maximum texture dimension.
@@ -116,6 +118,7 @@ export const URL_PROPERTIES: URLProperty[] = [
116118
{name: 'WEBGL_PACK_BINARY_OPERATIONS', type: Type.BOOLEAN},
117119
{name: 'WEBGL_PACK_ARRAY_OPERATIONS', type: Type.BOOLEAN},
118120
{name: 'WEBGL_PACK_IMAGE_OPERATIONS', type: Type.BOOLEAN},
121+
{name: 'WEBGL_PACK_REDUCE', type: Type.BOOLEAN},
119122
{name: 'WEBGL_CONV_IM2COL', type: Type.BOOLEAN},
120123
{name: 'WEBGL_MAX_TEXTURE_SIZE', type: Type.NUMBER},
121124
{name: 'WEBGL_NUM_MB_BEFORE_PAGING', type: Type.NUMBER},

src/gradients.ts

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@ function gradScope<T extends TensorContainer>(
7272
/** @doc {heading: 'Training', subheading: 'Gradients'} */
7373
function grad<I extends Tensor, O extends Tensor>(f: (x: I) => O): (
7474
x: I, dy?: O) => I {
75-
util.assert(util.isFunction(f), 'The f passed in grad(f) must be a function');
75+
util.assert(
76+
util.isFunction(f), () => 'The f passed in grad(f) must be a function');
7677
return (x: I, dy?: O): I => {
7778
util.assert(
78-
x instanceof Tensor, 'The x passed in grad(f)(x) must be a tensor');
79+
x instanceof Tensor,
80+
() => 'The x passed in grad(f)(x) must be a tensor');
7981
util.assert(
8082
dy == null || dy instanceof Tensor,
81-
'The dy passed in grad(f)(x, dy) must be a tensor');
83+
() => 'The dy passed in grad(f)(x, dy) must be a tensor');
8284
return ENV.engine.tidy(() => {
8385
const {value, grads} = ENV.engine.gradients(() => f(x), [x], dy);
8486
if (dy != null) {
@@ -124,14 +126,14 @@ function grad<I extends Tensor, O extends Tensor>(f: (x: I) => O): (
124126
function grads<O extends Tensor>(f: (...args: Tensor[]) => O): (
125127
args: Tensor[], dy?: O) => Tensor[] {
126128
util.assert(
127-
util.isFunction(f), 'The f passed in grads(f) must be a function');
129+
util.isFunction(f), () => 'The f passed in grads(f) must be a function');
128130
return (args: Tensor[], dy?: O): Tensor[] => {
129131
util.assert(
130132
Array.isArray(args) && args.every(arg => arg instanceof Tensor),
131-
'The args passed in grads(f)(args) must be an array of tensors');
133+
() => 'The args passed in grads(f)(args) must be an array of tensors');
132134
util.assert(
133135
dy == null || dy instanceof Tensor,
134-
'The dy passed in grads(f)(args, dy) must be a tensor');
136+
() => 'The dy passed in grads(f)(args, dy) must be a tensor');
135137
return ENV.engine.tidy(() => {
136138
const {value, grads} = ENV.engine.gradients(() => f(...args), args, dy);
137139
if (dy != null) {
@@ -176,14 +178,15 @@ function valueAndGrad<I extends Tensor, O extends Tensor>(f: (x: I) => O): (
176178
grad: I;
177179
} {
178180
util.assert(
179-
util.isFunction(f), 'The f passed in valueAndGrad(f) must be a function');
181+
util.isFunction(f),
182+
() => 'The f passed in valueAndGrad(f) must be a function');
180183
return (x: I, dy?: O) => {
181184
util.assert(
182185
x instanceof Tensor,
183-
'The x passed in valueAndGrad(f)(x) must be a tensor');
186+
() => 'The x passed in valueAndGrad(f)(x) must be a tensor');
184187
util.assert(
185188
dy == null || dy instanceof Tensor,
186-
'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');
189+
() => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');
187190
const {grads, value} = ENV.engine.gradients(() => f(x), [x], dy);
188191
checkGrads(grads);
189192
return {grad: grads[0] as I, value: value as O};
@@ -227,14 +230,15 @@ function valueAndGrads<O extends Tensor>(f: (...args: Tensor[]) => O): (
227230
} {
228231
util.assert(
229232
util.isFunction(f),
230-
'The f passed in valueAndGrads(f) must be a function');
233+
() => 'The f passed in valueAndGrads(f) must be a function');
231234
return (args: Tensor[], dy?: O) => {
232235
util.assert(
233236
Array.isArray(args) && args.every(arg => arg instanceof Tensor),
234-
'The args passed in valueAndGrads(f)(args) must be array of tensors');
237+
() => 'The args passed in valueAndGrads(f)(args) must be array of ' +
238+
'tensors');
235239
util.assert(
236240
dy == null || dy instanceof Tensor,
237-
'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');
241+
() => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');
238242
const res = ENV.engine.gradients(() => f(...args), args, dy);
239243
if (dy != null) {
240244
util.assertShapesMatch(
@@ -273,11 +277,12 @@ function variableGrads(f: () => Scalar, varList?: Variable[]):
273277
{value: Scalar, grads: NamedTensorMap} {
274278
util.assert(
275279
util.isFunction(f),
276-
'The f passed in variableGrads(f) must be a function');
280+
() => 'The f passed in variableGrads(f) must be a function');
277281
util.assert(
278282
varList == null ||
279283
Array.isArray(varList) && varList.every(v => v instanceof Variable),
280-
'The varList passed in variableGrads(f, varList) must be an array ' +
284+
() =>
285+
'The varList passed in variableGrads(f, varList) must be an array ' +
281286
'of variables');
282287
if (varList == null) {
283288
// Get all of the trainable variables.
@@ -291,7 +296,8 @@ function variableGrads(f: () => Scalar, varList?: Variable[]):
291296
varList = varList.filter(variable => variable.trainable);
292297
util.assert(
293298
varList.length > 0,
294-
`variableGrads() expects at least one of the input variables to be ` +
299+
() =>
300+
`variableGrads() expects at least one of the input variables to be ` +
295301
`trainable, but none of the ${originalVarCount} variables is ` +
296302
`trainable.`);
297303

@@ -301,12 +307,12 @@ function variableGrads(f: () => Scalar, varList?: Variable[]):
301307

302308
util.assert(
303309
grads.some(g => g != null),
304-
'Cannot find a connection between any variable and the result of the ' +
305-
'loss function y=f(x). Please make sure the operations that use ' +
306-
'variables are inside the function f passed to minimize().');
310+
() => 'Cannot find a connection between any variable and the result of ' +
311+
'the loss function y=f(x). Please make sure the operations that ' +
312+
'use variables are inside the function f passed to minimize().');
307313
util.assert(
308314
value.rank === 0,
309-
`The f passed in variableGrads(f) must return a scalar, but it ` +
315+
() => `The f passed in variableGrads(f) must return a scalar, but it ` +
310316
`returned a rank-${value.rank} tensor`);
311317

312318
const namedGrads: NamedTensorMap = {};

src/io/browser_http.ts

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {assert} from '../util';
2525

2626
import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
2727
import {IORouter, IORouterRegistry} from './router_registry';
28-
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
28+
import {IOHandler, LoadOptions, ModelArtifacts, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
2929
import {loadWeightsAsArrayBuffer} from './weights_loader';
3030

3131
const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
@@ -41,52 +41,61 @@ export class BrowserHTTPRequest implements IOHandler {
4141

4242
static readonly URL_SCHEME_REGEX = /^https?:\/\//;
4343

44-
constructor(
45-
path: string|string[], requestInit?: RequestInit,
46-
private readonly weightPathPrefix?: string, fetchFunc?: Function,
47-
private readonly onProgress?: Function) {
48-
if (fetchFunc == null) {
44+
private readonly weightPathPrefix: string;
45+
private readonly onProgress: OnProgressCallback;
46+
47+
constructor(path: string|string[], loadOptions?: LoadOptions) {
48+
if (loadOptions == null) {
49+
loadOptions = {};
50+
}
51+
this.weightPathPrefix = loadOptions.weightPathPrefix;
52+
this.onProgress = loadOptions.onProgress;
53+
54+
if (loadOptions.fetchFunc == null) {
4955
if (typeof fetch === 'undefined') {
5056
throw new Error(
5157
'browserHTTPRequest is not supported outside the web browser ' +
5258
'without a fetch polyfill.');
5359
}
5460
// Make sure fetch is always bound to window (the
5561
// original object) when available.
56-
fetchFunc = fetch.bind(typeof window === 'undefined' ? null : window);
62+
loadOptions.fetchFunc =
63+
fetch.bind(typeof window === 'undefined' ? null : window);
5764
} else {
5865
assert(
59-
typeof fetchFunc === 'function',
60-
'Must pass a function that matches the signature of ' +
66+
typeof loadOptions.fetchFunc === 'function',
67+
() => 'Must pass a function that matches the signature of ' +
6168
'`fetch` (see ' +
6269
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
6370
}
6471

6572
this.fetchFunc = (path: string, requestInits: RequestInit) => {
6673
// tslint:disable-next-line:no-any
67-
return fetchFunc(path, requestInits).catch((error: any) => {
74+
return loadOptions.fetchFunc(path, requestInits).catch((error: any) => {
6875
throw new Error(`Request for ${path} failed due to error: ${error}`);
6976
});
7077
};
7178

7279
assert(
7380
path != null && path.length > 0,
74-
'URL path for browserHTTPRequest must not be null, undefined or ' +
81+
() =>
82+
'URL path for browserHTTPRequest must not be null, undefined or ' +
7583
'empty.');
7684

7785
if (Array.isArray(path)) {
7886
assert(
7987
path.length === 2,
80-
'URL paths for browserHTTPRequest must have a length of 2, ' +
88+
() => 'URL paths for browserHTTPRequest must have a length of 2, ' +
8189
`(actual length is ${path.length}).`);
8290
}
8391
this.path = path;
8492

85-
if (requestInit != null && requestInit.body != null) {
93+
if (loadOptions.requestInit != null &&
94+
loadOptions.requestInit.body != null) {
8695
throw new Error(
8796
'requestInit is expected to have no pre-existing body, but has one.');
8897
}
89-
this.requestInit = requestInit || {};
98+
this.requestInit = loadOptions.requestInit || {};
9099
}
91100

92101
async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
@@ -235,8 +244,11 @@ export class BrowserHTTPRequest implements IOHandler {
235244
fetchURLs.push(pathPrefix + path + suffix);
236245
});
237246
});
238-
const buffers = await loadWeightsAsArrayBuffer(
239-
fetchURLs, this.requestInit, this.getFetchFunc(), this.onProgress);
247+
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
248+
requestInit: this.requestInit,
249+
fetchFunc: this.getFetchFunc(),
250+
onProgress: this.onProgress
251+
});
240252
return [weightSpecs, concatenateArrayBuffers(buffers)];
241253
}
242254

@@ -277,7 +289,7 @@ export function isHTTPScheme(url: string): boolean {
277289
}
278290

279291
export const httpRequestRouter: IORouter =
280-
(url: string|string[], onProgress?: Function) => {
292+
(url: string|string[], onProgress?: OnProgressCallback) => {
281293
if (typeof fetch === 'undefined') {
282294
// browserHTTPRequest uses `fetch`, if one wants to use it in node.js
283295
// they have to setup a global fetch polyfill.
@@ -290,7 +302,7 @@ export const httpRequestRouter: IORouter =
290302
isHTTP = isHTTPScheme(url);
291303
}
292304
if (isHTTP) {
293-
return browserHTTPRequest(url, null, null, null, onProgress);
305+
return browserHTTPRequest(url, {onProgress});
294306
}
295307
}
296308
return null;
@@ -433,17 +445,17 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
433445
* topology (filename: 'model.json') and the weights of the model (filename:
434446
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
435447
* `body`, an Error will be thrown.
436-
* @param weightPathPrefix Optional, this specifies the path prefix for weight
437-
* files, by default this is calculated from the path param.
438-
* @param fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
439-
* the `fetch` from node-fetch can be used here.
440-
* @param onProgress Optional, progress callback function, fired periodically
441-
* before the load is completed.
448+
* @param loadOptions Optional configuration for the loading. It includes the
449+
* following fields:
450+
* - weightPathPrefix Optional, this specifies the path prefix for weight
451+
* files, by default this is calculated from the path param.
452+
* - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
453+
* the `fetch` from node-fetch can be used here.
454+
* - onProgress Optional, progress callback function, fired periodically
455+
* before the load is completed.
442456
* @returns An instance of `IOHandler`.
443457
*/
444458
export function browserHTTPRequest(
445-
path: string|string[], requestInit?: RequestInit, weightPathPrefix?: string,
446-
fetchFunc?: Function, onProgress?: Function): IOHandler {
447-
return new BrowserHTTPRequest(
448-
path, requestInit, weightPathPrefix, fetchFunc, onProgress);
459+
path: string|string[], loadOptions?: LoadOptions): IOHandler {
460+
return new BrowserHTTPRequest(path, loadOptions);
449461
}

0 commit comments

Comments
 (0)