Skip to content

Commit 9534780

Browse files
LewuatheNikhil Thorat
authored and
Nikhil Thorat
committed
Add IRFFT ops. (tensorflow#1395)
FEATURE
1 parent b6c654a commit 9534780

File tree

2 files changed

+138
-3
lines changed

2 files changed

+138
-3
lines changed

src/ops/spectral_ops.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {complex, imag, real} from '../ops/complex_ops';
2020
import {op} from '../ops/operation';
2121
import {Tensor} from '../tensor';
2222
import {assert} from '../util';
23+
import {scalar} from './tensor_ops';
2324

2425
/**
2526
* Fast Fourier transform.
@@ -133,6 +134,55 @@ function rfft_(input: Tensor): Tensor {
133134
.reshape(outputShape);
134135
}
135136

137+
/**
138+
* Inversed real value input fast Fourier transform.
139+
*
140+
* Computes the 1-dimensional inversed discrete Fourier transform over the
141+
* inner-most dimension of the real input.
142+
*
143+
* ```js
144+
* const real = tf.tensor1d([1, 2, 3]);
145+
*
146+
* x.irfft().print();
147+
* ```
148+
* @param input The real value input to compute an irfft over.
149+
*/
150+
/**
151+
* @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
152+
*/
153+
function irfft_(input: Tensor): Tensor {
154+
const innerDimensionSize = input.shape[input.shape.length - 1];
155+
const batch = input.size / innerDimensionSize;
156+
157+
if (innerDimensionSize <= 2) {
158+
const complexInput = input.as2D(batch, innerDimensionSize);
159+
const ret = ENV.engine.runKernel(
160+
backend => backend.ifft(complexInput), {complexInput});
161+
return real(ret);
162+
} else {
163+
// The length of unique components of the DFT of a real-valued signal
164+
// is 2 * (input_len - 1)
165+
const outputShape = [batch, 2 * (innerDimensionSize - 1)];
166+
const realInput = real(input).as2D(batch, innerDimensionSize);
167+
const imagInput = imag(input).as2D(batch, innerDimensionSize);
168+
169+
const realConjugate =
170+
realInput.slice([0, 1], [batch, innerDimensionSize - 2]).reverse(1);
171+
const imagConjugate =
172+
imagInput.slice([0, 1], [batch, innerDimensionSize - 2])
173+
.reverse(1)
174+
.mul(scalar(-1));
175+
176+
const r = realInput.concat(realConjugate, 1);
177+
const i = imagInput.concat(imagConjugate, 1);
178+
const complexInput = complex(r, i).as2D(outputShape[0], outputShape[1]);
179+
const ret = ENV.engine.runKernel(
180+
backend => backend.ifft(complexInput), {complexInput});
181+
return real(ret);
182+
}
183+
}
184+
136185
export const fft = op({fft_});
137186
export const ifft = op({ifft_});
138187
export const rfft = op({rfft_});
188+
export const irfft = op({irfft_});

src/ops/spectral_ops_test.ts

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,93 @@ describeWithFlags('2D RFFT', WEBGL_ENVS, () => {
214214
});
215215

216216
it('should return the same value with TensorFlow (2x2x2 elements)', () => {
217-
const t1Real = tf.tensor3d([ 1, 2, 3, 4, 5, 6, 7, 8 ], [ 2, 2, 2 ]);
218-
expectArraysClose(tf.spectral.rfft(t1Real),
219-
[ 3, 0, -1, 0, 7, 0, -1, 0, 11, 0, -1, 0, 15, 0, -1, 0 ]);
217+
const t1Real = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
218+
expectArraysClose(
219+
tf.spectral.rfft(t1Real),
220+
[3, 0, -1, 0, 7, 0, -1, 0, 11, 0, -1, 0, 15, 0, -1, 0]);
221+
});
222+
});
223+
224+
describeWithFlags('1D IRFFT', ALL_ENVS, () => {
225+
it('should return the same value with TensorFlow (2 elements)', () => {
226+
const t1Real = tf.tensor1d([1, 2]);
227+
const t1Imag = tf.tensor1d([0, 0]);
228+
const t1 = tf.complex(t1Real, t1Imag);
229+
expectArraysClose(tf.spectral.irfft(t1), [1.5, -0.5]);
230+
});
231+
232+
it('should return the same value with TensorFlow (5 elements)', () => {
233+
const t1Real = tf.tensor1d([1, 2, 3, 4, 5]);
234+
const t1Imag = tf.tensor1d([0, 0, 0, 0, 0]);
235+
const t1 = tf.complex(t1Real, t1Imag);
236+
expectArraysClose(
237+
tf.spectral.irfft(t1),
238+
[3, -0.8535534, 0, -0.14644662, 0, -0.14644662, 0, -0.8535534]);
220239
});
240+
241+
it('should return the same value with TensorFlow (5 elements) with imag',
242+
() => {
243+
const t1Real = tf.tensor1d([1, 2, 3, 4, 5]);
244+
const t1Imag = tf.tensor1d([1, 2, 3, 4, 5]);
245+
const t1 = tf.complex(t1Real, t1Imag);
246+
expectArraysClose(
247+
tf.spectral.irfft(t1),
248+
[3, -2.6642137, 0.5, -0.45710677, 0, 0.16421354, -0.5, 0.95710677]);
249+
});
250+
});
251+
252+
describeWithFlags('2D IRFFT', ALL_ENVS, () => {
253+
it('should return the same value with TensorFlow (2x2 elements)', () => {
254+
const t1Real = tf.tensor2d([1, 2, 3, 4], [2, 2]);
255+
const t1Imag = tf.tensor2d([0, 0, 0, 0], [2, 2]);
256+
const t1 = tf.complex(t1Real, t1Imag);
257+
expectArraysClose(tf.spectral.irfft(t1), [1.5, -0.5, 3.5, -0.5]);
258+
});
259+
260+
it('should return the same value with TensorFlow (2x3 elements)', () => {
261+
const t1Real = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
262+
const t1Imag = tf.tensor2d([0, 0, 0, 0, 0, 0], [2, 3]);
263+
const t1 = tf.complex(t1Real, t1Imag);
264+
expectArraysClose(
265+
tf.spectral.irfft(t1), [2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5]);
266+
});
267+
268+
it('should return the same value with TensorFlow (2x3 elements) with imag',
269+
() => {
270+
const t1Real = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
271+
const t1Imag = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
272+
const t1 = tf.complex(t1Real, t1Imag);
273+
expectArraysClose(tf.spectral.irfft(t1), [2, -1.5, 0, 0.5, 5, -3, 0, 2]);
274+
});
275+
});
276+
277+
describeWithFlags('3D IRFFT', ALL_ENVS, () => {
278+
it('should return the same value with TensorFlow (2x2x2 elements)', () => {
279+
const t1Real = tf.tensor3d([1, 2, 3, 4, 1, 2, 3, 4], [2, 2, 2]);
280+
const t1Imag = tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 2]);
281+
const t1 = tf.complex(t1Real, t1Imag);
282+
expectArraysClose(
283+
tf.spectral.irfft(t1), [1.5, -0.5, 3.5, -0.5, 1.5, -0.5, 3.5, -0.5]);
284+
});
285+
286+
it('should return the same value with TensorFlow (2x2x3 elements)', () => {
287+
const t1Real = tf.tensor3d([1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], [2, 2, 3]);
288+
const t1Imag = tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 3]);
289+
const t1 = tf.complex(t1Real, t1Imag);
290+
expectArraysClose(tf.spectral.irfft(t1), [
291+
2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5, 2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5
292+
]);
293+
});
294+
295+
it('should return the same value with TensorFlow (2x2x3 elements) with imag',
296+
() => {
297+
const t1Real =
298+
tf.tensor3d([1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], [2, 2, 3]);
299+
const t1Imag =
300+
tf.tensor3d([1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], [2, 2, 3]);
301+
const t1 = tf.complex(t1Real, t1Imag);
302+
expectArraysClose(
303+
tf.spectral.irfft(t1),
304+
[2, -1.5, 0, 0.5, 5, -3, 0, 2, 2, -1.5, 0, 0.5, 5, -3, 0, 2]);
305+
});
221306
});

0 commit comments

Comments
 (0)