Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit e70a33f

Browse files
syt123450dsmilkov
authored andcommitted
Add non-default noiseShape support to dropout op (#1782)
This PR is a followup PR for [tfjs-core#1343](#1343) which added `dropout` op, dropout feature initially requested in [tfjs#117](tensorflow/tfjs#117). This PR cleans up a TODO, `implements non default noise shape`. The implementation of noiseShape aligns with TensorFlow Python's dropout API. The `noiseShape` feature would benefit further dropout related development, for example, make `tf.layers.dropout` support `noiseShape` configuration. **Relative PR:** * Make dropout layer support non default noiseShape and seed [tensorflow/tfjs-layers#556](tensorflow/tfjs-layers#556) **Reference:** * [TensorFlow tf.nn.dropout documentation](https://www.tensorflow.org/api_docs/python/tf/nn/dropout) * [TensorFlow tf.nn.dropout implementation](https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/nn_ops.py#L2982-L3054) * [TensorFlow _get_noise_shape implementation](https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/nn_ops.py#L2903-L2925)
1 parent 2ae5a8a commit e70a33f

File tree

5 files changed

+299
-58
lines changed

5 files changed

+299
-58
lines changed

src/ops/dropout.ts

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,61 @@
1515
* =============================================================================
1616
*/
1717

18-
import {Scalar, Tensor} from '../tensor';
19-
import {arraysEqual} from '../util';
18+
import {Tensor} from '../tensor';
19+
import {convertToTensor} from '../tensor_util_env';
20+
import {TensorLike} from '../types';
21+
import * as util from '../util';
2022

2123
import {randomUniform} from './array_ops';
22-
import {sub} from './binary_ops';
24+
import {getNoiseShape} from './dropout_util';
2325
import {op} from './operation';
2426

2527
/**
26-
* Sets entries in `x` to zero at random, while scaling the entire tensor.
28+
* Computes dropout.
29+
*
2730
* ```js
28-
* const x = tf.range(1, 21).reshape([10, 2]);
29-
* const rate = 0.5;
30-
* const seed = 23;
31-
* const noiseShape = null || x.shape;
32-
* const tensor = tf.dropout(x, rate, noiseShape, seed);
31+
* const x = tf.tensor1d([1, 2, 2, 1]);
32+
* const rate = 0.75;
33+
* const output = tf.dropout(x, rate);
34+
* output.print();
3335
* ```
34-
* @param x input tensor.
35-
* @param level fraction of the entries in the tensor that will be set to 0.
36-
* @param noiseShape shape of randomly generated keep/drop flags, must be
37-
* broadcastable to the shape of `x`.
38-
* @param seed random seed to ensure determinism.
39-
* @returns Result of the dropout operation.
36+
*
37+
* @param x A floating point Tensor or TensorLike.
38+
* @param rate A float in the range [0, 1). The probability that each element
39+
* of x is discarded.
40+
* @param noiseShape An array of numbers of type int32, representing the
41+
* shape for randomly generated keep/drop flags. If the noiseShape has null
42+
* value, it will be automatically replaced with the x's relative dimension
43+
* size. Optional.
44+
* @param seed Used to create random seeds. Optional.
45+
* @returns A Tensor of the same shape of x.
4046
*/
47+
/** @doc {heading: 'Operations', subheading: 'Dropout'} */
4148
function dropout_(
42-
x: Tensor, rate: Scalar|number, noiseShape?: number[],
43-
seed?: number): Tensor {
44-
if (noiseShape != null && !arraysEqual(x.shape, noiseShape)) {
45-
// TODO(VariableVasasMT): implement non default noise shape
46-
throw new Error(
47-
'Non-default noise shape is not implemented yet: ' +
48-
JSON.stringify(noiseShape));
49+
x: Tensor|TensorLike, rate: number, noiseShape?: number[],
50+
seed?: number|string): Tensor {
51+
const $x = convertToTensor(x, 'x', 'dropout');
52+
53+
util.assert(
54+
$x.dtype === 'float32',
55+
() => `x has to be a floating point tensor since it's going to be ` +
56+
`scaled, but got a ${$x.dtype} tensor instead.`);
57+
util.assert(
58+
rate >= 0 && rate < 1,
59+
() => `rate must be a float in the range [0, 1), but got ${rate}.`);
60+
61+
if (rate === 0) {
62+
return x instanceof Tensor ? $x.clone() : $x;
4963
}
5064

51-
let multiplier = randomUniform(x.shape, 0, 1, 'float32', seed).greater(rate);
52-
// Scale the kept elements, so the expected sum is unchanged.
53-
multiplier = multiplier.div(sub(1, rate) as Scalar);
54-
return x.mul(multiplier);
65+
const $noiseShape = getNoiseShape($x, noiseShape);
66+
const keepProb = 1 - rate;
67+
const multiplier = randomUniform($noiseShape, 0, 1, 'float32', seed)
68+
.add(keepProb)
69+
.floor()
70+
.div(keepProb);
71+
72+
return $x.mul(multiplier);
5573
}
5674

5775
export const dropout = op({dropout_});

src/ops/dropout_test.ts

Lines changed: 149 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,161 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17+
1718
import * as tf from '../index';
1819
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
19-
import {Tensor} from '../tensor';
20-
21-
function countParams(x: Tensor): number {
22-
const shape = x.shape;
23-
if (shape.length > 0) {
24-
return shape.reduce((a: number, b: number) => a * b);
25-
} else {
26-
// Scalar.
27-
return 1;
28-
}
29-
}
20+
import {expectArraysClose} from '../test_util';
21+
22+
import {tensor1d, tensor2d} from './tensor_ops';
3023

3124
describeWithFlags('dropout', ALL_ENVS, () => {
32-
const dropoutLevels = [0, 0.75];
33-
const seed = 23;
34-
for (const dropoutLevel of dropoutLevels) {
35-
it(`Level = ${dropoutLevel}`, async () => {
36-
const x = tf.range(1, 21).reshape([10, 2]);
37-
const y = tf.dropout(x, tf.scalar(dropoutLevel), null, seed);
38-
expect(y.dtype).toEqual(x.dtype);
39-
expect(y.shape).toEqual(x.shape);
40-
const xValue = await x.data();
41-
const yValue = await y.data();
42-
let nKept = 0;
43-
for (let i = 0; i < xValue.length; ++i) {
44-
if (yValue[i] !== 0) {
45-
nKept++;
46-
expect(yValue[i]).toBeCloseTo(1 / (1 - dropoutLevel) * xValue[i]);
25+
it('x 1d array, rate 0', async () => {
26+
const x = tensor1d([1, 2, 2, 1]);
27+
const rate = 0;
28+
const output = tf.dropout(x, rate);
29+
expect(output.dtype).toEqual(x.dtype);
30+
expect(output.shape).toEqual(x.shape);
31+
expectArraysClose(await x.data(), await output.data());
32+
});
33+
34+
it('x 1d array, rate 0.75', async () => {
35+
const x = tensor1d([1, 2, 2, 1]);
36+
const rate = 0.75;
37+
const output = tf.dropout(x, rate);
38+
expect(output.dtype).toEqual(x.dtype);
39+
expect(output.shape).toEqual(x.shape);
40+
const xValues = await x.data();
41+
const outputValues = await output.data();
42+
for (let i = 0; i < xValues.length; i++) {
43+
if (outputValues[i] !== 0) {
44+
expect(outputValues[i]).toBeCloseTo(1 / (1 - rate) * xValues[i]);
45+
}
46+
}
47+
});
48+
49+
it('x 2d array, rate 0', async () => {
50+
const x = tensor2d([1, 5, 2, 4, 3, 6], [2, 3]);
51+
const rate = 0;
52+
const output = tf.dropout(x, rate);
53+
expect(output.dtype).toEqual(x.dtype);
54+
expect(output.shape).toEqual(x.shape);
55+
expectArraysClose(await x.data(), await output.data());
56+
});
57+
58+
it('x 2d array, rate 0.75', async () => {
59+
const x = tensor2d([1, 5, 2, 4, 3, 6], [2, 3]);
60+
const rate = 0.75;
61+
const output = tf.dropout(x, rate);
62+
expect(output.dtype).toEqual(x.dtype);
63+
expect(output.shape).toEqual(x.shape);
64+
const xValues = await x.data();
65+
const outputValues = await output.data();
66+
for (let i = 0; i < xValues.length; i++) {
67+
if (outputValues[i] !== 0) {
68+
expect(outputValues[i]).toBeCloseTo(1 / (1 - rate) * xValues[i]);
69+
}
70+
}
71+
});
72+
73+
it('x 1d array, rate 0.75, with noise shape length = 1', async () => {
74+
const x = tensor1d([1, 2, 2, 1]);
75+
const rate = 0.75;
76+
const noiseShape = [1];
77+
const output = tf.dropout(x, rate, noiseShape);
78+
expect(output.dtype).toEqual(x.dtype);
79+
expect(output.shape).toEqual(x.shape);
80+
const xValues = await x.data();
81+
const outputValues = await output.data();
82+
const maskedOutput = outputValues[0];
83+
for (let i = 0; i < xValues.length; i++) {
84+
if (maskedOutput === 0) {
85+
expect(outputValues[i]).toBe(maskedOutput);
86+
}
87+
if (outputValues[i] !== 0) {
88+
expect(outputValues[i]).toBeCloseTo(1 / (1 - rate) * xValues[i]);
89+
}
90+
}
91+
});
92+
93+
it('x 2d array, rate 0.75, with noise shape length = 2', async () => {
94+
const x = tensor2d([1, 5, 2, 4, 3, 6], [2, 3]);
95+
const rate = 0.75;
96+
const noiseShape = [2, 1];
97+
const output = tf.dropout(x, rate, noiseShape);
98+
expect(output.dtype).toEqual(x.dtype);
99+
expect(output.shape).toEqual(x.shape);
100+
const xValues = await x.data();
101+
const outputValues = await output.data();
102+
for (let i = 0; i < x.shape[0]; i++) {
103+
const maskedOutput = outputValues[i * x.shape[1]];
104+
if (maskedOutput !== 0) {
105+
expect(maskedOutput)
106+
.toBeCloseTo(1 / (1 - rate) * xValues[i * x.shape[1]]);
107+
} else {
108+
for (let j = 0; j < x.shape[1]; j++) {
109+
expect(outputValues[i * x.shape[1] + j]).toBe(maskedOutput);
47110
}
48111
}
49-
const numel = countParams(x);
50-
if (dropoutLevel === 0) {
51-
expect(nKept).toEqual(numel);
112+
}
113+
});
114+
115+
it('broadcast noise shape', async () => {
116+
const x = tensor2d([1, 5, 2, 4, 3, 6], [2, 3]);
117+
const rate = 0.75;
118+
// broadcast noise shape, same output as using noiseShape [2, 1]
119+
const noiseShape = [1];
120+
const output = tf.dropout(x, rate, noiseShape);
121+
expect(output.dtype).toEqual(x.dtype);
122+
expect(output.shape).toEqual(x.shape);
123+
const xValues = await x.data();
124+
const outputValues = await output.data();
125+
for (let i = 0; i < x.shape[0]; i++) {
126+
const maskedOutput = outputValues[i * x.shape[1]];
127+
if (maskedOutput !== 0) {
128+
expect(maskedOutput)
129+
.toBeCloseTo(1 / (1 - rate) * xValues[i * x.shape[1]]);
52130
} else {
53-
expect(nKept).toBeLessThan(numel);
131+
for (let j = 0; j < x.shape[1]; j++) {
132+
expect(outputValues[i * x.shape[1] + j]).toBe(maskedOutput);
133+
}
54134
}
55-
});
56-
}
135+
}
136+
});
137+
138+
it('x 1d array, rate 0.75, with seed', async () => {
139+
const x = tensor1d([1, 2, 2, 1]);
140+
const rate = 0.75;
141+
const seed = 23;
142+
const output = tf.dropout(x, rate, null, seed);
143+
expect(output.dtype).toEqual(x.dtype);
144+
expect(output.shape).toEqual(x.shape);
145+
const xValues = await x.data();
146+
const outputValues = await output.data();
147+
for (let i = 0; i < xValues.length; i++) {
148+
if (outputValues[i] !== 0) {
149+
expect(outputValues[i]).toBeCloseTo(1 / (1 - rate) * xValues[i]);
150+
}
151+
}
152+
});
153+
154+
it('x TensorLike object', async () => {
155+
const x = [1.0, 2.0, 2.0, 1.0];
156+
const rate = 0;
157+
const output = tf.dropout(x, rate);
158+
expect(output.dtype).toEqual('float32');
159+
expect(output.shape).toEqual([4]);
160+
expectArraysClose(await output.data(), x);
161+
});
162+
163+
it('throws when x.dtype != float32', async () => {
164+
const x = tensor1d([1, 2, 2, 1], 'int32');
165+
const rate = 0.75;
166+
expect(() => tf.dropout(x, rate)).toThrowError();
167+
});
168+
169+
it('throws when rate is not in the range [0, 1)', async () => {
170+
const x = tensor1d([1, 2, 2, 1]);
171+
const rate = 1.5;
172+
expect(() => tf.dropout(x, rate)).toThrowError();
173+
});
57174
});

src/ops/dropout_util.ts

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Tensor} from '../tensor';
19+
import * as util from '../util';
20+
21+
/**
22+
* Normalize noise shape based on provided tensor and noise shape.
23+
*
24+
* @param x Tensor.
25+
* @param noiseShape The shape for the randomly generated keep/drop flags, as
26+
* an array of numbers. Optional.
27+
* @returns Normalized noise shape.
28+
*/
29+
export function getNoiseShape(x: Tensor, noiseShape?: number[]): number[] {
30+
if (noiseShape == null) {
31+
return x.shape.slice();
32+
}
33+
if (util.arraysEqual(x.shape, noiseShape)) {
34+
return noiseShape;
35+
}
36+
if (x.shape.length === noiseShape.length) {
37+
const newDimension: number[] = [];
38+
for (let i = 0; i < x.shape.length; i++) {
39+
if (noiseShape[i] == null && x.shape[i] != null) {
40+
newDimension.push(x.shape[i]);
41+
} else {
42+
newDimension.push(noiseShape[i]);
43+
}
44+
}
45+
return newDimension;
46+
}
47+
48+
return noiseShape;
49+
}

src/ops/dropout_util_test.ts

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '../index';
19+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20+
import {getNoiseShape} from './dropout_util';
21+
22+
describeWithFlags('getNoiseShape', ALL_ENVS, () => {
23+
it('x.shape == noiseShape', async () => {
24+
const x = tf.ones([2, 3]);
25+
const noiseShape = [2, 3];
26+
const shape = getNoiseShape(x, noiseShape);
27+
expect(shape).toEqual([2, 3]);
28+
});
29+
30+
it('x.shape and noiseShape have same length, different value', async () => {
31+
const x = tf.ones([2, 3]);
32+
const noiseShape = [2, 1];
33+
const shape = getNoiseShape(x, noiseShape);
34+
expect(shape).toEqual([2, 1]);
35+
});
36+
37+
it('noiseShape has null value', async () => {
38+
const x = tf.ones([2, 3]);
39+
const noiseShape = [2, null];
40+
const shape = getNoiseShape(x, noiseShape);
41+
expect(shape).toEqual([2, 3]);
42+
});
43+
44+
it('x.shape and noiseShape has different length', async () => {
45+
const x = tf.ones([2, 3, 4]);
46+
const noiseShape = [2, 3];
47+
const shape = getNoiseShape(x, noiseShape);
48+
expect(shape).toEqual([2, 3]);
49+
});
50+
51+
it('noiseShape is null', async () => {
52+
const x = tf.ones([2, 3]);
53+
const shape = getNoiseShape(x, null);
54+
expect(shape).toEqual([2, 3]);
55+
});
56+
});

src/tests.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import './ops/conv3d_test';
5858
import './ops/conv_util_test';
5959
import './ops/diag_test';
6060
import './ops/dropout_test';
61+
import './ops/dropout_util_test';
6162
import './ops/fused_test';
6263
import './ops/gather_nd_test';
6364
import './ops/image_ops_test';

0 commit comments

Comments
 (0)