Skip to content

Commit eab626e

Browse files
authored
Allow users to pass a list of TypedArray in tf.tensor() (tensorflow#1424)
Fixes tensorflow/tfjs#959 FEATURE
1 parent dce101e commit eab626e

File tree

10 files changed

+118
-68
lines changed

10 files changed

+118
-68
lines changed

src/kernels/backend_cpu_test.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ describe('backendCPU', () => {
6464
});
6565

6666
it('register string tensor with values and mismatched shape', () => {
67-
expect(() => tf.Tensor.make([4], {values: ['a', 'b', 'c']}, 'string'))
68-
.toThrowError();
67+
expect(() => tf.tensor(['a', 'b', 'c'], [4], 'string')).toThrowError();
6968
});
7069
});

src/kernels/backend_webgl_test.ts

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => {
6363
});
6464

6565
it('should work when the same input must be represented by' +
66-
'different textures', () => {
67-
const a = tf.tensor1d([1, 2]);
68-
const res = tf.dot(a, a);
69-
expectArraysClose(res, [5]);
70-
});
66+
'different textures',
67+
() => {
68+
const a = tf.tensor1d([1, 2]);
69+
const res = tf.dot(a, a);
70+
expectArraysClose(res, [5]);
71+
});
7172
});
7273

7374
describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
@@ -124,9 +125,7 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
124125
const backend = new MathBackendWebGL();
125126
tf.ENV.registerBackend('test-storage', () => backend);
126127
tf.setBackend('test-storage');
127-
128-
expect(() => tf.Tensor.make([4], {values: ['a', 'b', 'c']}, 'string'))
129-
.toThrowError();
128+
expect(() => tf.tensor(['a', 'b', 'c'], [4], 'string')).toThrowError();
130129
});
131130

132131
it('delayed storage, reading', () => {

src/ops/array_ops.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ function multinomial_(
288288
function oneHot_(
289289
indices: Tensor1D|TensorLike1D, depth: number, onValue = 1,
290290
offValue = 0): Tensor2D {
291-
const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
291+
const $indices =
292+
convertToTensor(indices, 'indices', 'oneHot', 'int32') as Tensor1D;
292293

293294
if (depth < 2) {
294295
throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
@@ -1061,7 +1062,7 @@ function expandDims_<R2 extends Rank>(
10611062
function depthToSpace_(
10621063
x: Tensor4D|TensorLike4D, blockSize: number,
10631064
dataFormat: 'NHWC'|'NCHW' = 'NHWC'): Tensor4D {
1064-
const $x = convertToTensor(x, 'x', 'depthToSpace');
1065+
const $x = convertToTensor(x, 'x', 'depthToSpace') as Tensor4D;
10651066

10661067
const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];
10671068
const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];

src/ops/tensor_ops.ts

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor
2020
import {convertToTensor, inferShape} from '../tensor_util_env';
2121
import {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TensorLike5D, TensorLike6D, TypedArray} from '../types';
2222
import {DataType, Rank, ShapeMap} from '../types';
23-
import {assertNonNull, assertShapesMatch, flatten, getArrayFromDType, inferDtype, isTypedArray, makeOnesTypedArray, makeZerosTypedArray, sizeFromShape, toTypedArray} from '../util';
23+
import {assert, assertNonNull, flatten, getArrayFromDType, inferDtype, isTypedArray, makeOnesTypedArray, makeZerosTypedArray, sizeFromShape, toTypedArray} from '../util';
24+
2425
import {complex} from './complex_ops';
2526
import {op} from './operation';
2627

@@ -68,19 +69,33 @@ function tensor<R extends Rank>(
6869
'an array of numbers/booleans/strings, or a TypedArray');
6970
}
7071
const inferredShape = inferShape(values);
71-
if (shape != null && inferredShape.length !== 1) {
72-
assertShapesMatch(
73-
shape, inferredShape,
74-
`Error creating a new Tensor. ` +
75-
`Inferred shape (${inferredShape}) does not match the ` +
76-
`provided shape (${shape}). `);
72+
if (shape != null) {
73+
const providedSize = sizeFromShape(shape);
74+
const inferredSize = sizeFromShape(inferredShape);
75+
assert(
76+
providedSize === inferredSize,
77+
() =>
78+
`Based on the provided shape, [${shape}], the tensor should have ` +
79+
`${providedSize} values but has ${inferredSize}`);
80+
81+
for (let i = 0; i < inferredShape.length; ++i) {
82+
const inferred = inferredShape[i];
83+
const flatDimsDontMatch = i === inferredShape.length - 1 ?
84+
inferred !== sizeFromShape(shape.slice(i)) :
85+
true;
86+
assert(
87+
inferredShape[i] === shape[i] || !flatDimsDontMatch,
88+
() => `Error creating a new Tensor. Inferred shape ` +
89+
`(${inferredShape}) does not match the provided ` +
90+
`shape (${shape}). `);
91+
}
7792
}
7893
if (!isTypedArray(values) && !Array.isArray(values)) {
7994
values = [values] as number[];
8095
}
8196
shape = shape || inferredShape;
8297
values = dtype !== 'string' ? toTypedArray(values, dtype, ENV.get('DEBUG')) :
83-
flatten(values) as string[];
98+
flatten(values as string[]) as string[];
8499
return Tensor.make(shape, {values}, dtype);
85100
}
86101

src/tensor.ts

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,6 @@ export class Tensor<R extends Rank = Rank> {
410410
this.shape = shape.slice();
411411
this.dtype = dtype || 'float32';
412412
this.size = util.sizeFromShape(shape);
413-
if (values != null) {
414-
util.assert(
415-
this.size === values.length,
416-
`Based on the provided shape, [${shape}], and dtype ` +
417-
`${this.dtype}, the tensor should have ` +
418-
`${this.size} values but has ${values.length}`);
419-
}
420-
421413
this.strides = computeStrides(shape);
422414
this.dataId = dataId != null ? dataId : {};
423415
this.id = trackerFn().nextTensorId();

src/tensor_test.ts

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,64 @@ describeWithFlags('tensor', ALL_ENVS, () => {
575575
expectArraysClose(a, [1, 5, 2]);
576576
});
577577

578+
it('tf.tensor() from Float32Array and number[]', () => {
579+
const a = tf.tensor([
580+
new Float32Array([1, 2]), new Float32Array([3, 4]),
581+
new Float32Array([5, 6]), [7, 8]
582+
]);
583+
expect(a.dtype).toBe('float32');
584+
expect(a.shape).toEqual([4, 2]);
585+
expectArraysClose(a, [1, 2, 3, 4, 5, 6, 7, 8]);
586+
});
587+
588+
it('tf.tensor() from Int32Array and number[]', () => {
589+
const a = tf.tensor([
590+
new Int32Array([1, 2]), new Int32Array([3, 4]), new Int32Array([5, 6]),
591+
[7, 8]
592+
]);
593+
expect(a.dtype).toBe('int32');
594+
expect(a.shape).toEqual([4, 2]);
595+
expectArraysClose(a, [1, 2, 3, 4, 5, 6, 7, 8]);
596+
});
597+
598+
it('tf.tensor() from mixed TypedArray', () => {
599+
const a = tf.tensor([
600+
new Float32Array([1, 2]), new Int32Array([3, 4]), new Uint8Array([5, 6]),
601+
[7, 8]
602+
]);
603+
expect(a.dtype).toBe('float32');
604+
expect(a.shape).toEqual([4, 2]);
605+
expectArraysClose(a, [1, 2, 3, 4, 5, 6, 7, 8]);
606+
});
607+
608+
it('tf.tensor() from TypedArrays which are themselves 3D', () => {
609+
// 2 tensors, each with shape 20x20x3, as flat Float32Arrays.
610+
const img1 = new Float32Array(20 * 20 * 3);
611+
const img2 = new Float32Array(20 * 20 * 3);
612+
const t = tf.tensor([img1, img2], [2, 20, 20, 3]);
613+
expect(t.dtype).toBe('float32');
614+
expect(t.shape).toEqual([2, 20, 20, 3]);
615+
});
616+
617+
it('tf.tensor() from TypeedArrays which are themselves 3D, wrong shape',
618+
() => {
619+
const img1 = new Float32Array(20 * 20 * 3);
620+
const img2 = new Float32Array(20 * 20 * 3);
621+
expect(() => tf.tensor([img1, img2], [3, 20, 20, 3])).toThrowError();
622+
});
623+
624+
it('tf.tensor() from TypedArray + number[] fails due to wrong shape', () => {
625+
expect(() => tf.tensor([
626+
new Float32Array([1, 2]),
627+
new Float32Array([3, 4]),
628+
new Float32Array([5, 6]),
629+
// Should be of length 4
630+
[7, 8, 9, 10],
631+
]))
632+
.toThrowError(
633+
/Element arr\[3\] should have 2 elements, but has 4 elements/);
634+
});
635+
578636
it('default dtype from ascii string', () => {
579637
const a = tf.tensor('hello');
580638
expect(a.dtype).toBe('string');
@@ -1951,15 +2009,6 @@ describeWithFlags('tensor with 0 in shape', ALL_ENVS, () => {
19512009
expectArraysEqual(a, []);
19522010
});
19532011

1954-
it('1d throws when values are not empty', () => {
1955-
const values = new Float32Array([1, 2, 3]);
1956-
// Have to use Tensor.make since tensor1d() does not let us provide a shape.
1957-
expect(() => Tensor.make([0], {values}, 'float32'))
1958-
.toThrowError(
1959-
'Based on the provided shape, [0], and dtype float32, the tensor ' +
1960-
'should have 0 values but has 3');
1961-
});
1962-
19632012
it('2d of shape [0, 5]', () => {
19642013
const a = tf.tensor2d([], [0, 5]);
19652014
expect(a.dtype).toBe('float32');
@@ -1980,7 +2029,7 @@ describeWithFlags('tensor with 0 in shape', ALL_ENVS, () => {
19802029
const values = [1, 2, 3, 4];
19812030
expect(() => tf.tensor2d(values, [0, 5], 'float32'))
19822031
.toThrowError(
1983-
'Based on the provided shape, [0,5], and dtype float32, the ' +
2032+
'Based on the provided shape, [0,5], the ' +
19842033
'tensor should have 0 values but has 4');
19852034
});
19862035

@@ -1996,7 +2045,7 @@ describeWithFlags('tensor with 0 in shape', ALL_ENVS, () => {
19962045
const values = [1, 2, 3];
19972046
expect(() => tf.tensor3d(values, [0, 3, 0], 'float32'))
19982047
.toThrowError(
1999-
'Based on the provided shape, [0,3,0], and dtype float32, the ' +
2048+
'Based on the provided shape, [0,3,0], the ' +
20002049
'tensor should have 0 values but has 3');
20012050
});
20022051

@@ -2012,7 +2061,7 @@ describeWithFlags('tensor with 0 in shape', ALL_ENVS, () => {
20122061
const values = [1, 2, 3];
20132062
expect(() => tf.tensor4d(values, [1, 3, 0, 5], 'float32'))
20142063
.toThrowError(
2015-
'Based on the provided shape, [1,3,0,5], and dtype float32, the ' +
2064+
'Based on the provided shape, [1,3,0,5], the ' +
20162065
'tensor should have 0 values but has 3');
20172066
});
20182067

src/tensor_util_env.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export function inferShape(val: TensorLike): number[] {
3131
}
3232
const shape: number[] = [];
3333

34-
while (firstElem instanceof Array) {
34+
while (firstElem instanceof Array || isTypedArray(firstElem)) {
3535
shape.push(firstElem.length);
3636
firstElem = firstElem[0];
3737
}
@@ -45,11 +45,11 @@ export function inferShape(val: TensorLike): number[] {
4545
function deepAssertShapeConsistency(
4646
val: TensorLike, shape: number[], indices: number[]) {
4747
indices = indices || [];
48-
if (!(val instanceof Array)) {
48+
if (!(val instanceof Array) && !isTypedArray(val)) {
4949
assert(
5050
shape.length === 0,
5151
() => `Element arr[${indices.join('][')}] is a primitive, ` +
52-
`but should be an array of ${shape[0]} elements`);
52+
`but should be an array/TypedArray of ${shape[0]} elements`);
5353
return;
5454
}
5555
assert(
@@ -100,15 +100,15 @@ export function convertToTensor<T extends Tensor>(
100100
typeof x !== 'boolean' && typeof x !== 'string') {
101101
throw new Error(
102102
`Argument '${argName}' passed to '${functionName}' must be a ` +
103-
`Tensor or TensorLike, but got '${x.constructor.name}'`);
103+
`Tensor or TensorLike, but got '${(x as {}).constructor.name}'`);
104104
}
105105
const inferredShape = inferShape(x);
106106
if (!isTypedArray(x) && !Array.isArray(x)) {
107107
x = [x] as number[];
108108
}
109109
const values = inferredDtype !== 'string' ?
110110
toTypedArray(x, inferredDtype as DataType, ENV.get('DEBUG')) :
111-
flatten(x) as string[];
111+
flatten(x as string[]) as string[];
112112
return Tensor.make(inferredShape, {values}, inferredDtype);
113113
}
114114

src/types.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ export function sumOutType(type: DataType): DataType {
120120
}
121121

122122
/** @docalias TypedArray|Array */
123-
export type TensorLike = TypedArray|number|boolean|string|RegularArray<number>|
123+
export type TensorLike =
124+
TypedArray|number|boolean|string|RegularArray<number|number[]|TypedArray>|
124125
RegularArray<boolean>|RegularArray<string>;
125126
/** @docalias TypedArray|Array */
126127
export type TensorLike1D = TypedArray|number[]|boolean[]|string[];

src/util.ts

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,10 @@ export function assertNonNull(a: TensorLike): void {
9898
// NOTE: We explicitly type out what T extends instead of any so that
9999
// util.flatten on a nested array of number doesn't try to infer T as a
100100
// number[][], causing us to explicitly type util.flatten<number>().
101-
export function flatten<T extends number|boolean|string|Promise<number>>(
101+
export function
102+
flatten<T extends number|boolean|string|Promise<number>|TypedArray>(
102103
arr: T|RecursiveArray<T>, ret: T[] = []): T[] {
103-
if (Array.isArray(arr)) {
104+
if (Array.isArray(arr) || isTypedArray(arr)) {
104105
for (let i = 0; i < arr.length; ++i) {
105106
flatten(arr[i], ret);
106107
}
@@ -376,7 +377,7 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean {
376377
return true;
377378
}
378379

379-
export function isTypedArray(a: TensorLike): boolean {
380+
export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array {
380381
return a instanceof Float32Array || a instanceof Int32Array ||
381382
a instanceof Uint8Array;
382383
}
@@ -422,24 +423,18 @@ export function isNumber(value: {}): boolean {
422423
}
423424

424425
export function inferDtype(values: TensorLike): DataType {
426+
if (values instanceof Array) {
427+
return inferDtype(values[0]);
428+
}
425429
if (values instanceof Float32Array) {
426430
return 'float32';
427431
} else if (values instanceof Int32Array || values instanceof Uint8Array) {
428432
return 'int32';
429-
} else if (
430-
isNumber(values) ||
431-
values instanceof Array &&
432-
isNumber(getFirstElemFromNestedArray(values))) {
433+
} else if (isNumber(values)) {
433434
return 'float32';
434-
} else if (
435-
isString(values) ||
436-
values instanceof Array &&
437-
isString(getFirstElemFromNestedArray(values))) {
435+
} else if (isString(values)) {
438436
return 'string';
439-
} else if (
440-
isBoolean(values) ||
441-
values instanceof Array &&
442-
isBoolean(getFirstElemFromNestedArray(values))) {
437+
} else if (isBoolean(values)) {
443438
return 'bool';
444439
}
445440
return 'float32';
@@ -549,10 +544,3 @@ export function now(): number {
549544
'in the browser or in Node.js');
550545
}
551546
}
552-
553-
function getFirstElemFromNestedArray(arr: TensorLike): {} {
554-
while (arr instanceof Array) {
555-
arr = arr[0];
556-
}
557-
return arr;
558-
}

src/util_test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ describe('util.flatten', () => {
103103
'a', 'b', 'c', 'd', 'e'
104104
]);
105105
});
106+
107+
it('mixed TypedArray and number[]', () => {
108+
const data =
109+
[new Float32Array([1, 2]), 3, [4, 5, new Float32Array([6, 7])]];
110+
expect(util.flatten(data)).toEqual([1, 2, 3, 4, 5, 6, 7]);
111+
});
106112
});
107113

108114
describe('util.bytesFromStringArray', () => {

0 commit comments

Comments
 (0)