Skip to content

Commit 52fe948

Browse files
anijain2305svekars
andauthored
Apply suggestions from code review
Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent e77a954 commit 52fe948

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

recipes_source/recipes/regional_compilation.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
Reducing torch.compile cold start compilation time with regional compilation
33
============================================================================
44
5-
Introduction
6-
------------
5+
**Author:** `Animesh Jain <https://github.com/anijain2305>`_
76
As deep learning models get larger, the compilation time of these models also
8-
increase. This increase in compilation time can lead to a large startup time in
9-
inference services or wasted resources in large scale training. This recipe
7+
increases. This extended compilation time can result in a large startup time in
8+
inference services or wasted resources in large-scale training. This recipe
109
shows an example of how to reduce the cold start compilation time by choosing to
1110
compile a repeated region of the model instead of the entire model.
1211
12+
Prerequisites
13+
----------------
14+
15+
* Pytorch 2.5 or later
1316
Setup
1417
-----
1518
Before we begin, we need to install ``torch`` if it is not already
@@ -19,6 +22,10 @@
1922
2023
pip install torch
2124
25+
.. note::
26+
This feature is available starting with the 2.5 release. If you are using version 2.4,
27+
you can enable the configuration flag ``torch._dynamo.config.inline_inbuilt_nn_modules=True``
28+
to prevent recompilations during regional compilation. In version 2.5, this flag is enabled by default.
2229
"""
2330

2431

@@ -27,28 +34,30 @@
2734
# Steps
2835
# -----
2936
#
30-
# 1. Import all necessary libraries
37+
# In this recipe, we will follow these steps:
38+
#
39+
# 1. Import all necessary libraries.
3140
# 2. Define and initialize a neural network with repeated regions.
3241
# 3. Understand the difference between the full model and the regional compilation.
3342
# 4. Measure the compilation time of the full model and the regional compilation.
3443
#
35-
# 1. Import necessary libraries for loading our data
36-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
44+
# First, let's import the necessary libraries for loading our data:
45+
#
3746
#
3847
#
3948

4049
import torch
4150
import torch.nn as nn
4251
from time import perf_counter
4352

53+
##########################################################
54+
# Next, let's define and initialize a neural network with repeated regions.
4455
#
45-
# 2. Define and initialize a neural network with repeated regions.
46-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
47-
# Typically neural networks are composed of repeated layers. For example, a
56+
# Typically, neural networks are composed of repeated layers. For example, a
4857
# large language model is composed of many Transformer blocks. In this recipe,
49-
# we will create a `Layer` `nn.Module` class as a proxy for a repeated region.
50-
# We will then create a `Model` which is composed of 64 instances of this
51-
# `Layer` class.
58+
# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region.
59+
# We will then create a ``Model`` which is composed of 64 instances of this
60+
# ``Layer`` class.
5261
#
5362
class Layer(torch.nn.Module):
5463
def __init__(self):
@@ -83,53 +92,51 @@ def forward(self, x):
8392
x = layer(x)
8493
return x
8594

86-
#
87-
# 3. Understand the difference between the full model and the regional compilation.
88-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
95+
####################################################
96+
# Next, let's review the difference between the full model and the regional compilation.
8997
#
90-
# In full model compilation, the full model is compiled as a whole. This is how
91-
# most users use torch.compile. In this example, we can apply torch.compile to
92-
# the `model` object. This will effectively inline the 64 layers, producing a
98+
# In full model compilation, the entire model is compiled as a whole. This is the common approach
99+
# most users take with ``torch.compile``. In this example, we apply ``torch.compile`` to
100+
# the ``Model`` object. This will effectively inline the 64 layers, producing a
93101
# large graph to compile. You can look at the full graph by running this recipe
94-
# with `TORCH_LOGS=graph_code`.
102+
# with ``TORCH_LOGS=graph_code``.
95103
#
96104
#
97105

98106
model = Model(apply_regional_compilation=False).cuda()
99107
full_compiled_model = torch.compile(model)
100108

101109

102-
#
110+
###################################################
103111
# The regional compilation, on the other hand, compiles a region of the model.
104-
# By wisely choosing to compile a repeated region of the model, we can compile a
105-
# much smaller graph and then reuse the compiled graph for all the regions. We
106-
# can apply regional compilation in the example as follows. `torch.compile` is
107-
# applied only to the `layers` and not the full model.
112+
# By strategically choosing to compile a repeated region of the model, we can compile a
113+
# much smaller graph and then reuse the compiled graph for all the regions.
114+
# In the example, ``torch.compile`` is applied only to the ``layers`` and not the full model.
108115
#
109116

110117
regional_compiled_model = Model(apply_regional_compilation=True).cuda()
111118

112119
# Applying compilation to a repeated region, instead of full model, leads to
113120
# large savings in compile time. Here, we will just compile a layer instance and
114-
# then reuse it 64 times in the `model` object.
121+
# then reuse it 64 times in the ``Model`` object.
115122
#
116123
# Note that with repeated regions, some part of the model might not be compiled.
117-
# For example, the `self.linear` in the `Model` is outside of the scope of
124+
# For example, the ``self.linear`` in the ``Model`` is outside of the scope of
118125
# regional compilation.
119126
#
120127
# Also, note that there is a tradeoff between performance speedup and compile
121-
# time. The full model compilation has larger graph and therefore,
122-
# theoretically, has more scope for optimizations. However for practical
128+
# time. Full model compilation involves a larger graph and,
129+
# theoretically, offers more scope for optimizations. However, for practical
123130
# purposes and depending on the model, we have observed many cases with minimal
124131
# speedup differences between the full model and regional compilation.
125132

126133

134+
###################################################
135+
# Next, let's measure the compilation time of the full model and the regional compilation.
127136
#
128-
# 4. Measure the compilation time of the full model and the regional compilation.
129-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
130-
# `torch.compile` is a JIT compiler, i.e., it compiles on the first invocation.
131-
# Here, we measure the total time spent in the first invocation. This is not
132-
# precise, but it gives a good idea because the majority of time is spent in
137+
# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation.
138+
# In the code below, we measure the total time spent in the first invocation. While this method is not
139+
# precise, it provides a good estimate since the majority of the time is spent in
133140
# compilation.
134141

135142
def measure_latency(fn, input):
@@ -150,11 +157,13 @@ def measure_latency(fn, input):
150157
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")
151158

152159
############################################################################
160+
# Conclusion
161+
# -----------
162+
#
153163
# This recipe shows how to control the cold start compilation time if your model
154-
# has repeated regions. This requires user changes to apply `torch.compile` to
164+
# has repeated regions. This approach requires user modifications to apply `torch.compile` to
155165
# the repeated regions instead of more commonly used full model compilation. We
156-
# are continually working on reducing cold start compilation time. So, please
157-
# stay tuned for our next tutorials.
166+
# are continually working on reducing cold start compilation time.
158167
#
159168
# This feature is available with 2.5 release. If you are on 2.4, you can use a
160169
# config flag - `torch._dynamo.config.inline_inbuilt_nn_modules=True` to avoid

0 commit comments

Comments
 (0)