-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Add compiled optimizer + lr scheduler tutorial #2874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
86c047f
Initial pass at lr sched tutorial
mlazos c30649c
Apply suggestions from code review
mlazos 67d5e45
Convert to py, comments
mlazos f8eafd7
Merge with other changes
mlazos 16b41c2
Remove rst
mlazos 185a009
Merge branch 'main' into mlazos/compiled-opt-lr-sched
mlazos 8a73fc0
PR comments
mlazos 7a0f792
Move conclusion to end
mlazos 258955c
Update version number
mlazos 9d07b46
Put on linux.g5.4xlarge.nvidia.gpu
svekars 8b0edb0
Update
svekars 6f41382
Update
svekars d759583
Update metadata.json
svekars dc62ced
Update compiling_optimizer_lr_scheduler.py
svekars 45c53ea
Update compiling_optimizer_lr_scheduler.py
svekars 63d47f4
Merge branch 'main' into mlazos/compiled-opt-lr-sched
svekars 56c0b83
Update compiling_optimizer_lr_scheduler.py
svekars File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
(beta) Running the compiled optimizer with an LR Scheduler | ||
============================================================ | ||
|
||
**Author:** `Michael Lazos <https://github.com/mlazos>`_ | ||
""" | ||
|
||
######################################################### | ||
# The optimizer is a key algorithm for training any deep learning model. | ||
# In this example, we will show how to pair the optimizer, which has been compiled using ``torch.compile``, | ||
# with the LR schedulers to accelerate training convergence. | ||
# | ||
# .. note:: | ||
# | ||
# This tutorial requires PyTorch 2.3.0 or later. | ||
|
||
##################################################################### | ||
# Model Setup | ||
# ~~~~~~~~~~~~~~~~~~~~~ | ||
# For this example, we'll use a simple sequence of linear layers. | ||
# | ||
|
||
import torch | ||
|
||
# Create simple model | ||
model = torch.nn.Sequential( | ||
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] | ||
) | ||
input = torch.rand(1024, device="cuda") | ||
|
||
# run forward pass | ||
output = model(input) | ||
|
||
# run backward to populate the grads for our optimizer below | ||
output.sum().backward() | ||
|
||
|
||
##################################################################### | ||
# Setting up and running the compiled optimizer with LR Scheduler | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# | ||
# In this section, we'll use the Adam optimizer with LinearLR Scheduler | ||
# and create a helper function to wrap the ``step()`` call for each of them | ||
# in ``torch.compile()``. | ||
# | ||
# .. note:: | ||
# | ||
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher. | ||
|
||
|
||
# exit cleanly if we are on a device that doesn't support ``torch.compile`` | ||
if torch.cuda.get_device_capability() < (7, 0): | ||
print("Exiting because torch.compile is not supported on this device.") | ||
import sys | ||
sys.exit(0) | ||
|
||
# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the | ||
# the optimizer with an LR Scheduler. | ||
# Without this, torch.compile will recompile as the value of the LR | ||
# changes. | ||
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01)) | ||
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5) | ||
|
||
@torch.compile(fullgraph=False) | ||
def fn(): | ||
opt.step() | ||
sched.step() | ||
|
||
|
||
# Warmup runs to compile the function | ||
for _ in range(5): | ||
fn() | ||
print(opt.param_groups[0]["lr"]) | ||
|
||
|
||
###################################################################### | ||
# Extension: What happens with a non-tensor LR? | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# For the curious, we will show how to peek into what happens with ``torch.compile`` when we don't wrap the | ||
# LR in a tensor. | ||
|
||
# No longer wrap the LR in a tensor here | ||
opt = torch.optim.Adam(model.parameters(), lr=0.01) | ||
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5) | ||
|
||
@torch.compile(fullgraph=False) | ||
def fn(): | ||
opt.step() | ||
sched.step() | ||
|
||
# Setup logging to view recompiles | ||
torch._logging.set_logs(recompiles=True) | ||
|
||
# Warmup runs to compile the function | ||
# We will now recompile on each iteration | ||
# as the value of the lr is mutated. | ||
for _ in range(5): | ||
fn() | ||
|
||
|
||
###################################################################### | ||
# With this example, we can see that we recompile the optimizer a few times | ||
# due to the guard failure on the ``lr`` in ``param_groups[0]``. | ||
|
||
###################################################################### | ||
# Conclusion | ||
# ~~~~~~~~~~ | ||
# | ||
# In this tutorial we showed how to pair the optimizer compiled with ``torch.compile`` | ||
# with an LR Scheduler to accelerate training convergence. We used a model consisting | ||
# of a simple sequence of linear layers with the Adam optimizer paired | ||
# with a LinearLR scheduler to demonstrate the LR changing across iterations. | ||
# | ||
# See also: | ||
# | ||
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer. | ||
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.