Skip to content

Commit bedb16a

Browse files
committed
enable TP fp8 allgather with PrepareFloat8ModuleInput
This PR is a follow up PR to enable fp8 allgather in TP after these PR landed: * pytorch/pytorch#128431 * pytorch-labs/float8_experimental#275 One need to update their pytorch/float8_experimental to have those changes in to train with fp8 changes. Since fp8 is not enabled as part of our integration tests yet, there should be no issues on CI
1 parent 763b810 commit bedb16a

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def selective_checkpointing_context_fn():
114114

115115
def get_tp_parallel_strategy(
116116
job_config: JobConfig,
117-
) -> Tuple[RowwiseParallel, ColwiseParallel]:
117+
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
118118
"""Get the parallel strategy for the transformer model.
119119
120120
This function handles the special case of using float8 with tensor parallelism.
@@ -123,10 +123,11 @@ def get_tp_parallel_strategy(
123123
from float8_experimental.float8_tensor_parallel import (
124124
Float8ColwiseParallel,
125125
Float8RowwiseParallel,
126+
PrepareFloat8ModuleInput,
126127
)
127128

128-
return Float8RowwiseParallel, Float8ColwiseParallel
129-
return RowwiseParallel, ColwiseParallel
129+
return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput
130+
return RowwiseParallel, ColwiseParallel, PrepareModuleInput
130131

131132

132133
def pipeline_llama(
@@ -299,9 +300,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
299300
)
300301

301302
tp_mesh = world_mesh["tp"]
302-
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
303-
job_config
304-
)
303+
(
304+
row_parallel_strategy,
305+
col_parallel_strategy,
306+
prepare_module_input,
307+
) = get_tp_parallel_strategy(job_config)
305308
loss_parallel = parallel_dims.loss_parallel_enabled
306309

307310
# 1. Parallelize the first embedding and the last linear proj layer
@@ -327,7 +330,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
327330
# Apply tensor + sequence parallelism to every transformer block
328331
for layer_id, transformer_block in model.layers.items():
329332
layer_plan = {
330-
"attention": PrepareModuleInput(
333+
"attention": prepare_module_input(
331334
input_layouts=(Shard(1), None),
332335
desired_input_layouts=(Replicate(), None),
333336
),
@@ -336,7 +339,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
336339
"attention.wv": col_parallel_strategy(),
337340
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
338341
"attention_norm": SequenceParallel(),
339-
"feed_forward": PrepareModuleInput(
342+
"feed_forward": prepare_module_input(
340343
input_layouts=(Shard(1),),
341344
desired_input_layouts=(Replicate(),),
342345
),

0 commit comments

Comments
 (0)