Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit eb2ab45

Browse files
kgrytedsmilkov
authored andcommitted
Add support for generating pseudorandom numbers drawn from a ga… (#1365)
cla: yes * Add support for generating pseudorandom numbers drawn from a gamma distribution
1 parent 707a669 commit eb2ab45

File tree

4 files changed

+255
-11
lines changed

4 files changed

+255
-11
lines changed

src/ops/array_ops.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import * as util from '../util';
2323
import {getAxesPermutation, getInnerMostAxes} from './axis_util';
2424
import {concat} from './concat_split';
2525
import {op} from './operation';
26-
import {MPRandGauss, UniformRandom} from './rand';
26+
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
2727
import {zeros, zerosLike} from './tensor_ops';
2828

2929
/**
@@ -166,6 +166,41 @@ function truncatedNormal_<R extends Rank>(
166166
return res.toTensor();
167167
}
168168

169+
/**
170+
* Creates a `tf.Tensor` with values sampled from a gamma distribution.
171+
*
172+
* ```js
173+
* tf.randomGamma([2, 2], 1).print();
174+
* ```
175+
*
176+
* @param shape An array of integers defining the output tensor shape.
177+
* @param alpha The shape parameter of the gamma distribution.
178+
* @param beta The inverse scale parameter of the gamma distribution. Defaults
179+
* to 1.
180+
* @param dtype The data type of the output. Defaults to float32.
181+
* @param seed The seed for the random number generator.
182+
*/
183+
/** @doc {heading: 'Tensors', subheading: 'Random'} */
184+
function randomGamma_<R extends Rank>(
185+
shape: ShapeMap[R], alpha: number, beta = 1,
186+
dtype: 'float32'|'int32' = 'float32', seed?: number): Tensor<R> {
187+
if (beta == null) {
188+
beta = 1;
189+
}
190+
if (dtype == null) {
191+
dtype = 'float32';
192+
}
193+
if (dtype !== 'float32' && dtype !== 'int32') {
194+
throw new Error(`Unsupported data type ${dtype}`);
195+
}
196+
const rgamma = new RandGamma(alpha, beta, dtype, seed);
197+
const res = buffer(shape, dtype);
198+
for (let i = 0; i < res.values.length; i++) {
199+
res.values[i] = rgamma.nextValue();
200+
}
201+
return res.toTensor();
202+
}
203+
169204
/**
170205
* Creates a `tf.Tensor` with values sampled from a uniform distribution.
171206
*
@@ -1102,6 +1137,7 @@ export const pad3d = op({pad3d_});
11021137
export const pad4d = op({pad4d_});
11031138
export const rand = op({rand_});
11041139
export const randomNormal = op({randomNormal_});
1140+
export const randomGamma = op({randomGamma_});
11051141
export const randomUniform = op({randomUniform_});
11061142
export const reshape = op({reshape_});
11071143
export const spaceToBatchND = op({spaceToBatchND_});

src/ops/array_ops_test.ts

Lines changed: 107 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,11 +1176,113 @@ describeWithFlags('truncatedNormal', ALL_ENVS, () => {
11761176
});
11771177
});
11781178

1179+
describeWithFlags('randomGamma', ALL_ENVS, () => {
1180+
it('should return a random 1D float32 array', async () => {
1181+
const shape: [number] = [10];
1182+
1183+
// Ensure defaults to float32 w/o type:
1184+
let result = tf.randomGamma(shape, 2, 2);
1185+
expect(result.dtype).toBe('float32');
1186+
expectValuesInRange(await result.data(), 0, 30);
1187+
1188+
result = tf.randomGamma(shape, 2, 2, 'float32');
1189+
expect(result.dtype).toBe('float32');
1190+
expectValuesInRange(await result.data(), 0, 30);
1191+
});
1192+
1193+
it('should return a random 1D int32 array', async () => {
1194+
const shape: [number] = [10];
1195+
const result = tf.randomGamma(shape, 2, 2, 'int32');
1196+
expect(result.dtype).toBe('int32');
1197+
expectValuesInRange(await result.data(), 0, 30);
1198+
});
1199+
1200+
it('should return a random 2D float32 array', async () => {
1201+
const shape: [number, number] = [3, 4];
1202+
1203+
// Ensure defaults to float32 w/o type:
1204+
let result = tf.randomGamma(shape, 2, 2);
1205+
expect(result.dtype).toBe('float32');
1206+
expectValuesInRange(await result.data(), 0, 30);
1207+
1208+
result = tf.randomGamma(shape, 2, 2, 'float32');
1209+
expect(result.dtype).toBe('float32');
1210+
expectValuesInRange(await result.data(), 0, 30);
1211+
});
1212+
1213+
it('should return a random 2D int32 array', async () => {
1214+
const shape: [number, number] = [3, 4];
1215+
const result = tf.randomGamma(shape, 2, 2, 'int32');
1216+
expect(result.dtype).toBe('int32');
1217+
expectValuesInRange(await result.data(), 0, 30);
1218+
});
1219+
1220+
it('should return a random 3D float32 array', async () => {
1221+
const shape: [number, number, number] = [3, 4, 5];
1222+
1223+
// Ensure defaults to float32 w/o type:
1224+
let result = tf.randomGamma(shape, 2, 2);
1225+
expect(result.dtype).toBe('float32');
1226+
expectValuesInRange(await result.data(), 0, 30);
1227+
1228+
result = tf.randomGamma(shape, 2, 2, 'float32');
1229+
expect(result.dtype).toBe('float32');
1230+
expectValuesInRange(await result.data(), 0, 30);
1231+
});
1232+
1233+
it('should return a random 3D int32 array', async () => {
1234+
const shape: [number, number, number] = [3, 4, 5];
1235+
const result = tf.randomGamma(shape, 2, 2, 'int32');
1236+
expect(result.dtype).toBe('int32');
1237+
expectValuesInRange(await result.data(), 0, 30);
1238+
});
1239+
1240+
it('should return a random 4D float32 array', async () => {
1241+
const shape: [number, number, number, number] = [3, 4, 5, 6];
1242+
1243+
// Ensure defaults to float32 w/o type:
1244+
let result = tf.randomGamma(shape, 2, 2);
1245+
expect(result.dtype).toBe('float32');
1246+
expectValuesInRange(await result.data(), 0, 30);
1247+
1248+
result = tf.randomGamma(shape, 2, 2, 'float32');
1249+
expect(result.dtype).toBe('float32');
1250+
expectValuesInRange(await result.data(), 0, 30);
1251+
});
1252+
1253+
it('should return a random 4D int32 array', async () => {
1254+
const shape: [number, number, number, number] = [3, 4, 5, 6];
1255+
const result = tf.randomGamma(shape, 2, 2, 'int32');
1256+
expect(result.dtype).toBe('int32');
1257+
expectValuesInRange(await result.data(), 0, 30);
1258+
});
1259+
1260+
it('should return a random 5D float32 array', async () => {
1261+
const shape: [number, number, number, number, number] = [2, 3, 4, 5, 6];
1262+
1263+
// Ensure defaults to float32 w/o type:
1264+
let result = tf.randomGamma(shape, 2, 2);
1265+
expect(result.dtype).toBe('float32');
1266+
expectValuesInRange(await result.data(), 0, 30);
1267+
1268+
result = tf.randomGamma(shape, 2, 2, 'float32');
1269+
expect(result.dtype).toBe('float32');
1270+
expectValuesInRange(await result.data(), 0, 30);
1271+
});
1272+
1273+
it('should return a random 5D int32 array', async () => {
1274+
const shape: [number, number, number, number, number] = [2, 3, 4, 5, 6];
1275+
const result = tf.randomGamma(shape, 2, 2, 'int32');
1276+
expect(result.dtype).toBe('int32');
1277+
expectValuesInRange(await result.data(), 0, 30);
1278+
});
1279+
});
1280+
11791281
describeWithFlags('randomUniform', ALL_ENVS, () => {
11801282
it('should return a random 1D float32 array', async () => {
11811283
const shape: [number] = [10];
11821284

1183-
// Enusre defaults to float32 w/o type:
1285+
// Ensure defaults to float32 w/o type:
11841286
let result = tf.randomUniform(shape, 0, 2.5);
11851287
expect(result.dtype).toBe('float32');
11861288
expectValuesInRange(await result.data(), 0, 2.5);
@@ -1207,7 +1309,7 @@ describeWithFlags('randomUniform', ALL_ENVS, () => {
12071309
it('should return a random 2D float32 array', async () => {
12081310
const shape: [number, number] = [3, 4];
12091311

1210-
// Enusre defaults to float32 w/o type:
1312+
// Ensure defaults to float32 w/o type:
12111313
let result = tf.randomUniform(shape, 0, 2.5);
12121314
expect(result.dtype).toBe('float32');
12131315
expectValuesInRange(await result.data(), 0, 2.5);
@@ -1234,7 +1336,7 @@ describeWithFlags('randomUniform', ALL_ENVS, () => {
12341336
it('should return a random 3D float32 array', async () => {
12351337
const shape: [number, number, number] = [3, 4, 5];
12361338

1237-
// Enusre defaults to float32 w/o type:
1339+
// Ensure defaults to float32 w/o type:
12381340
let result = tf.randomUniform(shape, 0, 2.5);
12391341
expect(result.dtype).toBe('float32');
12401342
expectValuesInRange(await result.data(), 0, 2.5);
@@ -1261,7 +1363,7 @@ describeWithFlags('randomUniform', ALL_ENVS, () => {
12611363
it('should return a random 4D float32 array', async () => {
12621364
const shape: [number, number, number, number] = [3, 4, 5, 6];
12631365

1264-
// Enusre defaults to float32 w/o type:
1366+
// Ensure defaults to float32 w/o type:
12651367
let result = tf.randomUniform(shape, 0, 2.5);
12661368
expect(result.dtype).toBe('float32');
12671369
expectValuesInRange(await result.data(), 0, 2.5);
@@ -1288,7 +1390,7 @@ describeWithFlags('randomUniform', ALL_ENVS, () => {
12881390
it('should return a random 5D float32 array', async () => {
12891391
const shape: [number, number, number, number, number] = [2, 3, 4, 5, 6];
12901392

1291-
// Enusre defaults to float32 w/o type:
1393+
// Ensure defaults to float32 w/o type:
12921394
let result = tf.randomUniform(shape, 0, 2.5);
12931395
expect(result.dtype).toBe('float32');
12941396
expectValuesInRange(await result.data(), 0, 2.5);

src/ops/rand.ts

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,20 @@ export interface RandomBase {
2121
nextValue(): number;
2222
}
2323

24+
export interface RandomGamma {
25+
nextValue(): number;
26+
}
27+
2428
export interface RandNormalDataTypes {
2529
float32: Float32Array;
2630
int32: Int32Array;
2731
}
2832

33+
export interface RandGammaDataTypes {
34+
float32: Float32Array;
35+
int32: Int32Array;
36+
}
37+
2938
// https://en.wikipedia.org/wiki/Marsaglia_polar_method
3039
export class MPRandGauss implements RandomBase {
3140
private mean: number;
@@ -53,7 +62,7 @@ export class MPRandGauss implements RandomBase {
5362
this.random = seedrandom.alea(seedValue.toString());
5463
}
5564

56-
/** Returns next sample from a gaussian distribution. */
65+
/** Returns next sample from a Gaussian distribution. */
5766
public nextValue(): number {
5867
if (!isNaN(this.nextVal)) {
5968
const value = this.nextVal;
@@ -86,7 +95,7 @@ export class MPRandGauss implements RandomBase {
8695
return this.convertValue(resultX);
8796
}
8897

89-
/** Handles proper rounding for non floating point numbers. */
98+
/** Handles proper rounding for non-floating-point numbers. */
9099
private convertValue(value: number): number {
91100
if (this.dtype == null || this.dtype === 'float32') {
92101
return value;
@@ -100,6 +109,68 @@ export class MPRandGauss implements RandomBase {
100109
}
101110
}
102111

112+
// Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating
113+
// Gamma Variables."
114+
export class RandGamma implements RandomGamma {
115+
private alpha: number;
116+
private beta: number;
117+
private d: number;
118+
private c: number;
119+
private dtype?: keyof RandGammaDataTypes;
120+
private randu: seedrandom.prng;
121+
private randn: MPRandGauss;
122+
123+
constructor(
124+
alpha: number, beta: number, dtype: keyof RandGammaDataTypes,
125+
seed?: number) {
126+
this.alpha = alpha;
127+
this.beta = 1 / beta; // convert rate to scale parameter
128+
this.dtype = dtype;
129+
130+
const seedValue = seed ? seed : Math.random();
131+
this.randu = seedrandom.alea(seedValue.toString());
132+
this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
133+
134+
if (alpha < 1) {
135+
this.d = alpha + (2 / 3);
136+
} else {
137+
this.d = alpha - (1 / 3);
138+
}
139+
this.c = 1 / Math.sqrt(9 * this.d);
140+
}
141+
142+
/** Returns next sample from a gamma distribution. */
143+
public nextValue(): number {
144+
let x2: number, v0: number, v1: number, x: number, u: number, v: number;
145+
while (true) {
146+
do {
147+
x = this.randn.nextValue();
148+
v = 1 + (this.c * x);
149+
} while (v <= 0);
150+
v *= v * v;
151+
x2 = x * x;
152+
v0 = 1 - (0.331 * x2 * x2);
153+
v1 = (0.5 * x2) + (this.d * (1 - v + Math.log(v)));
154+
u = this.randu();
155+
if (u < v0 || Math.log(u) < v1) {
156+
break;
157+
}
158+
}
159+
v = (1 / this.beta) * this.d * v;
160+
if (this.alpha < 1) {
161+
v *= Math.pow(this.randu(), 1 / this.alpha);
162+
}
163+
return this.convertValue(v);
164+
}
165+
/** Handles proper rounding for non-floating-point numbers. */
166+
private convertValue(value: number): number {
167+
if (this.dtype === 'float32') {
168+
return value;
169+
}
170+
return Math.round(value);
171+
}
172+
}
173+
103174
export class UniformRandom implements RandomBase {
104175
private min: number;
105176
private range: number;

src/ops/rand_test.ts

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
* =============================================================================
1616
*/
1717

18-
import {MPRandGauss, UniformRandom} from './rand';
18+
import {expectValuesInRange} from '../test_util';
19+
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
1920
import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util';
2021

2122
function isFloat(n: number): boolean {
@@ -31,7 +32,7 @@ describe('MPRandGauss', () => {
3132
expect(isFloat(rand.nextValue())).toBe(true);
3233
});
3334

34-
it('should handle create a mean/stdv of float32 numbers', () => {
35+
it('should handle a mean/stdv of float32 numbers', () => {
3536
const rand =
3637
new MPRandGauss(0, 1.5, 'float32', false /* truncated */, SEED);
3738
const values = [];
@@ -48,7 +49,7 @@ describe('MPRandGauss', () => {
4849
expect(isFloat(rand.nextValue())).toBe(false);
4950
});
5051

51-
it('should handle create a mean/stdv of int32 numbers', () => {
52+
it('should handle a mean/stdv of int32 numbers', () => {
5253
const rand = new MPRandGauss(0, 2, 'int32', false /* truncated */, SEED);
5354
const values = [];
5455
const size = 10000;
@@ -69,6 +70,40 @@ describe('MPRandGauss', () => {
6970
});
7071
});
7172

73+
describe('RandGamma', () => {
74+
const SEED = 2002;
75+
76+
it('should default to float32 numbers', () => {
77+
const rand = new RandGamma(2, 2, 'float32');
78+
expect(isFloat(rand.nextValue())).toBe(true);
79+
});
80+
81+
it('should handle an alpha/beta of float32 numbers', () => {
82+
const rand = new RandGamma(2, 2, 'float32', SEED);
83+
const values = [];
84+
const size = 10000;
85+
for (let i = 0; i < size; i++) {
86+
values.push(rand.nextValue());
87+
}
88+
expectValuesInRange(values, 0, 30);
89+
});
90+
91+
it('should handle int32 numbers', () => {
92+
const rand = new RandGamma(2, 2, 'int32');
93+
expect(isFloat(rand.nextValue())).toBe(false);
94+
});
95+
96+
it('should handle an alpha/beta of int32 numbers', () => {
97+
const rand = new RandGamma(2, 2, 'int32', SEED);
98+
const values = [];
99+
const size = 10000;
100+
for (let i = 0; i < size; i++) {
101+
values.push(rand.nextValue());
102+
}
103+
expectValuesInRange(values, 0, 30);
104+
});
105+
});
106+
72107
describe('UniformRandom', () => {
73108
it('float32, no seed', () => {
74109
const min = 0.2;

0 commit comments

Comments
 (0)