18
18
from torch .distributed .fsdp .fully_sharded_data_parallel import (
19
19
CPUOffload ,
20
20
FullyShardedDataParallel as FSDP ,
21
+ ShardingStrategy ,
21
22
)
22
23
from torch .distributed .tensor .parallel import (
23
24
ColwiseParallel ,
28
29
from torch .testing ._internal .common_fsdp import FSDPTest
29
30
from torch .testing ._internal .common_utils import (
30
31
instantiate_parametrized_tests ,
31
- parametrize ,
32
32
run_tests ,
33
33
TEST_WITH_DEV_DBG_ASAN ,
34
34
)
@@ -141,31 +141,36 @@ def _sync_tp_grads(
141
141
tp_world_size = tp_pg .size ()
142
142
fsdp_world_size = self .world_size // tp_world_size
143
143
assert (
144
- type (tp_fsdp_model ) is FSDP and len (list (tp_fsdp_model .parameters ())) == 1
144
+ type (tp_fsdp_model ) is FSDP
145
+ and len ([m for m in tp_fsdp_model .modules () if type (m ) is FSDP ]) == 1
145
146
), (
146
147
"The following logic assumes a single top-level-only FSDP wrapping "
147
148
"the model with TP already applied"
148
149
)
149
- flat_param = tp_fsdp_model .params [0 ]
150
- splits = tuple (param_name_to_numel .values ())
151
- # Create a mask over the gradient elements to manually reduce
152
- unsharded_size = torch .Size ([flat_param .numel () * fsdp_world_size ])
153
- unsharded_zeros = torch .zeros (unsharded_size , device = flat_param .device )
154
- per_param_masks = unsharded_zeros .split (splits )
155
- for param_idx , param_name in enumerate (
156
- param_name_to_numel .keys ()
157
- ): # assumes fixed order
158
- if param_name not in non_sharded_param_names :
159
- per_param_masks [param_idx ][:] = 1
160
- unsharded_mask = torch .cat (per_param_masks ).contiguous ().type (torch .BoolTensor )
161
- sharded_mask = unsharded_mask .chunk (fsdp_world_size )[self .rank // tp_world_size ]
162
- grad_device = flat_param .grad .device
163
- grad = flat_param .grad .detach ().clone ().cuda (self .rank )
164
- dist .all_reduce (grad , op = dist .ReduceOp .SUM , group = tp_pg )
165
- grad = grad .to (grad_device )
166
- flat_param .grad [~ sharded_mask ] = grad [~ sharded_mask ]
167
- # Average *all* gradient elements to match the FSDP only semantics
168
- flat_param .grad /= tp_world_size
150
+ for flat_param in tp_fsdp_model .params :
151
+ splits = tuple (param_name_to_numel .values ())
152
+ # Create a mask over the gradient elements to manually reduce
153
+ unsharded_size = torch .Size ([flat_param .numel () * fsdp_world_size ])
154
+ unsharded_zeros = torch .zeros (unsharded_size , device = flat_param .device )
155
+ per_param_masks = unsharded_zeros .split (splits )
156
+ for param_idx , param_name in enumerate (
157
+ param_name_to_numel .keys ()
158
+ ): # assumes fixed order
159
+ if param_name not in non_sharded_param_names :
160
+ per_param_masks [param_idx ][:] = 1
161
+ unsharded_mask = (
162
+ torch .cat (per_param_masks ).contiguous ().type (torch .BoolTensor )
163
+ )
164
+ sharded_mask = unsharded_mask .chunk (fsdp_world_size )[
165
+ self .rank // tp_world_size
166
+ ]
167
+ grad_device = flat_param .grad .device
168
+ grad = flat_param .grad .detach ().clone ().cuda (self .rank )
169
+ dist .all_reduce (grad , op = dist .ReduceOp .SUM , group = tp_pg )
170
+ grad = grad .to (grad_device )
171
+ flat_param .grad [~ sharded_mask ] = grad [~ sharded_mask ]
172
+ # Average *all* gradient elements to match the FSDP only semantics
173
+ flat_param .grad /= tp_world_size
169
174
170
175
def _get_grads_as_flattened (
171
176
self ,
@@ -182,7 +187,14 @@ def _get_grads_as_flattened(
182
187
returns the same value on all ranks.
183
188
"""
184
189
local_grads_as_flattened = (
185
- torch .cat ([torch .flatten (param .grad ) for param in model .parameters ()])
190
+ torch .cat (
191
+ [
192
+ torch .flatten (param .grad )
193
+ if param .grad is not None
194
+ else torch .zeros_like (torch .flatten (param ))
195
+ for param in model .parameters ()
196
+ ]
197
+ )
186
198
.contiguous ()
187
199
.cuda (self .rank )
188
200
)
@@ -214,16 +226,27 @@ def _get_grads_as_flattened(
214
226
return torch .cat (all_grads_per_param ).contiguous ()
215
227
216
228
@skip_if_lt_x_gpu (4 )
217
- @parametrize ("tensor_parallel_size" , [2 , 4 ])
218
- @parametrize (
219
- "cpu_offload" ,
220
- [CPUOffload (offload_params = False ), CPUOffload (offload_params = True )],
221
- )
222
- def test_fsdp_tp_integration (self , tensor_parallel_size , cpu_offload ):
229
+ def test_fsdp_tp_integration (self ):
230
+ self .run_subtests (
231
+ {
232
+ "cpu_offload" : [
233
+ CPUOffload (offload_params = False ),
234
+ CPUOffload (offload_params = True ),
235
+ ],
236
+ "sharding_strategy" : [None , ShardingStrategy .SHARD_GRAD_OP ],
237
+ "use_orig_params" : [False , True ],
238
+ },
239
+ self ._test_fsdp_tp_integration ,
240
+ )
241
+
242
+ def _test_fsdp_tp_integration (
243
+ self , cpu_offload , sharding_strategy , use_orig_params
244
+ ):
223
245
"""
224
246
Tests training for TP + FSDP integration by comparing an FSDP-only
225
247
model with a TP + FSDP model.
226
248
"""
249
+ tensor_parallel_size = 2
227
250
LR = 3e-5
228
251
torch .manual_seed (0 )
229
252
model = SimpleModel ().cuda (self .rank )
@@ -246,7 +269,13 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
246
269
self .assertEqual (model (inp ), tp_fsdp_model (inp )) # sanity check
247
270
248
271
mesh_1d = init_device_mesh ("cuda" , (self .world_size ,))
249
- fsdp_model = FSDP (model , cpu_offload = cpu_offload , device_mesh = mesh_1d )
272
+ fsdp_model = FSDP (
273
+ model ,
274
+ cpu_offload = cpu_offload ,
275
+ device_mesh = mesh_1d ,
276
+ sharding_strategy = sharding_strategy ,
277
+ use_orig_params = use_orig_params ,
278
+ )
250
279
mesh_2d = init_device_mesh (
251
280
"cuda" ,
252
281
(self .world_size // tensor_parallel_size , tensor_parallel_size ),
@@ -269,6 +298,8 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
269
298
tp_fsdp_model ,
270
299
cpu_offload = cpu_offload ,
271
300
device_mesh = mesh_2d ["dp" ],
301
+ sharding_strategy = sharding_strategy ,
302
+ use_orig_params = use_orig_params ,
272
303
)
273
304
fsdp_pg = mesh_2d ["dp" ].get_group (mesh_dim = 0 )
274
305
0 commit comments