235
235
< div class ="pytorch-left-menu-search ">
236
236
237
237
< div class ="version ">
238
- < a href ='https://pytorch.org/docs/versions.html '> master (2.0.0a0+git299ada9 ) ▼</ a >
238
+ < a href ='https://pytorch.org/docs/versions.html '> master (2.0.0a0+gitd51ca38 ) ▼</ a >
239
239
</ div >
240
240
241
241
328
328
< li class ="toctree-l1 "> < a class ="reference internal " href ="../../../autograd.html "> torch.autograd</ a > </ li >
329
329
< li class ="toctree-l1 "> < a class ="reference internal " href ="../../../library.html "> torch.library</ a > </ li >
330
330
< li class ="toctree-l1 "> < a class ="reference internal " href ="../../../cuda.html "> torch.cuda</ a > </ li >
331
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../../../mps.html "> torch.mps</ a > </ li >
331
332
< li class ="toctree-l1 "> < a class ="reference internal " href ="../../../backends.html "> torch.backends</ a > </ li >
332
333
< li class ="toctree-l1 "> < a class ="reference internal " href ="../../../distributed.html "> torch.distributed</ a > </ li >
333
334
< li class ="toctree-l1 "> < a class ="reference internal " href ="../../../distributed.algorithms.join.html "> torch.distributed.algorithms.join</ a > </ li >
@@ -510,7 +511,7 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
510
511
< span class ="kn "> from</ span > < span class ="nn "> .exc</ span > < span class ="kn "> import</ span > < span class ="n "> ResetRequired</ span >
511
512
< span class ="kn "> from</ span > < span class ="nn "> .mutation_guard</ span > < span class ="kn "> import</ span > < span class ="n "> install_generation_tagging_init</ span >
512
513
< span class ="kn "> from</ span > < span class ="nn "> .types</ span > < span class ="kn "> import</ span > < span class ="n "> DynamoCallback</ span >
513
- < span class ="kn "> from</ span > < span class ="nn "> .utils</ span > < span class ="kn "> import</ span > < span class ="n "> compile_times</ span >
514
+ < span class ="kn "> from</ span > < span class ="nn "> .utils</ span > < span class ="kn "> import</ span > < span class ="n "> compile_times</ span > < span class =" p " > , </ span > < span class =" n " > fake_mode_from_tensors </ span >
514
515
515
516
< span class ="n "> log</ span > < span class ="o "> =</ span > < span class ="n "> logging</ span > < span class ="o "> .</ span > < span class ="n "> getLogger</ span > < span class ="p "> (</ span > < span class ="vm "> __name__</ span > < span class ="p "> )</ span >
516
517
@@ -991,6 +992,7 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
991
992
< span class ="n "> f</ span > < span class ="o "> =</ span > < span class ="n "> innermost_fn</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
992
993
993
994
< span class ="n "> graph</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
995
+ < span class ="n "> compile_time_inputs</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
994
996
< span class ="n "> out_guards</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
995
997
< span class ="n "> graph_captured_input</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
996
998
< span class ="n "> graph_captured_result</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="o "> ...</ span > < span class ="p "> ]]</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
@@ -1033,9 +1035,11 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1033
1035
< span class ="n "> gm</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> fx</ span > < span class ="o "> .</ span > < span class ="n "> GraphModule</ span > < span class ="p "> ,</ span > < span class ="n "> example_inputs</ span >
1034
1036
< span class ="p "> ):</ span >
1035
1037
< span class ="k "> nonlocal</ span > < span class ="n "> graph</ span >
1038
+ < span class ="k "> nonlocal</ span > < span class ="n "> compile_time_inputs</ span >
1036
1039
1037
1040
< span class ="k "> assert</ span > < span class ="n "> graph</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="s2 "> "whole graph export entails exactly one graph"</ span >
1038
1041
< span class ="n "> graph</ span > < span class ="o "> =</ span > < span class ="n "> gm</ span >
1042
+ < span class ="n "> compile_time_inputs</ span > < span class ="o "> =</ span > < span class ="n "> example_inputs</ span >
1039
1043
1040
1044
< span class ="k "> def</ span > < span class ="nf "> result_capturing_wrapper</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> graph_inputs</ span > < span class ="p "> ):</ span >
1041
1045
< span class ="k "> nonlocal</ span > < span class ="n "> graph_captured_result</ span >
@@ -1092,6 +1096,8 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1092
1096
< span class ="n "> arg</ span > < span class ="o "> =</ span > < span class ="nb "> next</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> old_args_gen</ span > < span class ="p "> )</ span >
1093
1097
< span class ="k "> if</ span > < span class ="s2 "> "val"</ span > < span class ="ow "> in</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> current_node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> :</ span >
1094
1098
< span class ="n "> arg</ span > < span class ="o "> .</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> current_node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]</ span >
1099
+ < span class ="k "> if</ span > < span class ="s2 "> "tensor_dict"</ span > < span class ="ow "> in</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> current_node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> :</ span >
1100
+ < span class ="n "> arg</ span > < span class ="o "> .</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "tensor_dict"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> current_node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "tensor_dict"</ span > < span class ="p "> ]</ span >
1095
1101
< span class ="k "> return</ span > < span class ="n "> arg</ span >
1096
1102
1097
1103
< span class ="k "> def</ span > < span class ="nf "> output</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span > < span class ="n "> args</ span > < span class ="p "> ,</ span > < span class ="n "> kwargs</ span > < span class ="p "> ):</ span >
@@ -1100,22 +1106,28 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1100
1106
< span class ="n "> new_result_flat</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="n "> lookup</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="p "> ]</ span > < span class ="k "> for</ span > < span class ="n "> i</ span > < span class ="ow "> in</ span > < span class ="n "> matched_output_elements_positions</ span > < span class ="p "> ]</ span >
1101
1107
< span class ="k "> return</ span > < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="n "> output</ span > < span class ="p "> (</ span > < span class ="n "> target</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> new_result_flat</ span > < span class ="p "> ,),</ span > < span class ="p "> {})</ span >
1102
1108
1103
- < span class ="k "> def</ span > < span class ="nf "> run_node</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> n</ span > < span class ="p "> ):</ span >
1104
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> current_node</ span > < span class ="o "> =</ span > < span class ="n "> n</ span >
1105
- < span class ="k "> return</ span > < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="n "> run_node</ span > < span class ="p "> (</ span > < span class ="n "> n</ span > < span class ="p "> )</ span >
1106
-
1107
1109
< span class ="k "> if</ span > < span class ="n "> aten_graph</ span > < span class ="p "> :</ span >
1108
1110
< span class ="c1 "> # Running graph with interpreter is needed for propagating the stack_trace</ span >
1109
1111
< span class ="k "> def</ span > < span class ="nf "> graph_with_interpreter</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> args</ span > < span class ="p "> ):</ span >
1110
1112
< span class ="k "> with</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> fx</ span > < span class ="o "> .</ span > < span class ="n "> traceback</ span > < span class ="o "> .</ span > < span class ="n "> preserve_node_meta</ span > < span class ="p "> ():</ span >
1111
1113
< span class ="k "> return</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> fx</ span > < span class ="o "> .</ span > < span class ="n "> Interpreter</ span > < span class ="p "> (</ span > < span class ="n "> graph</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> run</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> args</ span > < span class ="p "> )</ span >
1112
1114
1113
- < span class ="n "> graph</ span > < span class ="o "> =</ span > < span class ="n "> make_fx</ span > < span class ="p "> (</ span >
1114
- < span class ="n "> graph_with_interpreter</ span > < span class ="p "> ,</ span >
1115
- < span class ="n "> decomposition_table</ span > < span class ="o "> =</ span > < span class ="n "> decomposition_table</ span > < span class ="p "> ,</ span >
1116
- < span class ="n "> tracing_mode</ span > < span class ="o "> =</ span > < span class ="n "> tracing_mode</ span > < span class ="p "> ,</ span >
1117
- < span class ="n "> _allow_non_fake_inputs</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1118
- < span class ="p "> )(</ span > < span class ="o "> *</ span > < span class ="n "> graph_captured_input</ span > < span class ="p "> )</ span >
1115
+ < span class ="k "> if</ span > < span class ="n "> tracing_mode</ span > < span class ="o "> ==</ span > < span class ="s2 "> "real"</ span > < span class ="p "> :</ span >
1116
+ < span class ="n "> graph</ span > < span class ="o "> =</ span > < span class ="n "> make_fx</ span > < span class ="p "> (</ span >
1117
+ < span class ="n "> graph_with_interpreter</ span > < span class ="p "> ,</ span >
1118
+ < span class ="n "> decomposition_table</ span > < span class ="o "> =</ span > < span class ="n "> decomposition_table</ span > < span class ="p "> ,</ span >
1119
+ < span class ="p "> )(</ span > < span class ="o "> *</ span > < span class ="n "> graph_captured_input</ span > < span class ="p "> )</ span >
1120
+ < span class ="k "> elif</ span > < span class ="n "> tracing_mode</ span > < span class ="o "> ==</ span > < span class ="s2 "> "symbolic"</ span > < span class ="p "> :</ span >
1121
+ < span class ="c1 "> # For dynamic shape, we need to make_fx through the graph with fake tensors under FakeTensorMode</ span >
1122
+ < span class ="c1 "> # The fake tensors may contain the fine grain dynamic shape passed down from dynamo</ span >
1123
+ < span class ="n "> fake_mode</ span > < span class ="o "> =</ span > < span class ="n "> fake_mode_from_tensors</ span > < span class ="p "> (</ span > < span class ="n "> compile_time_inputs</ span > < span class ="p "> )</ span >
1124
+ < span class ="k "> with</ span > < span class ="n "> fake_mode</ span > < span class ="p "> :</ span >
1125
+ < span class ="n "> graph</ span > < span class ="o "> =</ span > < span class ="n "> make_fx</ span > < span class ="p "> (</ span >
1126
+ < span class ="n "> graph_with_interpreter</ span > < span class ="p "> ,</ span >
1127
+ < span class ="n "> decomposition_table</ span > < span class ="o "> =</ span > < span class ="n "> decomposition_table</ span > < span class ="p "> ,</ span >
1128
+ < span class ="p "> )(</ span > < span class ="o "> *</ span > < span class ="n "> compile_time_inputs</ span > < span class ="p "> )</ span >
1129
+ < span class ="k "> else</ span > < span class ="p "> :</ span >
1130
+ < span class ="k "> raise</ span > < span class ="ne "> AssertionError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Unknown tracing mode </ span > < span class ="si "> {</ span > < span class ="n "> tracing_mode</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
1119
1131
1120
1132
< span class ="n "> new_graph</ span > < span class ="o "> =</ span > < span class ="n "> ChangeInputOutputSignature</ span > < span class ="p "> (</ span >
1121
1133
< span class ="n "> graph</ span > < span class ="p "> ,</ span >
0 commit comments