|
| 1 | +Introduction to Context Parallel |
| 2 | +====================================== |
| 3 | +**Authors**: `Xilun Wu <https://github.com/XilunWu>`_, `Chien-Chin Huang <https://github.com/fegin>`__ |
| 4 | + |
| 5 | +.. note:: |
| 6 | + |edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/context_parallel.rst>`__. |
| 7 | + |
| 8 | +.. grid:: 2 |
| 9 | + |
| 10 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 11 | + :class-card: card-prerequisites |
| 12 | + |
| 13 | + * `Context Parallel APIs <https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel>`__ |
| 14 | + * `1M sequence training in torchtitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__ |
| 15 | + |
| 16 | + |
| 17 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 18 | + :class-card: card-prerequisites |
| 19 | + |
| 20 | + * PyTorch 2.7 or later |
| 21 | + |
| 22 | + |
| 23 | +Introduction |
| 24 | +------------ |
| 25 | + |
| 26 | +Context Parallel is an approach used in LLM to reduce peak activation size by sharding the long input sequence across multiple devices. |
| 27 | +It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks. |
| 28 | + |
| 29 | +The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer. |
| 30 | +Ring Attention shuffles the KV shards and calculates the partial attention scores, |
| 31 | +repeats until all KV shards have been used on each device. |
| 32 | +We implemented two Ring Attention variants: `pass-KV <https://arxiv.org/abs/2411.01783>`__ and `all-to-all <https://openreview.net/forum?id=WsRHpHH4s0>`__. |
| 33 | +The pass-KV approach all-gathers KV shards while performing the local SDPA (Scaled Dot Product Attention) then performs the rest when the communication completes. |
| 34 | +The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA computation and the all-to-all communication |
| 35 | +necessary for the next SDPA. |
| 36 | + |
| 37 | +The Context Parallel APIs consist of two parts: |
| 38 | + |
| 39 | +1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``) |
| 40 | +will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to |
| 41 | +argument ``buffers`` and ``buffer_seq_dims`` respectively. |
| 42 | +2. ``set_rotate_method()`` allows users to choose between the pass-KV approach and the all-to-all approach. |
| 43 | + |
| 44 | + |
| 45 | +Setup |
| 46 | +--------------------- |
| 47 | + |
| 48 | +With ``torch.distributed.tensor.experimental.context_parallel()``, users can easily shard the Tensor input and parallelize the execution of the SDPA function. |
| 49 | +To better demonstrate the usage of this API, we start with a simple code snippet doing SDPA and then parallelize it using the API: |
| 50 | + |
| 51 | +.. code:: python |
| 52 | +
|
| 53 | + import torch |
| 54 | + import torch.nn.functional as F |
| 55 | +
|
| 56 | + from torch.nn.attention import sdpa_kernel, SDPBackend |
| 57 | +
|
| 58 | +
|
| 59 | + def sdpa_example(): |
| 60 | + assert torch.cuda.is_available() |
| 61 | + torch.cuda.set_device("cuda:0") |
| 62 | + torch.cuda.manual_seed(0) |
| 63 | +
|
| 64 | + batch = 8 |
| 65 | + nheads = 8 |
| 66 | + qkv_len = 8192 |
| 67 | + dim = 32 |
| 68 | + backend = SDPBackend.FLASH_ATTENTION |
| 69 | + dtype = ( |
| 70 | + torch.bfloat16 |
| 71 | + if backend == SDPBackend.FLASH_ATTENTION |
| 72 | + or backend == SDPBackend.CUDNN_ATTENTION |
| 73 | + else torch.float32 |
| 74 | + ) |
| 75 | +
|
| 76 | + qkv = [ |
| 77 | + torch.rand( |
| 78 | + (batch, nheads, qkv_len, dim), |
| 79 | + dtype=dtype, |
| 80 | + requires_grad=True, |
| 81 | + device='cuda', |
| 82 | + ) |
| 83 | + for _ in range(3) |
| 84 | + ] |
| 85 | +
|
| 86 | + with sdpa_kernel(backend): |
| 87 | + out = F.scaled_dot_product_attention(*qkv, is_causal=True) |
| 88 | +
|
| 89 | +
|
| 90 | + if __name__ == "__main__": |
| 91 | + sdpa_example() |
| 92 | +
|
| 93 | +
|
| 94 | +Enable Context Parallel |
| 95 | +----------------------- |
| 96 | + |
| 97 | +Now, let's first adapt it to a distributed program where each rank has the same tensor input. Then we apply the context parallel API to |
| 98 | +shard to input and distribute the computation across ranks: |
| 99 | + |
| 100 | +.. code:: python |
| 101 | +
|
| 102 | + # file: cp_sdpa_example.py |
| 103 | + import os |
| 104 | +
|
| 105 | + import torch |
| 106 | + import torch.distributed as dist |
| 107 | + import torch.nn.functional as F |
| 108 | + from torch.distributed.device_mesh import init_device_mesh |
| 109 | + from torch.distributed.tensor.experimental import context_parallel |
| 110 | + from torch.distributed.tensor.experimental._attention import context_parallel_unshard |
| 111 | + from torch.nn.attention import sdpa_kernel, SDPBackend |
| 112 | +
|
| 113 | +
|
| 114 | + def context_parallel_sdpa_example(world_size: int, rank: int): |
| 115 | + assert torch.cuda.is_available() |
| 116 | + assert dist.is_nccl_available() |
| 117 | + torch.cuda.set_device(f"cuda:{rank}") |
| 118 | + torch.cuda.manual_seed(0) |
| 119 | +
|
| 120 | + dist.init_process_group( |
| 121 | + backend="nccl", |
| 122 | + init_method="env://", |
| 123 | + world_size=world_size, |
| 124 | + rank=rank, |
| 125 | + ) |
| 126 | + device_mesh = init_device_mesh( |
| 127 | + device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",) |
| 128 | + ) |
| 129 | +
|
| 130 | + batch = 8 |
| 131 | + nheads = 8 |
| 132 | + qkv_len = 64 |
| 133 | + dim = 32 |
| 134 | + backend = SDPBackend.FLASH_ATTENTION |
| 135 | + dtype = ( |
| 136 | + torch.bfloat16 |
| 137 | + if backend == SDPBackend.FLASH_ATTENTION |
| 138 | + or backend == SDPBackend.CUDNN_ATTENTION |
| 139 | + else torch.float32 |
| 140 | + ) |
| 141 | +
|
| 142 | + qkv = [ |
| 143 | + torch.rand( |
| 144 | + (batch, nheads, qkv_len, dim), |
| 145 | + dtype=dtype, |
| 146 | + requires_grad=True, |
| 147 | + device='cuda', |
| 148 | + ) |
| 149 | + for _ in range(3) |
| 150 | + ] |
| 151 | + cp_qkv = [t.detach().clone() for t in qkv] |
| 152 | +
|
| 153 | + with sdpa_kernel(backend): |
| 154 | + with context_parallel( |
| 155 | + device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2) |
| 156 | + ): |
| 157 | + cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True) |
| 158 | +
|
| 159 | + (cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2]) |
| 160 | + out = F.scaled_dot_product_attention(*qkv, is_causal=True) |
| 161 | +
|
| 162 | + assert torch.allclose( |
| 163 | + cp_out, |
| 164 | + out, |
| 165 | + atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size), |
| 166 | + ) |
| 167 | +
|
| 168 | +
|
| 169 | + if __name__ == "__main__": |
| 170 | + rank = int(os.environ["RANK"]) |
| 171 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 172 | +
|
| 173 | + try: |
| 174 | + context_parallel_sdpa_example(world_size, rank) |
| 175 | + finally: |
| 176 | + dist.barrier() |
| 177 | + dist.destroy_process_group() |
| 178 | +
|
| 179 | +
|
| 180 | +You can use the command ``torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py`` to launch the above context parallel |
| 181 | +SDPA on 4 GPUs. We demonstrate the nemuric correctness by comparing the output of Ring Attention to that of SDPA on a single GPU. |
| 182 | + |
| 183 | + |
| 184 | +Select Rotation Approach |
| 185 | +------------------------ |
| 186 | + |
| 187 | +You can choose the desired shards rotation approach in Ring Attention by using ``torch.distributed.tensor.experimental._attention.set_rotate_method()``: |
| 188 | + |
| 189 | +.. code:: python |
| 190 | +
|
| 191 | + # file: cp_sdpa_example.py |
| 192 | + from torch.distributed.tensor.experimental._attention import set_rotate_method |
| 193 | +
|
| 194 | + set_rotate_method("alltoall") # rotate shards using all-to-all |
| 195 | +
|
| 196 | + with sdpa_kernel(backend): |
| 197 | + with context_parallel( |
| 198 | + device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2) |
| 199 | + ): |
| 200 | + cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True) |
| 201 | +
|
| 202 | +
|
| 203 | +Conclusion |
| 204 | +---------- |
| 205 | + |
| 206 | +In this tutorial, have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For |
| 207 | +design and implementation details, performance analysis, and an end-to-end training example in `torchtitan <https://github.com/pytorch/torchtitan>`__, |
| 208 | +see our post on `PyTorch native long-context training <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__. |
0 commit comments