Skip to content

Commit 91ee923

Browse files
Enable complex128 dtype for diag()
Change: 133749428
1 parent 54bd703 commit 91ee923

File tree

3 files changed

+44
-39
lines changed

3 files changed

+44
-39
lines changed

tensorflow/core/kernels/diag_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ REGISTER_DIAGOP(float);
123123
REGISTER_DIAGOP(int32);
124124
REGISTER_DIAGOP(int64);
125125
REGISTER_DIAGOP(complex64);
126+
REGISTER_DIAGOP(complex128);
126127

127128
#undef REGISTER_DIAGOP
128129

@@ -190,6 +191,7 @@ REGISTER_DIAGPARTOP(float);
190191
REGISTER_DIAGPARTOP(int32);
191192
REGISTER_DIAGPARTOP(int64);
192193
REGISTER_DIAGPARTOP(complex64);
194+
REGISTER_DIAGPARTOP(complex128);
193195

194196
#undef REGISTER_DIAGPARTOP
195197

tensorflow/core/ops/array_ops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ y: a tensor of the same shape and type as x but filled with zeros.
401401
REGISTER_OP("Diag")
402402
.Input("diagonal: T")
403403
.Output("output: T")
404-
.Attr("T: {float, double, int32, int64, complex64}")
404+
.Attr("T: {float, double, int32, int64, complex64, complex128}")
405405
.SetShapeFn([](InferenceContext* c) {
406406
ShapeHandle in = c->input(0);
407407
TF_RETURN_IF_ERROR(c->WithRankAtMost(in, 3, &in));
@@ -439,7 +439,7 @@ diagonal: Rank k tensor where k is at most 3.
439439
REGISTER_OP("DiagPart")
440440
.Input("input: T")
441441
.Output("diagonal: T")
442-
.Attr("T: {float, double, int32, int64, complex64}")
442+
.Attr("T: {float, double, int32, int64, complex64, complex128}")
443443
.SetShapeFn([](InferenceContext* c) {
444444
ShapeHandle in = c->input(0);
445445
if (!c->RankKnown(in)) {

tensorflow/python/kernel_tests/diag_op_test.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,13 @@ def testRankOneFloatTensor(self):
254254
self.diagOp(x, np.float64, expected_ans)
255255

256256
def testRankOneComplexTensor(self):
257-
x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype = np.complex64)
258-
expected_ans = np.array(
259-
[[1.1 + 1.1j, 0 + 0j, 0 + 0j],
260-
[0 + 0j, 2.2 + 2.2j, 0 + 0j],
261-
[0 + 0j, 0 + 0j, 3.3 + 3.3j]], dtype = np.complex64)
262-
self.diagOp(x, np.complex64, expected_ans)
257+
for dtype in [np.complex64, np.complex128]:
258+
x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype)
259+
expected_ans = np.array(
260+
[[1.1 + 1.1j, 0 + 0j, 0 + 0j],
261+
[0 + 0j, 2.2 + 2.2j, 0 + 0j],
262+
[0 + 0j, 0 + 0j, 3.3 + 3.3j]], dtype=dtype)
263+
self.diagOp(x, dtype, expected_ans)
263264

264265
def testRankTwoIntTensor(self):
265266
x = np.array([[1, 2, 3], [4, 5, 6]])
@@ -286,17 +287,18 @@ def testRankTwoFloatTensor(self):
286287
self.diagOp(x, np.float64, expected_ans)
287288

288289
def testRankTwoComplexTensor(self):
289-
x = np.array([[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
290-
[4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], dtype = np.complex64)
291-
expected_ans = np.array(
292-
[[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
293-
[[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
294-
[[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]],
295-
[[[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]],
296-
[[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]],
297-
[[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
298-
dtype = np.complex64)
299-
self.diagOp(x, np.complex64, expected_ans)
290+
for dtype in [np.complex64, np.complex128]:
291+
x = np.array([[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
292+
[4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], dtype=dtype)
293+
expected_ans = np.array(
294+
[[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
295+
[[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
296+
[[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]],
297+
[[[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]],
298+
[[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]],
299+
[[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
300+
dtype=dtype)
301+
self.diagOp(x, dtype, expected_ans)
300302

301303
def testRankThreeFloatTensor(self):
302304
x = np.array([[[1.1, 2.2], [3.3, 4.4]],
@@ -314,28 +316,29 @@ def testRankThreeFloatTensor(self):
314316
self.diagOp(x, np.float64, expected_ans)
315317

316318
def testRankThreeComplexTensor(self):
317-
x = np.array([[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
318-
[[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
319-
dtype = np.complex64)
320-
expected_ans = np.array(
321-
[[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]],
322-
[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
323-
[[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]],
324-
[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]],
325-
[[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]],
319+
for dtype in [np.complex64, np.complex128]:
320+
x = np.array([[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
321+
[[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
322+
dtype=dtype)
323+
expected_ans = np.array(
324+
[[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]],
325+
[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
326+
[[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]],
327+
[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]],
328+
[[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]],
326329
[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
327-
[[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]],
328-
[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]]],
329-
[[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
330-
[[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]]],
331-
[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
332-
[[0 + 0j, 6.6 + 6.6j], [0 + 0j, 0 + 0j]]]],
333-
[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
330+
[[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]],
331+
[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]]],
332+
[[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
333+
[[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]]],
334+
[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
335+
[[0 + 0j, 6.6 + 6.6j], [0 + 0j, 0 + 0j]]]],
336+
[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
334337
[[0 + 0j, 0 + 0j], [7.7 + 7.7j, 0 + 0j]]],
335-
[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
336-
[[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
337-
dtype = np.complex64)
338-
self.diagOp(x, np.complex64, expected_ans)
338+
[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
339+
[[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
340+
dtype=dtype)
341+
self.diagOp(x, dtype, expected_ans)
339342

340343

341344
class DiagPartOpTest(tf.test.TestCase):

0 commit comments

Comments
 (0)