-
Notifications
You must be signed in to change notification settings - Fork 943
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 4 of 4 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @syt123450)
src/ops/boolean_mask.ts, line 48 at r1 (raw file):
const $mask = convertToTensor(mask, 'mask', 'boolMask'); const axisFrom = axis === undefined ? 0 : axis;
axis == null
src/ops/boolean_mask.ts, line 54 at r1 (raw file):
util.assert( maskDim > 0, () => 'mask cannot be scalar'
are you using our clang-formatter here?
src/ops/boolean_mask.ts, line 58 at r1 (raw file):
util.assertShapesMatch( tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, 'mask\'s shape must match the first K dimensions of tensor\'s shape,'
use backticks so you dont have to escape
src/ops/boolean_mask.ts, line 70 at r1 (raw file):
const reshapedMask = $mask.reshape([-1]); const truePositions = whereImpl( [leadingSize], reshapedMask.dataSync());
this method should be async and you should use .data()
src/ops/boolean_mask.ts, line 78 at r1 (raw file):
const res = ENGINE.runKernel(b => b.gather(reshapedTensor as Tensor, gatherIndices, axisFrom),
no need to call runKernel here you can just call gather directly
Thanks for the PR! Do you know about our GCP test check? If you look above you will see "Trigger: ed916b23-8e01-4fda-85af-b7c8627d15a1 Failing after 1m — Summary". If you click "Details" you will see what the failure is. You should get that to pass before I can merge, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @nsthorat , thanks for the review. Based on your suggestions, I updated accordingly and made the tests pass. Could you please take a look at this PR again?
Reviewable status: 0 of 1 approvals obtained
src/ops/boolean_mask.ts, line 48 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
axis == null
Done.
src/ops/boolean_mask.ts, line 54 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
are you using our clang-formatter here?
Done. Format code with vscode clang-format extension.
src/ops/boolean_mask.ts, line 70 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
this method should be async and you should use .data()
Done. Refactored booleanMask op to be async, updated relative tests. As booleanMask is an async function now, I directly use tf.whereAsync to replace previous whereImpl, tf.squeeze to replace util.squeeze.
src/ops/boolean_mask.ts, line 78 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
no need to call runKernel here you can just call gather directly
Done. Call gather op directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @syt123450)
src/ops/boolean_mask.ts, line 43 at r3 (raw file):
*/ /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ async function booleanMask_<T extends Tensor, U extends Tensor>(
You don't use any of these generics anywhere
src/ops/boolean_mask.ts, line 71 at r3 (raw file):
} export const booleanMask = booleanMask_;
since you're not wrappign this in op() anymore let's make sure that this method doesn't leak any tensors by using tidys() and disposes() and adding a unit test that checks with tf.memory().numTensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @nsthorat , follow your reviews, I made relative update to this PR, could you take a look at this PR again? Thanks!
Reviewable status: 0 of 1 approvals obtained
src/ops/boolean_mask.ts, line 43 at r3 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
You don't use any of these generics anywhere
Done.
Removed the unused generics signature in this commit b3bbeac.
src/ops/boolean_mask.ts, line 71 at r3 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
since you're not wrappign this in op() anymore let's make sure that this method doesn't leak any tensors by using tidys() and disposes() and adding a unit test that checks with tf.memory().numTensors
Done.
I disposed some intermediate tensor to avoid potential memory leak, and added a relative unit test to ensure no memory leak in this commit 79b0266. The way to avoid memory leak was inspired by tf.whereAsync.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 1 of 1 files at r2.
Reviewable status:complete! 1 of 1 approvals obtained (waiting on @syt123450)
src/ops/boolean_mask_test.ts, line 77 at r4 (raw file):
let resultPromise: Promise<Tensor> = null; tf.tidy(() => {
lets remove this tidy here since this isn't normally how this will be used
Once you fix this LGTM and I will merge! In the future though it would be good to bug us about ops since we don't want to just add ops for ops sake, they should be used in a model :) |
Oh looks like someone asked for it, nevermind :) |
Is it`s possible to add set function, this method can only get elements, but I want to get and modify and save back to its original place, this will be very useful when doing image processing. |
Sorry for late reply, thanks for your following PR @nsthorat ! |
FEATURE
Add tf.booleanMask op. Feature requested in tensorflow/tfjs#380.
Reference:
This change is