@@ -254,12 +254,13 @@ def testRankOneFloatTensor(self):
254
254
self .diagOp (x , np .float64 , expected_ans )
255
255
256
256
def testRankOneComplexTensor (self ):
257
- x = np .array ([1.1 + 1.1j , 2.2 + 2.2j , 3.3 + 3.3j ], dtype = np .complex64 )
258
- expected_ans = np .array (
259
- [[1.1 + 1.1j , 0 + 0j , 0 + 0j ],
260
- [0 + 0j , 2.2 + 2.2j , 0 + 0j ],
261
- [0 + 0j , 0 + 0j , 3.3 + 3.3j ]], dtype = np .complex64 )
262
- self .diagOp (x , np .complex64 , expected_ans )
257
+ for dtype in [np .complex64 , np .complex128 ]:
258
+ x = np .array ([1.1 + 1.1j , 2.2 + 2.2j , 3.3 + 3.3j ], dtype = dtype )
259
+ expected_ans = np .array (
260
+ [[1.1 + 1.1j , 0 + 0j , 0 + 0j ],
261
+ [0 + 0j , 2.2 + 2.2j , 0 + 0j ],
262
+ [0 + 0j , 0 + 0j , 3.3 + 3.3j ]], dtype = dtype )
263
+ self .diagOp (x , dtype , expected_ans )
263
264
264
265
def testRankTwoIntTensor (self ):
265
266
x = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
@@ -286,17 +287,18 @@ def testRankTwoFloatTensor(self):
286
287
self .diagOp (x , np .float64 , expected_ans )
287
288
288
289
def testRankTwoComplexTensor (self ):
289
- x = np .array ([[1.1 + 1.1j , 2.2 + 2.2j , 3.3 + 3.3j ],
290
- [4.4 + 4.4j , 5.5 + 5.5j , 6.6 + 6.6j ]], dtype = np .complex64 )
291
- expected_ans = np .array (
292
- [[[[1.1 + 1.1j , 0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j , 0 + 0j ]],
293
- [[0 + 0j , 2.2 + 2.2j , 0 + 0j ], [0 + 0j , 0 + 0j , 0 + 0j ]],
294
- [[0 + 0j , 0 + 0j , 3.3 + 3.3j ], [0 + 0j , 0 + 0j , 0 + 0j ]]],
295
- [[[0 + 0j , 0 + 0j , 0 + 0j ], [4.4 + 4.4j , 0 + 0j , 0 + 0j ]],
296
- [[0 + 0j , 0 + 0j , 0 + 0j ], [0 + 0j , 5.5 + 5.5j , 0 + 0j ]],
297
- [[0 + 0j , 0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j , 6.6 + 6.6j ]]]],
298
- dtype = np .complex64 )
299
- self .diagOp (x , np .complex64 , expected_ans )
290
+ for dtype in [np .complex64 , np .complex128 ]:
291
+ x = np .array ([[1.1 + 1.1j , 2.2 + 2.2j , 3.3 + 3.3j ],
292
+ [4.4 + 4.4j , 5.5 + 5.5j , 6.6 + 6.6j ]], dtype = dtype )
293
+ expected_ans = np .array (
294
+ [[[[1.1 + 1.1j , 0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j , 0 + 0j ]],
295
+ [[0 + 0j , 2.2 + 2.2j , 0 + 0j ], [0 + 0j , 0 + 0j , 0 + 0j ]],
296
+ [[0 + 0j , 0 + 0j , 3.3 + 3.3j ], [0 + 0j , 0 + 0j , 0 + 0j ]]],
297
+ [[[0 + 0j , 0 + 0j , 0 + 0j ], [4.4 + 4.4j , 0 + 0j , 0 + 0j ]],
298
+ [[0 + 0j , 0 + 0j , 0 + 0j ], [0 + 0j , 5.5 + 5.5j , 0 + 0j ]],
299
+ [[0 + 0j , 0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j , 6.6 + 6.6j ]]]],
300
+ dtype = dtype )
301
+ self .diagOp (x , dtype , expected_ans )
300
302
301
303
def testRankThreeFloatTensor (self ):
302
304
x = np .array ([[[1.1 , 2.2 ], [3.3 , 4.4 ]],
@@ -314,28 +316,29 @@ def testRankThreeFloatTensor(self):
314
316
self .diagOp (x , np .float64 , expected_ans )
315
317
316
318
def testRankThreeComplexTensor (self ):
317
- x = np .array ([[[1.1 + 1.1j , 2.2 + 2.2j ], [3.3 + 3.3j , 4.4 + 4.4j ]],
318
- [[5.5 + 5.5j , 6.6 + 6.6j ], [7.7 + 7.7j , 8.8 + 8.8j ]]],
319
- dtype = np .complex64 )
320
- expected_ans = np .array (
321
- [[[[[[1.1 + 1.1j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
322
- [[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]]],
323
- [[[0 + 0j , 2.2 + 2.2j ], [0 + 0j , 0 + 0j ]],
324
- [[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]]]],
325
- [[[[0 + 0j , 0 + 0j ], [3.3 + 3.3j , 0 + 0j ]],
319
+ for dtype in [np .complex64 , np .complex128 ]:
320
+ x = np .array ([[[1.1 + 1.1j , 2.2 + 2.2j ], [3.3 + 3.3j , 4.4 + 4.4j ]],
321
+ [[5.5 + 5.5j , 6.6 + 6.6j ], [7.7 + 7.7j , 8.8 + 8.8j ]]],
322
+ dtype = dtype )
323
+ expected_ans = np .array (
324
+ [[[[[[1.1 + 1.1j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
325
+ [[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]]],
326
+ [[[0 + 0j , 2.2 + 2.2j ], [0 + 0j , 0 + 0j ]],
327
+ [[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]]]],
328
+ [[[[0 + 0j , 0 + 0j ], [3.3 + 3.3j , 0 + 0j ]],
326
329
[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]]],
327
- [[[0 + 0j , 0 + 0j ], [0 + 0j , 4.4 + 4.4j ]],
328
- [[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]]]]],
329
- [[[[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
330
- [[5.5 + 5.5j , 0 + 0j ], [0 + 0j , 0 + 0j ]]],
331
- [[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
332
- [[0 + 0j , 6.6 + 6.6j ], [0 + 0j , 0 + 0j ]]]],
333
- [[[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
330
+ [[[0 + 0j , 0 + 0j ], [0 + 0j , 4.4 + 4.4j ]],
331
+ [[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]]]]],
332
+ [[[[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
333
+ [[5.5 + 5.5j , 0 + 0j ], [0 + 0j , 0 + 0j ]]],
334
+ [[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
335
+ [[0 + 0j , 6.6 + 6.6j ], [0 + 0j , 0 + 0j ]]]],
336
+ [[[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
334
337
[[0 + 0j , 0 + 0j ], [7.7 + 7.7j , 0 + 0j ]]],
335
- [[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
336
- [[0 + 0j , 0 + 0j ], [0 + 0j , 8.8 + 8.8j ]]]]]],
337
- dtype = np . complex64 )
338
- self .diagOp (x , np . complex64 , expected_ans )
338
+ [[[0 + 0j , 0 + 0j ], [0 + 0j , 0 + 0j ]],
339
+ [[0 + 0j , 0 + 0j ], [0 + 0j , 8.8 + 8.8j ]]]]]],
340
+ dtype = dtype )
341
+ self .diagOp (x , dtype , expected_ans )
339
342
340
343
341
344
class DiagPartOpTest (tf .test .TestCase ):
0 commit comments