Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit e4d7607

Browse files
syt123450Nikhil Thorat
authored and
Nikhil Thorat
committed
Add inTopK op (#1734)
FEATURE This PR add inTopK op, which behaves the same way as [tf.math.in_top_k](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases) in TensorFlow. This op help develop further metrics which depend on inTopK operation, such as, topKCategoricalAccuracy (feature requested in [tensorflow/tfjs#27](tensorflow/tfjs#27) ), sparseTopKCategoricalAccuracy (feature requested in [tensorflow/tfjs#26](tensorflow/tfjs#26)). Relative PR [tensorflow/tfjs-layers#537](tensorflow/tfjs-layers#537) This PR: * Add new inTopK op to [src/ops](https://github.com/tensorflow/tfjs-core/tree/master/src/ops) * Register inTopK in [src/ops/ops.ts](https://github.com/tensorflow/tfjs-core/blob/master/src/ops/ops.ts) * Add inTopK kernel to backend * Add shared inTopK implementation between webgl and cpu * Add relative tests for inTopK Reference: * [TensorFlow in_top_k doc](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases) * [TensorFlow in_top_k implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/in_topk_op.cc) * [TensorFlow in_top_k test cases](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/nn_ops_test.cc#L442)
1 parent 8a6d4d5 commit e4d7607

File tree

8 files changed

+268
-0
lines changed

8 files changed

+268
-0
lines changed

src/backends/backend.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
241241
throw new Error('Not yet implemented');
242242
}
243243

244+
inTopK<T extends Tensor, U extends Tensor>(
245+
predictions: T, targets: U, k: number): U {
246+
throw new Error('Not yet implemented');
247+
}
248+
244249
min(x: Tensor, axes: number[]): Tensor {
245250
throw new Error('Not yet implemented');
246251
}

src/backends/cpu/backend_cpu.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
4141
import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend';
4242
import * as backend_util from '../backend_util';
4343
import * as complex_util from '../complex_util';
44+
import {inTopKImpl} from '../inTopK_impl';
4445
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
4546
import {split} from '../split_shared';
4647
import {tile} from '../tile_impl';
@@ -858,6 +859,18 @@ export class MathBackendCPU implements KernelBackend {
858859
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
859860
}
860861

862+
inTopK<T extends Tensor, U extends Tensor>(
863+
predictions: T, targets: U, k: number): U {
864+
this.assertNotComplex([predictions, targets], 'inTopK');
865+
866+
const predictionsVals = this.readSync(predictions.dataId) as TypedArray;
867+
const targetsVals = this.readSync(targets.dataId) as TypedArray;
868+
869+
return inTopKImpl(
870+
predictionsVals, predictions.shape, targetsVals, targets.shape,
871+
k) as U;
872+
}
873+
861874
min(x: Tensor, axes: number[]): Tensor {
862875
this.assertNotComplex(x, 'min');
863876

src/backends/inTopK_impl.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/** An implementation of the inTopK kernel shared between webgl and cpu. */
19+
20+
import {tensor} from '../ops/tensor_ops';
21+
import {Tensor} from '../tensor';
22+
import {TypedArray} from '../types';
23+
import {getTypedArrayFromDType} from '../util';
24+
25+
export function inTopKImpl<T extends Tensor>(
26+
predictionsVals: TypedArray, predictionsShape: number[],
27+
targetsVals: TypedArray, targetsShape: number[], k: number
28+
): T {
29+
// Reshape predictionsVals into a 2d tensor [batch, lastDim]
30+
// and look up topK along lastDim.
31+
const lastDim = predictionsShape[predictionsShape.length - 1];
32+
const [batch, size] = [predictionsVals.length / lastDim, lastDim];
33+
const precision = getTypedArrayFromDType('bool', batch);
34+
35+
for (let b = 0; b < batch; b++) {
36+
const offset = b * size;
37+
const vals = predictionsVals.subarray(offset, offset + size);
38+
const valAndInd: Array<{ value: number, index: number }> = [];
39+
for (let i = 0; i < vals.length; i++) {
40+
valAndInd.push({value: vals[i], index: i});
41+
}
42+
valAndInd.sort((a, b) => b.value - a.value);
43+
44+
precision[b] = 0;
45+
for (let i = 0; i < k; i++) {
46+
if (valAndInd[i].index === targetsVals[b]) {
47+
precision[b] = 1;
48+
break;
49+
}
50+
}
51+
}
52+
53+
// Output precision has the same shape as targets.
54+
return tensor(precision, targetsShape, 'bool') as T;
55+
}

src/backends/webgl/backend_webgl.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} fr
4444
import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';
4545
import * as backend_util from '../backend_util';
4646
import {mergeRealAndImagArrays} from '../complex_util';
47+
import {inTopKImpl} from '../inTopK_impl';
4748
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
4849
import {split} from '../split_shared';
4950
import {tile} from '../tile_impl';
@@ -1359,6 +1360,15 @@ export class MathBackendWebGL implements KernelBackend {
13591360
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
13601361
}
13611362

1363+
inTopK<T extends Tensor, U extends Tensor>(
1364+
predictions: T, targets: U, k: number): U {
1365+
const predictionsVals = predictions.dataSync();
1366+
const targetsVals = targets.dataSync();
1367+
return inTopKImpl(
1368+
predictionsVals, predictions.shape, targetsVals, targets.shape,
1369+
k) as U;
1370+
}
1371+
13621372
min(x: Tensor, axes: number[]): Tensor {
13631373
axis_util.assertAxesAreInnerMostDims('min', axes, x.rank);
13641374
const [outShape, reduceShape] =

src/ops/inTopK.ts

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ENGINE} from '../engine';
19+
import {NumericTensor, Tensor} from '../tensor';
20+
import {convertToTensor} from '../tensor_util_env';
21+
import {TensorLike} from '../types';
22+
import {assert, assertShapesMatch} from '../util';
23+
24+
import {op} from './operation';
25+
26+
/**
27+
* Says whether the targets are in the top K predictions.
28+
*
29+
* ```js
30+
* const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
31+
* const targets = tf.tensor1d([2, 0]);
32+
* const precision = tf.inTopK(predictions, targets);
33+
* precision.print();
34+
* ```
35+
* @param predictions 2-D or higher `tf.Tensor` with last dimension being
36+
* at least `k`.
37+
* @param targets 1-D or higher `tf.Tensor`.
38+
* @param k Optional Number of top elements to look at for computing precision,
39+
* default to 1.
40+
*/
41+
/** @doc {heading: 'Operations', subheading: 'Evaluation'} */
42+
function inTopK_<T extends Tensor, U extends Tensor>(
43+
predictions: T|TensorLike, targets: U|TensorLike, k = 1): U {
44+
const $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
45+
const $targets = convertToTensor(targets, 'targets', 'inTopK');
46+
47+
assert(
48+
$predictions.rank > 1,
49+
() => 'inTopK() expects the predictions to be of rank 2 or higher, ' +
50+
`but got ${$predictions.rank}`);
51+
assert(
52+
$predictions.rank - 1 === $targets.rank,
53+
() => `predictions' rank should be 1 larger than ` +
54+
`targets' rank, but got predictions' rank ` +
55+
`${$predictions.rank} and targets' rank ${$targets.rank}`);
56+
assertShapesMatch(
57+
$predictions.shape.slice(0, $predictions.shape.length - 1),
58+
$targets.shape,
59+
`predictions's shape should be align with the targets' shape, ` +
60+
'except the last dimension.');
61+
const lastDim = $predictions.shape[$predictions.shape.length - 1];
62+
assert(
63+
k > 0 && k <= lastDim,
64+
() => `'k' passed to inTopK() must be > 0 && <= the predictions' last ` +
65+
`dimension (${lastDim}), but got ${k}`);
66+
67+
const precision = ENGINE.runKernel(
68+
b =>
69+
b.inTopK($predictions as NumericTensor, $targets as NumericTensor, k),
70+
{$predictions, $targets});
71+
72+
return precision as U;
73+
}
74+
75+
export const inTopK = op({inTopK_});

src/ops/inTopK_test.ts

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '../index';
19+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20+
import {expectArraysClose} from '../test_util';
21+
22+
import {tensor1d, tensor2d, tensor3d} from './tensor_ops';
23+
24+
describeWithFlags('inTopK', ALL_ENVS, async () => {
25+
it('predictions 2d array, targets 1d array, with default k', async () => {
26+
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
27+
const targets = tensor1d([2, 0]);
28+
const precision = tf.inTopK(predictions, targets);
29+
expect(precision.shape).toEqual([2]);
30+
expect(precision.dtype).toBe('bool');
31+
expectArraysClose(await precision.data(), [1, 0]);
32+
});
33+
34+
it('predictions 2d array, targets 1d array, with k=2', async () => {
35+
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
36+
const targets = tensor1d([2, 0]);
37+
const k = 2;
38+
const precision = tf.inTopK(predictions, targets, k);
39+
expect(precision.shape).toEqual([2]);
40+
expect(precision.dtype).toBe('bool');
41+
expectArraysClose(await precision.data(), [1, 1]);
42+
});
43+
44+
it('predictions 3d array, targets 2d array, with default k', async () => {
45+
const predictions =
46+
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]);
47+
const targets = tensor2d([[1, 2], [0, 1]]);
48+
const precision = tf.inTopK(predictions, targets);
49+
expect(precision.shape).toEqual([2, 2]);
50+
expect(precision.dtype).toBe('bool');
51+
expectArraysClose(await precision.data(), [1, 1, 1, 0]);
52+
});
53+
54+
it('predictions 3d array, targets 2d array, with k=2', async () => {
55+
const predictions =
56+
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]);
57+
const targets = tensor2d([[1, 2], [0, 1]]);
58+
const k = 2;
59+
const precision = tf.inTopK(predictions, targets, k);
60+
expect(precision.shape).toEqual([2, 2]);
61+
expect(precision.dtype).toBe('bool');
62+
expectArraysClose(await precision.data(), [1, 1, 1, 1]);
63+
});
64+
65+
it('lower-index element count first, with default k', async () => {
66+
const predictions = tensor2d([[1, 2, 2, 1]]);
67+
68+
const targets1 = tensor1d([1]);
69+
const precision1 = tf.inTopK(predictions, targets1);
70+
expect(precision1.shape).toEqual([1]);
71+
expect(precision1.dtype).toBe('bool');
72+
expectArraysClose(await precision1.data(), [1]);
73+
74+
const targets2 = tensor1d([2]);
75+
const precision2 = tf.inTopK(predictions, targets2);
76+
expect(precision2.shape).toEqual([1]);
77+
expect(precision2.dtype).toBe('bool');
78+
expectArraysClose(await precision2.data(), [0]);
79+
});
80+
81+
it('accept tensor-like object, with default k', async () => {
82+
const predictions = [[20, 10, 40, 30], [30, 50, -20, 10]];
83+
const targets = [2, 0];
84+
const precision = tf.inTopK(predictions, targets);
85+
expect(precision.shape).toEqual([2]);
86+
expect(precision.dtype).toBe('bool');
87+
expectArraysClose(await precision.data(), [1, 0]);
88+
});
89+
90+
it('throws when predictions_rank <2', () => {
91+
const predictions = tensor1d([20, 10, 40, 30]);
92+
const targets = [2];
93+
expect(() => tf.inTopK(predictions, targets)).toThrowError();
94+
});
95+
96+
it('throws when prediction_rank != targets_rank + 1', () => {
97+
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
98+
const targets = tensor2d([[0], [0]]);
99+
expect(() => tf.inTopK(predictions, targets)).toThrowError();
100+
});
101+
102+
it('throws when k > size of last dimension of predictions', () => {
103+
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
104+
const targets = tensor1d([2, 0]);
105+
const k = 5;
106+
expect(() => tf.inTopK(predictions, targets, k)).toThrowError();
107+
});
108+
});

src/ops/ops.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export * from './gather_nd';
4848
export * from './diag';
4949
export * from './dropout';
5050
export * from './signal_ops';
51+
export * from './inTopK';
5152

5253
export {op} from './operation';
5354

src/tests.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ import './ops/dropout_test';
6161
import './ops/fused_test';
6262
import './ops/gather_nd_test';
6363
import './ops/image_ops_test';
64+
import './ops/inTopK_test';
6465
import './ops/linalg_ops_test';
6566
import './ops/logical_ops_test';
6667
import './ops/loss_ops_test';

0 commit comments

Comments
 (0)