2
2
Reducing torch.compile cold start compilation time with regional compilation
3
3
============================================================================
4
4
5
- Introduction
6
- ------------
5
+ **Author:** `Animesh Jain <https://github.com/anijain2305>`_
7
6
As deep learning models get larger, the compilation time of these models also
8
- increase . This increase in compilation time can lead to a large startup time in
9
- inference services or wasted resources in large scale training. This recipe
7
+ increases . This extended compilation time can result in a large startup time in
8
+ inference services or wasted resources in large- scale training. This recipe
10
9
shows an example of how to reduce the cold start compilation time by choosing to
11
10
compile a repeated region of the model instead of the entire model.
12
11
12
+ Prerequisites
13
+ ----------------
14
+
15
+ * Pytorch 2.5 or later
13
16
Setup
14
17
-----
15
18
Before we begin, we need to install ``torch`` if it is not already
19
22
20
23
pip install torch
21
24
25
+ .. note::
26
+ This feature is available starting with the 2.5 release. If you are using version 2.4,
27
+ you can enable the configuration flag ``torch._dynamo.config.inline_inbuilt_nn_modules=True``
28
+ to prevent recompilations during regional compilation. In version 2.5, this flag is enabled by default.
22
29
"""
23
30
24
31
27
34
# Steps
28
35
# -----
29
36
#
30
- # 1. Import all necessary libraries
37
+ # In this recipe, we will follow these steps:
38
+ #
39
+ # 1. Import all necessary libraries.
31
40
# 2. Define and initialize a neural network with repeated regions.
32
41
# 3. Understand the difference between the full model and the regional compilation.
33
42
# 4. Measure the compilation time of the full model and the regional compilation.
34
43
#
35
- # 1. Import necessary libraries for loading our data
36
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
44
+ # First, let's import the necessary libraries for loading our data:
45
+ #
37
46
#
38
47
#
39
48
40
49
import torch
41
50
import torch .nn as nn
42
51
from time import perf_counter
43
52
53
+ ##########################################################
54
+ # Next, let's define and initialize a neural network with repeated regions.
44
55
#
45
- # 2. Define and initialize a neural network with repeated regions.
46
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
47
- # Typically neural networks are composed of repeated layers. For example, a
56
+ # Typically, neural networks are composed of repeated layers. For example, a
48
57
# large language model is composed of many Transformer blocks. In this recipe,
49
- # we will create a `Layer` ` nn.Module` class as a proxy for a repeated region.
50
- # We will then create a `Model` which is composed of 64 instances of this
51
- # `Layer` class.
58
+ # we will create a `` Layer`` using the `` nn.Module` ` class as a proxy for a repeated region.
59
+ # We will then create a `` Model` ` which is composed of 64 instances of this
60
+ # `` Layer` ` class.
52
61
#
53
62
class Layer (torch .nn .Module ):
54
63
def __init__ (self ):
@@ -83,53 +92,51 @@ def forward(self, x):
83
92
x = layer (x )
84
93
return x
85
94
86
- #
87
- # 3. Understand the difference between the full model and the regional compilation.
88
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
95
+ ####################################################
96
+ # Next, let's review the difference between the full model and the regional compilation.
89
97
#
90
- # In full model compilation, the full model is compiled as a whole. This is how
91
- # most users use torch.compile. In this example, we can apply torch.compile to
92
- # the `model ` object. This will effectively inline the 64 layers, producing a
98
+ # In full model compilation, the entire model is compiled as a whole. This is the common approach
99
+ # most users take with `` torch.compile`` . In this example, we apply `` torch.compile`` to
100
+ # the ``Model` ` object. This will effectively inline the 64 layers, producing a
93
101
# large graph to compile. You can look at the full graph by running this recipe
94
- # with `TORCH_LOGS=graph_code`.
102
+ # with `` TORCH_LOGS=graph_code` `.
95
103
#
96
104
#
97
105
98
106
model = Model (apply_regional_compilation = False ).cuda ()
99
107
full_compiled_model = torch .compile (model )
100
108
101
109
102
- #
110
+ ###################################################
103
111
# The regional compilation, on the other hand, compiles a region of the model.
104
- # By wisely choosing to compile a repeated region of the model, we can compile a
105
- # much smaller graph and then reuse the compiled graph for all the regions. We
106
- # can apply regional compilation in the example as follows. `torch.compile` is
107
- # applied only to the `layers` and not the full model.
112
+ # By strategically choosing to compile a repeated region of the model, we can compile a
113
+ # much smaller graph and then reuse the compiled graph for all the regions.
114
+ # In the example, ``torch.compile`` is applied only to the ``layers`` and not the full model.
108
115
#
109
116
110
117
regional_compiled_model = Model (apply_regional_compilation = True ).cuda ()
111
118
112
119
# Applying compilation to a repeated region, instead of full model, leads to
113
120
# large savings in compile time. Here, we will just compile a layer instance and
114
- # then reuse it 64 times in the `model ` object.
121
+ # then reuse it 64 times in the ``Model` ` object.
115
122
#
116
123
# Note that with repeated regions, some part of the model might not be compiled.
117
- # For example, the `self.linear` in the `Model` is outside of the scope of
124
+ # For example, the `` self.linear`` in the `` Model` ` is outside of the scope of
118
125
# regional compilation.
119
126
#
120
127
# Also, note that there is a tradeoff between performance speedup and compile
121
- # time. The full model compilation has larger graph and therefore ,
122
- # theoretically, has more scope for optimizations. However for practical
128
+ # time. Full model compilation involves a larger graph and,
129
+ # theoretically, offers more scope for optimizations. However, for practical
123
130
# purposes and depending on the model, we have observed many cases with minimal
124
131
# speedup differences between the full model and regional compilation.
125
132
126
133
134
+ ###################################################
135
+ # Next, let's measure the compilation time of the full model and the regional compilation.
127
136
#
128
- # 4. Measure the compilation time of the full model and the regional compilation.
129
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
130
- # `torch.compile` is a JIT compiler, i.e., it compiles on the first invocation.
131
- # Here, we measure the total time spent in the first invocation. This is not
132
- # precise, but it gives a good idea because the majority of time is spent in
137
+ # ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation.
138
+ # In the code below, we measure the total time spent in the first invocation. While this method is not
139
+ # precise, it provides a good estimate since the majority of the time is spent in
133
140
# compilation.
134
141
135
142
def measure_latency (fn , input ):
@@ -150,11 +157,13 @@ def measure_latency(fn, input):
150
157
print (f"Regional compilation time = { regional_compilation_latency :.2f} seconds" )
151
158
152
159
############################################################################
160
+ # Conclusion
161
+ # -----------
162
+ #
153
163
# This recipe shows how to control the cold start compilation time if your model
154
- # has repeated regions. This requires user changes to apply `torch.compile` to
164
+ # has repeated regions. This approach requires user modifications to apply `torch.compile` to
155
165
# the repeated regions instead of more commonly used full model compilation. We
156
- # are continually working on reducing cold start compilation time. So, please
157
- # stay tuned for our next tutorials.
166
+ # are continually working on reducing cold start compilation time.
158
167
#
159
168
# This feature is available with 2.5 release. If you are on 2.4, you can use a
160
169
# config flag - `torch._dynamo.config.inline_inbuilt_nn_modules=True` to avoid
0 commit comments