@@ -511,6 +511,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
511
511
expectArraysEqual ( b , [ 1 , 1 , 1 ] ) ;
512
512
} ) ;
513
513
514
+ it ( '1D complex dtype' , ( ) => {
515
+ const real = tf . tensor1d ( [ 1 , 2 , 3 ] , 'float32' ) ;
516
+ const imag = tf . tensor1d ( [ 1 , 2 , 3 ] , 'float32' ) ;
517
+ const a = tf . complex ( real , imag ) ;
518
+ const b = tf . onesLike ( a ) ;
519
+ expect ( b . dtype ) . toBe ( 'complex64' ) ;
520
+ expect ( b . shape ) . toEqual ( [ 3 ] ) ;
521
+ expectArraysEqual ( b , [ 1 , 0 , 1 , 0 , 1 , 0 ] ) ;
522
+ } ) ;
523
+
514
524
it ( '2D default dtype' , ( ) => {
515
525
const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 ] ) ;
516
526
const b = tf . onesLike ( a ) ;
@@ -543,6 +553,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
543
553
expectArraysEqual ( b , [ 1 , 1 , 1 , 1 ] ) ;
544
554
} ) ;
545
555
556
+ it ( '2D complex dtype' , ( ) => {
557
+ const real = tf . tensor2d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 ] , 'float32' ) ;
558
+ const imag = tf . tensor2d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 ] , 'float32' ) ;
559
+ const a = tf . complex ( real , imag ) ;
560
+ const b = tf . onesLike ( a ) ;
561
+ expect ( b . dtype ) . toBe ( 'complex64' ) ;
562
+ expect ( b . shape ) . toEqual ( [ 2 , 2 ] ) ;
563
+ expectArraysEqual ( b , [ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ] ) ;
564
+ } ) ;
565
+
546
566
it ( '3D default dtype' , ( ) => {
547
567
const a = tf . tensor3d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 , 1 ] ) ;
548
568
const b = tf . onesLike ( a ) ;
@@ -575,6 +595,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
575
595
expectArraysEqual ( b , [ 1 , 1 , 1 , 1 ] ) ;
576
596
} ) ;
577
597
598
+ it ( '3D complex dtype' , ( ) => {
599
+ const real = tf . tensor3d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 , 1 ] , 'float32' ) ;
600
+ const imag = tf . tensor3d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 , 1 ] , 'float32' ) ;
601
+ const a = tf . complex ( real , imag ) ;
602
+ const b = tf . onesLike ( a ) ;
603
+ expect ( b . dtype ) . toBe ( 'complex64' ) ;
604
+ expect ( b . shape ) . toEqual ( [ 2 , 2 , 1 ] ) ;
605
+ expectArraysEqual ( b , [ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ] ) ;
606
+ } ) ;
607
+
578
608
it ( '4D default dtype' , ( ) => {
579
609
const a = tf . tensor4d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 , 1 , 1 ] ) ;
580
610
const b = tf . onesLike ( a ) ;
@@ -615,6 +645,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
615
645
expectArraysClose ( b , [ 1 , 1 , 1 , 1 ] ) ;
616
646
} ) ;
617
647
648
+ it ( '4D complex dtype' , ( ) => {
649
+ const real = tf . tensor4d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 , 1 , 1 ] , 'float32' ) ;
650
+ const imag = tf . tensor4d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 , 1 , 1 ] , 'float32' ) ;
651
+ const a = tf . complex ( real , imag ) ;
652
+ const b = tf . onesLike ( a ) ;
653
+ expect ( b . dtype ) . toBe ( 'complex64' ) ;
654
+ expect ( b . shape ) . toEqual ( [ 2 , 2 , 1 , 1 ] ) ;
655
+ expectArraysEqual ( b , [ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ] ) ;
656
+ } ) ;
657
+
618
658
it ( '5D float32 dtype' , ( ) => {
619
659
const a = tf . tensor5d ( [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 2 , 1 , 1 ] , 'float32' ) ;
620
660
const b = tf . onesLike ( a ) ;
@@ -647,6 +687,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
647
687
expectArraysClose ( b , [ 1 , 1 , 1 , 1 ] ) ;
648
688
} ) ;
649
689
690
+ it ( '5D complex dtype' , ( ) => {
691
+ const real = tf . tensor5d ( [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 2 , 1 , 1 ] , 'float32' ) ;
692
+ const imag = tf . tensor5d ( [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 2 , 1 , 1 ] , 'float32' ) ;
693
+ const a = tf . complex ( real , imag ) ;
694
+ const b = tf . onesLike ( a ) ;
695
+ expect ( b . dtype ) . toBe ( 'complex64' ) ;
696
+ expect ( b . shape ) . toEqual ( [ 1 , 2 , 2 , 1 , 1 ] ) ;
697
+ expectArraysEqual ( b , [ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ] ) ;
698
+ } ) ;
699
+
650
700
it ( '6D int32 dtype' , ( ) => {
651
701
const a = tf . tensor6d ( [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 2 , 1 , 1 , 1 ] , 'int32' ) ;
652
702
const b = tf . onesLike ( a ) ;
@@ -679,6 +729,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
679
729
expectArraysClose ( b , [ 1 , 1 , 1 , 1 ] ) ;
680
730
} ) ;
681
731
732
+ it ( '6D complex dtype' , ( ) => {
733
+ const real = tf . tensor6d ( [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 2 , 1 , 1 , 1 ] , 'float32' ) ;
734
+ const imag = tf . tensor6d ( [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 2 , 1 , 1 , 1 ] , 'float32' ) ;
735
+ const a = tf . complex ( real , imag ) ;
736
+ const b = tf . onesLike ( a ) ;
737
+ expect ( b . dtype ) . toBe ( 'complex64' ) ;
738
+ expect ( b . shape ) . toEqual ( [ 1 , 2 , 2 , 1 , 1 , 1 ] ) ;
739
+ expectArraysEqual ( b , [ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 ] ) ;
740
+ } ) ;
741
+
682
742
it ( 'throws when passed a non-tensor' , ( ) => {
683
743
expect ( ( ) => tf . onesLike ( { } as tf . Tensor ) )
684
744
. toThrowError ( / A r g u m e n t ' x ' p a s s e d t o ' o n e s L i k e ' m u s t b e a T e n s o r / ) ;
0 commit comments