Skip to content

Commit 30b310a

Browse files
authored
[2D] Update 2d example to use get_local_rank (#1203)
update 2d example to use get_local_rank
1 parent c0b889d commit 30b310a

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

distributed/tensor_parallelism/fsdp_tp_example.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
114114
tp_mesh = device_mesh["tp"]
115115
dp_mesh = device_mesh["dp"]
116116

117-
# To support identical inputs for TP groups, we need the dp process group
118-
dp_pg = device_mesh.get_dim_groups()[0]
119-
120117
# For TP, input needs to be same across all TP ranks.
121118
# while for SP, input can be different across all ranks.
122119
# We will use dp_rank for setting the random seed
123120
# to mimic the behavior of the dataloader.
124-
dp_rank = dist.get_rank(dp_pg)
125-
121+
dp_rank = dp_mesh.get_local_rank()
126122

127123
# create model and move it to GPU with id rank
128124
_mlp_dim = 1024

0 commit comments

Comments
 (0)