@@ -44,55 +44,48 @@ def apply_tp(
44
44
45
45
tp_mesh = world_mesh ["tp" ]
46
46
47
- # TODO: To figure out the TP for the tok_embedding and the linear proj layer.
48
- # # 1. Parallelize the first embedding and the last linear proj layer
49
- # # 2. Shard the first transformer block's inputs
47
+ # TODO: The commented part can further help with scaling but it will
48
+ # make inference very slow so we disable it for now.
49
+ # Parallelize the token embedding and the last linear proj layer
50
50
# model = parallelize_module(
51
51
# model,
52
52
# tp_mesh,
53
53
# {
54
- # "tok_embeddings": RowwiseParallel(
55
- # input_layouts=Replicate(),
54
+ # "tok_embeddings": ColwiseParallel(
56
55
# output_layouts=Replicate(),
57
56
# ),
58
57
# "output": ColwiseParallel(
59
- # input_layouts=Shard(1),
60
58
# output_layouts=Replicate(),
61
- # use_local_output=True,
62
59
# ),
63
60
# },
64
61
# )
65
62
63
+ # NOTE: This is indeed a hack because it assumes that we create cache
64
+ # after we apply TP to the model. Because we don't want to change model code
65
+ # when applying TP. We need to have change to ensure KVCache has the correct
66
+ # size as k and v.
67
+ model .config .n_local_heads = model .config .n_local_heads // tp_mesh .size ()
68
+
66
69
# Apply tensor parallelism to every transformer block
67
70
for transformer_block in model .layers :
68
71
layer_plan = {
69
- "attention" : PrepareModuleInput (
70
- input_layouts = (Replicate (), None ),
71
- desired_input_layouts = (Replicate (), None ),
72
- ),
73
72
"attention.wq" : ColwiseParallel (),
74
73
"attention.wk" : ColwiseParallel (),
75
74
"attention.wv" : ColwiseParallel (),
76
- "attention.wo" : RowwiseParallel (
77
- output_layouts = Replicate (),
78
- use_local_output = True ,
79
- ),
80
- "feed_forward" : PrepareModuleInput (
81
- input_layouts = (Replicate (),),
82
- desired_input_layouts = (Replicate (),),
83
- ),
75
+ "attention.wo" : RowwiseParallel (),
84
76
"feed_forward.w1" : ColwiseParallel (),
85
- "feed_forward.w2" : RowwiseParallel (
86
- output_layouts = Replicate (),
87
- use_local_output = True
88
- ),
77
+ "feed_forward.w2" : RowwiseParallel (),
89
78
"feed_forward.w3" : ColwiseParallel (),
90
79
}
91
80
92
81
# Adjust attention module to use the local number of heads
93
82
attn_layer = transformer_block .attention
83
+ assert attn_layer .n_heads % tp_mesh .size () == 0
84
+ assert attn_layer .n_local_heads % tp_mesh .size () == 0
85
+ assert attn_layer .dim % tp_mesh .size () == 0
94
86
attn_layer .n_heads = attn_layer .n_heads // tp_mesh .size ()
95
87
attn_layer .n_local_heads = attn_layer .n_local_heads // tp_mesh .size ()
88
+ attn_layer .dim = attn_layer .dim // tp_mesh .size ()
96
89
97
90
parallelize_module (
98
91
module = transformer_block ,
0 commit comments