Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit f135c5a

Browse files
LewuatheNikhil Thorat
authored and
Nikhil Thorat
committed
Add stft op (#1746)
FEATURE Add tf.signal.stft op. One TODO is passing fft length parameter because rfft does not support fft length parameter. We can pass fft length parameter after rfft supports it. See: tensorflow/tfjs#1362
1 parent 6a94672 commit f135c5a

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

src/ops/signal_ops.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import {op} from '../ops/operation';
1919
import {Tensor, Tensor1D} from '../tensor';
2020

21+
import {mul} from './binary_ops';
2122
import {concat} from './concat_split';
2223
import {slice} from './slice';
24+
import {rfft} from './spectral_ops';
2325
import {fill, tensor1d, tensor2d} from './tensor_ops';
2426

2527
/**
@@ -94,6 +96,45 @@ function frame_(
9496
return concat(output).as2D(output.length, frameLength);
9597
}
9698

99+
/**
100+
* Computes the Short-time Fourier Transform of signals
101+
* See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
102+
*
103+
* ```js
104+
* const input = tf.tensor1d([1, 1, 1, 1, 1])
105+
* tf.signal.stft(input, 3, 1).print();
106+
* ```
107+
* @param signal 1-dimensional real value tensor.
108+
* @param frameLength The window length of samples.
109+
* @param frameStep The number of samples to step.
110+
* @param fftLength The size of the FFT to apply.
111+
* @param windowFn A callable that takes a window length and returns 1-d tensor.
112+
*/
113+
/**
114+
* @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
115+
*/
116+
function stft_(
117+
signal: Tensor1D, frameLength: number, frameStep: number,
118+
fftLength?: number,
119+
windowFn: (length: number) => Tensor1D = hannWindow): Tensor {
120+
if (fftLength == null) {
121+
fftLength = enclosingPowerOfTwo(frameLength);
122+
}
123+
const framedSignal = frame(signal, frameLength, frameStep);
124+
const windowedSignal = mul(framedSignal, windowFn(frameLength));
125+
const output: Tensor[] = [];
126+
for (let i = 0; i < framedSignal.shape[0]; i++) {
127+
output.push(rfft(windowedSignal.slice([i, 0], [1, frameLength]),
128+
fftLength));
129+
}
130+
return concat(output);
131+
}
132+
133+
function enclosingPowerOfTwo(value: number) {
134+
// Return 2**N for integer N such that 2**N >= value.
135+
return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));
136+
}
137+
97138
function cosineWindow(windowLength: number, a: number, b: number): Tensor1D {
98139
const even = 1 - windowLength % 2;
99140
const newValues = new Float32Array(windowLength);
@@ -107,3 +148,4 @@ function cosineWindow(windowLength: number, a: number, b: number): Tensor1D {
107148
export const hannWindow = op({hannWindow_});
108149
export const hammingWindow = op({hammingWindow_});
109150
export const frame = op({frame_});
151+
export const stft = op({stft_});

src/ops/signal_ops_test.ts

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,166 @@ describeWithFlags('frame', ALL_ENVS, () => {
124124
expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 100]);
125125
});
126126
});
127+
128+
describeWithFlags('stft', ALL_ENVS, () => {
129+
it('3 length with hann window', async () => {
130+
const input = tf.tensor1d([1, 1, 1, 1, 1]);
131+
const frameLength = 3;
132+
const frameStep = 1;
133+
const output = tf.signal.stft(input, frameLength, frameStep);
134+
expect(output.shape).toEqual([3, 3]);
135+
expectArraysClose(await output.data(), [
136+
1.0, 0.0, 0.0, -1.0, -1.0, 0.0,
137+
1.0, 0.0, 0.0, -1.0, -1.0, 0.0,
138+
1.0, 0.0, 0.0, -1.0, -1.0, 0.0,
139+
]);
140+
});
141+
142+
it('3 length with hann window (sequencial number)', async () => {
143+
const input = tf.tensor1d([1, 2, 3, 4, 5]);
144+
const frameLength = 3;
145+
const frameStep = 1;
146+
const output = tf.signal.stft(input, frameLength, frameStep);
147+
expect(output.shape).toEqual([3, 3]);
148+
expectArraysClose(await output.data(), [
149+
2.0, 0.0, 0.0, -2.0, -2.0, 0.0,
150+
3.0, 0.0, 0.0, -3.0, -3.0, 0.0,
151+
4.0, 0.0, 0.0, -4.0, -4.0, 0.0
152+
]);
153+
});
154+
155+
it('3 length, 2 step with hann window', async () => {
156+
const input = tf.tensor1d([1, 1, 1, 1, 1]);
157+
const frameLength = 3;
158+
const frameStep = 2;
159+
const output = tf.signal.stft(input, frameLength, frameStep);
160+
expect(output.shape).toEqual([2, 3]);
161+
expectArraysClose(await output.data(), [
162+
1.0, 0.0, 0.0, -1.0, -1.0, 0.0,
163+
1.0, 0.0, 0.0, -1.0, -1.0, 0.0
164+
]);
165+
});
166+
167+
it('3 fftLength, 5 frameLength, 2 step', async () => {
168+
const input = tf.tensor1d([1, 1, 1, 1, 1, 1]);
169+
const frameLength = 5;
170+
const frameStep = 1;
171+
const fftLength = 3;
172+
const output = tf.signal.stft(input, frameLength, frameStep, fftLength);
173+
expect(output.shape[0]).toEqual(2);
174+
expectArraysClose(await output.data(), [
175+
1.5, 0.0, -0.749999, 0.433,
176+
1.5, 0.0, -0.749999, 0.433
177+
]);
178+
});
179+
180+
it('5 length with hann window', async () => {
181+
const input = tf.tensor1d([1, 1, 1, 1, 1]);
182+
const frameLength = 5;
183+
const frameStep = 1;
184+
const output = tf.signal.stft(input, frameLength, frameStep);
185+
expect(output.shape).toEqual([1, 5]);
186+
expectArraysClose(
187+
await output.data(),
188+
[2.0, 0.0, 0.0, -1.7071068, -1.0, 0.0, 0.0, 0.29289323, 0.0, 0.0]);
189+
});
190+
191+
it('5 length with hann window (sequential)', async () => {
192+
const input = tf.tensor1d([1, 2, 3, 4, 5]);
193+
const frameLength = 5;
194+
const frameStep = 1;
195+
const output = tf.signal.stft(input, frameLength, frameStep);
196+
expect(output.shape).toEqual([1, 5]);
197+
expectArraysClose(
198+
await output.data(),
199+
[6.0, 0.0, -0.70710677, -5.1213202, -3.0, 1.0,
200+
0.70710677, 0.87867975, 0.0, 0.0]);
201+
});
202+
203+
it('3 length with hamming window', async () => {
204+
const input = tf.tensor1d([1, 1, 1, 1, 1]);
205+
const frameLength = 3;
206+
const frameStep = 1;
207+
const fftLength = 3;
208+
const output = tf.signal.stft(input, frameLength, frameStep,
209+
fftLength, (length) => tf.signal.hammingWindow(length));
210+
expect(output.shape).toEqual([3, 2]);
211+
expectArraysClose(await output.data(), [
212+
1.16, 0.0, -0.46, -0.79674333,
213+
1.16, 0.0, -0.46, -0.79674333,
214+
1.16, 0.0, -0.46, -0.79674333
215+
]);
216+
});
217+
218+
it('3 length, 2 step with hamming window', async () => {
219+
const input = tf.tensor1d([1, 1, 1, 1, 1]);
220+
const frameLength = 3;
221+
const frameStep = 2;
222+
const fftLength = 3;
223+
const output = tf.signal.stft(input, frameLength, frameStep,
224+
fftLength, (length) => tf.signal.hammingWindow(length));
225+
expect(output.shape).toEqual([2, 2]);
226+
expectArraysClose(await output.data(), [
227+
1.16, 0.0, -0.46, -0.79674333,
228+
1.16, 0.0, -0.46, -0.79674333
229+
]);
230+
});
231+
232+
it('3 fftLength, 5 frameLength, 2 step with hamming window', async () => {
233+
const input = tf.tensor1d([1, 1, 1, 1, 1, 1]);
234+
const frameLength = 5;
235+
const frameStep = 1;
236+
const fftLength = 3;
237+
const output = tf.signal.stft(input, frameLength, frameStep,
238+
fftLength, (length) => tf.signal.hammingWindow(length));
239+
expect(output.shape).toEqual([2, 2]);
240+
expectArraysClose(await output.data(), [
241+
1.619999, 0.0, -0.69, 0.39837,
242+
1.619999, 0.0, -0.69, 0.39837
243+
]);
244+
});
245+
246+
it('5 length with hann window (sequential)', async () => {
247+
const input = tf.tensor1d([1, 2, 3, 4, 5]);
248+
const frameLength = 5;
249+
const frameStep = 1;
250+
const fftLength = 5;
251+
const output = tf.signal.stft(input, frameLength, frameStep,
252+
fftLength, (length) => tf.signal.hammingWindow(length));
253+
expect(output.shape).toEqual([1, 3]);
254+
expectArraysClose(
255+
await output.data(),
256+
[6.72, 0.0, -3.6371822, -1.1404576, 0.4771822, 0.39919350]);
257+
});
258+
259+
it('3 length without window function', async () => {
260+
const input = tf.tensor1d([1, 1, 1, 1, 1]);
261+
const frameLength = 3;
262+
const frameStep = 1;
263+
const fftLength = 3;
264+
const ident = (length: number) => tf.ones([length]).as1D();
265+
const output = tf.signal.stft(input, frameLength, frameStep,
266+
fftLength, ident);
267+
expect(output.shape).toEqual([3, 2]);
268+
expectArraysClose(await output.data(), [
269+
3.0, 0.0, 0.0, 0.0,
270+
3.0, 0.0, 0.0, 0.0,
271+
3.0, 0.0, 0.0, 0.0
272+
]);
273+
});
274+
275+
it('3 length, 2 step without window function', async () => {
276+
const input = tf.tensor1d([1, 1, 1, 1, 1]);
277+
const frameLength = 3;
278+
const frameStep = 2;
279+
const fftLength = 3;
280+
const ident = (length: number) => tf.ones([length]).as1D();
281+
const output = tf.signal.stft(input, frameLength, frameStep,
282+
fftLength, ident);
283+
expect(output.shape).toEqual([2, 2]);
284+
expectArraysClose(await output.data(), [
285+
3.0, 0.0, 0.0, 0.0,
286+
3.0, 0.0, 0.0, 0.0
287+
]);
288+
});
289+
});

0 commit comments

Comments
 (0)