Skip to content

Commit e858ab4

Browse files
kwen2501wconstab
authored andcommitted
Fix 1D PP tracer test
forgot to enable tracer for tracer test in the last PR ghstack-source-id: 1cb1379 Pull Request resolved: #362
1 parent 3bb7bf3 commit e858ab4

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

test_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,14 @@ def build_test_list():
123123
"--checkpoint.enable_checkpoint",
124124
"--experimental.pipeline_parallel_degree 2",
125125
"--experimental.pipeline_parallel_split_points layers.1",
126+
"--experimental.pipeline_parallel_split_mode tracer",
126127
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
127128
],
128129
],
129130
"PP tracer frontend test",
130131
"pp_tracer",
131132
requires_seed_checkpoint=True,
133+
ngpu=2,
132134
),
133135
OverrideDefinitions(
134136
[

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -261,23 +261,23 @@ def pipeline_llama_tracer(
261261

262262
pp_mesh = world_mesh["pp"]
263263
pp_rank = pp_mesh.get_local_rank()
264+
microbatches = (
265+
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
266+
)
267+
(input,) = _llama_trace_input(job_config, model_config, device=device)
264268
stage_idx = pp_rank
265-
layers_per_rank = len(model.layers) // parallel_dims.pp
266269
split_spec = {
267-
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
268-
for i in range(1, parallel_dims.pp)
270+
layer_name: SplitPoint.BEGINNING
271+
for layer_name in job_config.experimental.pipeline_parallel_split_points
269272
}
270-
271273
pipe = pipeline(
272274
model,
273-
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp,
274-
example_args=_llama_trace_input(job_config, model_config),
275+
mb_args=(input.chunk(microbatches)[0],),
275276
split_spec=split_spec,
276277
)
277278
model = pipe.get_stage_module(stage_idx)
278-
stage = PipelineStage(
279-
pipe,
280-
stage_index=stage_idx,
279+
stage = pipe.build_stage(
280+
stage_idx,
281281
device=device,
282282
group=pp_mesh.get_group(),
283283
)

0 commit comments

Comments
 (0)