Skip to content

Commit 81be75e

Browse files
fduwjjmalfet
authored andcommitted
[Dist][Inference] Further TP fix to make sure e2e TP is working (#878)
1 parent 6e608cd commit 81be75e

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

distributed/parallelize_llama.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,55 +44,48 @@ def apply_tp(
4444

4545
tp_mesh = world_mesh["tp"]
4646

47-
# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
48-
# # 1. Parallelize the first embedding and the last linear proj layer
49-
# # 2. Shard the first transformer block's inputs
47+
# TODO: The commented part can further help with scaling but it will
48+
# make inference very slow so we disable it for now.
49+
# Parallelize the token embedding and the last linear proj layer
5050
# model = parallelize_module(
5151
# model,
5252
# tp_mesh,
5353
# {
54-
# "tok_embeddings": RowwiseParallel(
55-
# input_layouts=Replicate(),
54+
# "tok_embeddings": ColwiseParallel(
5655
# output_layouts=Replicate(),
5756
# ),
5857
# "output": ColwiseParallel(
59-
# input_layouts=Shard(1),
6058
# output_layouts=Replicate(),
61-
# use_local_output=True,
6259
# ),
6360
# },
6461
# )
6562

63+
# NOTE: This is indeed a hack because it assumes that we create cache
64+
# after we apply TP to the model. Because we don't want to change model code
65+
# when applying TP. We need to have change to ensure KVCache has the correct
66+
# size as k and v.
67+
model.config.n_local_heads = model.config.n_local_heads // tp_mesh.size()
68+
6669
# Apply tensor parallelism to every transformer block
6770
for transformer_block in model.layers:
6871
layer_plan = {
69-
"attention": PrepareModuleInput(
70-
input_layouts=(Replicate(), None),
71-
desired_input_layouts=(Replicate(), None),
72-
),
7372
"attention.wq": ColwiseParallel(),
7473
"attention.wk": ColwiseParallel(),
7574
"attention.wv": ColwiseParallel(),
76-
"attention.wo": RowwiseParallel(
77-
output_layouts=Replicate(),
78-
use_local_output=True,
79-
),
80-
"feed_forward": PrepareModuleInput(
81-
input_layouts=(Replicate(),),
82-
desired_input_layouts=(Replicate(),),
83-
),
75+
"attention.wo": RowwiseParallel(),
8476
"feed_forward.w1": ColwiseParallel(),
85-
"feed_forward.w2": RowwiseParallel(
86-
output_layouts=Replicate(),
87-
use_local_output=True
88-
),
77+
"feed_forward.w2": RowwiseParallel(),
8978
"feed_forward.w3": ColwiseParallel(),
9079
}
9180

9281
# Adjust attention module to use the local number of heads
9382
attn_layer = transformer_block.attention
83+
assert attn_layer.n_heads % tp_mesh.size() == 0
84+
assert attn_layer.n_local_heads % tp_mesh.size() == 0
85+
assert attn_layer.dim % tp_mesh.size() == 0
9486
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
9587
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
88+
attn_layer.dim = attn_layer.dim // tp_mesh.size()
9689

9790
parallelize_module(
9891
module=transformer_block,

0 commit comments

Comments
 (0)