Skip to content

Commit dace205

Browse files
committed
pipelining tutorials
1 parent c043bab commit dace205

File tree

2 files changed

+241
-0
lines changed

2 files changed

+241
-0
lines changed

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,13 @@ Welcome to PyTorch Tutorials
744744
:link: intermediate/rpc_param_server_tutorial.html
745745
:tags: Parallel-and-Distributed-Training
746746

747+
.. customcarditem::
748+
:header: Introduction to Distributed Pipeline Parallelism
749+
:card_description: Demonstrate how to implement pipeline parallelism using torch.distributed.pipelining
750+
:image: _static/img/thumbnails/cropped/Introduction-to-Distributed-Pipeline-Parallelism.png
751+
:link: intermediate/pipelining_tutorial.html
752+
:tags: Parallel-and-Distributed-Training
753+
747754
.. customcarditem::
748755
:header: Implementing Batch RPC Processing Using Asynchronous Executions
749756
:card_description: Learn how to use rpc.functions.async_execution to implement batch RPC
@@ -1128,6 +1135,7 @@ Additional Resources
11281135
intermediate/FSDP_tutorial
11291136
intermediate/FSDP_adavnced_tutorial
11301137
intermediate/TP_tutorial
1138+
intermediate/pipelining_tutorial
11311139
intermediate/process_group_cpp_extension_tutorial
11321140
intermediate/rpc_tutorial
11331141
intermediate/rpc_param_server_tutorial
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
Introduction to Distributed Pipeline Parallelism
2+
================================================
3+
**Authors**: `Howard Huang <https://github.com/H-Huang>`_
4+
5+
.. note::
6+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/pipelining_tutorial.rst>`__.
7+
8+
This tutorial uses a gpt-style transformer model to demonstrate implementing distributed
9+
pipeline parallelism with `torch.distributed.pipelining <https://pytorch.org/docs/main/distributed.pipelining.html>`__
10+
APIs.
11+
12+
.. grid:: 2
13+
14+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
15+
16+
* How to use ``torch.distributed.pipelining`` APIs
17+
* How to apply pipeline parallelism to a transformer model
18+
* How to utilize different schedules on a set of microbatches
19+
20+
21+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
22+
23+
* Familiarity with `basic distributed training <https://pytorch.org/tutorials/beginner/dist_overview.html>`__ in PyTorch
24+
25+
Setup
26+
-----
27+
28+
With ``torch.distributed.pipelining`` we will be partitioning the execution of a model and scheduling computation on micro-batches. We will be using a simplified version
29+
of a transformer decoder model. The model architecture is for educational purposes and has multiple transformer decoder layers as we want to demonstrate how to split the model into different
30+
chunks. First, let us define the model:
31+
32+
.. code:: python
33+
34+
import torch
35+
import torch.nn as nn
36+
from dataclasses import dataclass
37+
38+
@dataclass
39+
class ModelArgs:
40+
dim: int = 512
41+
n_layers: int = 8
42+
n_heads: int = 8
43+
vocab_size: int = 10000
44+
45+
class Transformer(nn.Module):
46+
def __init__(self, model_args: ModelArgs):
47+
super().__init__()
48+
49+
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
50+
51+
# Using a ModuleDict lets us delete layers witout affecting names,
52+
# ensuring checkpoints will correctly save and load.
53+
self.layers = torch.nn.ModuleDict()
54+
for layer_id in range(model_args.n_layers):
55+
self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)
56+
57+
self.norm = nn.LayerNorm(model_args.dim)
58+
self.output = nn.Linear(model_args.dim, model_args.vocab_size)
59+
60+
def forward(self, tokens: torch.Tensor):
61+
# Handling layers being 'None' at runtime enables easy pipeline splitting
62+
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
63+
64+
for layer in self.layers.values():
65+
h = layer(h, h)
66+
67+
h = self.norm(h) if self.norm else h
68+
output = self.output(h).float() if self.output else h
69+
return output
70+
71+
Then, we need to import the necessary libraries in our script and initialize the distributed training process. In this case, we are defining some global variables to use
72+
later in the script:
73+
74+
.. code:: python
75+
76+
import os
77+
import torch.distributed as dist
78+
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe
79+
80+
global rank, device, pp_group, stage_index, num_stages
81+
def init_distributed():
82+
global rank, device, pp_group, stage_index, num_stages
83+
rank = int(os.environ["LOCAL_RANK"])
84+
world_size = int(os.environ["WORLD_SIZE"])
85+
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
86+
dist.init_process_group()
87+
88+
pp_group = dist.new_group()
89+
stage_index = rank
90+
num_stages = world_size
91+
92+
The ``rank``, ``world_size``, and ``init_process_group()`` code should seem familiar to you as those are commonly used in
93+
all distributed programs. The globals specific to pipeline parallelism include ``pp_group`` which is the process
94+
group that will be used for send/recv communications, ``stage_index`` which, in this example, is a single rank
95+
per stage so the index is equivalent to the rank, and ``num_stages`` which is equivalent to world_size.
96+
97+
The ``num_stages`` is used to set the number of stages that will be used in the pipeline parallelism schedule. For example,
98+
for ``num_stages=4``, a microbatch will need to go through 4 forwards and 4 backwards before it is completed. The ``stage_index``
99+
is necessary for the framework to know how to communicate between stages. For example, for the first stage (``stage_index=0``), it will
100+
use data from the dataloader and does not need to receive data from any previous peers to perform its computation.
101+
102+
103+
Step 1: Partition the Transformer Model
104+
---------------------------------------
105+
106+
There are two different ways of partitioning the model:
107+
108+
First is the manual mode in which we can manually create two instances of the model by deleting portions of
109+
attributes of the model. In this example for a 2 stage (2 ranks) the model is cut in half.
110+
111+
.. code:: python
112+
113+
def manual_model_split(model, example_input_microbatch, model_args) -> PipelineStage:
114+
if stage_index == 0:
115+
# prepare the first stage model
116+
for i in range(4, 8):
117+
del model.layers[str(i)]
118+
model.norm = None
119+
model.output = None
120+
stage_input_microbatch = example_input_microbatch
121+
122+
elif stage_index == 1:
123+
# prepare the second stage model
124+
for i in range(4):
125+
del model.layers[str(i)]
126+
model.tok_embeddings = None
127+
stage_input_microbatch = torch.randn(example_input_microbatch.shape[0], model_args.dim)
128+
129+
stage = PipelineStage(
130+
model,
131+
stage_index,
132+
num_stages,
133+
device,
134+
input_args=stage_input_microbatch,
135+
)
136+
return stage
137+
138+
As we can see the first stage does not have the layer norm or the output layer, and it only includes the first four transformer blocks.
139+
The second stage does not have the input embedding layers, but includes the output layers and the final four transformer blocks. The function
140+
then returns the ``PipelineStage`` for the current rank.
141+
142+
The second method is the tracer-based mode which automatically splits the model based on a ``split_spec`` argument. Using the pipeline specification, we can instruct
143+
``torch.distributed.pipelining`` where to split the model. In the following code block,
144+
we are splitting before the before 4th transformer decoder layer, mirroring the manual split described above. Similarly,
145+
we can retrieve a ``PipelineStage`` by calling ``build_stage`` after this splitting is done.
146+
147+
.. code:: python
148+
def tracer_model_split(model, example_input_microbatch) -> PipelineStage:
149+
pipe = pipeline(
150+
module=model,
151+
mb_args=(example_input_microbatch,),
152+
split_spec={
153+
"layers.4": SplitPoint.BEGINNING,
154+
}
155+
)
156+
stage = pipe.build_stage(stage_index, device, pp_group)
157+
return stage
158+
159+
160+
Step 2: Define The Main Execution
161+
---------------------------------
162+
163+
In the main function we will create a particular pipeline schedule that the stages should follow. ``torch.distributed.pipelining``
164+
supports multiple schedules including supports multiple schedules, including single-stage-per-rank schedules ``GPipe`` and ``1F1B``,
165+
as well as multiple-stage-per-rank schedules such as ``Interleaved1F1B`` and ``LoopedBFS``.
166+
167+
.. code:: python
168+
169+
if __name__ == "__main__":
170+
init_distributed()
171+
num_microbatches = 4
172+
model_args = ModelArgs()
173+
model = Transformer(model_args)
174+
175+
# Dummy data
176+
x = torch.ones(32, 500, dtype=torch.long)
177+
y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)
178+
example_input_microbatch = x.chunk(num_microbatches)[0]
179+
180+
# Option 1: Manual model splitting
181+
stage = manual_model_split(model, example_input_microbatch, model_args)
182+
183+
# Option 2: Tracer model splitting
184+
# stage = tracer_model_split(model, example_input_microbatch)
185+
186+
x = x.to(device)
187+
y = y.to(device)
188+
189+
def tokenwise_loss_fn(outputs, targets):
190+
loss_fn = nn.CrossEntropyLoss()
191+
outputs = outputs.view(-1, model_args.vocab_size)
192+
targets = targets.view(-1)
193+
return loss_fn(outputs, targets)
194+
195+
schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=token_loss_fn)
196+
197+
if rank == 0:
198+
schedule.step(x)
199+
elif rank == 1:
200+
losses = []
201+
output = schedule.step(target=y, losses=losses)
202+
dist.destroy_process_group()
203+
204+
In the example above, we are using the manual method to split the model, but the code can be uncommented to also try the
205+
tracer-based model splitting function. In our schedule, we need to pass in the number of microbatches and
206+
the loss function used to evaluate the targets.
207+
208+
The ``.step()`` function processes the entire minibatch and automatically splits it into microbatches based
209+
on the ``n_microbatches`` passed previously. The microbatches are then operated on according to the schedule class.
210+
In the example above, we are using GPipe, which follows a simple all-forwards and then all-backwards schedule. The output
211+
returned from rank 1 will be the same as if the model was on a single GPU and run with the entire batch. Similarly,
212+
we can pass in a ``losses`` container to store the corresponding losses for each microbatch.
213+
214+
Step 3: Launch the Distributed Processes
215+
----------------------------------------
216+
217+
Finally, we are ready to run the script. We will use ``torchrun`` to create a single host, 2-process job.
218+
Our script is already written in a way rank 0 that performs the required logic for pipeline stage 0, and rank 1
219+
performs the logic for pipeline stage 1.
220+
221+
``torchrun --standalone --nnodes 1 --nproc_per_node 2 pipelining_tutorial.py``
222+
223+
Conclusion
224+
----------
225+
226+
In this tutorial, we have learned how to implement distributed pipeline parallelism using PyTorch's ``torch.distributed.pipelining`` APIs.
227+
We explored setting up the environment, defining a transformer model, and partitioning it for distributed training.
228+
We discussed two methods of model partitioning, manual and tracer-based, and demonstrated how to schedule computations on
229+
micro-batches across different stages. Finally, we covered the execution of the pipeline schedule and the launch of distributed
230+
processes using ``torchrun``.
231+
232+
For a production ready usage of pipeline parallelism as well as composition with other distributed techniques, see also
233+
`TorchTitan end to end example of 3D parallelism <https://github.com/pytorch/torchtitan>`__.

0 commit comments

Comments
 (0)