Skip to content

Commit dcf02de

Browse files
committed
fix: address comment
1 parent f75c9fd commit dcf02de

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

prototype_source/context_parallel.rst

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Introduction to Context Parallel
2323
Introduction
2424
------------
2525

26-
Context Parallel is an approach used in LLM to reduce peak activation size by sharding the long input sequence across multiple devices.
26+
Context Parallel is an approach used in large language model training to reduce peak activation size by sharding the long input sequence across multiple devices.
2727
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.
2828

2929
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
@@ -82,7 +82,7 @@ To better demonstrate the usage of this API, we start with a simple code snippet
8282
)
8383
for _ in range(3)
8484
]
85-
85+
# specify the SDPABackend to use
8686
with sdpa_kernel(backend):
8787
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
8888
@@ -148,22 +148,35 @@ shard to input and distribute the computation across ranks:
148148
)
149149
for _ in range(3)
150150
]
151+
# specify the SDPABackend to use
152+
with sdpa_kernel(backend):
153+
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
154+
155+
# make a clean copy of QKV for output comparison
151156
cp_qkv = [t.detach().clone() for t in qkv]
152157
153158
with sdpa_kernel(backend):
159+
# This `context_parallel()` performs two actions:
160+
# 1. shard the tensor objects in `buffers` in-place along the dimension
161+
# specified in `buffer_seq_dims`, the tensors in `buffers` and their
162+
# sharding dims in `buffer_seq_dims` are organized in the same order.
163+
# 2. replace the execution of `F.scaled_dot_product_attention` with a
164+
# context-paralleled-enabled Ring Attention.
154165
with context_parallel(
155166
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
156167
):
157168
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)
158169
170+
# the output `cp_out` is still sharded in the same way as QKV
171+
# the `context_parallel_unshard` API allows users to easily
172+
# unshard to gain the full tensor.
159173
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])
160-
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
161174
162-
assert torch.allclose(
163-
cp_out,
164-
out,
165-
atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
166-
)
175+
assert torch.allclose(
176+
cp_out,
177+
out,
178+
atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
179+
)
167180
168181
169182
if __name__ == "__main__":

prototype_source/prototype_index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ Prototype features are not available as part of binary distributions like PyPI o
239239
:link: ../prototype/flight_recorder_tutorial.html
240240
:tags: Distributed, Debugging, FlightRecorder
241241

242-
.. Distributed
243242
.. customcarditem::
244243
:header: Context Parallel Tutorial
245244
:card_description: Parallelize the attention computation along sequence dimension

0 commit comments

Comments
 (0)