@@ -87,6 +87,19 @@ describeWithFlags('concat1d', ALL_ENVS, () => {
87
87
const expected = [ 3 , 5 ] ;
88
88
expectArraysClose ( await result . data ( ) , expected ) ;
89
89
} ) ;
90
+
91
+ it ( 'concat complex input' , async ( ) => {
92
+ // [1+1j, 2+2j]
93
+ const c1 = tf . complex ( [ 1 , 2 ] , [ 1 , 2 ] ) ;
94
+ // [3+3j, 4+4j]
95
+ const c2 = tf . complex ( [ 3 , 4 ] , [ 3 , 4 ] ) ;
96
+
97
+ const axis = 0 ;
98
+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
99
+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 ] ;
100
+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
101
+ expectArraysClose ( await result . data ( ) , expected ) ;
102
+ } ) ;
90
103
} ) ;
91
104
92
105
describeWithFlags ( 'concat2d' , ALL_ENVS , ( ) => {
@@ -220,6 +233,32 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
220
233
expect ( res2 . shape ) . toEqual ( [ 0 , 15 ] ) ;
221
234
expectArraysEqual ( await res2 . data ( ) , [ ] ) ;
222
235
} ) ;
236
+
237
+ it ( 'concat complex input axis=0' , async ( ) => {
238
+ // [[1+1j, 2+2j], [3+3j, 4+4j]]
239
+ const c1 = tf . complex ( [ [ 1 , 2 ] , [ 3 , 4 ] ] , [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
240
+ // [[5+5j, 6+6j], [7+7j, 8+8j]]
241
+ const c2 = tf . complex ( [ [ 5 , 6 ] , [ 7 , 8 ] ] , [ [ 5 , 6 ] , [ 7 , 8 ] ] ) ;
242
+
243
+ const axis = 0 ;
244
+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
245
+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 8 , 8 ] ;
246
+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
247
+ expectArraysClose ( await result . data ( ) , expected ) ;
248
+ } ) ;
249
+
250
+ it ( 'concat complex input axis=1' , async ( ) => {
251
+ // [[1+1j, 2+2j], [3+3j, 4+4j]]
252
+ const c1 = tf . complex ( [ [ 1 , 2 ] , [ 3 , 4 ] ] , [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
253
+ // [[5+5j, 6+6j], [7+7j, 8+8j]]
254
+ const c2 = tf . complex ( [ [ 5 , 6 ] , [ 7 , 8 ] ] , [ [ 5 , 6 ] , [ 7 , 8 ] ] ) ;
255
+
256
+ const axis = 1 ;
257
+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
258
+ const expected = [ 1 , 1 , 2 , 2 , 5 , 5 , 6 , 6 , 3 , 3 , 4 , 4 , 7 , 7 , 8 , 8 ] ;
259
+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
260
+ expectArraysClose ( await result . data ( ) , expected ) ;
261
+ } ) ;
223
262
} ) ;
224
263
225
264
describeWithFlags ( 'concat3d' , ALL_ENVS , ( ) => {
@@ -460,6 +499,54 @@ describeWithFlags('concat3d', ALL_ENVS, () => {
460
499
expect ( values . shape ) . toEqual ( [ 2 , 3 , 1 ] ) ;
461
500
expectArraysClose ( await values . data ( ) , [ 1 , 2 , 3 , 4 , 5 , 6 ] ) ;
462
501
} ) ;
502
+
503
+ it ( 'concat complex input axis=0' , async ( ) => {
504
+ // [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
505
+ const c1 = tf . complex (
506
+ [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] , [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] ) ;
507
+ // [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
508
+ const c2 = tf . complex (
509
+ [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] , [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] ) ;
510
+
511
+ const axis = 0 ;
512
+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
513
+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 ,
514
+ 7 , 7 , 8 , 8 , 9 , 9 , 10 , 10 , 11 , 11 , 12 , 12 ] ;
515
+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
516
+ expectArraysClose ( await result . data ( ) , expected ) ;
517
+ } ) ;
518
+
519
+ it ( 'concat complex input axis=1' , async ( ) => {
520
+ // [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
521
+ const c1 = tf . complex (
522
+ [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] , [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] ) ;
523
+ // [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
524
+ const c2 = tf . complex (
525
+ [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] , [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] ) ;
526
+
527
+ const axis = 1 ;
528
+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
529
+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 ,
530
+ 7 , 7 , 8 , 8 , 9 , 9 , 10 , 10 , 11 , 11 , 12 , 12 ] ;
531
+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
532
+ expectArraysClose ( await result . data ( ) , expected ) ;
533
+ } ) ;
534
+
535
+ it ( 'concat complex input axis=1' , async ( ) => {
536
+ // [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
537
+ const c1 = tf . complex (
538
+ [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] , [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] ) ;
539
+ // [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
540
+ const c2 = tf . complex (
541
+ [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] , [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] ) ;
542
+
543
+ const axis = 2 ;
544
+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
545
+ const expected = [ 1 , 1 , 2 , 2 , 7 , 7 , 8 , 8 , 3 , 3 , 4 , 4 ,
546
+ 9 , 9 , 10 , 10 , 5 , 5 , 6 , 6 , 11 , 11 , 12 , 12 ] ;
547
+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
548
+ expectArraysClose ( await result . data ( ) , expected ) ;
549
+ } ) ;
463
550
} ) ;
464
551
465
552
describeWithFlags ( 'concat throws for non-tensors' , ALL_ENVS , ( ) => {
0 commit comments