Skip to content

Commit d14ff1c

Browse files
committed
update
1 parent 45a4087 commit d14ff1c

File tree

1 file changed

+79
-67
lines changed

1 file changed

+79
-67
lines changed

intermediate_source/compiled_autograd_tutorial.py

Lines changed: 79 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,57 @@
44
Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
55
==========================================================================
66
7+
**Author:** `Simon Fan <https://github.com/xmfan>`_
8+
9+
.. grid:: 2
10+
11+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12+
:class-card: card-prerequisites
13+
14+
* How compiled autograd interacts with torch.compile
15+
* How to use the compiled autograd API
16+
* How to inspect logs using TORCH_LOGS
17+
18+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
19+
:class-card: card-prerequisites
20+
21+
* PyTorch 2.4
22+
* `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ familiarity
23+
724
"""
825

926
######################################################################
27+
# Overview
28+
# ------------
1029
# Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4
11-
# that allows the capture of a larger backward graph. It is highly recommended
12-
# to familiarize yourself with `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_.
30+
# that allows the capture of a larger backward graph.
1331
#
14-
15-
######################################################################
1632
# Doesn't torch.compile already capture the backward graph?
1733
# ------------
18-
# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
19-
# - Graph breaks in the forward lead to graph breaks in the backward
20-
# - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
34+
# And it does, **partially**. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
35+
# 1. Graph breaks in the forward lead to graph breaks in the backward
36+
# 2. `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
2137
#
2238
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
2339
# it to capture the full backward graph at runtime. Models with these two characteristics should try
2440
# Compiled Autograd, and potentially observe better performance.
2541
#
2642
# However, Compiled Autograd has its own limitations:
27-
# - Dynamic autograd structure leads to recompiles
43+
# 1. Additional runtime overhead at the start of the backward
44+
# 2. Dynamic autograd structure leads to recompiles
45+
#
46+
# .. note:: Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer to `Compiled Autograd Landing Page <https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY>`_.
2847
#
2948

30-
######################################################################
31-
# Tutorial output cells setup
32-
# ------------
33-
#
34-
35-
import os
36-
37-
class ScopedLogging:
38-
def __init__(self):
39-
assert "TORCH_LOGS" not in os.environ
40-
assert "TORCH_LOGS_FORMAT" not in os.environ
41-
os.environ["TORCH_LOGS"] = "compiled_autograd_verbose"
42-
os.environ["TORCH_LOGS_FORMAT"] = "short"
43-
44-
def __del__(self):
45-
del os.environ["TORCH_LOGS"]
46-
del os.environ["TORCH_LOGS_FORMAT"]
47-
4849

4950
######################################################################
50-
# Basic Usage
51+
# Setup
5152
# ------------
52-
#
53+
# In this tutorial, we'll base our examples on this toy model.
54+
#
5355

5456
import torch
5557

56-
# NOTE: Must be enabled before using the decorator
57-
torch._dynamo.config.compiled_autograd = True
58-
5958
class Model(torch.nn.Module):
6059
def __init__(self):
6160
super().__init__()
@@ -64,24 +63,30 @@ def __init__(self):
6463
def forward(self, x):
6564
return self.linear(x)
6665

66+
67+
######################################################################
68+
# Basic usage
69+
# ------------
70+
# .. note:: The ``torch._dynamo.config.compiled_autograd = True`` config must be enabled before calling the torch.compile API.
71+
#
72+
73+
model = Model()
74+
x = torch.randn(10)
75+
76+
torch._dynamo.config.compiled_autograd = True
6777
@torch.compile
6878
def train(model, x):
6979
loss = model(x).sum()
7080
loss.backward()
7181

72-
model = Model()
73-
x = torch.randn(10)
7482
train(model, x)
7583

7684
######################################################################
7785
# Inspecting the compiled autograd logs
7886
# ------------
79-
# Run the script with either TORCH_LOGS environment variables
80-
#
81-
# - To only print the compiled autograd graph, use `TORCH_LOGS="compiled_autograd" python example.py`
82-
# - To sacrifice some performance, in order to print the graph with more tensor medata and recompile reasons, use `TORCH_LOGS="compiled_autograd_verbose" python example.py`
83-
#
84-
# Logs can also be enabled through the private API torch._logging._internal.set_logs.
87+
# Run the script with the TORCH_LOGS environment variables:
88+
# - To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py``
89+
# - To print the graph with more tensor medata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py``
8590
#
8691

8792
@torch.compile
@@ -92,13 +97,11 @@ def train(model, x):
9297
train(model, x)
9398

9499
######################################################################
95-
# The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by aot0_,
96-
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0.
97-
#
98-
# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically
99-
# generated some python code to represent the entire C++ autograd execution.
100+
# The compiled autograd graph should now be logged to stderr. Certain graph nodes will have names that are prefixed by ``aot0_``,
101+
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0 e.g. ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
100102
#
101-
"""
103+
104+
stderr_output = """
102105
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
103106
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
104107
===== Compiled autograd graph =====
@@ -152,6 +155,10 @@ def forward(self, inputs, sizes, scalars, hooks):
152155
return []
153156
"""
154157

158+
######################################################################
159+
# .. note:: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd generates some python code to represent the entire C++ autograd execution.
160+
#
161+
155162
######################################################################
156163
# Compiling the forward and backward pass using different flags
157164
# ------------
@@ -163,7 +170,7 @@ def train(model, x):
163170
torch.compile(lambda: loss.backward(), fullgraph=True)()
164171

165172
######################################################################
166-
# Or you can use the context manager, which will apply to all autograd calls within it
173+
# Or you can use the context manager, which will apply to all autograd calls within its scope.
167174
#
168175

169176
def train(model, x):
@@ -174,7 +181,7 @@ def train(model, x):
174181

175182

176183
######################################################################
177-
# Demonstrating the limitations of AOTAutograd addressed by Compiled Autograd
184+
# Compiled Autograd addresses certain limitations of AOTAutograd
178185
# ------------
179186
# 1. Graph breaks in the forward lead to graph breaks in the backward
180187
#
@@ -208,7 +215,12 @@ def fn(x):
208215

209216

210217
######################################################################
211-
# 2. `Backward hooks are not captured
218+
# In the ``1. base torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``.
219+
# Whereas in ``2. torch.compile with compiled autograd``, we see that a full backward graph was traced despite the graph breaks.
220+
#
221+
222+
######################################################################
223+
# 2. Backward hooks are not captured
212224
#
213225

214226
@torch.compile(backend="aot_eager")
@@ -223,19 +235,19 @@ def fn(x):
223235
loss.backward()
224236

225237
######################################################################
226-
# There is a `call_hook` node in the graph, which dynamo will inline
238+
# There should be a ``call_hook`` node in the graph, which dynamo will later inline into
227239
#
228240

229-
"""
241+
stderr_output = """
230242
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
231243
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
232-
===== Compiled autograd graph =====
233-
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
234-
def forward(self, inputs, sizes, scalars, hooks):
235-
...
236-
getitem_2 = hooks[0]; hooks = None
237-
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
238-
...
244+
===== Compiled autograd graph =====
245+
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
246+
def forward(self, inputs, sizes, scalars, hooks):
247+
...
248+
getitem_2 = hooks[0]; hooks = None
249+
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
250+
...
239251
"""
240252

241253
######################################################################
@@ -250,10 +262,10 @@ def forward(self, inputs, sizes, scalars, hooks):
250262
torch.compile(lambda: loss.backward(), backend="eager")()
251263

252264
######################################################################
253-
# You should see some cache miss logs (recompiles):
265+
# You should see some recompile messages: **Cache miss due to new autograd node**.
254266
#
255267

256-
"""
268+
stderr_output = """
257269
Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
258270
...
259271
Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
@@ -268,18 +280,17 @@ def forward(self, inputs, sizes, scalars, hooks):
268280
# 2. Due to dynamic shapes
269281
#
270282

271-
torch._logging._internal.set_logs(compiled_autograd_verbose=True)
272283
torch._dynamo.config.compiled_autograd = True
273284
for i in [10, 100, 10]:
274285
x = torch.randn(i, i, requires_grad=True)
275286
loss = x.sum()
276287
torch.compile(lambda: loss.backward(), backend="eager")()
277288

278289
######################################################################
279-
# You should see some cache miss logs (recompiles):
290+
# You should see some recompiles messages: **Cache miss due to changed shapes**.
280291
#
281292

282-
"""
293+
stderr_output = """
283294
...
284295
Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
285296
Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
@@ -289,8 +300,9 @@ def forward(self, inputs, sizes, scalars, hooks):
289300
"""
290301

291302
######################################################################
292-
# Compatibility and rough edges
293-
# ------------
294-
#
295-
# Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features.
296-
# For the latest status on a particular feature, refer to: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY.
303+
# Conclusion
304+
# ----------
305+
# In this tutorial, we went over the high-level ecosystem of torch.compile with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
306+
#
307+
# For feedback on this tutorial, please file an issue on https://github.com/pytorch/tutorials.
308+
#

0 commit comments

Comments
 (0)