235
235
< div class ="pytorch-left-menu-search ">
236
236
237
237
< div class ="version ">
238
- < a href ='https://pytorch.org/docs/versions.html '> master (1.14.0a0+gita43e09c ) ▼</ a >
238
+ < a href ='https://pytorch.org/docs/versions.html '> master (1.14.0a0+gitf628f2e ) ▼</ a >
239
239
</ div >
240
240
241
241
299
299
</ ul >
300
300
< p class ="caption " role ="heading "> < span class ="caption-text "> torch.compile</ span > </ p >
301
301
< ul >
302
- < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/custom-backends.html "> Custom Backends</ a > </ li >
303
- < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/deep-dive.html "> TorchDynamo Deeper Dive</ a > </ li >
304
- < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/faq.html "> Frequently Asked Questions</ a > </ li >
302
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/index.html "> TorchDynamo Overview</ a > </ li >
303
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/installation.html "> Installing TorchDynamo</ a > </ li >
305
304
< li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/get-started.html "> Getting Started</ a > </ li >
306
305
< li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/guards-overview.html "> Guards Overview</ a > </ li >
307
- < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/index .html "> TorchDynamo Documentation </ a > </ li >
308
- < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/installation .html "> Installing TorchDynamo</ a > </ li >
306
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/custom-backends .html "> Custom Backends </ a > </ li >
307
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/deep-dive .html "> TorchDynamo Deeper Dive </ a > </ li >
309
308
< li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/troubleshooting.html "> TorchDynamo Troubleshooting</ a > </ li >
309
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../dynamo/faq.html "> Frequently Asked Questions</ a > </ li >
310
310
</ ul >
311
311
< p class ="caption " role ="heading "> < span class ="caption-text "> Language Bindings</ span > </ p >
312
312
< ul >
@@ -490,7 +490,7 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
490
490
491
491
< span class ="kn "> from</ span > < span class ="nn "> ._six</ span > < span class ="kn "> import</ span > < span class ="n "> string_classes</ span > < span class ="k "> as</ span > < span class ="n "> _string_classes</ span >
492
492
493
- < span class ="kn "> from</ span > < span class ="nn "> typing</ span > < span class ="kn "> import</ span > < span class ="n "> Set </ span > < span class ="p "> ,</ span > < span class ="n "> Type </ span > < span class ="p "> ,</ span > < span class ="n "> TYPE_CHECKING </ span > < span class ="p "> ,</ span > < span class ="n "> Union </ span > < span class ="p "> ,</ span > < span class ="n "> Callable </ span > < span class ="p "> ,</ span > < span class ="n "> Any </ span >
493
+ < span class ="kn "> from</ span > < span class ="nn "> typing</ span > < span class ="kn "> import</ span > < span class ="n "> Any </ span > < span class ="p "> ,</ span > < span class ="n "> Callable </ span > < span class ="p "> ,</ span > < span class ="n "> Dict </ span > < span class ="p "> ,</ span > < span class ="n "> Optional </ span > < span class ="p "> ,</ span > < span class ="n "> Set </ span > < span class ="p "> ,</ span > < span class ="n "> Type </ span > < span class =" p " > , </ span > < span class =" n " > TYPE_CHECKING </ span > < span class =" p " > , </ span > < span class =" n " > Union </ span >
494
494
< span class ="kn "> import</ span > < span class ="nn "> builtins</ span >
495
495
496
496
< span class ="n "> __all__</ span > < span class ="o "> =</ span > < span class ="p "> [</ span >
@@ -509,6 +509,7 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
509
509
< span class ="s1 "> 'set_deterministic_debug_mode'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'get_deterministic_debug_mode'</ span > < span class ="p "> ,</ span >
510
510
< span class ="s1 "> 'set_float32_matmul_precision'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'get_float32_matmul_precision'</ span > < span class ="p "> ,</ span >
511
511
< span class ="s1 "> 'set_warn_always'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'is_warn_always_enabled'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'SymInt'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'SymFloat'</ span > < span class ="p "> ,</ span >
512
+ < span class ="s1 "> 'compile'</ span > < span class ="p "> ,</ span >
512
513
< span class ="p "> ]</ span >
513
514
514
515
< span class ="c1 "> ################################################################################</ span >
@@ -1573,6 +1574,74 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
1573
1574
< span class ="n "> lstsq</ span > < span class ="p "> ,</ span >
1574
1575
< span class ="p "> )</ span >
1575
1576
1577
+ < div class ="viewcode-block " id ="compile "> < a class ="viewcode-back " href ="../generated/torch.compile.html#torch.compile "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> compile</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="n "> Callable</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="p "> ,</ span >
1578
+ < span class ="n "> fullgraph</ span > < span class ="p "> :</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
1579
+ < span class ="n "> dynamic</ span > < span class ="p "> :</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
1580
+ < span class ="n "> backend</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> Callable</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="s2 "> "inductor"</ span > < span class ="p "> ,</ span >
1581
+ < span class ="n "> mode</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="kc "> None</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
1582
+ < span class ="n "> passes</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="n "> Dict</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> int</ span > < span class ="p "> ,</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ]]]</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
1583
+ < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Callable</ span > < span class ="p "> :</ span >
1584
+ < span class ="sd "> """</ span >
1585
+ < span class ="sd "> Optimizes given model/function using Dynamo and specified backend</ span >
1586
+
1587
+ < span class ="sd "> Args:</ span >
1588
+ < span class ="sd "> model (Callable): Module/function to optimize</ span >
1589
+ < span class ="sd "> fullgraph (bool): Whether it is ok to break model into several subgraphs</ span >
1590
+ < span class ="sd "> dynamic (bool): Use dynamic shape tracing</ span >
1591
+ < span class ="sd "> backend (str or Callable): backend to be used</ span >
1592
+ < span class ="sd "> mode (str): Can be either "default", "reduce-overhead" or "max-autotune"</ span >
1593
+ < span class ="sd "> passes (dict): A dictionary of passes to the backend. Passes currently recognized by inductor backend:</ span >
1594
+ < span class ="sd "> - static-memory</ span >
1595
+ < span class ="sd "> - matmul-tune</ span >
1596
+ < span class ="sd "> - matmul-padding</ span >
1597
+ < span class ="sd "> - triton-autotune</ span >
1598
+ < span class ="sd "> - triton-bmm</ span >
1599
+ < span class ="sd "> - triton-mm</ span >
1600
+ < span class ="sd "> - triton-convolution</ span >
1601
+ < span class ="sd "> - rematerialize-threshold</ span >
1602
+ < span class ="sd "> - rematerialize-acc-threshold</ span >
1603
+
1604
+ < span class ="sd "> Example::</ span >
1605
+
1606
+ < span class ="sd "> @torch.compile(passes={"matmul-padding": True}, fullgraph=True)</ span >
1607
+ < span class ="sd "> def foo(x):</ span >
1608
+ < span class ="sd "> return torch.sin(x) + torch.cos(x)</ span >
1609
+
1610
+ < span class ="sd "> """</ span >
1611
+ < span class ="c1 "> # Decorator mode</ span >
1612
+ < span class ="k "> if</ span > < span class ="n "> model</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1613
+ < span class ="k "> def</ span > < span class ="nf "> fn</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> :</ span > < span class ="n "> Callable</ span > < span class ="p "> ):</ span >
1614
+ < span class ="k "> if</ span > < span class ="n "> model</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1615
+ < span class ="k "> raise</ span > < span class ="ne "> RuntimeError</ span > < span class ="p "> (</ span > < span class ="s2 "> "Model can't be None"</ span > < span class ="p "> )</ span >
1616
+ < span class ="k "> return</ span > < span class ="nb "> compile</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span >
1617
+ < span class ="n "> fullgraph</ span > < span class ="o "> =</ span > < span class ="n "> fullgraph</ span > < span class ="p "> ,</ span >
1618
+ < span class ="n "> dynamic</ span > < span class ="o "> =</ span > < span class ="n "> dynamic</ span > < span class ="p "> ,</ span >
1619
+ < span class ="n "> backend</ span > < span class ="o "> =</ span > < span class ="n "> backend</ span > < span class ="p "> ,</ span >
1620
+ < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="n "> mode</ span > < span class ="p "> ,</ span >
1621
+ < span class ="n "> passes</ span > < span class ="o "> =</ span > < span class ="n "> passes</ span > < span class ="p "> ,</ span >
1622
+ < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> )</ span >
1623
+ < span class ="k "> return</ span > < span class ="n "> fn</ span >
1624
+
1625
+ < span class ="kn "> import</ span > < span class ="nn "> torch._dynamo</ span >
1626
+ < span class ="kn "> from</ span > < span class ="nn "> torch._dynamo.eval_frame</ span > < span class ="kn "> import</ span > < span class ="n "> lookup_backend</ span >
1627
+ < span class ="kn "> from</ span > < span class ="nn "> torch._inductor.config</ span > < span class ="kn "> import</ span > < span class ="n "> InductorConfigContext</ span >
1628
+ < span class ="k "> if</ span > < span class ="n "> mode</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="ow "> and</ span > < span class ="n "> passes</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1629
+ < span class ="k "> raise</ span > < span class ="ne "> RuntimeError</ span > < span class ="p "> (</ span > < span class ="s2 "> "Either mode or passes can be specified, but both can't be specified at the same time."</ span > < span class ="p "> )</ span >
1630
+ < span class ="k "> if</ span > < span class ="n "> mode</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="ow "> and</ span > < span class ="n "> passes</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1631
+ < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="s2 "> "default"</ span >
1632
+ < span class ="k "> if</ span > < span class ="n "> backend</ span > < span class ="o "> ==</ span > < span class ="s2 "> "inductor"</ span > < span class ="p "> :</ span >
1633
+ < span class ="n "> compile_fn</ span > < span class ="o "> =</ span > < span class ="n "> lookup_backend</ span > < span class ="p "> (</ span > < span class ="n "> backend</ span > < span class ="p "> )</ span >
1634
+ < span class ="n "> cm</ span > < span class ="o "> =</ span > < span class ="n "> InductorConfigContext</ span > < span class ="p "> (</ span > < span class ="n "> mode</ span > < span class ="k "> if</ span > < span class ="n "> mode</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="k "> else</ span > < span class ="n "> passes</ span > < span class ="p "> )</ span >
1635
+
1636
+ < span class ="k "> def</ span > < span class ="nf "> _compile_fn</ span > < span class ="p "> (</ span > < span class ="n "> model_</ span > < span class ="p "> ,</ span > < span class ="n "> inputs_</ span > < span class ="p "> ):</ span >
1637
+ < span class ="k "> with</ span > < span class ="n "> cm</ span > < span class ="p "> :</ span >
1638
+ < span class ="k "> return</ span > < span class ="n "> compile_fn</ span > < span class ="p "> (</ span > < span class ="n "> model_</ span > < span class ="p "> ,</ span > < span class ="n "> inputs_</ span > < span class ="p "> )</ span >
1639
+
1640
+ < span class ="n "> _compile_fn</ span > < span class ="o "> .</ span > < span class ="n "> _torchdynamo_orig_callable</ span > < span class ="o "> =</ span > < span class ="n "> compile_fn</ span > < span class ="c1 "> # type: ignore[attr-defined]</ span >
1641
+ < span class ="n "> backend</ span > < span class ="o "> =</ span > < span class ="n "> _compile_fn</ span >
1642
+ < span class ="k "> return</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _dynamo</ span > < span class ="o "> .</ span > < span class ="n "> optimize</ span > < span class ="p "> (</ span > < span class ="n "> backend</ span > < span class ="o "> =</ span > < span class ="n "> backend</ span > < span class ="p "> ,</ span > < span class ="n "> nopython</ span > < span class ="o "> =</ span > < span class ="n "> fullgraph</ span > < span class ="p "> ,</ span > < span class ="n "> dynamic</ span > < span class ="o "> =</ span > < span class ="n "> dynamic</ span > < span class ="p "> ,</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> )(</ span > < span class ="n "> model</ span > < span class ="p "> )</ span > </ div >
1643
+
1644
+
1576
1645
< span class ="k "> def</ span > < span class ="nf "> _register_device_module</ span > < span class ="p "> (</ span > < span class ="n "> device_type</ span > < span class ="p "> ,</ span > < span class ="n "> module</ span > < span class ="p "> ):</ span >
1577
1646
< span class ="sa "> r</ span > < span class ="sd "> """Register an external runtime module of the specific :attr:`device_type`</ span >
1578
1647
< span class ="sd "> supported by torch.</ span >
0 commit comments