Skip to content

Commit b87d98d

Browse files
committed
Add Context Parallel tutorial
1 parent f2fcf6f commit b87d98d

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

prototype_source/context_parallel.rst

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
Introduction to Context Parallel
2+
======================================
3+
**Authors**: `Xilun Wu <https://github.com/XilunWu>`_, `Chien-Chin Huang <https://github.com/fegin>`__
4+
5+
.. note::
6+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/context_parallel.rst>`__.
7+
8+
.. grid:: 2
9+
10+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
11+
:class-card: card-prerequisites
12+
13+
* `Context Parallel APIs <https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel>`__
14+
* `1M sequence training in torchtitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__
15+
16+
17+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
18+
:class-card: card-prerequisites
19+
20+
* PyTorch 2.7 or later
21+
22+
23+
Introduction
24+
------------
25+
26+
Context Parallel is an approach used in LLM to reduce peak activation size by sharding the long input sequence across multiple devices.
27+
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.
28+
29+
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
30+
Ring Attention shuffles the KV shards and calculates the partial attention scores,
31+
repeats until all KV shards have been used on each device.
32+
We implemented two Ring Attention variants: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__.
33+
The pass-KV approach all-gathers KV shards while performing the local SDPA (Scaled Dot Product Attention) then performs the rest when the communication completes.
34+
The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication
35+
necessary for the next SDPA.
36+
37+
The Context Parallel APIs consist of two parts:
38+
39+
1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
40+
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
41+
argument ``buffers`` and ``buffer_seq_dims`` respectively.
42+
2. ``set_rotate_method()`` allows users to choose between the pass-KV approach and the all-to-all approach.
43+
44+
45+
Setup
46+
---------------------
47+
48+
With ``torch.distributed.tensor.experimental.context_parallel()``, users can easily shard the Tensor input and parallelize the execution of the SDPA function.
49+
To better demonstrate the usage of this API, we start with a simple code snippet doing SDPA and then parallelize it using the API:
50+
51+
.. code:: python
52+
53+
import torch
54+
import torch.nn.functional as F
55+
56+
from torch.nn.attention import sdpa_kernel, SDPBackend
57+
58+
59+
def sdpa_example():
60+
assert torch.cuda.is_available()
61+
torch.cuda.set_device("cuda:0")
62+
torch.cuda.manual_seed(0)
63+
64+
batch = 8
65+
nheads = 8
66+
qkv_len = 8192
67+
dim = 32
68+
backend = SDPBackend.FLASH_ATTENTION
69+
dtype = (
70+
torch.bfloat16
71+
if backend == SDPBackend.FLASH_ATTENTION
72+
or backend == SDPBackend.CUDNN_ATTENTION
73+
else torch.float32
74+
)
75+
76+
qkv = [
77+
torch.rand(
78+
(batch, nheads, qkv_len, dim),
79+
dtype=dtype,
80+
requires_grad=True,
81+
device='cuda',
82+
)
83+
for _ in range(3)
84+
]
85+
86+
with sdpa_kernel(backend):
87+
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
88+
89+
90+
if __name__ == "__main__":
91+
sdpa_example()
92+
93+
94+
Enable Context Parallel
95+
-----------------------
96+
97+
Now, let's first adapt it to a distributed program where each rank has the same tensor input. Then we apply the context parallel API to
98+
shard to input and distribute the computation across ranks:
99+
100+
.. code:: python
101+
102+
# file: cp_sdpa_example.py
103+
import os
104+
105+
import torch
106+
import torch.distributed as dist
107+
import torch.nn.functional as F
108+
from torch.distributed.device_mesh import init_device_mesh
109+
from torch.distributed.tensor.experimental import context_parallel
110+
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
111+
from torch.nn.attention import sdpa_kernel, SDPBackend
112+
113+
114+
def context_parallel_sdpa_example(world_size: int, rank: int):
115+
assert torch.cuda.is_available()
116+
assert dist.is_nccl_available()
117+
torch.cuda.set_device(f"cuda:{rank}")
118+
torch.cuda.manual_seed(0)
119+
120+
dist.init_process_group(
121+
backend="nccl",
122+
init_method="env://",
123+
world_size=world_size,
124+
rank=rank,
125+
)
126+
device_mesh = init_device_mesh(
127+
device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
128+
)
129+
130+
batch = 8
131+
nheads = 8
132+
qkv_len = 64
133+
dim = 32
134+
backend = SDPBackend.FLASH_ATTENTION
135+
dtype = (
136+
torch.bfloat16
137+
if backend == SDPBackend.FLASH_ATTENTION
138+
or backend == SDPBackend.CUDNN_ATTENTION
139+
else torch.float32
140+
)
141+
142+
qkv = [
143+
torch.rand(
144+
(batch, nheads, qkv_len, dim),
145+
dtype=dtype,
146+
requires_grad=True,
147+
device='cuda',
148+
)
149+
for _ in range(3)
150+
]
151+
cp_qkv = [t.detach().clone() for t in qkv]
152+
153+
with sdpa_kernel(backend):
154+
with context_parallel(
155+
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
156+
):
157+
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
158+
159+
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])
160+
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
161+
162+
assert torch.allclose(
163+
cp_out,
164+
out,
165+
atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
166+
)
167+
168+
169+
if __name__ == "__main__":
170+
rank = int(os.environ["RANK"])
171+
world_size = int(os.environ["WORLD_SIZE"])
172+
173+
try:
174+
context_parallel_sdpa_example(world_size, rank)
175+
finally:
176+
dist.barrier()
177+
dist.destroy_process_group()
178+
179+
180+
You can use the command ``torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py`` to launch the above context parallel
181+
SDPA on 4 GPUs. We demonstrate the nemuric correctness by comparing the output of Ring Attention to that of SDPA on a single GPU.
182+
183+
184+
Select Rotation Approach
185+
------------------------
186+
187+
You can choose the desired shards rotation approach in Ring Attention by using ``torch.distributed.tensor.experimental._attention.set_rotate_method()``:
188+
189+
.. code:: python
190+
191+
# file: cp_sdpa_example.py
192+
from torch.distributed.tensor.experimental._attention import set_rotate_method
193+
194+
set_rotate_method("alltoall") # rotate shards using all-to-all
195+
196+
with sdpa_kernel(backend):
197+
with context_parallel(
198+
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
199+
):
200+
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
201+
202+
203+
Conclusion
204+
----------
205+
206+
In this tutorial, have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For
207+
design and implementation details, performance analysis, and an end-to-end training example in `torchtitan <https://github.com/pytorch/torchtitan>`__,
208+
see our post on `PyTorch native long-context training <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__.

prototype_source/prototype_index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ Prototype features are not available as part of binary distributions like PyPI o
239239
:link: ../prototype/flight_recorder_tutorial.html
240240
:tags: Distributed, Debugging, FlightRecorder
241241

242+
.. Distributed
243+
.. customcarditem::
244+
:header: Context Parallel Tutorial
245+
:card_description: Parallelize the attention computation along sequence dimension
246+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
247+
:link: ../prototype/context_parallel.html
248+
:tags: Distributed, Context Parallel
249+
242250
.. Integration
243251
.. customcarditem::
244252
:header: Out-of-tree extension autoloading in Python
@@ -265,6 +273,7 @@ Prototype features are not available as part of binary distributions like PyPI o
265273
.. toctree::
266274
:hidden:
267275

276+
prototype/context_parallel.html
268277
prototype/fx_graph_mode_quant_guide.html
269278
prototype/fx_graph_mode_ptq_dynamic.html
270279
prototype/fx_graph_mode_ptq_static.html

0 commit comments

Comments
 (0)