Skip to content

Commit 134616d

Browse files
LewuatheNikhil Thorat
authored and
Nikhil Thorat
committed
Support complex type in onesLike. (tensorflow#1609)
FEATURE
1 parent 3ed9016 commit 134616d

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

src/ops/array_ops_test.ts

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
511511
expectArraysEqual(b, [1, 1, 1]);
512512
});
513513

514+
it('1D complex dtype', () => {
515+
const real = tf.tensor1d([1, 2, 3], 'float32');
516+
const imag = tf.tensor1d([1, 2, 3], 'float32');
517+
const a = tf.complex(real, imag);
518+
const b = tf.onesLike(a);
519+
expect(b.dtype).toBe('complex64');
520+
expect(b.shape).toEqual([3]);
521+
expectArraysEqual(b, [1, 0, 1, 0, 1, 0]);
522+
});
523+
514524
it('2D default dtype', () => {
515525
const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
516526
const b = tf.onesLike(a);
@@ -543,6 +553,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
543553
expectArraysEqual(b, [1, 1, 1, 1]);
544554
});
545555

556+
it('2D complex dtype', () => {
557+
const real = tf.tensor2d([1, 2, 3, 4], [2, 2], 'float32');
558+
const imag = tf.tensor2d([1, 2, 3, 4], [2, 2], 'float32');
559+
const a = tf.complex(real, imag);
560+
const b = tf.onesLike(a);
561+
expect(b.dtype).toBe('complex64');
562+
expect(b.shape).toEqual([2, 2]);
563+
expectArraysEqual(b, [1, 0, 1, 0, 1, 0, 1, 0]);
564+
});
565+
546566
it('3D default dtype', () => {
547567
const a = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
548568
const b = tf.onesLike(a);
@@ -575,6 +595,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
575595
expectArraysEqual(b, [1, 1, 1, 1]);
576596
});
577597

598+
it('3D complex dtype', () => {
599+
const real = tf.tensor3d([1, 2, 3, 4], [2, 2, 1], 'float32');
600+
const imag = tf.tensor3d([1, 2, 3, 4], [2, 2, 1], 'float32');
601+
const a = tf.complex(real, imag);
602+
const b = tf.onesLike(a);
603+
expect(b.dtype).toBe('complex64');
604+
expect(b.shape).toEqual([2, 2, 1]);
605+
expectArraysEqual(b, [1, 0, 1, 0, 1, 0, 1, 0]);
606+
});
607+
578608
it('4D default dtype', () => {
579609
const a = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1]);
580610
const b = tf.onesLike(a);
@@ -615,6 +645,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
615645
expectArraysClose(b, [1, 1, 1, 1]);
616646
});
617647

648+
it('4D complex dtype', () => {
649+
const real = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1], 'float32');
650+
const imag = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1], 'float32');
651+
const a = tf.complex(real, imag);
652+
const b = tf.onesLike(a);
653+
expect(b.dtype).toBe('complex64');
654+
expect(b.shape).toEqual([2, 2, 1, 1]);
655+
expectArraysEqual(b, [1, 0, 1, 0, 1, 0, 1, 0]);
656+
});
657+
618658
it('5D float32 dtype', () => {
619659
const a = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'float32');
620660
const b = tf.onesLike(a);
@@ -647,6 +687,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
647687
expectArraysClose(b, [1, 1, 1, 1]);
648688
});
649689

690+
it('5D complex dtype', () => {
691+
const real = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'float32');
692+
const imag = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'float32');
693+
const a = tf.complex(real, imag);
694+
const b = tf.onesLike(a);
695+
expect(b.dtype).toBe('complex64');
696+
expect(b.shape).toEqual([1, 2, 2, 1, 1]);
697+
expectArraysEqual(b, [1, 0, 1, 0, 1, 0, 1, 0]);
698+
});
699+
650700
it('6D int32 dtype', () => {
651701
const a = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'int32');
652702
const b = tf.onesLike(a);
@@ -679,6 +729,16 @@ describeWithFlags('onesLike', ALL_ENVS, () => {
679729
expectArraysClose(b, [1, 1, 1, 1]);
680730
});
681731

732+
it('6D complex dtype', () => {
733+
const real = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'float32');
734+
const imag = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'float32');
735+
const a = tf.complex(real, imag);
736+
const b = tf.onesLike(a);
737+
expect(b.dtype).toBe('complex64');
738+
expect(b.shape).toEqual([1, 2, 2, 1, 1, 1]);
739+
expectArraysEqual(b, [1, 0, 1, 0, 1, 0, 1, 0]);
740+
});
741+
682742
it('throws when passed a non-tensor', () => {
683743
expect(() => tf.onesLike({} as tf.Tensor))
684744
.toThrowError(/Argument 'x' passed to 'onesLike' must be a Tensor/);

src/ops/tensor_ops.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, Tens
2222
import {DataType, Rank, ShapeMap} from '../types';
2323
import {assert, assertNonNull, assertNonNegativeIntegerDimensions, flatten, inferDtype, isTypedArray, makeOnesTypedArray, makeZerosTypedArray, sizeFromShape, toTypedArray} from '../util';
2424

25-
import {complex} from './complex_ops';
25+
import {complex, real, imag} from './complex_ops';
2626
import {op} from './operation';
2727

2828
/**
@@ -447,6 +447,11 @@ function fill<R extends Rank>(
447447
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
448448
function onesLike_<T extends Tensor>(x: T|TensorLike): T {
449449
const $x = convertToTensor(x, 'x', 'onesLike');
450+
if ($x.dtype === 'complex64') {
451+
const r = onesLike(real($x));
452+
const i = zerosLike(imag($x));
453+
return complex(r, i);
454+
}
450455
return ENV.engine.runKernel(backend => backend.onesLike($x), {$x}, null) as T;
451456
}
452457

0 commit comments

Comments
 (0)