Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 975e5f6

Browse files
LewuatheNikhil Thorat
authored andcommitted
Support complex type in concat op (#1829)
FEATURE Modify concat op to support complex type. It is also necessary to support the arbitrary shape of the returned complex value in stft op.
1 parent 5cc5267 commit 975e5f6

File tree

4 files changed

+108
-1
lines changed

4 files changed

+108
-1
lines changed

src/backends/cpu/backend_cpu.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import {split} from '../split_shared';
4545
import {tile} from '../tile_impl';
4646
import {topkImpl} from '../topk_impl';
4747
import {whereImpl} from '../where_impl';
48+
import {real, imag, complex} from '../../ops/complex_ops';
4849

4950
function mapActivation(
5051
backend: MathBackendCPU, x: Tensor, activation: Activation,
@@ -380,7 +381,11 @@ export class MathBackendCPU implements KernelBackend {
380381
}
381382

382383
concat(tensors: Tensor[], axis: number): Tensor {
383-
this.assertNotComplex(tensors, 'concat');
384+
if (tensors[0].dtype === 'complex64') {
385+
const reals = tensors.map((t) => real(t));
386+
const imags = tensors.map((t) => imag(t));
387+
return complex(this.concat(reals, axis), this.concat(imags, axis));
388+
}
384389
const tensors2D = tensors.map(t => {
385390
const innerSize = util.sizeFromShape(t.shape.slice(axis));
386391
return t.as2D(-1, innerSize);

src/backends/webgl/backend_webgl.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ import * as unary_packed_op from './unaryop_packed_gpu';
131131
import {UnaryOpPackedProgram} from './unaryop_packed_gpu';
132132
import {UnpackProgram} from './unpack_gpu';
133133
import * as webgl_util from './webgl_util';
134+
import {real, imag, complex} from '../../ops/complex_ops';
134135

135136
type KernelInfo = {
136137
name: string; query: Promise<number>;
@@ -798,6 +799,11 @@ export class MathBackendWebGL implements KernelBackend {
798799
}
799800

800801
concat(tensors: Tensor[], axis: number): Tensor {
802+
if (tensors[0].dtype === 'complex64') {
803+
const reals = tensors.map((t) => real(t));
804+
const imags = tensors.map((t) => imag(t));
805+
return complex(this.concat(reals, axis), this.concat(imags, axis));
806+
}
801807
if (this.shouldExecuteOnCPU(tensors)) {
802808
return this.cpuBackend.concat(tensors, axis);
803809
}

src/ops/concat_split.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ function concat4d_(
163163
function concat_<T extends Tensor>(tensors: Array<T|TensorLike>, axis = 0): T {
164164
assert(tensors.length >= 1, () => 'Pass at least one tensor to concat');
165165
let $tensors = convertToTensorArray(tensors, 'tensors', 'concat');
166+
if ($tensors[0].dtype === 'complex64') {
167+
$tensors.forEach(tensor => {
168+
if (tensor.dtype !== 'complex64') {
169+
throw new Error(`Cannot concatenate complex64 tensors with a tensor
170+
with dtype ${tensor.dtype}. `);
171+
}
172+
});
173+
}
174+
166175
axis = parseAxisParam(axis, $tensors[0].shape)[0];
167176
const outShape = computeOutShape($tensors.map(t => t.shape), axis);
168177
if (sizeFromShape(outShape) === 0) {

src/ops/concat_test.ts

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ describeWithFlags('concat1d', ALL_ENVS, () => {
8787
const expected = [3, 5];
8888
expectArraysClose(await result.data(), expected);
8989
});
90+
91+
it('concat complex input', async() => {
92+
// [1+1j, 2+2j]
93+
const c1 = tf.complex([1, 2], [1, 2]);
94+
// [3+3j, 4+4j]
95+
const c2 = tf.complex([3, 4], [3, 4]);
96+
97+
const axis = 0;
98+
const result = tf.concat([c1, c2], axis);
99+
const expected = [1, 1, 2, 2, 3, 3, 4, 4];
100+
expect(result.dtype).toEqual('complex64');
101+
expectArraysClose(await result.data(), expected);
102+
});
90103
});
91104

92105
describeWithFlags('concat2d', ALL_ENVS, () => {
@@ -220,6 +233,32 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
220233
expect(res2.shape).toEqual([0, 15]);
221234
expectArraysEqual(await res2.data(), []);
222235
});
236+
237+
it('concat complex input axis=0', async() => {
238+
// [[1+1j, 2+2j], [3+3j, 4+4j]]
239+
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
240+
// [[5+5j, 6+6j], [7+7j, 8+8j]]
241+
const c2 = tf.complex([[5, 6], [7, 8]], [[5, 6], [7, 8]]);
242+
243+
const axis = 0;
244+
const result = tf.concat([c1, c2], axis);
245+
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8];
246+
expect(result.dtype).toEqual('complex64');
247+
expectArraysClose(await result.data(), expected);
248+
});
249+
250+
it('concat complex input axis=1', async() => {
251+
// [[1+1j, 2+2j], [3+3j, 4+4j]]
252+
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
253+
// [[5+5j, 6+6j], [7+7j, 8+8j]]
254+
const c2 = tf.complex([[5, 6], [7, 8]], [[5, 6], [7, 8]]);
255+
256+
const axis = 1;
257+
const result = tf.concat([c1, c2], axis);
258+
const expected = [1, 1, 2, 2, 5, 5, 6, 6, 3, 3, 4, 4, 7, 7, 8, 8];
259+
expect(result.dtype).toEqual('complex64');
260+
expectArraysClose(await result.data(), expected);
261+
});
223262
});
224263

225264
describeWithFlags('concat3d', ALL_ENVS, () => {
@@ -460,6 +499,54 @@ describeWithFlags('concat3d', ALL_ENVS, () => {
460499
expect(values.shape).toEqual([2, 3, 1]);
461500
expectArraysClose(await values.data(), [1, 2, 3, 4, 5, 6]);
462501
});
502+
503+
it('concat complex input axis=0', async() => {
504+
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
505+
const c1 = tf.complex(
506+
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
507+
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
508+
const c2 = tf.complex(
509+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
510+
511+
const axis = 0;
512+
const result = tf.concat([c1, c2], axis);
513+
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
514+
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
515+
expect(result.dtype).toEqual('complex64');
516+
expectArraysClose(await result.data(), expected);
517+
});
518+
519+
it('concat complex input axis=1', async() => {
520+
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
521+
const c1 = tf.complex(
522+
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
523+
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
524+
const c2 = tf.complex(
525+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
526+
527+
const axis = 1;
528+
const result = tf.concat([c1, c2], axis);
529+
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
530+
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
531+
expect(result.dtype).toEqual('complex64');
532+
expectArraysClose(await result.data(), expected);
533+
});
534+
535+
it('concat complex input axis=1', async() => {
536+
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
537+
const c1 = tf.complex(
538+
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
539+
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
540+
const c2 = tf.complex(
541+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
542+
543+
const axis = 2;
544+
const result = tf.concat([c1, c2], axis);
545+
const expected = [1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
546+
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12];
547+
expect(result.dtype).toEqual('complex64');
548+
expectArraysClose(await result.data(), expected);
549+
});
463550
});
464551

465552
describeWithFlags('concat throws for non-tensors', ALL_ENVS, () => {

0 commit comments

Comments
 (0)