-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Added optimizer compile recipe #2700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
64d73e1
Added optimizer compile tutorial
mlazos 79b2087
Apply suggestions from code review
mlazos 4f3ef4b
More suggestions
mlazos d90487f
Remove bad index changes
mlazos 741d46b
Abort if device is not supported
mlazos ef6cce0
Merge branch 'main' into mlazos/compile-opt
svekars a408310
Merge branch 'main' into mlazos/compile-opt
svekars 6f3b736
Merge branch 'main' into mlazos/compile-opt
svekars c4e7e37
Merge branch 'main' into mlazos/compile-opt
mlazos 572dc44
Merge branch 'main' into mlazos/compile-opt
svekars 7a89abf
Merge branch 'main' into mlazos/compile-opt
svekars c26a574
Merge branch 'main' into mlazos/compile-opt
svekars 16fd39a
Update recipes_source/compiling_optimizer.py
malfet 86f884f
Merge branch 'main' into mlazos/compile-opt
malfet 329bf93
Make optimizer recipe an rst, since there is something strange happen…
mlazos 375f5a6
Merge branch 'mlazos/compile-opt' of github.com:pytorch/tutorials int…
mlazos 7a5420a
Merge branch 'main' into mlazos/compile-opt
svekars 6799a8d
Update compiling_optimizer.rst
svekars 5d71ad6
Merge branch 'main' into mlazos/compile-opt
svekars File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
(beta) Compiling the optimizer with torch.compile | ||
========================================================================================== | ||
|
||
|
||
**Author:** `Michael Lazos <https://github.com/mlazos>`_ | ||
""" | ||
|
||
###################################################################### | ||
# | ||
# The optimizer is a key algorithm for training any deep learning model. | ||
# Since it is responsible for updating every model parameter, it can often | ||
# become the bottleneck in training performance for large models. In this recipe, | ||
# we will apply ``torch.compile`` to the optimizer to observe the GPU performance | ||
# improvement. | ||
# | ||
# .. note:: | ||
# | ||
# This tutorial requires PyTorch 2.2.0 or later. | ||
# | ||
|
||
|
||
###################################################################### | ||
# Model Setup | ||
# ~~~~~~~~~~~~~~~~~~~~~ | ||
# For this example, we'll use a simple sequence of linear layers. | ||
# Since we are only benchmarking the optimizer, the choice of model doesn't matter | ||
# because optimizer performance is a function of the number of parameters. | ||
# | ||
# Depending on what machine you are using, your exact results may vary. | ||
|
||
import torch | ||
|
||
model = torch.nn.Sequential( | ||
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] | ||
) | ||
input = torch.rand(1024, device="cuda") | ||
output = model(input) | ||
output.sum().backward() | ||
|
||
############################################################################# | ||
# Setting up and running the optimizer benchmark | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# In this example, we'll use the Adam optimizer | ||
# and create a helper function to wrap the step() | ||
# in torch.compile() | ||
# | ||
# .. note:: | ||
# | ||
# torch.compile is only supported on cuda devices with compute capability >= 7.0 | ||
|
||
# exit cleanly if we are on a device that doesn't support torch.compile | ||
if torch.cuda.get_device_capability() < (7, 0): | ||
print("Exiting because torch.compile is not supported on this device.") | ||
import sys | ||
sys.exit(0) | ||
|
||
|
||
opt = torch.optim.Adam(model.parameters(), lr=0.01) | ||
|
||
|
||
@torch.compile() | ||
def fn(): | ||
opt.step() | ||
|
||
|
||
# Let's define a helpful benchmarking function: | ||
import torch.utils.benchmark as benchmark | ||
|
||
|
||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | ||
t0 = benchmark.Timer( | ||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | ||
) | ||
return t0.blocked_autorange().mean * 1e6 | ||
|
||
|
||
# Warmup runs to compile the function | ||
for _ in range(5): | ||
fn() | ||
|
||
eager_runtime = benchmark_torch_function_in_microseconds(opt.step) | ||
compiled_runtime = benchmark_torch_function_in_microseconds(fn) | ||
|
||
assert eager_runtime > compiled_runtime | ||
|
||
print(f"eager runtime: {eager_runtime}us") | ||
print(f"compiled runtime: {compiled_runtime}us") | ||
|
||
# Sample Results: | ||
# eager runtime: 747.2437149845064us | ||
# compiled runtime: 392.07384741178us |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.