Skip to content

[Dist][Inference] Further TP fix to make sure e2e TP is working #878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 16 additions & 23 deletions distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,55 +44,48 @@ def apply_tp(

tp_mesh = world_mesh["tp"]

# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
# # 1. Parallelize the first embedding and the last linear proj layer
# # 2. Shard the first transformer block's inputs
# TODO: The commented part can further help with scaling but it will
# make inference very slow so we disable it for now.
# Parallelize the token embedding and the last linear proj layer
# model = parallelize_module(
# model,
# tp_mesh,
# {
# "tok_embeddings": RowwiseParallel(
# input_layouts=Replicate(),
# "tok_embeddings": ColwiseParallel(
# output_layouts=Replicate(),
# ),
# "output": ColwiseParallel(
# input_layouts=Shard(1),
# output_layouts=Replicate(),
# use_local_output=True,
# ),
# },
# )

# NOTE: This is indeed a hack because it assumes that we create cache
# after we apply TP to the model. Because we don't want to change model code
# when applying TP. We need to have change to ensure KVCache has the correct
# size as k and v.
model.config.n_local_heads = model.config.n_local_heads // tp_mesh.size()

# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(
output_layouts=Replicate(),
use_local_output=True,
),
"feed_forward": PrepareModuleInput(
input_layouts=(Replicate(),),
desired_input_layouts=(Replicate(),),
),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(
output_layouts=Replicate(),
use_local_output=True
),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
assert attn_layer.n_heads % tp_mesh.size() == 0
assert attn_layer.n_local_heads % tp_mesh.size() == 0
assert attn_layer.dim % tp_mesh.size() == 0
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
attn_layer.dim = attn_layer.dim // tp_mesh.size()

parallelize_module(
module=transformer_block,
Expand Down
Loading