@@ -124,3 +124,166 @@ describeWithFlags('frame', ALL_ENVS, () => {
124
124
expectArraysClose ( await output . data ( ) , [ 1 , 2 , 3 , 4 , 5 , 100 ] ) ;
125
125
} ) ;
126
126
} ) ;
127
+
128
+ describeWithFlags ( 'stft' , ALL_ENVS , ( ) => {
129
+ it ( '3 length with hann window' , async ( ) => {
130
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 ] ) ;
131
+ const frameLength = 3 ;
132
+ const frameStep = 1 ;
133
+ const output = tf . signal . stft ( input , frameLength , frameStep ) ;
134
+ expect ( output . shape ) . toEqual ( [ 3 , 3 ] ) ;
135
+ expectArraysClose ( await output . data ( ) , [
136
+ 1.0 , 0.0 , 0.0 , - 1.0 , - 1.0 , 0.0 ,
137
+ 1.0 , 0.0 , 0.0 , - 1.0 , - 1.0 , 0.0 ,
138
+ 1.0 , 0.0 , 0.0 , - 1.0 , - 1.0 , 0.0 ,
139
+ ] ) ;
140
+ } ) ;
141
+
142
+ it ( '3 length with hann window (sequencial number)' , async ( ) => {
143
+ const input = tf . tensor1d ( [ 1 , 2 , 3 , 4 , 5 ] ) ;
144
+ const frameLength = 3 ;
145
+ const frameStep = 1 ;
146
+ const output = tf . signal . stft ( input , frameLength , frameStep ) ;
147
+ expect ( output . shape ) . toEqual ( [ 3 , 3 ] ) ;
148
+ expectArraysClose ( await output . data ( ) , [
149
+ 2.0 , 0.0 , 0.0 , - 2.0 , - 2.0 , 0.0 ,
150
+ 3.0 , 0.0 , 0.0 , - 3.0 , - 3.0 , 0.0 ,
151
+ 4.0 , 0.0 , 0.0 , - 4.0 , - 4.0 , 0.0
152
+ ] ) ;
153
+ } ) ;
154
+
155
+ it ( '3 length, 2 step with hann window' , async ( ) => {
156
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 ] ) ;
157
+ const frameLength = 3 ;
158
+ const frameStep = 2 ;
159
+ const output = tf . signal . stft ( input , frameLength , frameStep ) ;
160
+ expect ( output . shape ) . toEqual ( [ 2 , 3 ] ) ;
161
+ expectArraysClose ( await output . data ( ) , [
162
+ 1.0 , 0.0 , 0.0 , - 1.0 , - 1.0 , 0.0 ,
163
+ 1.0 , 0.0 , 0.0 , - 1.0 , - 1.0 , 0.0
164
+ ] ) ;
165
+ } ) ;
166
+
167
+ it ( '3 fftLength, 5 frameLength, 2 step' , async ( ) => {
168
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 , 1 ] ) ;
169
+ const frameLength = 5 ;
170
+ const frameStep = 1 ;
171
+ const fftLength = 3 ;
172
+ const output = tf . signal . stft ( input , frameLength , frameStep , fftLength ) ;
173
+ expect ( output . shape [ 0 ] ) . toEqual ( 2 ) ;
174
+ expectArraysClose ( await output . data ( ) , [
175
+ 1.5 , 0.0 , - 0.749999 , 0.433 ,
176
+ 1.5 , 0.0 , - 0.749999 , 0.433
177
+ ] ) ;
178
+ } ) ;
179
+
180
+ it ( '5 length with hann window' , async ( ) => {
181
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 ] ) ;
182
+ const frameLength = 5 ;
183
+ const frameStep = 1 ;
184
+ const output = tf . signal . stft ( input , frameLength , frameStep ) ;
185
+ expect ( output . shape ) . toEqual ( [ 1 , 5 ] ) ;
186
+ expectArraysClose (
187
+ await output . data ( ) ,
188
+ [ 2.0 , 0.0 , 0.0 , - 1.7071068 , - 1.0 , 0.0 , 0.0 , 0.29289323 , 0.0 , 0.0 ] ) ;
189
+ } ) ;
190
+
191
+ it ( '5 length with hann window (sequential)' , async ( ) => {
192
+ const input = tf . tensor1d ( [ 1 , 2 , 3 , 4 , 5 ] ) ;
193
+ const frameLength = 5 ;
194
+ const frameStep = 1 ;
195
+ const output = tf . signal . stft ( input , frameLength , frameStep ) ;
196
+ expect ( output . shape ) . toEqual ( [ 1 , 5 ] ) ;
197
+ expectArraysClose (
198
+ await output . data ( ) ,
199
+ [ 6.0 , 0.0 , - 0.70710677 , - 5.1213202 , - 3.0 , 1.0 ,
200
+ 0.70710677 , 0.87867975 , 0.0 , 0.0 ] ) ;
201
+ } ) ;
202
+
203
+ it ( '3 length with hamming window' , async ( ) => {
204
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 ] ) ;
205
+ const frameLength = 3 ;
206
+ const frameStep = 1 ;
207
+ const fftLength = 3 ;
208
+ const output = tf . signal . stft ( input , frameLength , frameStep ,
209
+ fftLength , ( length ) => tf . signal . hammingWindow ( length ) ) ;
210
+ expect ( output . shape ) . toEqual ( [ 3 , 2 ] ) ;
211
+ expectArraysClose ( await output . data ( ) , [
212
+ 1.16 , 0.0 , - 0.46 , - 0.79674333 ,
213
+ 1.16 , 0.0 , - 0.46 , - 0.79674333 ,
214
+ 1.16 , 0.0 , - 0.46 , - 0.79674333
215
+ ] ) ;
216
+ } ) ;
217
+
218
+ it ( '3 length, 2 step with hamming window' , async ( ) => {
219
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 ] ) ;
220
+ const frameLength = 3 ;
221
+ const frameStep = 2 ;
222
+ const fftLength = 3 ;
223
+ const output = tf . signal . stft ( input , frameLength , frameStep ,
224
+ fftLength , ( length ) => tf . signal . hammingWindow ( length ) ) ;
225
+ expect ( output . shape ) . toEqual ( [ 2 , 2 ] ) ;
226
+ expectArraysClose ( await output . data ( ) , [
227
+ 1.16 , 0.0 , - 0.46 , - 0.79674333 ,
228
+ 1.16 , 0.0 , - 0.46 , - 0.79674333
229
+ ] ) ;
230
+ } ) ;
231
+
232
+ it ( '3 fftLength, 5 frameLength, 2 step with hamming window' , async ( ) => {
233
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 , 1 ] ) ;
234
+ const frameLength = 5 ;
235
+ const frameStep = 1 ;
236
+ const fftLength = 3 ;
237
+ const output = tf . signal . stft ( input , frameLength , frameStep ,
238
+ fftLength , ( length ) => tf . signal . hammingWindow ( length ) ) ;
239
+ expect ( output . shape ) . toEqual ( [ 2 , 2 ] ) ;
240
+ expectArraysClose ( await output . data ( ) , [
241
+ 1.619999 , 0.0 , - 0.69 , 0.39837 ,
242
+ 1.619999 , 0.0 , - 0.69 , 0.39837
243
+ ] ) ;
244
+ } ) ;
245
+
246
+ it ( '5 length with hann window (sequential)' , async ( ) => {
247
+ const input = tf . tensor1d ( [ 1 , 2 , 3 , 4 , 5 ] ) ;
248
+ const frameLength = 5 ;
249
+ const frameStep = 1 ;
250
+ const fftLength = 5 ;
251
+ const output = tf . signal . stft ( input , frameLength , frameStep ,
252
+ fftLength , ( length ) => tf . signal . hammingWindow ( length ) ) ;
253
+ expect ( output . shape ) . toEqual ( [ 1 , 3 ] ) ;
254
+ expectArraysClose (
255
+ await output . data ( ) ,
256
+ [ 6.72 , 0.0 , - 3.6371822 , - 1.1404576 , 0.4771822 , 0.39919350 ] ) ;
257
+ } ) ;
258
+
259
+ it ( '3 length without window function' , async ( ) => {
260
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 ] ) ;
261
+ const frameLength = 3 ;
262
+ const frameStep = 1 ;
263
+ const fftLength = 3 ;
264
+ const ident = ( length : number ) => tf . ones ( [ length ] ) . as1D ( ) ;
265
+ const output = tf . signal . stft ( input , frameLength , frameStep ,
266
+ fftLength , ident ) ;
267
+ expect ( output . shape ) . toEqual ( [ 3 , 2 ] ) ;
268
+ expectArraysClose ( await output . data ( ) , [
269
+ 3.0 , 0.0 , 0.0 , 0.0 ,
270
+ 3.0 , 0.0 , 0.0 , 0.0 ,
271
+ 3.0 , 0.0 , 0.0 , 0.0
272
+ ] ) ;
273
+ } ) ;
274
+
275
+ it ( '3 length, 2 step without window function' , async ( ) => {
276
+ const input = tf . tensor1d ( [ 1 , 1 , 1 , 1 , 1 ] ) ;
277
+ const frameLength = 3 ;
278
+ const frameStep = 2 ;
279
+ const fftLength = 3 ;
280
+ const ident = ( length : number ) => tf . ones ( [ length ] ) . as1D ( ) ;
281
+ const output = tf . signal . stft ( input , frameLength , frameStep ,
282
+ fftLength , ident ) ;
283
+ expect ( output . shape ) . toEqual ( [ 2 , 2 ] ) ;
284
+ expectArraysClose ( await output . data ( ) , [
285
+ 3.0 , 0.0 , 0.0 , 0.0 ,
286
+ 3.0 , 0.0 , 0.0 , 0.0
287
+ ] ) ;
288
+ } ) ;
289
+ } ) ;
0 commit comments