Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1ba79c8

Browse files
committed
delete swap_linear_with_dynamic from fsdp2 eager test case
Summary: clean up some tech debt, the `swap_linear_with_dynamic` function is redundant with `swap_linear_with_float8_linear` Test Plan: ``` python ./test/test_fsdp2/test_fsdp2_eager.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b6546db Pull Request resolved: #311
1 parent 9facff8 commit 1ba79c8

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,6 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
7373
dist.broadcast(global_inp, src=0)
7474
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)
7575

76-
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
77-
kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
78-
kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
79-
kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
80-
return swap_linear_with_float8_linear(module, **kwargs)
81-
8276

8377
class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
8478
@property
@@ -96,11 +90,11 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
9690
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
9791
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
9892
weight_tying = not enable_fsdp_fp8_all_gather
99-
module = self.init_transformer(weight_tying=weight_tying)
93+
module = self.init_transformer(weight_tying=weight_tying).cuda()
10094
ref_module = copy.deepcopy(module)
101-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
95+
swap_linear_with_float8_linear(ref_module)
10296
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
103-
module = self.swap_linear_with_dynamic(module)
97+
swap_linear_with_float8_linear(module)
10498
for submodule in module.modules():
10599
if isinstance(submodule, TransformerBlock):
106100
fully_shard(submodule)
@@ -141,7 +135,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
141135
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
142136
# requirement to use a smaller activation size
143137
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
144-
model = self.swap_linear_with_dynamic(model, emulate=True)
138+
swap_linear_with_float8_linear(model, emulate=True)
145139
model_unsharded_numel = sum(p.numel() for p in model.parameters())
146140
model_sharded_numel = (model_unsharded_numel + 1) // 2
147141
block_lin_weight_numel = 0
@@ -319,7 +313,8 @@ def get_expected_all_gather_size(module: nn.Module):
319313
module_fp32 = self.init_single_module()
320314
ref_module = copy.deepcopy(module_fp32)
321315
with set_enable_fsdp_fp8_all_gather(True):
322-
module = self.swap_linear_with_dynamic(module_fp32)
316+
module_fp32 = swap_linear_with_float8_linear(module_fp32)
317+
module = module_fp32
323318
fully_shard(module)
324319
local_inp = self.get_local_inp()
325320
expected_all_gather_size = get_expected_all_gather_size(ref_module)
@@ -347,7 +342,7 @@ def get_expected_all_gather_size(module: nn.Module):
347342
module = self.init_multi_module()
348343
ref_module = copy.deepcopy(module)
349344
with set_enable_fsdp_fp8_all_gather(True):
350-
module = self.swap_linear_with_dynamic(module)
345+
module = swap_linear_with_float8_linear(module)
351346
for submodule in module:
352347
fully_shard(submodule)
353348
fully_shard(module)
@@ -371,10 +366,11 @@ def test_fp32_fp8_single_module_parity(self):
371366
"""
372367
for enable_fsdp_fp8_all_gather in [False, True]:
373368
module_fp32 = self.init_single_module()
374-
ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32))
369+
ref_module = copy.deepcopy(module_fp32)
370+
ref_module = swap_linear_with_float8_linear(ref_module)
375371
ref_module = ref_module.cuda()
376372
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
377-
module = self.swap_linear_with_dynamic(module_fp32)
373+
module = swap_linear_with_float8_linear(module_fp32)
378374
fully_shard(module)
379375
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
380376
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
@@ -395,11 +391,11 @@ def test_fp32_fp8_multi_module_parity(self):
395391
multiple modules/FSDP communication groups.
396392
"""
397393
for enable_fsdp_fp8_all_gather in [False, True]:
398-
module = self.init_multi_module()
394+
module = self.init_multi_module().cuda()
399395
ref_module = copy.deepcopy(module)
400-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
396+
ref_module = swap_linear_with_float8_linear(ref_module)
401397
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
402-
module = self.swap_linear_with_dynamic(module)
398+
module = swap_linear_with_float8_linear(module)
403399
for submodule in module:
404400
fully_shard(submodule)
405401
fully_shard(module)

0 commit comments

Comments
 (0)