|
17 | 17 |
|
18 | 18 | import {inferShape} from './tensor_util_env';
|
19 | 19 | import * as util from './util';
|
| 20 | +import {scalar, tensor2d} from './ops/ops'; |
20 | 21 |
|
21 | 22 | describe('Util', () => {
|
22 | 23 | it('Correctly gets size from shape', () => {
|
@@ -455,6 +456,47 @@ describe('util.hasEncodingLoss', () => {
|
455 | 456 | });
|
456 | 457 | });
|
457 | 458 |
|
| 459 | +describe('util.toNestedArray', () => { |
| 460 | + it('2 dimensions', () => { |
| 461 | + const a = new Float32Array([1, 2, 3, 4, 5, 6]); |
| 462 | + expect(util.toNestedArray([2, 3], a)) |
| 463 | + .toEqual([[1,2,3], [4,5,6]]); |
| 464 | + }); |
| 465 | + |
| 466 | + it('3 dimensions (2x2x3)', () => { |
| 467 | + const a = new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); |
| 468 | + expect(util.toNestedArray([2, 2, 3], a)) |
| 469 | + .toEqual([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]); |
| 470 | + }); |
| 471 | + |
| 472 | + it('3 dimensions (3x2x2)', () => { |
| 473 | + const a = new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); |
| 474 | + expect(util.toNestedArray([3, 2, 2], a)) |
| 475 | + .toEqual([[[0, 1],[2, 3]],[[4, 5],[6, 7]],[[8, 9],[10, 11]]]); |
| 476 | + }); |
| 477 | + |
| 478 | + it('invalid dimension', () => { |
| 479 | + const a = new Float32Array([1, 2, 3]); |
| 480 | + expect(() => util.toNestedArray([2, 2], a)).toThrowError(); |
| 481 | + }); |
| 482 | + |
| 483 | + it('tensor to nested array', () => { |
| 484 | + const x = tensor2d([1, 2, 3, 4], [2, 2]); |
| 485 | + expect(util.toNestedArray(x.shape, x.dataSync())) |
| 486 | + .toEqual([[1, 2], [3, 4]]); |
| 487 | + }); |
| 488 | + |
| 489 | + it('scalar to nested array', () => { |
| 490 | + const x = scalar(1); |
| 491 | + expect(util.toNestedArray(x.shape, x.dataSync())).toEqual([]); |
| 492 | + }); |
| 493 | + |
| 494 | + it('tensor with zero shape', () => { |
| 495 | + const a = new Float32Array([0, 1]); |
| 496 | + expect(util.toNestedArray([1, 0, 2], a)).toEqual([]); |
| 497 | + }); |
| 498 | +}); |
| 499 | + |
458 | 500 | describe('util.monitorPromisesProgress', () => {
|
459 | 501 | it('Default progress from 0 to 1', (done) => {
|
460 | 502 | const expectFractions: number[] = [0.25, 0.50, 0.75, 1.00];
|
|
0 commit comments