Skip to content

Commit 21dd980

Browse files
committed
add comment pointing to Sequence Parallel optimization example
ghstack-source-id: 6fa0dcd Pull Request resolved: #438
1 parent b0ed7f0 commit 21dd980

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
362362
)
363363

364364
# Apply tensor + sequence parallelism to every transformer block
365+
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
366+
# by folding (and unfolding) the batch dimension and the sequence dimension.
367+
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
365368
for layer_id, transformer_block in model.layers.items():
366369
layer_plan = {
367370
"attention": prepare_module_input(

0 commit comments

Comments
 (0)