Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Add booleanMask op #1749

Merged
merged 15 commits into from
Aug 7, 2019
Merged

Add booleanMask op #1749

merged 15 commits into from
Aug 7, 2019

Conversation

syt123450
Copy link
Contributor

@syt123450 syt123450 commented May 8, 2019

FEATURE

Add tf.booleanMask op. Feature requested in tensorflow/tfjs#380.

Reference:


This change is Reviewable

Copy link
Contributor

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 4 of 4 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @syt123450)


src/ops/boolean_mask.ts, line 48 at r1 (raw file):

  const $mask = convertToTensor(mask, 'mask', 'boolMask');

  const axisFrom = axis === undefined ? 0 : axis;

axis == null


src/ops/boolean_mask.ts, line 54 at r1 (raw file):

  util.assert(
      maskDim > 0,
      () => 'mask cannot be scalar'

are you using our clang-formatter here?


src/ops/boolean_mask.ts, line 58 at r1 (raw file):

  util.assertShapesMatch(
      tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape,
      'mask\'s shape must match the first K dimensions of tensor\'s shape,'

use backticks so you dont have to escape


src/ops/boolean_mask.ts, line 70 at r1 (raw file):

  const reshapedMask = $mask.reshape([-1]);
  const truePositions = whereImpl(
      [leadingSize], reshapedMask.dataSync());

this method should be async and you should use .data()


src/ops/boolean_mask.ts, line 78 at r1 (raw file):

  const res =
      ENGINE.runKernel(b =>
              b.gather(reshapedTensor as Tensor, gatherIndices, axisFrom),

no need to call runKernel here you can just call gather directly

@nsthorat
Copy link
Contributor

Thanks for the PR! Do you know about our GCP test check? If you look above you will see "Trigger: ed916b23-8e01-4fda-85af-b7c8627d15a1 Failing after 1m — Summary". If you click "Details" you will see what the failure is.

You should get that to pass before I can merge, thanks!

Copy link
Contributor Author

@syt123450 syt123450 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nsthorat , thanks for the review. Based on your suggestions, I updated accordingly and made the tests pass. Could you please take a look at this PR again?

Reviewable status: 0 of 1 approvals obtained


src/ops/boolean_mask.ts, line 48 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

axis == null

Done.


src/ops/boolean_mask.ts, line 54 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

are you using our clang-formatter here?

Done. Format code with vscode clang-format extension.


src/ops/boolean_mask.ts, line 70 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

this method should be async and you should use .data()

Done. Refactored booleanMask op to be async, updated relative tests. As booleanMask is an async function now, I directly use tf.whereAsync to replace previous whereImpl, tf.squeeze to replace util.squeeze.


src/ops/boolean_mask.ts, line 78 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

no need to call runKernel here you can just call gather directly

Done. Call gather op directly.

Copy link
Contributor

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @syt123450)


src/ops/boolean_mask.ts, line 43 at r3 (raw file):

 */
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
async function booleanMask_<T extends Tensor, U extends Tensor>(

You don't use any of these generics anywhere


src/ops/boolean_mask.ts, line 71 at r3 (raw file):

}

export const booleanMask = booleanMask_;

since you're not wrappign this in op() anymore let's make sure that this method doesn't leak any tensors by using tidys() and disposes() and adding a unit test that checks with tf.memory().numTensors

Copy link
Contributor Author

@syt123450 syt123450 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nsthorat , follow your reviews, I made relative update to this PR, could you take a look at this PR again? Thanks!

Reviewable status: 0 of 1 approvals obtained


src/ops/boolean_mask.ts, line 43 at r3 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

You don't use any of these generics anywhere

Done.

Removed the unused generics signature in this commit b3bbeac.


src/ops/boolean_mask.ts, line 71 at r3 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

since you're not wrappign this in op() anymore let's make sure that this method doesn't leak any tensors by using tidys() and disposes() and adding a unit test that checks with tf.memory().numTensors

Done.

I disposed some intermediate tensor to avoid potential memory leak, and added a relative unit test to ensure no memory leak in this commit 79b0266. The way to avoid memory leak was inspired by tf.whereAsync.

Copy link
Contributor

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 1 files at r2.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @syt123450)


src/ops/boolean_mask_test.ts, line 77 at r4 (raw file):

    let resultPromise: Promise<Tensor> = null;

    tf.tidy(() => {

lets remove this tidy here since this isn't normally how this will be used

@nsthorat
Copy link
Contributor

Once you fix this LGTM and I will merge!

In the future though it would be good to bug us about ops since we don't want to just add ops for ops sake, they should be used in a model :)

@nsthorat
Copy link
Contributor

Oh looks like someone asked for it, nevermind :)

@wu-jingtao
Copy link

Is it`s possible to add set function, this method can only get elements, but I want to get and modify and save back to its original place, this will be very useful when doing image processing.

@nsthorat nsthorat merged commit cac5b15 into tensorflow:master Aug 7, 2019
@syt123450
Copy link
Contributor Author

Sorry for late reply, thanks for your following PR @nsthorat !

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants