@@ -72,12 +72,6 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
72
72
dist .broadcast (global_inp , src = 0 )
73
73
return global_inp .view (self .world_size , - 1 )[self .rank ].view (16 , 16 )
74
74
75
- def swap_linear_with_dynamic (self , module : nn .Module , ** kwargs : Any ) -> nn .Module :
76
- kwargs ["scaling_type_x" ] = TensorScalingType .DYNAMIC
77
- kwargs ["scaling_type_w" ] = TensorScalingType .DYNAMIC
78
- kwargs ["scaling_type_dL_dY" ] = TensorScalingType .DYNAMIC
79
- return swap_linear_with_float8_linear (module , ** kwargs )
80
-
81
75
82
76
class TestFloat8MultiProcess (FSDPTest , TestFloat8Common ):
83
77
@property
@@ -106,11 +100,11 @@ def _test_transformer_parity_dynamic(
106
100
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
107
101
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
108
102
weight_tying = not enable_fsdp_fp8_all_gather
109
- module = self .init_transformer (weight_tying = weight_tying )
103
+ module = self .init_transformer (weight_tying = weight_tying ). cuda ()
110
104
ref_module = copy .deepcopy (module )
111
- ref_module = self . swap_linear_with_dynamic (ref_module ). cuda ( )
105
+ swap_linear_with_float8_linear (ref_module )
112
106
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
113
- module = self . swap_linear_with_dynamic (module )
107
+ swap_linear_with_float8_linear (module )
114
108
for submodule in module .modules ():
115
109
if isinstance (submodule , TransformerBlock ):
116
110
fully_shard (submodule )
@@ -153,7 +147,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
153
147
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
154
148
# requirement to use a smaller activation size
155
149
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
156
- model = self . swap_linear_with_dynamic (model , emulate = True )
150
+ swap_linear_with_float8_linear (model , emulate = True )
157
151
model_unsharded_numel = sum (p .numel () for p in model .parameters ())
158
152
model_sharded_numel = (model_unsharded_numel + 1 ) // 2
159
153
block_lin_weight_numel = 0
@@ -331,7 +325,8 @@ def get_expected_all_gather_size(module: nn.Module):
331
325
module_fp32 = self .init_single_module ()
332
326
ref_module = copy .deepcopy (module_fp32 )
333
327
with set_enable_fsdp_fp8_all_gather (True ):
334
- module = self .swap_linear_with_dynamic (module_fp32 )
328
+ module_fp32 = swap_linear_with_float8_linear (module_fp32 )
329
+ module = module_fp32
335
330
fully_shard (module )
336
331
local_inp = self .get_local_inp ()
337
332
expected_all_gather_size = get_expected_all_gather_size (ref_module )
@@ -359,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module):
359
354
module = self .init_multi_module ()
360
355
ref_module = copy .deepcopy (module )
361
356
with set_enable_fsdp_fp8_all_gather (True ):
362
- module = self . swap_linear_with_dynamic (module )
357
+ module = swap_linear_with_float8_linear (module )
363
358
for submodule in module :
364
359
fully_shard (submodule )
365
360
fully_shard (module )
@@ -383,10 +378,11 @@ def test_fp32_fp8_single_module_parity(self):
383
378
"""
384
379
for enable_fsdp_fp8_all_gather in [False , True ]:
385
380
module_fp32 = self .init_single_module ()
386
- ref_module = self .swap_linear_with_dynamic (copy .deepcopy (module_fp32 ))
381
+ ref_module = copy .deepcopy (module_fp32 )
382
+ ref_module = swap_linear_with_float8_linear (ref_module )
387
383
ref_module = ref_module .cuda ()
388
384
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
389
- module = self . swap_linear_with_dynamic (module_fp32 )
385
+ module = swap_linear_with_float8_linear (module_fp32 )
390
386
fully_shard (module )
391
387
ref_optim = torch .optim .Adam (ref_module .parameters (), lr = 1e-2 )
392
388
optim = torch .optim .Adam (module .parameters (), lr = 1e-2 , foreach = True )
@@ -407,11 +403,11 @@ def test_fp32_fp8_multi_module_parity(self):
407
403
multiple modules/FSDP communication groups.
408
404
"""
409
405
for enable_fsdp_fp8_all_gather in [False , True ]:
410
- module = self .init_multi_module ()
406
+ module = self .init_multi_module (). cuda ()
411
407
ref_module = copy .deepcopy (module )
412
- ref_module = self . swap_linear_with_dynamic (ref_module ). cuda ( )
408
+ ref_module = swap_linear_with_float8_linear (ref_module )
413
409
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
414
- module = self . swap_linear_with_dynamic (module )
410
+ module = swap_linear_with_float8_linear (module )
415
411
for submodule in module :
416
412
fully_shard (submodule )
417
413
fully_shard (module )
0 commit comments