Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit cac5b15

Browse files
syt123450Nikhil Thorat
authored and
Nikhil Thorat
committed
Add booleanMask op (#1749)
FEATURE Add tf.booleanMask op. Feature requested in [tensorflow/tfjs#380](tensorflow/tfjs#380). Reference: * [tf.boolean_mask documentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/boolean_mask?hl=en) * [tf.boolean_mask tensorflow python implementation](https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/array_ops.py#L1274)
1 parent eb2ab45 commit cac5b15

File tree

4 files changed

+208
-0
lines changed

4 files changed

+208
-0
lines changed

src/ops/boolean_mask.ts

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Tensor} from '../tensor';
19+
import {convertToTensor} from '../tensor_util_env';
20+
import {TensorLike} from '../types';
21+
import * as util from '../util';
22+
23+
import {whereAsync} from './logical_ops';
24+
import {gather} from './segment_ops';
25+
26+
/**
27+
* Apply boolean mask to tensor.
28+
*
29+
* ```js
30+
* const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
31+
* const mask = tf.tensor1d([1, 0, 1], 'bool');
32+
* const result = await tf.booleanMask(tensor, mask);
33+
* result.print();
34+
* ```
35+
*
36+
* @param N-D tensor.
37+
* @param mask K-D boolean tensor, K <= N and K must be known statically.
38+
* @param axis A 0-D int Tensor representing the axis in tensor to mask from.
39+
* By default, axis is 0 which will mask from the first dimension.
40+
* Otherwise K + axis <= N.
41+
*/
42+
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
43+
async function booleanMask_(
44+
tensor: Tensor|TensorLike, mask: Tensor|TensorLike,
45+
axis?: number): Promise<Tensor> {
46+
const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
47+
const $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
48+
49+
const axisFrom = axis == null ? 0 : axis;
50+
const maskDim = $mask.rank;
51+
const tensorShape = $tensor.shape;
52+
53+
util.assert(maskDim > 0, () => 'mask cannot be scalar');
54+
util.assertShapesMatch(
55+
tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape,
56+
`mask's shape must match the first K dimensions of tensor's shape,`);
57+
58+
let leadingSize = 1;
59+
for (let i = axisFrom; i < axisFrom + maskDim; i++) {
60+
leadingSize *= tensorShape[i];
61+
}
62+
const targetTensorShape =
63+
tensorShape.slice(0, axisFrom)
64+
.concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
65+
const reshapedTensor = $tensor.reshape(targetTensorShape);
66+
const reshapedMask = $mask.reshape([-1]);
67+
const positivePositions = await whereAsync(reshapedMask);
68+
const indices = positivePositions.squeeze([1]);
69+
70+
const res = gather(reshapedTensor, indices, axisFrom);
71+
72+
// Ensure no memory leak.
73+
if (tensor !== $tensor) {
74+
$tensor.dispose();
75+
}
76+
if (mask !== $mask) {
77+
$mask.dispose();
78+
}
79+
indices.dispose();
80+
reshapedTensor.dispose();
81+
reshapedMask.dispose();
82+
positivePositions.dispose();
83+
84+
return res;
85+
}
86+
87+
export const booleanMask = booleanMask_;

src/ops/boolean_mask_test.ts

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '../index';
19+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20+
import {Tensor} from '../tensor';
21+
import {expectArraysClose} from '../test_util';
22+
23+
describeWithFlags('booleanMask', ALL_ENVS, () => {
24+
it('1d array, 1d mask, default axis', async () => {
25+
const array = tf.tensor1d([1, 2, 3]);
26+
const mask = tf.tensor1d([1, 0, 1], 'bool');
27+
const result = await tf.booleanMask(array, mask);
28+
expect(result.shape).toEqual([2]);
29+
expect(result.dtype).toBe('float32');
30+
expectArraysClose(await result.data(), [1, 3]);
31+
});
32+
33+
it('2d array, 1d mask, default axis', async () => {
34+
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
35+
const mask = tf.tensor1d([1, 0, 1], 'bool');
36+
const result = await tf.booleanMask(array, mask);
37+
expect(result.shape).toEqual([2, 2]);
38+
expect(result.dtype).toBe('float32');
39+
expectArraysClose(await result.data(), [1, 2, 5, 6]);
40+
});
41+
42+
it('2d array, 2d mask, default axis', async () => {
43+
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
44+
const mask = tf.tensor2d([1, 0, 1, 0, 1, 0], [3, 2], 'bool');
45+
const result = await tf.booleanMask(array, mask);
46+
expect(result.shape).toEqual([3]);
47+
expect(result.dtype).toBe('float32');
48+
expectArraysClose(await result.data(), [1, 3, 5]);
49+
});
50+
51+
it('2d array, 1d mask, axis=1', async () => {
52+
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
53+
const mask = tf.tensor1d([0, 1], 'bool');
54+
const axis = 1;
55+
const result = await tf.booleanMask(array, mask, axis);
56+
expect(result.shape).toEqual([3, 1]);
57+
expect(result.dtype).toBe('float32');
58+
expectArraysClose(await result.data(), [2, 4, 6]);
59+
});
60+
61+
it('accepts tensor-like object as array or mask', async () => {
62+
const array = [[1, 2], [3, 4], [5, 6]];
63+
const mask = [1, 0, 1];
64+
const result = await tf.booleanMask(array, mask);
65+
expect(result.shape).toEqual([2, 2]);
66+
expect(result.dtype).toBe('float32');
67+
expectArraysClose(await result.data(), [1, 2, 5, 6]);
68+
});
69+
70+
it('ensure no memory leak', async () => {
71+
const numTensorsBefore = tf.memory().numTensors;
72+
73+
const array = tf.tensor1d([1, 2, 3]);
74+
const mask = tf.tensor1d([1, 0, 1], 'bool');
75+
let resultPromise: Promise<Tensor> = null;
76+
77+
tf.tidy(() => {
78+
resultPromise = tf.booleanMask(array, mask);
79+
});
80+
81+
const result = await resultPromise;
82+
expect(result.shape).toEqual([2]);
83+
expect(result.dtype).toBe('float32');
84+
expectArraysClose(await result.data(), [1, 3]);
85+
array.dispose();
86+
mask.dispose();
87+
result.dispose();
88+
89+
const numTensorsAfter = tf.memory().numTensors;
90+
expect(numTensorsAfter).toBe(numTensorsBefore);
91+
});
92+
93+
it('should throw if mask is scalar', async () => {
94+
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
95+
const mask = tf.scalar(1, 'bool');
96+
let errorMessage = 'No error thrown.';
97+
try {
98+
await tf.booleanMask(array, mask);
99+
} catch (error) {
100+
errorMessage = error.message;
101+
}
102+
expect(errorMessage).toBe('mask cannot be scalar');
103+
});
104+
105+
it('should throw if array and mask shape miss match', async () => {
106+
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
107+
const mask = tf.tensor2d([1, 0], [1, 2], 'bool');
108+
let errorMessage = 'No error thrown.';
109+
try {
110+
await tf.booleanMask(array, mask);
111+
} catch (error) {
112+
errorMessage = error.message;
113+
}
114+
expect(errorMessage)
115+
.toBe(
116+
`mask's shape must match the first K ` +
117+
`dimensions of tensor's shape, Shapes 3,2 and 1,2 must match`);
118+
});
119+
});

src/ops/ops.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
export * from './batchnorm';
19+
export * from './boolean_mask';
1920
export * from './complex_ops';
2021
export * from './concat_split';
2122
export * from './conv';

src/tests.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import './ops/array_ops_test';
4141
import './ops/axis_util_test';
4242
import './ops/batchnorm_test';
4343
import './ops/binary_ops_test';
44+
import './ops/boolean_mask_test';
4445
import './ops/broadcast_util_test';
4546
import './ops/clone_test';
4647
import './ops/compare_ops_test';

0 commit comments

Comments
 (0)