44
44
font-size : 10px ;
45
45
}
46
46
47
- # trendline-container svg {
47
+ div [ id *= ' trendline-container' ] svg {
48
48
overflow : visible;
49
49
border-bottom : 1px solid # ccc ;
50
50
border-left : 1px solid # ccc ;
51
51
}
52
52
53
- # trendline-container .label {
53
+ div [ id *= ' trendline-container' ] .label {
54
54
font-size : 14px ;
55
55
font-weight : bold;
56
56
}
57
57
58
- # trendline-container path {
58
+ div [ id *= ' trendline-container' ] path {
59
59
fill : none;
60
60
stroke : # 222 ;
61
61
}
62
62
63
- # trendline {
63
+ . trendline {
64
64
position : relative;
65
65
margin-top : 20px ;
66
66
}
67
67
68
- # trendline # yMax ,
69
- # trendline # yMin {
68
+ . trendline . yMax ,
69
+ . trendline . yMin {
70
70
position : absolute;
71
71
right : calc (100% + 6px );
72
72
font-size : 11px ;
73
73
white-space : nowrap;
74
74
}
75
75
76
- # trendline # yMin {
76
+ . trendline . yMin {
77
77
bottom : 0 ;
78
78
}
79
79
80
- # trendline # yMax {
80
+ . trendline . yMax {
81
81
top : 0 ;
82
82
}
83
83
@@ -146,11 +146,31 @@ <h2>TensorFlow.js Model Benchmark</h2>
146
146
< tbody >
147
147
</ tbody >
148
148
</ table >
149
- < div class ="box " id ="trendline-container ">
150
- < div class ="label "> </ div >
151
- < div id ="trendline ">
152
- < div id ="yMax "> </ div >
153
- < div id ="yMin "> 0 ms</ div >
149
+ < div class ="box " id ="perf-trendline-container ">
150
+ < div class ="label "> Inference times</ div >
151
+ < div class ="trendline ">
152
+ < div class ="yMax "> </ div >
153
+ < div class ="yMin "> </ div >
154
+ < svg >
155
+ < path > </ path >
156
+ </ svg >
157
+ </ div >
158
+ </ div >
159
+ < div class ="box " id ="mem-trendline-container ">
160
+ < div class ="label "> Number of tensors</ div >
161
+ < div class ="trendline ">
162
+ < div class ="yMax "> </ div >
163
+ < div class ="yMin "> </ div >
164
+ < svg >
165
+ < path > </ path >
166
+ </ svg >
167
+ </ div >
168
+ </ div >
169
+ < div class ="box " id ="bytes-trendline-container ">
170
+ < div class ="label "> Number of bytes used</ div >
171
+ < div class ="trendline ">
172
+ < div class ="yMax "> </ div >
173
+ < div class ="yMin "> </ div >
154
174
< svg >
155
175
< path > </ path >
156
176
</ svg >
@@ -322,7 +342,7 @@ <h2>TensorFlow.js Model Benchmark</h2>
322
342
const modalDiv = document . getElementById ( 'modal-msg' ) ;
323
343
const timeTable = document . querySelector ( '#timings tbody' ) ;
324
344
const envDiv = document . getElementById ( 'env' ) ;
325
- let model , isAsync , predict ;
345
+ let model , isAsync , predict , chartWidth ;
326
346
327
347
async function showMsg ( message ) {
328
348
if ( message != null ) {
@@ -413,15 +433,30 @@ <h2>TensorFlow.js Model Benchmark</h2>
413
433
appendRow ( timeTable , 'Model load' , printTime ( elapsed ) ) ;
414
434
}
415
435
436
+ const chartHeight = 150 ;
437
+ function populateTrendline ( node , data , forceYMinToZero = false , yFormatter = d => d ) {
438
+ node . querySelector ( "svg" ) . setAttribute ( "width" , chartWidth ) ;
439
+ node . querySelector ( "svg" ) . setAttribute ( "height" , chartHeight ) ;
440
+
441
+ const yMax = Math . max ( ...data ) ;
442
+ const yMin = forceYMinToZero ? 0 : Math . min ( ...data ) ;
443
+
444
+ node . querySelector ( ".yMin" ) . textContent = yFormatter ( yMin ) ;
445
+ node . querySelector ( ".yMax" ) . textContent = yFormatter ( yMax ) ;
446
+
447
+ const xIncrement = chartWidth / ( data . length - 1 ) ;
448
+ node . querySelector ( "path" )
449
+ . setAttribute ( "d" , `M${ data . map ( ( d , i ) => `${ i * xIncrement } ,${ chartHeight - ( ( d - yMin ) / ( yMax - yMin ) ) * chartHeight } ` ) . join ( 'L' ) } ` ) ;
450
+ }
451
+
416
452
async function measureAveragePredictTime ( ) {
417
- document . querySelector ( "#trendline-container .label" ) . textContent = `Inference times over ${ state . numRuns } runs` ;
418
453
await showMsg ( `Running predict ${ state . numRuns } times` ) ;
419
- const chartHeight = 150 ;
420
- const chartWidth = document . querySelector ( "#trendline-container" ) . getBoundingClientRect ( ) . width ;
421
- document . querySelector ( "#trendline-container svg" ) . setAttribute ( "width" , chartWidth ) ;
422
- document . querySelector ( "#trendline-container svg" ) . setAttribute ( "height" , chartHeight ) ;
454
+ chartWidth = document . querySelector ( "#perf-trendline-container" ) . getBoundingClientRect ( ) . width ;
423
455
424
456
const times = [ ] ;
457
+ const numTensors = [ ] ;
458
+ const numBytes = [ ] ;
459
+
425
460
for ( let i = 0 ; i < state . numRuns ; i ++ ) {
426
461
const start = performance . now ( ) ;
427
462
let res = predict ( model ) ;
@@ -434,18 +469,21 @@ <h2>TensorFlow.js Model Benchmark</h2>
434
469
}
435
470
436
471
times . push ( performance . now ( ) - start ) ;
472
+ const memInfo = tf . memory ( ) ;
473
+ numTensors . push ( memInfo . numTensors ) ;
474
+ numBytes . push ( memInfo . numBytes ) ;
437
475
}
438
476
439
- const average = times . reduce ( ( acc , curr ) => acc + curr , 0 ) / times . length ;
440
- const max = Math . max ( ...times ) ;
441
- const min = Math . min ( ...times ) ;
442
- const xIncrement = chartWidth / times . length ;
477
+ const forceInferenceTrendYMinToZero = true ;
478
+ populateTrendline ( document . querySelector ( "#perf-trendline-container" ) , times , forceInferenceTrendYMinToZero , printTime ) ;
479
+ populateTrendline ( document . querySelector ( "#mem-trendline-container" ) , numTensors ) ;
443
480
444
- document . querySelector ( "#trendline-container #yMax" ) . textContent = printTime ( max ) ;
445
- document . querySelector ( "#trendline-container path" )
446
- . setAttribute ( "d" , `M${ times . map ( ( d , i ) => `${ i * xIncrement } ,${ chartHeight - ( d / max ) * chartHeight } ` ) . join ( 'L' ) } ` ) ;
481
+ const forceBytesTrendlineYMinToZero = false ;
482
+ populateTrendline ( document . querySelector ( "#bytes-trendline-container" ) , numBytes , forceBytesTrendlineYMinToZero , d => `${ ( d / 1e6 ) . toPrecision ( 3 ) } MB` ) ;
447
483
448
484
await showMsg ( null ) ;
485
+ const average = times . reduce ( ( acc , curr ) => acc + curr , 0 ) / times . length ;
486
+ const min = Math . min ( ...times ) ;
449
487
appendRow ( timeTable , `Subsequent average (${ state . numRuns } runs)` , printTime ( average ) ) ;
450
488
appendRow ( timeTable , 'Best time' , printTime ( min ) ) ;
451
489
}
0 commit comments