235
235
< div class ="pytorch-left-menu-search ">
236
236
237
237
< div class ="version ">
238
- < a href ='https://pytorch.org/docs/versions.html '> master (2.1.0a0+gitd1fbd33 ) ▼</ a >
238
+ < a href ='https://pytorch.org/docs/versions.html '> master (2.1.0a0+git35991df ) ▼</ a >
239
239
</ div >
240
240
241
241
@@ -478,6 +478,7 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
478
478
< span class ="kn "> import</ span > < span class ="nn "> inspect</ span >
479
479
< span class ="kn "> import</ span > < span class ="nn "> logging</ span >
480
480
< span class ="kn "> import</ span > < span class ="nn "> os</ span >
481
+ < span class ="kn "> import</ span > < span class ="nn "> re</ span >
481
482
< span class ="kn "> import</ span > < span class ="nn "> sys</ span >
482
483
< span class ="kn "> import</ span > < span class ="nn "> textwrap</ span >
483
484
< span class ="kn "> import</ span > < span class ="nn "> threading</ span >
@@ -1143,6 +1144,22 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1143
1144
< span class ="s2 "> "you can specify them separately instead."</ span >
1144
1145
< span class ="p "> )</ span >
1145
1146
1147
+ < span class ="nd "> @property</ span >
1148
+ < span class ="k "> def</ span > < span class ="nf "> serializable_spec</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ):</ span >
1149
+ < span class ="c1 "> # We need a serialization compatible format of the constraint so that it</ span >
1150
+ < span class ="c1 "> # can be savedin the graph module w/o breaking the module serialization.</ span >
1151
+ < span class ="c1 "> # The saved constraints will be used directly for the post-exporting pass</ span >
1152
+ < span class ="c1 "> # that converts constraints to runtime assertion. The saved constraints</ span >
1153
+ < span class ="c1 "> # will not be saved in the serialized module.</ span >
1154
+ < span class ="c1 "> # TODO: A better way is needed. Currently we use 't_id' to map the constraint,</ span >
1155
+ < span class ="c1 "> # which is not reliable</ span >
1156
+ < span class ="k "> return</ span > < span class ="p "> {</ span >
1157
+ < span class ="s2 "> "t_id"</ span > < span class ="p "> :</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> t_id</ span > < span class ="p "> ,</ span >
1158
+ < span class ="s2 "> "dim"</ span > < span class ="p "> :</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> dim</ span > < span class ="p "> ,</ span >
1159
+ < span class ="s2 "> "min"</ span > < span class ="p "> :</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> constraint_range</ span > < span class ="o "> .</ span > < span class ="n "> vr</ span > < span class ="o "> .</ span > < span class ="n "> lower</ span > < span class ="p "> ,</ span >
1160
+ < span class ="s2 "> "max"</ span > < span class ="p "> :</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> constraint_range</ span > < span class ="o "> .</ span > < span class ="n "> vr</ span > < span class ="o "> .</ span > < span class ="n "> upper</ span > < span class ="p "> ,</ span >
1161
+ < span class ="p "> }</ span >
1162
+
1146
1163
1147
1164
< div class ="viewcode-block " id ="export "> < a class ="viewcode-back " href ="../../../_dynamo.html#torch._dynamo.export "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> export</ span > < span class ="p "> (</ span >
1148
1165
< span class ="n "> f</ span > < span class ="p "> :</ span > < span class ="n "> Callable</ span > < span class ="p "> [</ span > < span class ="o "> ...</ span > < span class ="p "> ,</ span > < span class ="n "> Any</ span > < span class ="p "> ],</ span >
@@ -1246,6 +1263,7 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1246
1263
1247
1264
< span class ="n "> fake_mode</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
1248
1265
< span class ="n "> example_inputs</ span > < span class ="o "> =</ span > < span class ="p "> []</ span >
1266
+ < span class ="n "> var_to_range_map</ span > < span class ="o "> =</ span > < span class ="p "> {}</ span >
1249
1267
1250
1268
< span class ="k "> def</ span > < span class ="nf "> dynamo_normalization_capturing_compiler</ span > < span class ="p "> (</ span >
1251
1269
< span class ="n "> gm</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> fx</ span > < span class ="o "> .</ span > < span class ="n "> GraphModule</ span > < span class ="p "> ,</ span > < span class ="n "> inner_example_inputs</ span >
@@ -1256,9 +1274,11 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1256
1274
< span class ="p "> ),</ span > < span class ="s2 "> "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."</ span >
1257
1275
< span class ="n "> graph</ span > < span class ="o "> =</ span > < span class ="n "> gm</ span >
1258
1276
1259
- < span class ="k "> nonlocal</ span > < span class ="n "> fake_mode</ span > < span class ="p "> ,</ span > < span class ="n "> example_inputs</ span >
1277
+ < span class ="k "> nonlocal</ span > < span class ="n "> fake_mode</ span > < span class ="p "> ,</ span > < span class ="n "> example_inputs</ span > < span class =" p " > , </ span > < span class =" n " > var_to_range_map </ span >
1260
1278
< span class ="n "> fake_mode</ span > < span class ="o "> =</ span > < span class ="n "> _guards</ span > < span class ="o "> .</ span > < span class ="n "> detect_fake_mode</ span > < span class ="p "> (</ span > < span class ="n "> inner_example_inputs</ span > < span class ="p "> )</ span >
1261
1279
< span class ="n "> example_inputs</ span > < span class ="o "> =</ span > < span class ="n "> inner_example_inputs</ span >
1280
+ < span class ="k "> if</ span > < span class ="n "> fake_mode</ span > < span class ="ow "> and</ span > < span class ="n "> fake_mode</ span > < span class ="o "> .</ span > < span class ="n "> shape_env</ span > < span class ="p "> :</ span >
1281
+ < span class ="n "> var_to_range_map</ span > < span class ="o "> =</ span > < span class ="n "> fake_mode</ span > < span class ="o "> .</ span > < span class ="n "> shape_env</ span > < span class ="o "> .</ span > < span class ="n "> var_to_range</ span >
1262
1282
1263
1283
< span class ="k "> def</ span > < span class ="nf "> result_capturing_wrapper</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> graph_inputs</ span > < span class ="p "> ):</ span >
1264
1284
< span class ="k "> nonlocal</ span > < span class ="n "> graph_captured_result</ span >
@@ -1378,6 +1398,21 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1378
1398
< span class ="n "> graph</ span > < span class ="p "> ,</ span >
1379
1399
< span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> transform</ span > < span class ="p "> ()</ span >
1380
1400
1401
+ < span class ="c1 "> # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check</ span >
1402
+ < span class ="n "> new_graph</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "example_inputs"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> example_inputs</ span >
1403
+ < span class ="n "> new_graph</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "input_shape_constraints"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> (</ span >
1404
+ < span class ="p "> [</ span > < span class ="n "> constraint</ span > < span class ="o "> .</ span > < span class ="n "> serializable_spec</ span > < span class ="k "> for</ span > < span class ="n "> constraint</ span > < span class ="ow "> in</ span > < span class ="n "> constraints</ span > < span class ="p "> ]</ span >
1405
+ < span class ="k "> if</ span > < span class ="n "> constraints</ span >
1406
+ < span class ="k "> else</ span > < span class ="kc "> None</ span >
1407
+ < span class ="p "> )</ span >
1408
+ < span class ="n "> new_graph</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "inline_constraints"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> {</ span >
1409
+ < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> expr</ span > < span class ="p "> :</ span > < span class ="n "> var_to_range_map</ span > < span class ="p "> [</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> expr</ span > < span class ="p "> ]</ span >
1410
+ < span class ="k "> for</ span > < span class ="n "> node</ span > < span class ="ow "> in</ span > < span class ="n "> new_graph</ span > < span class ="o "> .</ span > < span class ="n "> graph</ span > < span class ="o "> .</ span > < span class ="n "> nodes</ span >
1411
+ < span class ="k "> if</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> op</ span > < span class ="o "> !=</ span > < span class ="s2 "> "placeholder"</ span > < span class ="ow "> and</ span > < span class ="s2 "> "val"</ span > < span class ="ow "> in</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span >
1412
+ < span class ="c1 "> # Find constraints frome unbacked symints</ span >
1413
+ < span class ="ow "> and</ span > < span class ="n "> re</ span > < span class ="o "> .</ span > < span class ="n "> match</ span > < span class ="p "> (</ span > < span class ="sa "> r</ span > < span class ="s2 "> "^i\d+$"</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> (</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]))</ span >
1414
+ < span class ="p "> }</ span >
1415
+
1381
1416
< span class ="k "> def</ span > < span class ="nf "> signature_to_fullargspec</ span > < span class ="p "> (</ span > < span class ="n "> sig</ span > < span class ="p "> :</ span > < span class ="n "> inspect</ span > < span class ="o "> .</ span > < span class ="n "> Signature</ span > < span class ="p "> ):</ span >
1382
1417
< span class ="c1 "> # Get a list of Parameter objects from the Signature object</ span >
1383
1418
< span class ="n "> params</ span > < span class ="o "> =</ span > < span class ="nb "> list</ span > < span class ="p "> (</ span > < span class ="n "> sig</ span > < span class ="o "> .</ span > < span class ="n "> parameters</ span > < span class ="o "> .</ span > < span class ="n "> values</ span > < span class ="p "> ())</ span >
0 commit comments