Skip to content

Commit 98221cc

Browse files
mlazossvekars
andauthored
Add compiled optimizer + lr scheduler tutorial (#2874)
* Added the first version of LR sched tutorial --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent f66b5b2 commit 98221cc

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

.jenkins/metadata.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
"intermediate_source/model_parallel_tutorial.py": {
2929
"needs": "linux.16xlarge.nvidia.gpu"
3030
},
31+
"advanced_source/pendulum.py": {
32+
"needs": "linux.g5.4xlarge.nvidia.gpu",
33+
"_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run."
34+
},
3135
"intermediate_source/torchvision_tutorial.py": {
3236
"needs": "linux.g5.4xlarge.nvidia.gpu",
3337
"_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py."
@@ -36,6 +40,9 @@
3640
"needs": "linux.g5.4xlarge.nvidia.gpu",
3741
"_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py."
3842
},
43+
"recipes_source/compiling_optimizer_lr_scheduler.py": {
44+
"needs": "linux.g5.4xlarge.nvidia.gpu"
45+
},
3946
"intermediate_source/torch_compile_tutorial.py": {
4047
"needs": "linux.g5.4xlarge.nvidia.gpu"
4148
},
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
(beta) Running the compiled optimizer with an LR Scheduler
3+
============================================================
4+
5+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
6+
"""
7+
8+
#########################################################
9+
# The optimizer is a key algorithm for training any deep learning model.
10+
# In this example, we will show how to pair the optimizer, which has been compiled using ``torch.compile``,
11+
# with the LR schedulers to accelerate training convergence.
12+
#
13+
# .. note::
14+
#
15+
# This tutorial requires PyTorch 2.3.0 or later.
16+
17+
#####################################################################
18+
# Model Setup
19+
# ~~~~~~~~~~~~~~~~~~~~~
20+
# For this example, we'll use a simple sequence of linear layers.
21+
#
22+
23+
import torch
24+
25+
# Create simple model
26+
model = torch.nn.Sequential(
27+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
28+
)
29+
input = torch.rand(1024, device="cuda")
30+
31+
# run forward pass
32+
output = model(input)
33+
34+
# run backward to populate the grads for our optimizer below
35+
output.sum().backward()
36+
37+
38+
#####################################################################
39+
# Setting up and running the compiled optimizer with LR Scheduler
40+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41+
#
42+
# In this section, we'll use the Adam optimizer with LinearLR Scheduler
43+
# and create a helper function to wrap the ``step()`` call for each of them
44+
# in ``torch.compile()``.
45+
#
46+
# .. note::
47+
#
48+
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.
49+
50+
51+
# exit cleanly if we are on a device that doesn't support ``torch.compile``
52+
if torch.cuda.get_device_capability() < (7, 0):
53+
print("Exiting because torch.compile is not supported on this device.")
54+
import sys
55+
sys.exit(0)
56+
57+
# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the
58+
# the optimizer with an LR Scheduler.
59+
# Without this, torch.compile will recompile as the value of the LR
60+
# changes.
61+
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
62+
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
63+
64+
@torch.compile(fullgraph=False)
65+
def fn():
66+
opt.step()
67+
sched.step()
68+
69+
70+
# Warmup runs to compile the function
71+
for _ in range(5):
72+
fn()
73+
print(opt.param_groups[0]["lr"])
74+
75+
76+
######################################################################
77+
# Extension: What happens with a non-tensor LR?
78+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79+
# For the curious, we will show how to peek into what happens with ``torch.compile`` when we don't wrap the
80+
# LR in a tensor.
81+
82+
# No longer wrap the LR in a tensor here
83+
opt = torch.optim.Adam(model.parameters(), lr=0.01)
84+
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
85+
86+
@torch.compile(fullgraph=False)
87+
def fn():
88+
opt.step()
89+
sched.step()
90+
91+
# Setup logging to view recompiles
92+
torch._logging.set_logs(recompiles=True)
93+
94+
# Warmup runs to compile the function
95+
# We will now recompile on each iteration
96+
# as the value of the lr is mutated.
97+
for _ in range(5):
98+
fn()
99+
100+
101+
######################################################################
102+
# With this example, we can see that we recompile the optimizer a few times
103+
# due to the guard failure on the ``lr`` in ``param_groups[0]``.
104+
105+
######################################################################
106+
# Conclusion
107+
# ~~~~~~~~~~
108+
#
109+
# In this tutorial we showed how to pair the optimizer compiled with ``torch.compile``
110+
# with an LR Scheduler to accelerate training convergence. We used a model consisting
111+
# of a simple sequence of linear layers with the Adam optimizer paired
112+
# with a LinearLR scheduler to demonstrate the LR changing across iterations.
113+
#
114+
# See also:
115+
#
116+
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.
117+
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.

recipes_source/recipes_index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
307307
:link: ../recipes/compiling_optimizer.html
308308
:tags: Model-Optimization
309309

310+
.. (beta) Running the compiled optimizer with an LR Scheduler
311+
312+
.. customcarditem::
313+
:header: (beta) Running the compiled optimizer with an LR Scheduler
314+
:card_description: Speed up training with LRScheduler and torch.compiled optimizer
315+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
316+
:link: ../recipes/compiling_optimizer_lr_scheduler.html
317+
:tags: Model-Optimization
318+
310319
.. Using User-Defined Triton Kernels with ``torch.compile``
311320
312321
.. customcarditem::

0 commit comments

Comments
 (0)