We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ab41031 commit 655ea0fCopy full SHA for 655ea0f
distributed/parallelize_llama.py
@@ -46,8 +46,7 @@ def apply_tp(
46
47
# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
48
# # 1. Parallelize the first embedding and the last linear proj layer
49
- # # 2. Parallelize the root norm layer over the sequence dim
50
- # # 3. Shard the first transformer block's inputs
+ # # 2. Shard the first transformer block's inputs
51
# model = parallelize_module(
52
# model,
53
# tp_mesh,
@@ -64,7 +63,7 @@ def apply_tp(
64
63
# },
65
# )
66
67
- # Apply tensor + sequence parallelism to every transformer block
+ # Apply tensor parallelism to every transformer block
68
for transformer_block in model.layers:
69
layer_plan = {
70
"attention": PrepareModuleInput(
0 commit comments