Skip to content

Commit 759b446

Browse files
Lewuathedsmilkov
authored andcommitted
Add util to get the non-flattened tensor data (tensorflow#1483)
INTERNAL #### Description <!-- Please describe the pull request here. Also, if this is an issue/bug fix, please add the issue link for reference here. --> MISC Provide a utility to get the non-flattened Tensor data. See tensorflow/tfjs#979
1 parent 2e3142b commit 759b446

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

src/util.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,42 @@ export function toTypedArray(
517517
}
518518
}
519519

520+
function createNestedArray(offset: number, shape: number[], a: TypedArray) {
521+
const ret = new Array();
522+
if (shape.length === 1) {
523+
const d = shape[0];
524+
for (let i = 0; i < d; i++) {
525+
ret[i] = a[offset + i];
526+
}
527+
} else {
528+
const d = shape[0];
529+
const rest = shape.slice(1);
530+
const len = rest.reduce((acc, c) => acc * c);
531+
for (let i = 0; i < d; i++) {
532+
ret[i] = createNestedArray(offset + i * len, rest, a);
533+
}
534+
}
535+
return ret;
536+
}
537+
538+
// Provide a nested array of TypedArray in given shape.
539+
export function toNestedArray(shape: number[], a: TypedArray) {
540+
if (shape.length === 0) {
541+
// Scalar type should be empty list.
542+
return [];
543+
}
544+
const size = shape.reduce((acc, c) => acc * c);
545+
if (size === 0) {
546+
// A tensor with shape zero should be turned into empty list.
547+
return [];
548+
}
549+
if (size !== a.length) {
550+
throw new Error(`[${shape}] does not match the input size.`);
551+
}
552+
553+
return createNestedArray(0, shape, a);
554+
}
555+
520556
function noConversionNeeded(a: TensorLike, dtype: DataType): boolean {
521557
return (a instanceof Float32Array && dtype === 'float32') ||
522558
(a instanceof Int32Array && dtype === 'int32') ||

src/util_test.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import {inferShape} from './tensor_util_env';
1919
import * as util from './util';
20+
import {scalar, tensor2d} from './ops/ops';
2021

2122
describe('Util', () => {
2223
it('Correctly gets size from shape', () => {
@@ -455,6 +456,47 @@ describe('util.hasEncodingLoss', () => {
455456
});
456457
});
457458

459+
describe('util.toNestedArray', () => {
460+
it('2 dimensions', () => {
461+
const a = new Float32Array([1, 2, 3, 4, 5, 6]);
462+
expect(util.toNestedArray([2, 3], a))
463+
.toEqual([[1,2,3], [4,5,6]]);
464+
});
465+
466+
it('3 dimensions (2x2x3)', () => {
467+
const a = new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
468+
expect(util.toNestedArray([2, 2, 3], a))
469+
.toEqual([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]);
470+
});
471+
472+
it('3 dimensions (3x2x2)', () => {
473+
const a = new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
474+
expect(util.toNestedArray([3, 2, 2], a))
475+
.toEqual([[[0, 1],[2, 3]],[[4, 5],[6, 7]],[[8, 9],[10, 11]]]);
476+
});
477+
478+
it('invalid dimension', () => {
479+
const a = new Float32Array([1, 2, 3]);
480+
expect(() => util.toNestedArray([2, 2], a)).toThrowError();
481+
});
482+
483+
it('tensor to nested array', () => {
484+
const x = tensor2d([1, 2, 3, 4], [2, 2]);
485+
expect(util.toNestedArray(x.shape, x.dataSync()))
486+
.toEqual([[1, 2], [3, 4]]);
487+
});
488+
489+
it('scalar to nested array', () => {
490+
const x = scalar(1);
491+
expect(util.toNestedArray(x.shape, x.dataSync())).toEqual([]);
492+
});
493+
494+
it('tensor with zero shape', () => {
495+
const a = new Float32Array([0, 1]);
496+
expect(util.toNestedArray([1, 0, 2], a)).toEqual([]);
497+
});
498+
});
499+
458500
describe('util.monitorPromisesProgress', () => {
459501
it('Default progress from 0 to 1', (done) => {
460502
const expectFractions: number[] = [0.25, 0.50, 0.75, 1.00];

0 commit comments

Comments
 (0)