15
15
}
16
16
17
17
body {
18
- margin : 30px 0 0 30px ;
18
+ margin : 20px 100px ;
19
+ }
20
+
21
+ h2 {
22
+ margin-bottom : 30px ;
23
+ }
24
+
25
+ # kernels {
26
+ max-width : 750px ;
19
27
}
20
28
21
29
# container {
25
33
}
26
34
27
35
.box {
28
- margin-right : 20 px ;
29
- margin-bottom : 20 px ;
36
+ margin-right : 30 px ;
37
+ margin-bottom : 30 px ;
30
38
}
31
39
32
40
.box pre {
36
44
font-size : 10px ;
37
45
}
38
46
47
+ # trendline-container svg {
48
+ overflow : visible;
49
+ border-bottom : 1px solid # ccc ;
50
+ border-left : 1px solid # ccc ;
51
+ }
52
+
53
+ # trendline-container .label {
54
+ font-size : 14px ;
55
+ font-weight : bold;
56
+ }
57
+
58
+ # trendline-container path {
59
+ fill : none;
60
+ stroke : # 222 ;
61
+ }
62
+
63
+ # trendline {
64
+ position : relative;
65
+ margin-top : 20px ;
66
+ }
67
+
68
+ # trendline # yMax , # trendline # yMin {
69
+ position : absolute;
70
+ right : calc (100% + 6px );
71
+ font-size : 11px ;
72
+ white-space : nowrap;
73
+ }
74
+
75
+ # trendline # yMin {
76
+ bottom : 0 ;
77
+ }
78
+
79
+ # trendline # yMax {
80
+ top : 0 ;
81
+ }
82
+
39
83
# modal-msg {
40
84
border-radius : 5px ;
41
85
background-color : black;
48
92
}
49
93
50
94
.table {
51
- margin-right : 20 px ;
52
- margin-bottom : 20 px ;
95
+ margin-right : 30 px ;
96
+ margin-bottom : 30 px ;
53
97
border : 1px solid # ccc ;
54
98
border-collapse : collapse;
55
99
border-spacing : 0 ;
87
131
< h2 > TensorFlow.js Model Benchmark</ h2 >
88
132
< div id ="modal-msg "> </ div >
89
133
< div id ="container ">
90
- < div class ="box ">
91
- < pre id ="env "> </ pre >
134
+ < div id ="stats ">
135
+ < div class ="box ">
136
+ < pre id ="env "> </ pre >
137
+ </ div >
138
+ < table class ="table " id ="timings ">
139
+ < thead >
140
+ < tr >
141
+ < th > Type</ th >
142
+ < th > Value</ th >
143
+ </ tr >
144
+ </ thead >
145
+ < tbody >
146
+ </ tbody >
147
+ </ table >
148
+ < div class ="box " id ="trendline-container ">
149
+ < div class ="label "> </ div >
150
+ < div id ="trendline ">
151
+ < div id ="yMax "> </ div >
152
+ < div id ="yMin "> 0 ms</ div >
153
+ < svg > < path > </ path > </ svg >
154
+ </ div >
155
+ </ div >
92
156
</ div >
93
- < table class ="table " id ="timings ">
94
- < thead >
95
- < tr >
96
- < th > Type</ th >
97
- < th > Value</ th >
98
- </ tr >
99
- </ thead >
100
- < tbody >
101
- </ tbody >
102
- </ table >
103
157
< table class ="table " id ="kernels ">
104
158
< thead >
105
159
< tr >
@@ -131,19 +185,21 @@ <h2>TensorFlow.js Model Benchmark</h2>
131
185
//////////////////////////////////
132
186
// Place model prediction code here.
133
187
//////////////////////////////////
188
+ if ( isAsync ) {
189
+ return model . executeAsync ( zeros ) ;
190
+ }
134
191
return model . predict ( zeros ) ;
135
192
}
136
193
</ script >
137
194
< script >
138
195
'use strict' ;
139
196
const state = {
140
- numRuns : 20 ,
141
-
197
+ numRuns : 50 ,
142
198
} ;
143
199
const modalDiv = document . getElementById ( 'modal-msg' ) ;
144
200
const timeTable = document . querySelector ( '#timings tbody' ) ;
145
201
const envDiv = document . getElementById ( 'env' ) ;
146
- let model ;
202
+ let model , isAsync ;
147
203
148
204
async function showMsg ( message ) {
149
205
if ( message != null ) {
@@ -169,6 +225,10 @@ <h2>TensorFlow.js Model Benchmark</h2>
169
225
envDiv . innerHTML += `<br/>${ JSON . stringify ( tf . ENV . features , null , 2 ) } ` ;
170
226
}
171
227
228
+ function printTime ( elapsed ) {
229
+ return elapsed . toFixed ( 1 ) + ' ms' ;
230
+ }
231
+
172
232
function printMemory ( bytes ) {
173
233
if ( bytes < 1024 ) {
174
234
return bytes + ' B' ;
@@ -200,12 +260,14 @@ <h2>TensorFlow.js Model Benchmark</h2>
200
260
if ( res instanceof Promise ) {
201
261
res = await res ;
202
262
}
263
+
203
264
if ( res instanceof tf . Tensor ) {
204
- await res . data ( ) ;
265
+ res . dataSync ( ) ;
205
266
}
267
+
206
268
const elapsed = performance . now ( ) - start ;
207
269
await showMsg ( null ) ;
208
- appendRow ( timeTable , 'Warmup ' , elapsed . toFixed ( 1 ) + ' ms' ) ;
270
+ appendRow ( timeTable , '1st inference ' , printTime ( start ) ) ;
209
271
}
210
272
211
273
function sleep ( timeMs ) {
@@ -216,34 +278,66 @@ <h2>TensorFlow.js Model Benchmark</h2>
216
278
await showMsg ( 'Loading the model' ) ;
217
279
const start = performance . now ( ) ;
218
280
model = await load ( ) ;
281
+ isAsync = model . executor != null && model . executor . isControlFlowModel ;
282
+
219
283
const elapsed = performance . now ( ) - start ;
220
284
await showMsg ( null ) ;
221
- appendRow ( timeTable , 'Model load' , elapsed . toFixed ( 1 ) + ' ms' ) ;
285
+ appendRow ( timeTable , 'Model load' , printTime ( elapsed ) ) ;
222
286
}
223
287
224
288
async function measureAveragePredictTime ( ) {
289
+ document . querySelector ( "#trendline-container .label" ) . textContent = `Inference times over ${ state . numRuns } runs` ;
225
290
await showMsg ( `Running predict ${ state . numRuns } times` ) ;
226
- const start = performance . now ( ) ;
227
- let res ;
291
+ const chartHeight = 150 ;
292
+ const chartWidth = document . querySelector ( "#trendline-container" ) . getBoundingClientRect ( ) . width ;
293
+ document . querySelector ( "#trendline-container svg" ) . setAttribute ( "width" , chartWidth ) ;
294
+ document . querySelector ( "#trendline-container svg" ) . setAttribute ( "height" , chartHeight ) ;
295
+
296
+ const times = [ ] ;
228
297
for ( let i = 0 ; i < state . numRuns ; i ++ ) {
229
- res = predict ( model ) ;
298
+ const start = performance . now ( ) ;
299
+ let res = predict ( model ) ;
300
+ if ( res instanceof Promise ) {
301
+ res = await res ;
302
+ }
303
+
304
+ if ( res instanceof tf . Tensor ) {
305
+ res . dataSync ( ) ;
306
+ }
307
+
308
+ times . push ( performance . now ( ) - start ) ;
230
309
}
310
+
311
+ const average = times . reduce ( ( acc , curr ) => acc + curr , 0 ) / times . length ;
312
+ const max = Math . max ( ...times ) ;
313
+ const min = Math . min ( ...times ) ;
314
+ const xIncrement = chartWidth / times . length ;
315
+
316
+ document . querySelector ( "#trendline-container #yMax" ) . textContent = printTime ( max ) ;
317
+ document . querySelector ( "#trendline-container path" )
318
+ . setAttribute ( "d" , `M${ times . map ( ( d , i ) => `${ i * xIncrement } ,${ chartHeight - ( d / max ) * chartHeight } ` ) . join ( 'L' ) } ` ) ;
319
+
320
+ await showMsg ( null ) ;
321
+ appendRow ( timeTable , `Subsequent average (${ state . numRuns } runs)` , printTime ( average ) ) ;
322
+ appendRow ( timeTable , 'Best time' , printTime ( min ) ) ;
323
+ }
324
+
325
+ async function profileMemory ( ) {
326
+ await showMsg ( 'Profile memory' ) ;
327
+ const start = performance . now ( ) ;
328
+ let res ;
329
+ const data = await tf . profile ( ( ) => res = predict ( model ) ) ;
231
330
if ( res instanceof Promise ) {
232
331
res = await res ;
233
332
}
333
+
234
334
if ( res instanceof tf . Tensor ) {
235
335
res . dataSync ( ) ;
236
336
}
237
- const elapsed = ( performance . now ( ) - start ) / state . numRuns ;
238
- await showMsg ( null ) ;
239
- appendRow ( timeTable , `Predict (${ state . numRuns } runs)` , elapsed . toFixed ( 1 ) + ' ms' ) ;
240
- }
241
-
242
- async function profileMemory ( ) {
243
- await showMsg ( 'Profile memory' ) ;
244
- const data = await tf . profile ( ( ) => predict ( model ) ) ;
337
+ const elapsed = performance . now ( ) - start ;
245
338
await showMsg ( null ) ;
246
339
appendRow ( timeTable , 'Peak memory' , printMemory ( data . peakBytes ) ) ;
340
+ appendRow ( timeTable , '2nd inference' , printTime ( elapsed ) ) ;
247
341
}
248
342
249
343
function showKernelTime ( kernels ) {
@@ -285,12 +379,14 @@ <h2>TensorFlow.js Model Benchmark</h2>
285
379
}
286
380
}
287
381
let res = predict ( model ) ;
288
- if ( res instanceof Promise ) {
382
+ if ( res instanceof Promise ) {
289
383
res = await res ;
290
384
}
291
- if ( res instanceof tf . Tensor ) {
385
+
386
+ if ( res instanceof tf . Tensor ) {
292
387
res . dataSync ( ) ;
293
388
}
389
+
294
390
await showMsg ( null ) ;
295
391
await sleep ( 10 ) ;
296
392
kernels = kernels . sort ( ( a , b ) => b . time - a . time ) ;
0 commit comments