Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 707a669

Browse files
kedevkeddsmilkov
authored andcommitted
add tf.diag (#1256)
cla: yes Add tf.diag
1 parent 975e5f6 commit 707a669

File tree

8 files changed

+219
-2
lines changed

8 files changed

+219
-2
lines changed

src/backends/backend.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,10 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
609609
throw new Error('Not yet implemented');
610610
}
611611

612+
diag(x: Tensor): Tensor {
613+
throw new Error('Not yet implemented');
614+
}
615+
612616
fill<R extends Rank>(
613617
shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor<R> {
614618
throw new Error('Not yet implemented.');

src/backends/cpu/backend_cpu.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {warn} from '../../log';
2323
import * as array_ops_util from '../../ops/array_ops_util';
2424
import * as axis_util from '../../ops/axis_util';
2525
import * as broadcast_util from '../../ops/broadcast_util';
26+
import {complex, imag, real} from '../../ops/complex_ops';
2627
import * as concat_util from '../../ops/concat_util';
2728
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
2829
import * as erf_util from '../../ops/erf_util';
@@ -45,7 +46,6 @@ import {split} from '../split_shared';
4546
import {tile} from '../tile_impl';
4647
import {topkImpl} from '../topk_impl';
4748
import {whereImpl} from '../where_impl';
48-
import {real, imag, complex} from '../../ops/complex_ops';
4949

5050
function mapActivation(
5151
backend: MathBackendCPU, x: Tensor, activation: Activation,
@@ -343,6 +343,16 @@ export class MathBackendCPU implements KernelBackend {
343343
return buffer.toTensor().reshape(shape) as T;
344344
}
345345

346+
diag(x: Tensor): Tensor {
347+
const xVals = this.readSync(x.dataId) as TypedArray;
348+
const buffer = ops.buffer([x.size, x.size], x.dtype);
349+
const vals = buffer.values;
350+
for (let i = 0; i < xVals.length; i++) {
351+
vals[i * x.size + i] = xVals[i];
352+
}
353+
return buffer.toTensor();
354+
}
355+
346356
unstack(x: Tensor, axis: number): Tensor[] {
347357
const num = x.shape[axis];
348358
const outShape: number[] = new Array(x.rank - 1);

src/backends/webgl/backend_webgl.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import {warn} from '../../log';
2626
import {buffer} from '../../ops/array_ops';
2727
import * as array_ops_util from '../../ops/array_ops_util';
2828
import * as axis_util from '../../ops/axis_util';
29+
import {complex, imag, real} from '../../ops/complex_ops';
2930
import {computeOutShape} from '../../ops/concat_util';
3031
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
3132
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
@@ -78,6 +79,7 @@ import {CumSumProgram} from './cumsum_gpu';
7879
import {DecodeMatrixProgram} from './decode_matrix_gpu';
7980
import {DecodeMatrixPackedProgram} from './decode_matrix_packed_gpu';
8081
import {DepthToSpaceProgram} from './depth_to_space_gpu';
82+
import {DiagProgram} from './diag_gpu';
8183
import {EncodeFloatProgram} from './encode_float_gpu';
8284
import {EncodeFloatPackedProgram} from './encode_float_packed_gpu';
8385
import {EncodeMatrixProgram} from './encode_matrix_gpu';
@@ -131,7 +133,6 @@ import * as unary_packed_op from './unaryop_packed_gpu';
131133
import {UnaryOpPackedProgram} from './unaryop_packed_gpu';
132134
import {UnpackProgram} from './unpack_gpu';
133135
import * as webgl_util from './webgl_util';
134-
import {real, imag, complex} from '../../ops/complex_ops';
135136

136137
type KernelInfo = {
137138
name: string; query: Promise<number>;
@@ -2208,6 +2209,11 @@ export class MathBackendWebGL implements KernelBackend {
22082209
return this.compileAndRun(program, [indices]);
22092210
}
22102211

2212+
diag(x: Tensor): Tensor {
2213+
const program = new DiagProgram(x.size);
2214+
return this.compileAndRun(program, [x]);
2215+
}
2216+
22112217
nonMaxSuppression(
22122218
boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,
22132219
iouThreshold: number, scoreThreshold: number): Tensor1D {

src/backends/webgl/diag_gpu.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {GPGPUProgram} from './gpgpu_math';
19+
20+
export class DiagProgram implements GPGPUProgram {
21+
variableNames = ['X'];
22+
outputShape: number[];
23+
userCode: string;
24+
25+
constructor(size: number) {
26+
this.outputShape = [size, size];
27+
this.userCode = `
28+
void main() {
29+
ivec2 coords = getOutputCoords();
30+
float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
31+
setOutput(val);
32+
}
33+
`;
34+
}
35+
}

src/ops/diag.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ENGINE} from '../engine';
19+
import {Tensor} from '../tensor';
20+
import {convertToTensor} from '../tensor_util_env';
21+
import {op} from './operation';
22+
23+
/**
24+
* Returns a diagonal tensor with a given diagonal values.
25+
*
26+
* Given a diagonal, this operation returns a tensor with the diagonal and
27+
* everything else padded with zeros.
28+
*
29+
* Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
30+
* of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
31+
*
32+
* ```js
33+
* const x = tf.tensor1d([1, 2, 3, 4]);
34+
*
35+
* tf.diag(x).print()
36+
* ```
37+
* ```js
38+
* const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])
39+
*
40+
* tf.diag(x).print()
41+
* ```
42+
* @param x The input tensor.
43+
*/
44+
function diag_(x: Tensor): Tensor {
45+
const $x = convertToTensor(x, 'x', 'diag').flatten();
46+
const outShape = [...x.shape, ...x.shape];
47+
return ENGINE.runKernel(backend => backend.diag($x), {$x}).reshape(outShape);
48+
}
49+
50+
export const diag = op({diag_});

src/ops/diag_test.ts

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import * as tf from '../index';
18+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
19+
import {expectArraysClose, expectArraysEqual} from '../test_util';
20+
21+
describeWithFlags('diag', ALL_ENVS, () => {
22+
it('1d', async () => {
23+
const m = tf.tensor1d([5]);
24+
const diag = tf.diag(m);
25+
expect(diag.shape).toEqual([1, 1]);
26+
expectArraysClose(await diag.data(), [5]);
27+
});
28+
it('2d', async () => {
29+
const m = tf.tensor2d([8, 2, 3, 4, 5, 1], [3, 2]);
30+
const diag = tf.diag(m);
31+
expect(diag.shape).toEqual([3, 2, 3, 2]);
32+
expectArraysClose(await diag.data(), [
33+
8, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0,
34+
0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 1
35+
]);
36+
});
37+
it('3d', async () => {
38+
const m = tf.tensor3d([8, 5, 5, 7, 9, 10, 15, 1, 2, 14, 12, 3], [2, 2, 3]);
39+
const diag = tf.diag(m);
40+
expect(diag.shape).toEqual([2, 2, 3, 2, 2, 3]);
41+
expectArraysClose(await diag.data(), [
42+
8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0,
43+
0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0,
44+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
45+
0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0,
46+
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
47+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0,
48+
0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
49+
]);
50+
});
51+
it('4d', async () => {
52+
const m = tf.tensor4d(
53+
[
54+
8, 5, 5, 7, 9, 10, 15, 1, 2, 14, 12, 3,
55+
9, 6, 6, 8, 10, 11, 16, 2, 3, 15, 13, 4
56+
],
57+
[2, 2, 3, 2]);
58+
const diag = tf.diag(m);
59+
expect(diag.shape).toEqual([2, 2, 3, 2, 2, 2, 3, 2]);
60+
expectArraysClose(await diag.data(), [
61+
8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62+
0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
63+
0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
64+
0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
65+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0,
66+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0,
67+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0,
68+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
69+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
70+
0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
71+
0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
72+
0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
73+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
74+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0,
75+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0,
76+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0,
77+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
78+
0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79+
0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
80+
0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
82+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0,
83+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
84+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0,
85+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
86+
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87+
0, 0, 0, 4
88+
]);
89+
});
90+
it('int32', async () => {
91+
const m = tf.tensor1d([5, 3], 'int32');
92+
const diag = tf.diag(m);
93+
expect(diag.shape).toEqual([2, 2]);
94+
expect(diag.dtype).toBe('int32');
95+
expectArraysEqual(await diag.data(), [5, 0, 0, 3]);
96+
});
97+
it('bool', async () => {
98+
const m = tf.tensor1d([5, 3], 'bool');
99+
const diag = tf.diag(m);
100+
expect(diag.shape).toEqual([2, 2]);
101+
expect(diag.dtype).toBe('bool');
102+
expectArraysEqual(await diag.data(), [1, 0, 0, 1]);
103+
});
104+
it('complex', () => {
105+
const real = tf.tensor1d([2.25]);
106+
const imag = tf.tensor1d([4.75]);
107+
const m = tf.complex(real, imag);
108+
expect(() => tf.diag(m)).toThrowError();
109+
});
110+
});

src/ops/ops.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export * from './scatter_nd';
4444
export * from './spectral_ops';
4545
export * from './sparse_to_dense';
4646
export * from './gather_nd';
47+
export * from './diag';
4748
export * from './dropout';
4849
export * from './signal_ops';
4950

src/tests.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ import './ops/conv2d_test';
5555
import './ops/conv2d_transpose_test';
5656
import './ops/conv3d_test';
5757
import './ops/conv_util_test';
58+
import './ops/diag_test';
5859
import './ops/dropout_test';
5960
import './ops/fused_test';
6061
import './ops/gather_nd_test';

0 commit comments

Comments
 (0)