You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: prototype_source/context_parallel.rst
+21-8Lines changed: 21 additions & 8 deletions
Original file line number
Diff line number
Diff line change
@@ -23,7 +23,7 @@ Introduction to Context Parallel
23
23
Introduction
24
24
------------
25
25
26
-
Context Parallel is an approach used in LLM to reduce peak activation size by sharding the long input sequence across multiple devices.
26
+
Context Parallel is an approach used in large language model training to reduce peak activation size by sharding the long input sequence across multiple devices.
27
27
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.
28
28
29
29
The core of Context Parallel is Ring Attention, a novel parallel implementation of the Attention layer.
@@ -82,7 +82,7 @@ To better demonstrate the usage of this API, we start with a simple code snippet
82
82
)
83
83
for _ inrange(3)
84
84
]
85
-
85
+
# specify the SDPABackend to use
86
86
with sdpa_kernel(backend):
87
87
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
88
88
@@ -148,22 +148,35 @@ shard to input and distribute the computation across ranks:
148
148
)
149
149
for _ inrange(3)
150
150
]
151
+
# specify the SDPABackend to use
152
+
with sdpa_kernel(backend):
153
+
out = F.scaled_dot_product_attention(*qkv, is_causal=True)
154
+
155
+
# make a clean copy of QKV for output comparison
151
156
cp_qkv = [t.detach().clone() for t in qkv]
152
157
153
158
with sdpa_kernel(backend):
159
+
# This `context_parallel()` performs two actions:
160
+
# 1. shard the tensor objects in `buffers` in-place along the dimension
161
+
# specified in `buffer_seq_dims`, the tensors in `buffers` and their
162
+
# sharding dims in `buffer_seq_dims` are organized in the same order.
163
+
# 2. replace the execution of `F.scaled_dot_product_attention` with a
0 commit comments