Skip to content

Commit 215d81e

Browse files
committed
Regional compilation recipe
1 parent 01d2270 commit 215d81e

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed

recipes_source/recipes/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ PyTorch Recipes
5656
14. amp_recipe.py
5757
Automatic Mixed Precision
5858
https://pytorch.org/tutorials/recipes/amp_recipe.html
59+
60+
15. regional_compilation.py
61+
Reducing torch.compile cold start compilation time with regional compilation
62+
https://pytorch.org/tutorials/recipes/regional_compilation.html
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
Reducing torch.compile cold start compilation time with regional compilation
3+
============================================================================
4+
5+
Introduction
6+
------------
7+
As deep learning models get larger, the compilation time of these models also
8+
increase. This increase in compilation time can lead to a large startup time in
9+
inference services or wasted resources in large scale training. This recipe
10+
shows an example of how to reduce the cold start compilation time by choosing to
11+
compile a repeated region of the model instead of the entire model.
12+
13+
Setup
14+
-----
15+
Before we begin, we need to install ``torch`` if it is not already
16+
available.
17+
18+
.. code-block:: sh
19+
20+
pip install torch
21+
22+
"""
23+
24+
25+
26+
######################################################################
27+
# Steps
28+
# -----
29+
#
30+
# 1. Import all necessary libraries
31+
# 2. Define and initialize a neural network with repeated regions.
32+
# 3. Understand the difference between the full model and the regional compilation.
33+
# 4. Measure the compilation time of the full model and the regional compilation.
34+
#
35+
# 1. Import necessary libraries for loading our data
36+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
#
38+
#
39+
40+
import torch
41+
import torch.nn as nn
42+
from time import perf_counter
43+
44+
#
45+
# 2. Define and initialize a neural network with repeated regions.
46+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
47+
# Typically neural networks are composed of repeated layers. For example, a
48+
# large language model is composed of many Transformer blocks. In this recipe,
49+
# we will create a `Layer` `nn.Module` class as a proxy for a repeated region.
50+
# We will then create a `Model` which is composed of 64 instances of this
51+
# `Layer` class.
52+
#
53+
class Layer(torch.nn.Module):
54+
def __init__(self):
55+
super().__init__()
56+
self.linear1 = torch.nn.Linear(10, 10)
57+
self.relu1 = torch.nn.ReLU()
58+
self.linear2 = torch.nn.Linear(10, 10)
59+
self.relu2 = torch.nn.ReLU()
60+
61+
def forward(self, x):
62+
a = self.linear1(x)
63+
a = self.relu1(a)
64+
a = torch.sigmoid(a)
65+
b = self.linear2(a)
66+
b = self.relu2(b)
67+
return b
68+
69+
class Model(torch.nn.Module):
70+
def __init__(self, apply_regional_compilation):
71+
super().__init__()
72+
self.linear = torch.nn.Linear(10, 10)
73+
# Apply compile only to the repeated layers.
74+
if apply_regional_compilation:
75+
self.layers = torch.nn.ModuleList([torch.compile(Layer()) for _ in range(64)])
76+
else:
77+
self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])
78+
79+
def forward(self, x):
80+
# In regional compilation, the self.linear is outside of the scope of `torch.compile`.
81+
x = self.linear(x)
82+
for layer in self.layers:
83+
x = layer(x)
84+
return x
85+
86+
#
87+
# 3. Understand the difference between the full model and the regional compilation.
88+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89+
#
90+
# In full model compilation, the full model is compiled as a whole. This is how
91+
# most users use torch.compile. In this example, we can apply torch.compile to
92+
# the `model` object. This will effectively inline the 64 layers, producing a
93+
# large graph to compile. You can look at the full graph by running this recipe
94+
# with `TORCH_LOGS=graph_code`.
95+
#
96+
#
97+
98+
model = Model(apply_regional_compilation=False).cuda()
99+
full_compiled_model = torch.compile(model)
100+
101+
102+
#
103+
# The regional compilation, on the other hand, compiles a region of the model.
104+
# By wisely choosing to compile a repeated region of the model, we can compile a
105+
# much smaller graph and then reuse the compiled graph for all the regions. We
106+
# can apply regional compilation in the example as follows. `torch.compile` is
107+
# applied only to the `layers` and not the full model.
108+
#
109+
110+
regional_compiled_model = Model(apply_regional_compilation=True).cuda()
111+
112+
# Applying compilation to a repeated region, instead of full model, leads to
113+
# large savings in compile time. Here, we will just compile a layer instance and
114+
# then reuse it 64 times in the `model` object.
115+
#
116+
# Note that with repeated regions, some part of the model might not be compiled.
117+
# For example, the `self.linear` in the `Model` is outside of the scope of
118+
# regional compilation.
119+
#
120+
# Also, note that there is a tradeoff between performance speedup and compile
121+
# time. The full model compilation has larger graph and therefore,
122+
# theoretically, has more scope for optimizations. However for practical
123+
# purposes and depending on the model, we have observed many cases with minimal
124+
# speedup differences between the full model and regional compilation.
125+
126+
127+
#
128+
# 4. Measure the compilation time of the full model and the regional compilation.
129+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
130+
# `torch.compile` is a JIT compiler, i.e., it compiles on the first invocation.
131+
# Here, we measure the total time spent in the first invocation. This is not
132+
# precise, but it gives a good idea because the majority of time is spent in
133+
# compilation.
134+
135+
def measure_latency(fn, input):
136+
# Reset the compiler caches to ensure no reuse between different runs
137+
torch.compiler.reset()
138+
with torch._inductor.utils.fresh_inductor_cache():
139+
start = perf_counter()
140+
fn(input)
141+
torch.cuda.synchronize()
142+
end = perf_counter()
143+
return end - start
144+
145+
input = torch.randn(10, 10, device="cuda")
146+
full_model_compilation_latency = measure_latency(full_compiled_model, input)
147+
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")
148+
149+
regional_compilation_latency = measure_latency(regional_compiled_model, input)
150+
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")
151+
152+
############################################################################
153+
# This recipe shows how to control the cold start compilation time if your model
154+
# has repeated regions. This requires user changes to apply `torch.compile` to
155+
# the repeated regions instead of more commonly used full model compilation. We
156+
# are continually working on reducing cold start compilation time. So, please
157+
# stay tuned for our next tutorials.
158+
#
159+
# This feature is available with 2.5 release. If you are on 2.4, you can use a
160+
# config flag - `torch._dynamo.config.inline_inbuilt_nn_modules=True` to avoid
161+
# recompilations on the regional compilation. In 2.5, this flag is turned on by
162+
# default.

recipes_source/recipes_index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
339339
:link: ../recipes/torch_compile_caching_tutorial.html
340340
:tags: Model-Optimization
341341

342+
.. Reducing Cold Start Compilation Time with Regional Compilation
343+
344+
.. customcarditem::
345+
:header: Reducing torch.compile cold start compilation time with regional compilation
346+
:card_description: Learn how to use regional compilation to control cold start compile time
347+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
348+
:link: ../recipes/recipes/regional_compilation.html
349+
:tags: Model-Optimization
350+
342351
.. Intel(R) Extension for PyTorch*
343352
344353
.. customcarditem::
@@ -452,6 +461,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
452461
/recipes/recipes/amp_recipe
453462
/recipes/recipes/tuning_guide
454463
/recipes/recipes/xeon_run_cpu
464+
/recipes/recipes/regional_compilation
455465
/recipes/recipes/intel_extension_for_pytorch
456466
/recipes/compiling_optimizer
457467
/recipes/torch_compile_backend_ipex

0 commit comments

Comments
 (0)