Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Add inTopK op #1734

Merged
merged 6 commits into from
Aug 7, 2019
Merged

Add inTopK op #1734

merged 6 commits into from
Aug 7, 2019

Conversation

syt123450
Copy link
Contributor

@syt123450 syt123450 commented May 2, 2019

FEATURE

This PR add inTopK op, which behaves the same way as tf.math.in_top_k in TensorFlow. This op help develop further metrics which depend on inTopK operation, such as, topKCategoricalAccuracy (feature requested in tensorflow/tfjs#27 ), sparseTopKCategoricalAccuracy (feature requested in tensorflow/tfjs#26). Relative PR tensorflow/tfjs-layers#537

This PR:

  • Add new inTopK op to src/ops
  • Register inTopK in src/ops/ops.ts
  • Add inTopK kernel to backend
  • Add shared inTopK implementation between webgl and cpu
  • Add relative tests for inTopK

Reference:


This change is Reviewable

Copy link
Contributor

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 2 of 8 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @syt123450)


src/backends/webgl/backend_webgl.ts, line 1323 at r1 (raw file):

  inTopK<T extends Tensor, U extends Tensor>(
      predictions: T, targets: U, k: number): U {
    const predictionsVals = predictions.dataSync();

We made a mistake with topk, it actually should be async. Can you make these call .data() instead and return a Promise?

Copy link
Contributor

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 8 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @syt123450)


src/backends/webgl/backend_webgl.ts, line 1323 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

We made a mistake with topk, it actually should be async. Can you make these call .data() instead and return a Promise?

Oh hm... we can't make the kernel's asynchronous.. let me think about this.

@syt123450
Copy link
Contributor Author

Hi @nsthorat , it seems that in backend_cpu, this.readSync is widely used to synchronously read data, so I replaced the previous .dataSync() with this.readSync, and all tests passed. While in backend_webgl, the .dataSync() still exists in nonMaxSuppression, floatPrecision, where, topk, if you figure out how to solve this synchronous call, please let me know, I would apply your way to fix it.

@nsthorat
Copy link
Contributor

nsthorat commented Aug 7, 2019

I'm going to merge this and send a follow up to remove this as a kernel, apologies. This will be fine for now.

Thanks for the PR!

@nsthorat nsthorat merged commit e4d7607 into tensorflow:master Aug 7, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants