Skip to content

Commit 655ea0f

Browse files
committed
Address comments
1 parent ab41031 commit 655ea0f

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

distributed/parallelize_llama.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def apply_tp(
4646

4747
# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
4848
# # 1. Parallelize the first embedding and the last linear proj layer
49-
# # 2. Parallelize the root norm layer over the sequence dim
50-
# # 3. Shard the first transformer block's inputs
49+
# # 2. Shard the first transformer block's inputs
5150
# model = parallelize_module(
5251
# model,
5352
# tp_mesh,
@@ -64,7 +63,7 @@ def apply_tp(
6463
# },
6564
# )
6665

67-
# Apply tensor + sequence parallelism to every transformer block
66+
# Apply tensor parallelism to every transformer block
6867
for transformer_block in model.layers:
6968
layer_plan = {
7069
"attention": PrepareModuleInput(

0 commit comments

Comments
 (0)