@@ -73,12 +73,6 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
73
73
dist .broadcast (global_inp , src = 0 )
74
74
return global_inp .view (self .world_size , - 1 )[self .rank ].view (16 , 16 )
75
75
76
- def swap_linear_with_dynamic (self , module : nn .Module , ** kwargs : Any ) -> nn .Module :
77
- kwargs ["scaling_type_x" ] = TensorScalingType .DYNAMIC
78
- kwargs ["scaling_type_w" ] = TensorScalingType .DYNAMIC
79
- kwargs ["scaling_type_dL_dY" ] = TensorScalingType .DYNAMIC
80
- return swap_linear_with_float8_linear (module , ** kwargs )
81
-
82
76
83
77
class TestFloat8MultiProcess (FSDPTest , TestFloat8Common ):
84
78
@property
@@ -96,11 +90,11 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
96
90
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
97
91
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
98
92
weight_tying = not enable_fsdp_fp8_all_gather
99
- module = self .init_transformer (weight_tying = weight_tying )
93
+ module = self .init_transformer (weight_tying = weight_tying ). cuda ()
100
94
ref_module = copy .deepcopy (module )
101
- ref_module = self . swap_linear_with_dynamic (ref_module ). cuda ( )
95
+ swap_linear_with_float8_linear (ref_module )
102
96
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
103
- module = self . swap_linear_with_dynamic (module )
97
+ swap_linear_with_float8_linear (module )
104
98
for submodule in module .modules ():
105
99
if isinstance (submodule , TransformerBlock ):
106
100
fully_shard (submodule )
@@ -141,7 +135,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
141
135
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
142
136
# requirement to use a smaller activation size
143
137
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
144
- model = self . swap_linear_with_dynamic (model , emulate = True )
138
+ swap_linear_with_float8_linear (model , emulate = True )
145
139
model_unsharded_numel = sum (p .numel () for p in model .parameters ())
146
140
model_sharded_numel = (model_unsharded_numel + 1 ) // 2
147
141
block_lin_weight_numel = 0
@@ -319,7 +313,8 @@ def get_expected_all_gather_size(module: nn.Module):
319
313
module_fp32 = self .init_single_module ()
320
314
ref_module = copy .deepcopy (module_fp32 )
321
315
with set_enable_fsdp_fp8_all_gather (True ):
322
- module = self .swap_linear_with_dynamic (module_fp32 )
316
+ module_fp32 = swap_linear_with_float8_linear (module_fp32 )
317
+ module = module_fp32
323
318
fully_shard (module )
324
319
local_inp = self .get_local_inp ()
325
320
expected_all_gather_size = get_expected_all_gather_size (ref_module )
@@ -347,7 +342,7 @@ def get_expected_all_gather_size(module: nn.Module):
347
342
module = self .init_multi_module ()
348
343
ref_module = copy .deepcopy (module )
349
344
with set_enable_fsdp_fp8_all_gather (True ):
350
- module = self . swap_linear_with_dynamic (module )
345
+ module = swap_linear_with_float8_linear (module )
351
346
for submodule in module :
352
347
fully_shard (submodule )
353
348
fully_shard (module )
@@ -371,10 +366,11 @@ def test_fp32_fp8_single_module_parity(self):
371
366
"""
372
367
for enable_fsdp_fp8_all_gather in [False , True ]:
373
368
module_fp32 = self .init_single_module ()
374
- ref_module = self .swap_linear_with_dynamic (copy .deepcopy (module_fp32 ))
369
+ ref_module = copy .deepcopy (module_fp32 )
370
+ ref_module = swap_linear_with_float8_linear (ref_module )
375
371
ref_module = ref_module .cuda ()
376
372
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
377
- module = self . swap_linear_with_dynamic (module_fp32 )
373
+ module = swap_linear_with_float8_linear (module_fp32 )
378
374
fully_shard (module )
379
375
ref_optim = torch .optim .Adam (ref_module .parameters (), lr = 1e-2 )
380
376
optim = torch .optim .Adam (module .parameters (), lr = 1e-2 , foreach = True )
@@ -395,11 +391,11 @@ def test_fp32_fp8_multi_module_parity(self):
395
391
multiple modules/FSDP communication groups.
396
392
"""
397
393
for enable_fsdp_fp8_all_gather in [False , True ]:
398
- module = self .init_multi_module ()
394
+ module = self .init_multi_module (). cuda ()
399
395
ref_module = copy .deepcopy (module )
400
- ref_module = self . swap_linear_with_dynamic (ref_module ). cuda ( )
396
+ ref_module = swap_linear_with_float8_linear (ref_module )
401
397
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
402
- module = self . swap_linear_with_dynamic (module )
398
+ module = swap_linear_with_float8_linear (module )
403
399
for submodule in module :
404
400
fully_shard (submodule )
405
401
fully_shard (module )
0 commit comments