Skip to content

Commit d60135e

Browse files
weifengpypytorchmergebot
authored andcommitted
[FSDP1] fix _same_storage check for DTensor (pytorch#123617)
for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However, ``DTensor.untyped_storage().data_ptr()`` does not work in ``_same_storage``. Thus desugar to ``DTensor._local_tensor.untyped_storage().data_ptr()`` pytorch#123272 credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files Pull Request resolved: pytorch#123617 Approved by: https://github.com/awgu
1 parent 37fd547 commit d60135e

File tree

3 files changed

+70
-30
lines changed

3 files changed

+70
-30
lines changed

.ci/pytorch/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ test_inductor_distributed() {
316316
pytest test/distributed/_composable/fsdp/test_fully_shard_frozen.py
317317
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype
318318
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype
319+
pytest test/distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration
319320

320321
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
321322
# with if required # gpus aren't available

test/distributed/fsdp/test_fsdp_tp_integration.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.distributed.fsdp.fully_sharded_data_parallel import (
1919
CPUOffload,
2020
FullyShardedDataParallel as FSDP,
21+
ShardingStrategy,
2122
)
2223
from torch.distributed.tensor.parallel import (
2324
ColwiseParallel,
@@ -28,7 +29,6 @@
2829
from torch.testing._internal.common_fsdp import FSDPTest
2930
from torch.testing._internal.common_utils import (
3031
instantiate_parametrized_tests,
31-
parametrize,
3232
run_tests,
3333
TEST_WITH_DEV_DBG_ASAN,
3434
)
@@ -141,31 +141,36 @@ def _sync_tp_grads(
141141
tp_world_size = tp_pg.size()
142142
fsdp_world_size = self.world_size // tp_world_size
143143
assert (
144-
type(tp_fsdp_model) is FSDP and len(list(tp_fsdp_model.parameters())) == 1
144+
type(tp_fsdp_model) is FSDP
145+
and len([m for m in tp_fsdp_model.modules() if type(m) is FSDP]) == 1
145146
), (
146147
"The following logic assumes a single top-level-only FSDP wrapping "
147148
"the model with TP already applied"
148149
)
149-
flat_param = tp_fsdp_model.params[0]
150-
splits = tuple(param_name_to_numel.values())
151-
# Create a mask over the gradient elements to manually reduce
152-
unsharded_size = torch.Size([flat_param.numel() * fsdp_world_size])
153-
unsharded_zeros = torch.zeros(unsharded_size, device=flat_param.device)
154-
per_param_masks = unsharded_zeros.split(splits)
155-
for param_idx, param_name in enumerate(
156-
param_name_to_numel.keys()
157-
): # assumes fixed order
158-
if param_name not in non_sharded_param_names:
159-
per_param_masks[param_idx][:] = 1
160-
unsharded_mask = torch.cat(per_param_masks).contiguous().type(torch.BoolTensor)
161-
sharded_mask = unsharded_mask.chunk(fsdp_world_size)[self.rank // tp_world_size]
162-
grad_device = flat_param.grad.device
163-
grad = flat_param.grad.detach().clone().cuda(self.rank)
164-
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=tp_pg)
165-
grad = grad.to(grad_device)
166-
flat_param.grad[~sharded_mask] = grad[~sharded_mask]
167-
# Average *all* gradient elements to match the FSDP only semantics
168-
flat_param.grad /= tp_world_size
150+
for flat_param in tp_fsdp_model.params:
151+
splits = tuple(param_name_to_numel.values())
152+
# Create a mask over the gradient elements to manually reduce
153+
unsharded_size = torch.Size([flat_param.numel() * fsdp_world_size])
154+
unsharded_zeros = torch.zeros(unsharded_size, device=flat_param.device)
155+
per_param_masks = unsharded_zeros.split(splits)
156+
for param_idx, param_name in enumerate(
157+
param_name_to_numel.keys()
158+
): # assumes fixed order
159+
if param_name not in non_sharded_param_names:
160+
per_param_masks[param_idx][:] = 1
161+
unsharded_mask = (
162+
torch.cat(per_param_masks).contiguous().type(torch.BoolTensor)
163+
)
164+
sharded_mask = unsharded_mask.chunk(fsdp_world_size)[
165+
self.rank // tp_world_size
166+
]
167+
grad_device = flat_param.grad.device
168+
grad = flat_param.grad.detach().clone().cuda(self.rank)
169+
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=tp_pg)
170+
grad = grad.to(grad_device)
171+
flat_param.grad[~sharded_mask] = grad[~sharded_mask]
172+
# Average *all* gradient elements to match the FSDP only semantics
173+
flat_param.grad /= tp_world_size
169174

170175
def _get_grads_as_flattened(
171176
self,
@@ -182,7 +187,14 @@ def _get_grads_as_flattened(
182187
returns the same value on all ranks.
183188
"""
184189
local_grads_as_flattened = (
185-
torch.cat([torch.flatten(param.grad) for param in model.parameters()])
190+
torch.cat(
191+
[
192+
torch.flatten(param.grad)
193+
if param.grad is not None
194+
else torch.zeros_like(torch.flatten(param))
195+
for param in model.parameters()
196+
]
197+
)
186198
.contiguous()
187199
.cuda(self.rank)
188200
)
@@ -214,16 +226,27 @@ def _get_grads_as_flattened(
214226
return torch.cat(all_grads_per_param).contiguous()
215227

216228
@skip_if_lt_x_gpu(4)
217-
@parametrize("tensor_parallel_size", [2, 4])
218-
@parametrize(
219-
"cpu_offload",
220-
[CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
221-
)
222-
def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
229+
def test_fsdp_tp_integration(self):
230+
self.run_subtests(
231+
{
232+
"cpu_offload": [
233+
CPUOffload(offload_params=False),
234+
CPUOffload(offload_params=True),
235+
],
236+
"sharding_strategy": [None, ShardingStrategy.SHARD_GRAD_OP],
237+
"use_orig_params": [False, True],
238+
},
239+
self._test_fsdp_tp_integration,
240+
)
241+
242+
def _test_fsdp_tp_integration(
243+
self, cpu_offload, sharding_strategy, use_orig_params
244+
):
223245
"""
224246
Tests training for TP + FSDP integration by comparing an FSDP-only
225247
model with a TP + FSDP model.
226248
"""
249+
tensor_parallel_size = 2
227250
LR = 3e-5
228251
torch.manual_seed(0)
229252
model = SimpleModel().cuda(self.rank)
@@ -246,7 +269,13 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
246269
self.assertEqual(model(inp), tp_fsdp_model(inp)) # sanity check
247270

248271
mesh_1d = init_device_mesh("cuda", (self.world_size,))
249-
fsdp_model = FSDP(model, cpu_offload=cpu_offload, device_mesh=mesh_1d)
272+
fsdp_model = FSDP(
273+
model,
274+
cpu_offload=cpu_offload,
275+
device_mesh=mesh_1d,
276+
sharding_strategy=sharding_strategy,
277+
use_orig_params=use_orig_params,
278+
)
250279
mesh_2d = init_device_mesh(
251280
"cuda",
252281
(self.world_size // tensor_parallel_size, tensor_parallel_size),
@@ -269,6 +298,8 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
269298
tp_fsdp_model,
270299
cpu_offload=cpu_offload,
271300
device_mesh=mesh_2d["dp"],
301+
sharding_strategy=sharding_strategy,
302+
use_orig_params=use_orig_params,
272303
)
273304
fsdp_pg = mesh_2d["dp"].get_group(mesh_dim=0)
274305

torch/distributed/fsdp/_flat_param.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,6 +2711,14 @@ def _warn_use_fake_reduce(log: logging.Logger, warning: str):
27112711

27122712

27132713
def _same_storage(a, b):
2714+
# Params are DTensors in backward
2715+
# with SHARD_GRAD_OP + TP
2716+
from torch.distributed._tensor import DTensor
2717+
2718+
if isinstance(a, DTensor):
2719+
a = a._local_tensor
2720+
if isinstance(b, DTensor):
2721+
b = b._local_tensor
27142722
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
27152723

27162724

0 commit comments

Comments
 (0)