Skip to content

Added optimizer compile recipe #2700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,13 @@ What's new in PyTorch tutorials?
:link: intermediate/torch_compile_tutorial.html
:tags: Model-Optimization

.. customcarditem::
:header: (beta) Compiling the Optimizer with torch.compile
:card_description: Speed up the optimizer using torch.compile
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: recipes/compiling_optimizer.html
:tags: Model-Optimization

.. customcarditem::
:header: Inductor CPU Backend Debugging and Profiling
:card_description: Learn the usage, debugging and performance profiling for ``torch.compile`` with Inductor CPU backend.
Expand Down Expand Up @@ -1046,6 +1053,7 @@ Additional Resources
intermediate/nvfuser_intro_tutorial
intermediate/ax_multiobjective_nas_tutorial
intermediate/torch_compile_tutorial
recipes/compiling_optimizer
intermediate/inductor_debug_cpu
intermediate/scaled_dot_product_attention_tutorial
beginner/knowledge_distillation_tutorial
Expand Down
71 changes: 71 additions & 0 deletions recipes_source/compiling_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
(beta) Compiling the optimizer with torch.compile
==========================================================================================


**Author:** `Michale Lazos <https://github.com/mlazos>`_
"""

######################################################################
# Summary
# ~~~~~~~~
#
# In this tutorial we will apply torch.compile to the optimizer to observe
# the GPU performance improvement
#
# .. note::
#
# This tutorial requires PyTorch 2.2.0 or later.
#


######################################################################
# Model Setup
# ~~~~~~~~~~~~~~~~~~~~~
# For this example we'll use a simple sequence of linear layers.
# Since we are only benchmarking the optimizer, choice of model doesn't matter
# because optimizer performance is a function of the number of parameters.
#
# Depending on what machine you are using, your exact results may vary.

import torch

model = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
output = model(input)
output.sum().backward()

#############################################################################
# Setting up and running the optimizer benchmark
# ~~~~~~~~~~~~~~~~~~~~~~~~~
# In this example, we'll use the Adam optimizer
# and create a helper function to wrap the step()
# in torch.compile()

opt = torch.optim.Adam(model.parameters(), lr=0.01)


@torch.compile()
def fn():
opt.step()


# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark


def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6


# Warmup runs to compile the function
for _ in range(5):
fn()

print(f"eager runtime: {benchmark_torch_function_in_microseconds(opt.step)}us")
print(f"compiled runtime: {benchmark_torch_function_in_microseconds(fn)}us")