Skip to content

Commit 58e1c57

Browse files
committed
pipelining tutorials
1 parent 2b0e464 commit 58e1c57

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from dataclasses import dataclass
2+
import os
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.distributed as dist
7+
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe
8+
9+
# torchrun --standalone --nnodes 1 --nproc_per_node 2 pipelining_tut.py
10+
11+
rank = int(os.environ["LOCAL_RANK"])
12+
world_size = int(os.environ["WORLD_SIZE"])
13+
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
14+
dist.init_process_group()
15+
16+
pp_group = dist.new_group()
17+
stage_index = rank
18+
num_stages = world_size
19+
20+
@dataclass
21+
class ModelArgs:
22+
dim: int = 512
23+
n_layers: int = 8
24+
n_heads: int = 8
25+
vocab_size: int = 10000
26+
27+
class Transformer(nn.Module):
28+
def __init__(self, model_args: ModelArgs):
29+
super().__init__()
30+
31+
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
32+
33+
# Using a ModuleDict lets us delete layers witout affecting names,
34+
# ensuring checkpoints will correctly save and load.
35+
self.layers = torch.nn.ModuleDict()
36+
for layer_id in range(model_args.n_layers):
37+
self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)
38+
39+
self.norm = nn.LayerNorm(model_args.dim)
40+
self.output = nn.Linear(model_args.dim, model_args.vocab_size)
41+
42+
def forward(self, tokens: torch.Tensor):
43+
# Handling layers being 'None' at runtime enables easy pipeline splitting
44+
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
45+
46+
for layer in self.layers.values():
47+
h = layer(h, h)
48+
49+
h = self.norm(h) if self.norm else h
50+
output = self.output(h).float() if self.output else h
51+
return output
52+
53+
54+
55+
# Manual usage
56+
def manual_model_split(model) -> PipelineStage:
57+
# To be implemented
58+
stage = None
59+
return stage
60+
61+
# Tracer usage (pipeline API)
62+
def tracer_model_split(model) -> PipelineStage:
63+
x = torch.ones(32, 500, dtype=torch.long)
64+
pipe = pipeline(
65+
module=model,
66+
mb_args=(x,),
67+
split_spec={
68+
"layers.4": SplitPoint.BEGINNING,
69+
}
70+
)
71+
stage = pipe.build_stage(stage_index, device, pp_group)
72+
return stage
73+
74+
if __name__ == "__main__":
75+
model = Transformer(ModelArgs())
76+
if rank == 0:
77+
print(model)
78+
stage = tracer_model_split(model)
79+
80+
print(stage)
81+
82+
schedule = ScheduleGPipe(stage, n_microbatches=4)
83+
x = torch.ones(32, 500, dtype=torch.long)
84+
85+
output = schedule.step(x)
86+
print(output)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
Pipeline manual
2+
=================================================
3+
**Author**: `Howard Huang <https://github.com/H-Huang>`_
4+
5+
.. note::
6+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/pipelining_tutorial.rst>`__.
7+
8+
Prerequisites:
9+
10+
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
11+
12+
This tutorial uses a transformer model to demonstrate implementing distributed
13+
pipeline parallelism with `torch.distributed.pipelining <https://pytorch.org/docs/main/distributed.pipelining.html>`__
14+
APIs.
15+
16+
Basics
17+
------
18+
19+
TBD
20+
21+
22+
Step 1: Partition the Transformer Model
23+
--------------------------------
24+
25+
TBD
26+
27+
28+
Step 2: Define The Training Loop
29+
--------------------------------
30+
31+
32+
TBD
33+
34+
Step 3: Launch the Distributed Processes
35+
----------------------------
36+
37+
38+
TBD

0 commit comments

Comments
 (0)