Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 3fe7c4a

Browse files
vkuzofacebook-github-bot
authored andcommitted
delete swap_linear_with_dynamic from fsdp2 eager test case (#311)
Summary: Pull Request resolved: #311 clean up some tech debt, the `swap_linear_with_dynamic` function is redundant with `swap_linear_with_float8_linear` Reviewed By: awgu Differential Revision: D59685259 fbshipit-source-id: aab8861a5d9932a7bdd3dcb3014a49a0a048c55d
1 parent d44a996 commit 3fe7c4a

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,6 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
7272
dist.broadcast(global_inp, src=0)
7373
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)
7474

75-
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
76-
kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
77-
kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
78-
kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
79-
return swap_linear_with_float8_linear(module, **kwargs)
80-
8175

8276
class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
8377
@property
@@ -106,11 +100,11 @@ def _test_transformer_parity_dynamic(
106100
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
107101
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
108102
weight_tying = not enable_fsdp_fp8_all_gather
109-
module = self.init_transformer(weight_tying=weight_tying)
103+
module = self.init_transformer(weight_tying=weight_tying).cuda()
110104
ref_module = copy.deepcopy(module)
111-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
105+
swap_linear_with_float8_linear(ref_module)
112106
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
113-
module = self.swap_linear_with_dynamic(module)
107+
swap_linear_with_float8_linear(module)
114108
for submodule in module.modules():
115109
if isinstance(submodule, TransformerBlock):
116110
fully_shard(submodule)
@@ -153,7 +147,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
153147
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
154148
# requirement to use a smaller activation size
155149
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
156-
model = self.swap_linear_with_dynamic(model, emulate=True)
150+
swap_linear_with_float8_linear(model, emulate=True)
157151
model_unsharded_numel = sum(p.numel() for p in model.parameters())
158152
model_sharded_numel = (model_unsharded_numel + 1) // 2
159153
block_lin_weight_numel = 0
@@ -331,7 +325,8 @@ def get_expected_all_gather_size(module: nn.Module):
331325
module_fp32 = self.init_single_module()
332326
ref_module = copy.deepcopy(module_fp32)
333327
with set_enable_fsdp_fp8_all_gather(True):
334-
module = self.swap_linear_with_dynamic(module_fp32)
328+
module_fp32 = swap_linear_with_float8_linear(module_fp32)
329+
module = module_fp32
335330
fully_shard(module)
336331
local_inp = self.get_local_inp()
337332
expected_all_gather_size = get_expected_all_gather_size(ref_module)
@@ -359,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module):
359354
module = self.init_multi_module()
360355
ref_module = copy.deepcopy(module)
361356
with set_enable_fsdp_fp8_all_gather(True):
362-
module = self.swap_linear_with_dynamic(module)
357+
module = swap_linear_with_float8_linear(module)
363358
for submodule in module:
364359
fully_shard(submodule)
365360
fully_shard(module)
@@ -383,10 +378,11 @@ def test_fp32_fp8_single_module_parity(self):
383378
"""
384379
for enable_fsdp_fp8_all_gather in [False, True]:
385380
module_fp32 = self.init_single_module()
386-
ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32))
381+
ref_module = copy.deepcopy(module_fp32)
382+
ref_module = swap_linear_with_float8_linear(ref_module)
387383
ref_module = ref_module.cuda()
388384
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
389-
module = self.swap_linear_with_dynamic(module_fp32)
385+
module = swap_linear_with_float8_linear(module_fp32)
390386
fully_shard(module)
391387
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
392388
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
@@ -407,11 +403,11 @@ def test_fp32_fp8_multi_module_parity(self):
407403
multiple modules/FSDP communication groups.
408404
"""
409405
for enable_fsdp_fp8_all_gather in [False, True]:
410-
module = self.init_multi_module()
406+
module = self.init_multi_module().cuda()
411407
ref_module = copy.deepcopy(module)
412-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
408+
ref_module = swap_linear_with_float8_linear(ref_module)
413409
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
414-
module = self.swap_linear_with_dynamic(module)
410+
module = swap_linear_with_float8_linear(module)
415411
for submodule in module:
416412
fully_shard(submodule)
417413
fully_shard(module)

0 commit comments

Comments
 (0)