Skip to content

Commit 8d6822f

Browse files
mori360bluenote10
authored andcommitted
E2E composability testing (pytorch#141398)
Add 3D(pp+tp+fsdp) test `test_3d_with_tp_dp_pp` at test_pp_compodability Currently provide @parametrize on "ScheduleClass" for pp in [ScheduleGPipe, Schedule1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble] "MixedPrecisionParam" for fsdp in [torch.bfloat16, torch.float32] Future work: 1. add fp8 2. add cp(context parallelism) to enable 4D test Pull Request resolved: pytorch#141398 Approved by: https://github.com/wconstab, https://github.com/kwen2501
1 parent 1055131 commit 8d6822f

File tree

1 file changed

+192
-1
lines changed

1 file changed

+192
-1
lines changed

test/distributed/_composable/test_composability/test_pp_composability.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import torch
77
import torch.distributed.checkpoint as dcp
88
import torch.nn as nn
9+
import torch.nn.functional as F
910
from torch.distributed._tensor import DTensor
1011
from torch.distributed.checkpoint import FileSystemReader
1112
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
1213
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
1314
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
1415
from torch.distributed.checkpoint.stateful import Stateful
15-
from torch.distributed.device_mesh import init_device_mesh
16+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1617
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
1718
from torch.distributed.pipelining import PipelineStage
1819
from torch.distributed.pipelining.schedules import (
@@ -23,6 +24,11 @@
2324
ScheduleInterleavedZeroBubble,
2425
ScheduleLoopedBFS,
2526
)
27+
from torch.distributed.tensor.parallel import (
28+
ColwiseParallel,
29+
parallelize_module,
30+
RowwiseParallel,
31+
)
2632
from torch.nn.parallel import DistributedDataParallel as DDP
2733
from torch.testing._internal.common_cuda import TEST_MULTIGPU
2834
from torch.testing._internal.common_distributed import (
@@ -58,6 +64,20 @@ def forward(self, x):
5864
return x
5965

6066

67+
class MLPModuleEven(torch.nn.Module):
68+
def __init__(self, d_hid: int):
69+
super().__init__()
70+
self.net1 = nn.Linear(d_hid, d_hid)
71+
self.net2 = nn.Linear(d_hid, d_hid)
72+
self.net3 = nn.Linear(d_hid, d_hid * 2)
73+
74+
def forward(self, x):
75+
x = F.relu(self.net1(x))
76+
x = F.relu(self.net2(x))
77+
x = F.relu(self.net3(x))
78+
return x
79+
80+
6181
class ComposabilityTest(MultiProcessTestCase):
6282
@classmethod
6383
def backend_str(cls) -> str:
@@ -357,6 +377,177 @@ def _dcp_test(self):
357377

358378
_dcp_test(self)
359379

380+
@requires_nccl()
381+
@skip_if_lt_x_gpu(8)
382+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs")
383+
@parametrize(
384+
"ScheduleClass",
385+
[
386+
ScheduleGPipe,
387+
Schedule1F1B,
388+
ScheduleInterleaved1F1B,
389+
ScheduleLoopedBFS,
390+
ScheduleInterleavedZeroBubble,
391+
],
392+
)
393+
@parametrize(
394+
"MixedPrecisionParam",
395+
[
396+
torch.bfloat16,
397+
torch.float32,
398+
],
399+
)
400+
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
401+
device = torch.device("cuda", self.device)
402+
torch.cuda.set_device(self.device)
403+
store = torch.distributed.FileStore(self.file_name, self.world_size)
404+
torch.distributed.init_process_group(
405+
backend="nccl",
406+
store=store,
407+
rank=self.rank,
408+
world_size=self.world_size,
409+
)
410+
dim = 8
411+
tp_size = 2
412+
pp_size = 2
413+
num_microbatches = 8
414+
dp_size = self.world_size // (tp_size * pp_size)
415+
device_mesh = init_device_mesh(
416+
"cuda",
417+
mesh_shape=(dp_size, pp_size, tp_size),
418+
mesh_dim_names=("dp", "pp", "tp"),
419+
)
420+
dp_mesh = device_mesh["dp"]
421+
tp_mesh = device_mesh["tp"]
422+
pp_mesh = device_mesh["pp"]
423+
pp_group = device_mesh["pp"].get_group()
424+
425+
# create "entire model"
426+
total_layers = 8
427+
full_model = nn.ModuleList([MLPModuleEven(dim) for _ in range(total_layers)])
428+
429+
# dummy loss needed just to force backwards to run in schedule step
430+
def loss_fn(y, target):
431+
return y.sum()
432+
433+
# Apply DP to stage module
434+
def apply_fsdp(partial_model):
435+
# apply FSDP
436+
mp_policy = MixedPrecisionPolicy(
437+
param_dtype=MixedPrecisionParam,
438+
reduce_dtype=torch.float32,
439+
)
440+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
441+
for layer_id in range(len(partial_model)):
442+
fully_shard(
443+
partial_model[layer_id],
444+
**fsdp_config,
445+
reshard_after_forward=False,
446+
)
447+
dp_model = fully_shard(partial_model, **fsdp_config)
448+
return dp_model
449+
450+
def apply_tp(
451+
model: nn.Module,
452+
tp_mesh: DeviceMesh,
453+
):
454+
parallelize_plan = {
455+
"net1": ColwiseParallel(),
456+
"net2": RowwiseParallel(),
457+
"net3": ColwiseParallel(),
458+
}
459+
for layer in model:
460+
parallelize_module(layer, tp_mesh, parallelize_plan)
461+
return model
462+
463+
# Attach to a schedule
464+
if issubclass(ScheduleClass, PipelineScheduleSingle):
465+
stage_idx = pp_group.rank()
466+
partial_model = nn.Sequential(
467+
*full_model[stage_idx * 2 : stage_idx * 2 + 2]
468+
)
469+
partial_model.to(self.device)
470+
471+
tp_model = apply_tp(partial_model, tp_mesh)
472+
dp_model = apply_fsdp(tp_model)
473+
pipeline_stage = PipelineStage(
474+
dp_model,
475+
stage_idx,
476+
pp_group.size(),
477+
self.device,
478+
group=pp_group,
479+
)
480+
partial_models = [pipeline_stage.submod]
481+
pipeline_schedule = ScheduleClass(
482+
pipeline_stage,
483+
n_microbatches=num_microbatches,
484+
loss_fn=loss_fn,
485+
)
486+
else:
487+
n_virtual = 2
488+
num_stages = pp_group.size() * n_virtual
489+
stages = []
490+
for i in range(n_virtual):
491+
stage_idx = pp_group.rank() + n_virtual * i
492+
# divide the model layers by the number of stages
493+
partial_model = nn.Sequential(*full_model[stage_idx : stage_idx + 1])
494+
partial_model.to(self.device)
495+
496+
tp_model = apply_tp(partial_model, tp_mesh)
497+
dp_model = apply_fsdp(tp_model)
498+
stage = PipelineStage(
499+
dp_model,
500+
stage_idx,
501+
num_stages,
502+
self.device,
503+
group=pp_group,
504+
)
505+
506+
stages.append(stage)
507+
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
508+
pipeline_schedule = ScheduleClass(
509+
stages,
510+
n_microbatches=num_microbatches,
511+
loss_fn=loss_fn,
512+
)
513+
514+
optimizer_kwargs = {
515+
"lr": 0.01,
516+
"betas": (0.9, 0.95),
517+
"weight_decay": 0.1,
518+
"fused": False,
519+
"foreach": True,
520+
}
521+
optimizers = [
522+
torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
523+
for model in partial_models
524+
]
525+
526+
for train_step in range(5):
527+
for optimizer in optimizers:
528+
optimizer.zero_grad()
529+
inputs = torch.rand((num_microbatches, dim), device=self.device)
530+
labels = torch.rand((num_microbatches, dim), device=self.device)
531+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
532+
if pp_mesh.get_local_rank() == 0:
533+
pipeline_schedule.step(inputs)
534+
elif is_last_stage:
535+
losses = []
536+
pipeline_schedule.step(target=labels, losses=losses)
537+
else:
538+
pipeline_schedule.step()
539+
540+
# accumulate losses across pipeline microbatches
541+
loss = (
542+
torch.mean(torch.stack(losses))
543+
if is_last_stage
544+
else torch.Tensor([-1.0])
545+
)
546+
for optimizer in optimizers:
547+
optimizer.step()
548+
549+
torch.distributed.destroy_process_group()
550+
360551

361552
instantiate_parametrized_tests(ComposabilityTest)
362553

0 commit comments

Comments
 (0)